Source code for ibl_alignment_gui.handlers.shank_handler

import numpy as np

from ibl_alignment_gui.handlers.alignment_handler import AlignmentHandler
from ibl_alignment_gui.loaders.plot_loader import (
    ImageData,
    LineData,
    ProbeData,
    ScatterData,
)
from iblutil.util import Bunch


[docs] class ShankHandler: """ Model for handling data on a shank of a probe for a given recording configuration. Parameters ---------- loaders: Bunch A Bunch object containing all the relevant loaders and uploaders to read and write data. shank_idx: int The index of the shank in the probe """ def __init__(self, loaders: Bunch, shank_idx: int): self.shank_idx: int = shank_idx self.loaders: Bunch = loaders self.loaders['align'].load_previous_alignments() self.loaders['align'].get_starting_alignment(0) self.align_exists: bool = True self.data_loaded: bool = False # ------------------------------------------------------------------------- # Alignment loader - attributes and methods in loaders['align'] # -------------------------------------------------------------------------
[docs] def set_init_alignment(self) -> None: """Set the initial alignment based on previous features and tracks.""" self.align_handle.set_init_feature_track( self.loaders['align'].feature_prev, self.loaders['align'].track_prev )
@property def feature_prev(self) -> np.ndarray: """ Return the previous feature from the alignment loader for the currently active shank. Returns ------- np.ndarray Previous feature array. """ return self.loaders['align'].feature_prev # ------------------------------------------------------------------------- # Alignment handler - attributes and methods in align_handle # -------------------------------------------------------------------------
[docs] def offset_hist_data(self, *args) -> None: """See :meth:`AlignmentHandler.offset_hist_data` for details.""" self.align_handle.offset_hist_data(*args)
[docs] def scale_hist_data(self, *args, **kwargs) -> None: """See :meth:`AlignmentHandler.scale_hist_data` for details.""" self.align_handle.scale_hist_data(*args, **kwargs)
[docs] def get_scaled_histology(self) -> None: """See :meth:`AlignmentHandler.get_scaled_histology` for details.""" self.hist_data, self.hist_data_ref, self.scale_data = ( self.align_handle.get_scaled_histology() )
[docs] def feature2track_lin( self, depths: np.ndarray, feature: np.ndarray, track: np.ndarray ) -> np.ndarray: """ Estimate values of depth according to linear fit between feature and track reference lines. Parameters ---------- depths: np.ndarray The depths to estimate the new depths for feature: np.ndarray The feature line positions track: np.ndarray The track line positions Returns ------- np.ndarray The new estimated depths """ return self.align_handle.ephysalign.feature2track_lin(depths, feature, track)
[docs] def reset_features_and_tracks(self) -> None: """See :meth:`AlignmentHandler.reset_features_and_tracks` for details.""" self.align_handle.reset_features_and_tracks()
@property def track(self) -> np.ndarray: """See :meth:`AlignmentHandler.track` for details.""" return self.align_handle.track @property def feature(self) -> np.ndarray: """See :meth:`AlignmentHandler.feature` for details.""" return self.align_handle.feature @property def xyz_channels(self) -> np.ndarray: """See :meth:`AlignmentHandler.xyz_channels` for details.""" return self.align_handle.xyz_channels @property def xyz_track(self) -> np.ndarray: """See :meth:`AlignmentHandler.xyz_track` for details.""" return self.align_handle.xyz_track @property def track_lines(self) -> list[np.ndarray]: """See :meth:`AlignmentHandler.track_lines` for details.""" return self.align_handle.track_lines # ------------------------------------------------------------------------- # Plot properties - methods and attributes in loaders['plots'] # ------------------------------------------------------------------------- @property def chn_min(self) -> float: """ Return the minimum y channel value for the currently active shank. Returns ------- float: The minimum channel value, or 0 if the minimum is positive. """ return np.min([0, self.loaders['plots'].chn_min]) @property def chn_max(self) -> float: """ Return the maximum y channel value for the currently active shank. Returns ------- float: The maximum channel value """ return self.loaders['plots'].chn_max @property def y_min(self) -> float: """ Return the minimum y channel value. Returns ------- float: The minimum channel value """ return self.loaders['plots'].chn_min @property def y_max(self) -> float: """ Return the maximum y channel value. Returns ------- float: The maximum channel value """ return self.loaders['plots'].chn_max @property def image_plots(self) -> Bunch[str, ImageData]: """ Access the image plots for the currently active shank. Returns ------- Bunch: A bunch of available slice plots. """ return self.loaders['plots'].image_plots @property def scatter_plots(self) -> Bunch[str, ScatterData]: """ Access the scatter plots for the currently active shank. Returns ------- Bunch: A bunch of available slice plots. """ return self.loaders['plots'].scatter_plots @property def line_plots(self) -> Bunch[str, LineData]: """ Access the slice plots for the currently active shank. Returns ------- Bunch: A bunch of available slice plots. """ return self.loaders['plots'].line_plots @property def probe_plots(self) -> Bunch[str, ProbeData]: """ Access the probe plots for the currently active shank. Returns ------- Bunch: A bunch of available probe plots. """ return self.loaders['plots'].probe_plots @property def slice_plots(self) -> Bunch[str, Bunch]: """ Access the slice plots from the current shank's plot loader. Returns ------- Bunch: A bunch of available slice plots. """ return self.loaders['plots'].slice_plots @property def feature_plots(self) -> Bunch[str, Bunch]: """ Access the feature plots from the current shank's plot loader. Returns ------- Bunch: A bunch of available feature plots. """ return self.loaders['plots'].feature_plots
[docs] def reset_levels(self) -> None: """Reset the levels for all image, scatter, line and probe plots.""" for plot in [self.image_plots, self.scatter_plots, self.line_plots, self.probe_plots]: for _, data in plot.items(): data.levels = np.copy(data.default_levels)
# ------------------------------------------------------------------------- # Methods of current class # ------------------------------------------------------------------------- @property def xyz_clusters(self) -> np.ndarray: """ Return the xyz cluster locations estimated using the fit from the track and feature lines. The values in the current index in the circular buffer for the currently active shank are used. Returns ------- np.ndarray xyz positions of clusters in 3D space """ clust = self.raw_data['clusters']['channels'][self.loaders['plots'].cluster_idx] return self.xyz_channels[clust]
[docs] def load_data(self) -> None: """Load the geometry, ephys and alignment data.""" if self.data_loaded: return # Load the geometry data self.loaders['geom'].get_geometry() shank_sites = self.loaders['geom'].get_sites_for_shank(self.shank_idx) self.chn_sites = self.loaders['geom'].get_sites_for_shank(self.shank_idx, sites='channels') # Load in the spike sorting and ephys data self.raw_data = self.loaders['data'].get_data(shank_sites) # Load in the raw data snippets self.raw_data['raw_snippets'] = self.loaders['ephys'].load_ap_snippets() # Load in the features data if self.loaders.get('features', None) is not None: self.raw_data['features'] = self.loaders['features'].load_features() else: self.raw_data['features'] = Bunch(exists=False) # Create the plot data using the raw data self.loaders['plots'].get_data(self.raw_data, shank_sites) # These are the locations of the channels and clusters from spikesorting on the probe self.chn_coords = self.chn_sites['sites_coords'] self.chn_depths = self.chn_coords[:, 1] if self.raw_data['clusters']['exists']: self.cluster_chns = self.raw_data['clusters']['channels'] elif self.chn_depths is not None: self.cluster_chns = np.arange(self.chn_depths.size) if self.chn_coords is not None and self.loaders['align'].xyz_picks is not None: # Load the alignment handler self.align_handle = AlignmentHandler( self.loaders['align'].xyz_picks, self.chn_depths, self.loaders['upload'].brain_atlas, ) self.set_init_alignment() # Load in the histology data self.loaders['plots'].slice_plots = self.loaders['hist'].get_slices( self.align_handle.xyz_samples ) else: self.align_exists = False self.loaders['plots'].slice_plots = Bunch() self.data_loaded = True
[docs] def load_plots(self): """Load all the plot data for the current shank.""" self.loaders['plots'].get_plots()
[docs] def filter_units(self, filter_type: str) -> None: """ Filter the spikesorting data by selected unit type and recompute plot data. Parameters ---------- filter_type: str The type of unit to filter by """ self.loaders['plots'].filter_units(filter_type) self.loaders['plots'].compute_rasters() self.loaders['plots'].get_plots()
[docs] def upload_data(self) -> str: """Upload the data, save the channels and the alignments.""" data = { 'chn_coords': self.chn_coords, 'xyz_channels': self.align_handle.xyz_channels, 'feature': self.align_handle.feature.tolist(), 'track': self.align_handle.track.tolist(), 'alignments': self.loaders['align'].alignments, 'cluster_chns': self.cluster_chns, 'probe_collection': self.loaders['data'].probe_collection, 'chn_depths': self.chn_depths, 'xyz_picks': self.loaders['align'].xyz_picks, } return self.loaders['upload'].upload_data(data, shank_sites=self.chn_sites)