Source code for ibl_alignment_gui.loaders.plot_loader

import logging
from dataclasses import dataclass
from functools import wraps
from typing import Any

import numpy as np
from matplotlib import cm, colors

from brainbox.task import passive
from ibl_alignment_gui.loaders.geometry_loader import (
    ChannelGeometry,
    arrange_channels_into_banks,
    average_chns_at_same_depths,
    pad_data_to_full_chn_map,
)
from iblutil.numerical import bincount2D
from iblutil.util import Bunch

logger = logging.getLogger(__name__)

np.seterr(divide='ignore', invalid='ignore')


[docs] def skip_missing(required_keys): """Skip method execution if required data keys are missing or false.""" def decorator(func): @wraps(func) def wrapper(self, *args, **kwargs): for key in required_keys: val = self.data[key]['exists'] if not val: return {} return func(self, *args, **kwargs) return wrapper return decorator
[docs] @dataclass class ScatterData: """ Data structure for 2D scatter plots. Attributes ---------- x : np.ndarray x-coordinates of points. y : np.ndarray y-coordinates of points. levels : list or np.ndarray Levels for colormap scaling. These can be updated by the user default_levels : list or np.ndarray Default levels for colormap scaling. colours : np.ndarray Hex colour or data values for each point. pen : string or None Colour for the outline marker of each point size : np.ndarray Size of each point. symbol : str or np.ndarray Marker symbol(s) for each point. xrange : np.ndarray Range of the x-axis. xaxis : str Label for the x-axis. title : str Plot title. cmap : str Colormap name for coloring points. cluster : bool Whether data is cluster data. """ x: np.ndarray y: np.ndarray levels: list | np.ndarray default_levels: list | np.ndarray colours: np.ndarray pen: str | None size: np.ndarray symbol: str | np.ndarray xrange: np.ndarray xaxis: str title: str cmap: str cluster: bool
[docs] @dataclass class ImageData: """ Data structure for 2D image plots. Attributes ---------- img : np.ndarray 2D array representing image values. scale : np.ndarray Scaling factors for axes (x and y). levels : list or np.ndarray Levels for colormap scaling. These can be updated by the user default_levels : list or np.ndarray Default levels for colormap scaling. offset : np.ndarray Offset for axes (x and y). xrange : np.ndarray Range of the x-axis. xaxis : str Label for the x-axis. cmap : str Colormap name. title : str Plot title. """ img: np.ndarray scale: np.ndarray levels: np.ndarray default_levels: list | np.ndarray offset: np.ndarray xrange: np.ndarray xaxis: str cmap: str title: str
[docs] @dataclass class LineData: """ Data structure for line plots. Attributes ---------- x : np.ndarray x-coordinates of the line. y : np.ndarray y-coordinates of the line. levels : list or np.ndarray Levels for colormap scaling. These can be updated by the user default_levels : list or np.ndarray Default levels for colormap scaling. xrange : np.ndarray Range of the x-axis. xaxis : str Label for the x-axis. vlines : list or None Positions of vertical lines to be drawn. mask: np.ndarray or None A boolean array indicating which poitns in the data to highlight with scatter points. mask_colour: str or None The colour to use for the mask points. mask_style: str or None The style to use for the mask points. """ x: np.ndarray y: np.ndarray levels: np.ndarray default_levels: list | np.ndarray xrange: np.ndarray xaxis: str vlines: list | None = None mask: np.ndarray | None = None mask_colour: str | None = None mask_style: str | None = None
[docs] @dataclass class ProbeData: """ Data structure for probe plots. Attributes ---------- # TODO fix docstring img : np.ndarray 2D array containing data arranged according to probe banks. scale : list or np.ndarray Scaling factor along x and y axes. levels : list or np.ndarray Levels for colormap scaling. These can be updated by the user. default_levels : list or np.ndarray Default levels for colormap scaling. offset : list or np.ndarray Offset along x and y axes. xrange : np.ndarray Range of the x-axis. cmap : str Colormap name. title : str Plot title. data : np.ndarray or None An array of the data along the depth of probe (for 3D view) boundaries : np.ndarray or None Array of boundaries for banks or regions. """ img: np.ndarray scale: np.ndarray levels: list | np.ndarray default_levels: list | np.ndarray offset: np.ndarray xrange: np.ndarray cmap: str title: str data: np.ndarray | None = None boundaries: np.ndarray | None = None
FILTER_MATCH = { 'IBL good': ('label', 1), 'KS good': ('ks2_label', 'good'), 'KS mua': ('ks2_label', 'mua'), } TBIN = 0.05 DBIN = 5 BNK_SIZE = 10
[docs] def compute_spike_average( spikes: Bunch[str, Any], clusters: Bunch[str, Any] ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """ Compute average spike amplitudes, depths, and firing rates for each cluster. Parameters ---------- spikes : Bunch Spike data containing 'amps', 'depths', 'times', 'clusters'. clusters : Bunch Cluster data containing 'channels' and 'metrics'. Returns ------- clust_idx : np.ndarray Array of cluster indices. avg_amps : np.ndarray Average spike amplitude per cluster (uV). avg_depths : np.ndarray Average depth per cluster. avg_fr : np.ndarray Average firing rate per cluster (spikes/sec). Notes ----- - Clusters with no spikes are returned as NaN. """ # Remove exists key for pandas operation exists = spikes.pop('exists') spike_df = spikes.to_df().groupby('clusters') avgs = spike_df.agg(['mean', 'count']) # Add back in for use elsewhere spikes['exists'] = exists # Some clusters don't have any spikes so we need to reindex into the original clusters data idx = avgs.index.values clust_idx = np.arange(clusters['channels'].size) avg_amps = np.full(clust_idx.size, np.nan) avg_amps[idx] = avgs['amps']['mean'].values * 1e6 avg_fr = np.full(clust_idx.size, np.nan) avg_fr[idx] = avgs['depths']['count'].values / spikes['times'].max() avg_depths = np.full(clust_idx.size, np.nan) avg_depths[idx] = avgs['depths']['mean'].values return clust_idx, avg_amps, avg_depths, avg_fr
[docs] def compute_bincount( spike_times: np.ndarray, spike_depths: np.ndarray, spike_amps: np.ndarray, xbin: float = TBIN, ybin: float = DBIN, **kwargs, ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """ Compute 2D binned spike count and amplitude over time and depth. Parameters ---------- spike_times : np.ndarray Spike times. spike_depths : np.ndarray Depths of spikes. spike_amps : np.ndarray Amplitudes of spikes. xbin : float Bin width along the x-axis (time). ybin : float Bin width along the y-axis (depth). **kwargs : Additional arguments for `bincount2D`. Returns ------- count : np.ndarray 2D binned spike counts. amp : np.ndarray 2D binned spike amplitudes. times : np.ndarray Bin edges for x-axis (time). depths : np.ndarray Bin edges for y-axis (depth). """ count, times, depths = bincount2D(spike_times, spike_depths, xbin=xbin, ybin=ybin, **kwargs) amp, times, depths = bincount2D( spike_times, spike_depths, xbin=xbin, ybin=ybin, weights=spike_amps, **kwargs ) return count, amp, times, depths
[docs] def group_bincount(arr: np.ndarray, group_size: int, axis: int = 1) -> np.ndarray: """ Average over chunks of `group_size` along the given axis. If leftover elements exist, sum them and append as the final group. Parameters ---------- arr : np.ndarray 2D array to process. group_size : int Number of elements per group to average. axis : int Axis to operate on: 0 (rows) or 1 (columns). Default is 1. Returns ------- np.ndarray Array with grouped means and a final summed group if leftovers exist. """ if arr.ndim != 2: raise ValueError('Input array must be 2D.') if axis not in (0, 1): raise ValueError('Axis must be 0 or 1.') # Transpose if operating on axis 0 to reuse logic if axis == 0: arr = arr.T num_elements = arr.shape[1] num_full = num_elements // group_size full_cols = num_full * group_size arr_full = arr[:, :full_cols] arr_extra = arr[:, full_cols:] # Compute mean over full groups arr_grouped = arr_full.reshape(arr.shape[0], num_full, group_size) arr_avg = arr_grouped.sum(axis=2) # Sum the leftover group (if any) if arr_extra.shape[1] > 0: arr_sum = arr_extra.sum(axis=1, keepdims=True) result = np.concatenate([arr_avg, arr_sum], axis=1) else: result = arr_avg return result.T if axis == 0 else result
[docs] class PlotLoader: """Class for handling plot data generation.""" def __init__(self): self.data: Bunch | None = None self.shank_sites: Bunch | None = None self.chn_min: float | None = None self.chn_max: float | None = None self.image_plots: Bunch | None = None self.probe_plots: Bunch | None = None self.line_plots: Bunch | None = None self.scatter_plots: Bunch | None = None self.feature_plots: Bunch | None = None # -------------------------------------------------------------------------------------------- # Main entry point to get all plots # --------------------------------------------------------------------------------------------
[docs] def get_data(self, data: Bunch[str, Any], shank_sites: Bunch[str, Any]): """ Get all plot data. Parameters ---------- data: Bunch A bunch containing all the spikes and ephys data required to generate plots shank_sites: Bunch A bunch containing electrode geometry information for given shank """ self.data = data self.shank_sites = shank_sites self.chn_min = self.shank_sites['sites_min'] self.chn_max = self.shank_sites['sites_max'] self.filter_units('All') self.compute_avg_cluster_activity() self.compute_rasters()
[docs] def get_plots(self): """ Get all plot data for the different plot types. Notes ----- This method sets the following attributes: self.image_plots : Bunch All plots of type image self.scatter_plots : Bunch All plots of type scatter self.line_plots : Bunch All plots of type line self.probe_plots : Bunch All plots of type probe """ self.image_plots = self._get_plots('image') self.scatter_plots = self._get_plots('scatter') self.line_plots = self._get_plots('line') self.probe_plots = self._get_plots('probe') self.feature_plots = self._get_plots('feature')
def _get_plots(self, plot_prefix: str) -> Bunch[str, Any]: """ Find and call all methods that begin with given `plot_prefix`. Parameters ---------- plot_prefix : str Prefix for plot methods (e.g., 'scatter', 'image'). Returns ------- Bunch A bunch object containing the plot data for all methods with the specified prefix. """ results = Bunch() for attr_name in dir(self): if attr_name.startswith(plot_prefix): method = getattr(self, attr_name) if callable(method): results.update(method()) return results # -------------------------------------------------------------------------------------------- # Properties # -------------------------------------------------------------------------------------------- @property def spike_amps(self) -> np.ndarray: """Get spike amplitudes for the selected spikes and non-NaN depths and amplitudes.""" return self.data['spikes']['amps'][self.spike_idx][self.kp_idx] @property def spike_depths(self) -> np.ndarray: """Get spike depths for the selected spikes and non-NaN depths and amplitudes.""" return self.data['spikes']['depths'][self.spike_idx][self.kp_idx] @property def spike_clusters(self) -> np.ndarray: """Get spike clusters for the selected spikes and non-NaN depths and amplitudes.""" return self.data['spikes']['clusters'][self.spike_idx][self.kp_idx] @property def spike_times(self) -> np.ndarray: """Get spike times for the selected spikes and non-NaN depths and amplitudes.""" return self.data['spikes']['times'][self.spike_idx][self.kp_idx] # -------------------------------------------------------------------------------------------- # Data handling # --------------------------------------------------------------------------------------------
[docs] @skip_missing(['spikes']) def compute_avg_cluster_activity(self) -> None: """ Compute average amplitude, depth and firing rate for each cluster. Notes ----- This method sets the following attributes: self.clust_id : np.ndarray Cluster identifiers. self.avg_amp : np.ndarray Average spike amplitude per cluster. self.avg_depth : np.ndarray Average spike depth per cluster. self.avg_fr : np.ndarray Average firing rate per cluster. """ self.clust_id, self.avg_amp, self.avg_depth, self.avg_fr = compute_spike_average( self.data['spikes'], self.data['clusters'] )
[docs] @skip_missing(['spikes']) def compute_rasters(self) -> None: """ Compute binned firing rate, amplitude, spike times, and depths. Notes ----- This method sets the following attributes: self.chn_min_bc : float Minimum depth boundary including spike depths. self.chn_max_bc : float Maximum depth boundary including spike depths. self.fr : np.ndarray Binned firing rate array. self.amp : np.ndarray Binned spike amplitude array. self.times : np.ndarray Binned spike time array. self.depths : np.ndarray Depth values corresponding to bins. """ self.chn_min_bc = np.min(np.r_[self.chn_min, self.spike_depths]) self.chn_max_bc = np.max(np.r_[self.chn_max, self.spike_depths]) self.fr, self.amp, self.times, self.depths = compute_bincount( self.spike_times, self.spike_depths, self.spike_amps, ylim=[self.chn_min_bc, self.chn_max_bc], )
[docs] @skip_missing(['spikes']) def filter_units(self, filter_type) -> None: """ Filter spikes according to cluster metrics. Parameters ---------- filter_type: str The filter criterion. Options are 'All', 'IBL good', 'KS good', 'KS mua'. Notes ----- This method sets the following attributes: self.cluster_idx : np.ndarray The index of clusters that match the filter criteria self.spike_idx : np.ndarray The index of spikes contained in the filtered clusters (cluster_idx) self.kp_idx : np.ndarray The index of spikes that do not have NaN values for depth and amplitude """ try: if filter_type == 'All': self.cluster_idx = np.arange(self.data['clusters'].channels.size) self.spike_idx = np.arange(self.data['spikes']['clusters'].size) else: column, condition = FILTER_MATCH[filter_type] self.cluster_idx = np.where(self.data['clusters'].metrics[column] == condition)[0] self.spike_idx = np.where( np.isin(self.data['spikes']['clusters'], self.cluster_idx) )[0] self.kp_idx = np.where( ~np.isnan(self.data['spikes']['depths'][self.spike_idx]) & ~np.isnan(self.data['spikes']['amps'][self.spike_idx]) )[0] except Exception: logger.warning(f'{filter_type} metrics not found will return all units instead') self.filter_units('All')
# -------------------------------------------------------------------------------------------- # Scatter plots # --------------------------------------------------------------------------------------------
[docs] @skip_missing(['spikes']) def scatter_firing_rate(self) -> dict[str, Any]: """ Generate data for a scatter plot of spike depths vs spike times, coloured by amplitude. Returns ------- Dict A dict containing a ScatterData object with key 'Amplitude'. Notes ----- - Spikes data is subsampled for performance. - Amplitudes are split into a_bin bins and colours set accordingly. - Saturated amplitudes, those above the 90th percentile, are coloured dark purple. """ a_bin = 10 subsample = 500 # Subsample data times = self.spike_times[::subsample] depths = self.spike_depths[::subsample] amps = self.spike_amps[::subsample] # Amplitude bins (ignore top 10% outliers) amp_range = np.quantile(amps, [0, 0.9]) amp_bins = np.linspace(amp_range[0], amp_range[1], a_bin) # Map amplitudes to bin indices bin_idx = np.digitize(amps, amp_bins, right=True) # Build colormap colour_bin = np.linspace(0.0, 1.0, a_bin + 1) colormap = cm.get_cmap('BuPu')(colour_bin)[..., :3] # Initialize colours and sizes spikes_colours = np.array(['#000000'] * amps.size) spikes_size = np.zeros(amps.size) # Assign colour and sizes according to bin index valid = bin_idx < a_bin spikes_colours[valid] = [colors.to_hex(c) for c in colormap[bin_idx[valid]]] spikes_size[valid] = bin_idx[valid] / (a_bin / 4) # For saturated amplitudes, set to dark purple and larger size saturated = bin_idx >= a_bin spikes_colours[saturated] = '#400080' spikes_size[saturated] = (a_bin - 1) / (a_bin / 4) xrange = np.array([np.min(times), np.max(times)]) scatter = ScatterData( x=times, y=depths, levels=amp_range * 1e6, default_levels=amp_range * 1e6, colours=spikes_colours, pen=None, size=spikes_size, symbol=np.array('o'), xrange=xrange, xaxis='Time (s)', title='Amplitude (uV)', cmap='BuPu', cluster=False, ) return {'Amplitude': scatter}
[docs] @skip_missing(['spikes']) def scatter_amp_depth_fr(self) -> dict[str, Any]: """ Generate data for a scatter plot of cluster depth vs. cluster amplitude. Scatter points are coloured by cluster firing rate. Returns ------- Dict A dict containing a ScatterData object with key 'Cluster Amp vs Depth vs FR'. """ levels = np.nanquantile(self.avg_fr[self.cluster_idx], [0, 1]) scatter = ScatterData( x=self.avg_amp[self.cluster_idx], y=self.avg_depth[self.cluster_idx], levels=levels, default_levels=np.copy(levels), colours=self.avg_fr[self.cluster_idx], pen='k', size=np.array(8), symbol=np.array('o'), xrange=np.array( [ 0.9 * np.nanmin(self.avg_amp[self.cluster_idx]), 1.1 * np.nanmax(self.avg_amp[self.cluster_idx]), ] ), xaxis='Amplitude (uV)', title='Firing Rate (Sp/s)', cmap='hot', cluster=True, ) return {'Cluster Amp vs Depth vs FR': scatter}
[docs] @skip_missing(['spikes']) def scatter_amp_depth_duration(self) -> dict[str, Any]: """ Generate data for a scatter plot of cluster depth vs. cluster amplitude. Scatter points are coloured by cluster peak to trough duration. Returns ------- Dict A dict containing a ScatterData object with key 'Cluster Amp vs Depth vs Duration'. """ levels = np.array([-1.5, 1.5]) scatter = ScatterData( x=self.avg_amp[self.cluster_idx], y=self.avg_depth[self.cluster_idx], levels=levels, default_levels=np.copy(levels), colours=self.data['clusters']['peakToTrough'][self.cluster_idx], pen='k', size=np.array(8), symbol=np.array('o'), xrange=np.array( [ 0.9 * np.nanmin(self.avg_amp[self.cluster_idx]), 1.1 * np.nanmax(self.avg_amp[self.cluster_idx]), ] ), xaxis='Amplitude (uV)', title='Peak to Trough duration (ms)', cmap='RdYlGn', cluster=True, ) return {'Cluster Amp vs Depth vs Duration': scatter}
[docs] @skip_missing(['spikes']) def scatter_fr_depth_amp(self) -> dict[str, Any]: """ Generate data for a scatter plot of cluster depth vs. cluster firing rate. Scatter points are coloured by cluster amplitude. Returns ------- Dict A dict containing a ScatterData object with key 'Cluster FR vs Depth vs Amp'. """ levels = np.nanquantile(self.avg_amp[self.cluster_idx], [0, 1]) scatter = ScatterData( x=self.avg_fr[self.cluster_idx], y=self.avg_depth[self.cluster_idx], levels=levels, default_levels=np.copy(levels), colours=self.avg_amp[self.cluster_idx], pen='k', size=np.array(8), symbol=np.array('o'), xrange=np.array( [ 0.9 * np.nanmin(self.avg_fr[self.cluster_idx]), 1.1 * np.nanmax(self.avg_fr[self.cluster_idx]), ] ), xaxis='Firing Rate (Sp/s)', title='Amplitude (uV)', cmap='magma', cluster=True, ) return {'Cluster FR vs Depth vs Amp': scatter}
# -------------------------------------------------------------------------------------------- # Image plots # --------------------------------------------------------------------------------------------
[docs] @skip_missing(['spikes']) def image_firing_rate(self) -> dict[str, Any]: """ Generate data for an image plot of binned firing rates across time. Returns ------- Dict A dict containing a ImageData object with key 'Firing Rate'. """ xscale = (self.times[-1] - self.times[0]) / self.fr.shape[1] yscale = (self.depths[-1] - self.depths[0]) / self.fr.shape[0] levels = np.quantile(np.mean(self.fr.T, axis=0), [0, 1]) img = ImageData( img=self.fr.T, scale=np.array([xscale, yscale]), levels=levels, default_levels=np.copy(levels), offset=np.array([0, self.chn_min]), xrange=np.array([self.times[0], self.times[-1]]), xaxis='Time (s)', cmap='binary', title='Firing Rate', ) return {'Firing Rate': img}
[docs] @skip_missing(['spikes']) def image_correlation(self) -> dict[str, Any]: """ Generate data for an image plot of the correlation of binned firing rates across depth. Returns ------- Dict A dict containing a ImageData object with key 'Correlation'. """ # Resample to 40um depth bins for correlation calculation dbin = 40 factor = int(dbin / DBIN) bincount = group_bincount(self.fr, factor, axis=0) depths = self.depths[::factor] corr = np.corrcoef(bincount) corr[np.isnan(corr)] = 0 scale = (np.max(depths) - np.min(depths)) / corr.shape[0] levels = np.array([np.min(corr), np.max(corr)]) img = ImageData( img=corr, scale=np.array([scale, scale]), levels=levels, default_levels=np.copy(levels), offset=np.array([self.chn_min, self.chn_min]), xrange=np.array([self.chn_min, self.chn_max]), cmap='viridis', title='Correlation', xaxis='Distance from probe tip (um)', ) return {'Correlation': img}
[docs] @skip_missing(['rms_AP']) def image_rms_ap(self) -> dict[str, Any]: """ Generate data for an image plot of the RMS of the AP band across time. Returns ------- Dict A dict containing a ImageData object with key 'rms_AP'. """ return self._image_rms('AP')
[docs] @skip_missing(['rms_LF']) def image_rms_lf(self) -> dict[str, Any]: """ Generate data for an image plot of the RMS of the LFP band across time. Returns ------- Dict A bunch containing a ImageData object with key 'rms_LF'. """ return self._image_rms('LF')
def _image_rms(self, band: str) -> dict[str, Any]: """ Generate data for an image plot of the RMS for the specified frequency band (AP or LF). Parameters ---------- band: str The frequency band to process (AP or LF). Returns ------- Dict A dict containing a ImageData object with key 'rms_{band}'. Notes ----- - Channels with the same depth are averaged together - The median across depths is subtracted per time point to remove striping, but the global median is added back for interpretability. - If the probe has non-contiguous channels, the output is padded with NaNs to align with the full channel map. """ # Identify channels at the same depth img = ( average_chns_at_same_depths(self.shank_sites, self.data[f'rms_{band}']['rms']) * 1e6 ) # convert to µV # Median subtract across depths (remove horizontal bands) depth_medians = np.median(img, axis=1, keepdims=True) global_median = np.mean(depth_medians) img = img - depth_medians + global_median # Reconstruct full channel map (handles gaps in channel geometry) img_full = pad_data_to_full_chn_map(self.shank_sites, img) # Scaling for plotting timestamps = self.data[f'rms_{band}']['timestamps'] xscale = (timestamps[-1] - timestamps[0]) / img_full.shape[0] yscale = (self.chn_max - self.chn_min) / img_full.shape[1] levels = np.quantile(img, [0.1, 0.9]) cmap = 'plasma' if band == 'AP' else 'inferno' img = ImageData( img=img_full, scale=np.array([xscale, yscale]), levels=levels, default_levels=np.copy(levels), offset=np.array([0, self.chn_min]), cmap=cmap, xrange=np.array([timestamps[0], timestamps[-1]]), xaxis=self.data[f'rms_{band}']['xaxis'], title=f'{band} RMS (uV)', ) return {f'rms {band}': img}
[docs] @skip_missing(['psd_LF']) def image_lfp_spectrum(self) -> dict[str, Any]: """ Generate data for an image plot of the LFP power spectrum across frequency. Returns ------- Dict A dict containing a ImageData object with key 'LF spectrum'. Notes ----- - Channels with the same depth are averaged together - The power spectrum is limited to the range 0-300 Hz - The power is converted to dB scale """ # Find frequency range freq_range = [0, 300] freq_idx = np.where( (self.data['psd_LF']['freqs'] >= freq_range[0]) & (self.data['psd_LF']['freqs'] < freq_range[1]) )[0] # Extract PSD data for the selected frequency range and selected channels lfp_power = self.data['psd_LF']['power'][freq_idx, :] lfp_power = 10 * np.log10(lfp_power) # Average data across channels at the same depth img = average_chns_at_same_depths(self.shank_sites, lfp_power) # Reconstruct full channel map (handles gaps in channel geometry) img_full = pad_data_to_full_chn_map(self.shank_sites, img) # Scaling for plotting xscale = (freq_range[-1] - freq_range[0]) / img_full.shape[0] yscale = (self.chn_max - self.chn_min) / img_full.shape[1] levels = np.quantile(img, [0.1, 0.9]) img = ImageData( img=img_full, scale=np.array([xscale, yscale]), levels=levels, default_levels=np.copy(levels), offset=np.array([0, self.chn_min]), cmap='viridis', xrange=np.array([freq_range[0], freq_range[-1]]), xaxis='Frequency (Hz)', title='PSD (dB)', ) return {'LF spectrum': img}
[docs] @skip_missing(['spikes']) def image_passive_events(self) -> dict[str, Any]: """ Generate data for image plots of the passive event aligned PSTHs. Returns ------- Dict A dict containing multiple ImageData objects with keys according to stimulus type. Notes ----- - Will only return data for passive events that are present in the data """ # Find the list of passive events that are present in the data if not self.data['pass_stim']['exists'] and not self.data['gabor']['exists']: return dict() elif not self.data['pass_stim']['exists'] and self.data['gabor']['exists']: stim_types = ['leftGabor', 'rightGabor'] stims = {stim_type: self.data['gabor'][stim_type] for stim_type in stim_types} elif self.data['pass_stim']['exists'] and not self.data['gabor']['exists']: stim_types = ['valveOn', 'toneOn', 'noiseOn'] stims = {stim_type: self.data['pass_stim'][stim_type] for stim_type in stim_types} else: stim_types = ['valveOn', 'toneOn', 'noiseOn', 'leftGabor', 'rightGabor'] stims = {stim_type: self.data['pass_stim'][stim_type] for stim_type in stim_types[0:3]} stims.update( {stim_type: self.data['gabor'][stim_type] for stim_type in stim_types[3:]} ) # Compute normalised event aligned psths base_stim = 1 pre_stim = 0.4 post_stim = 1 stim_events = passive.get_stim_aligned_activity( stims, self.spike_times, self.spike_depths, pre_stim=pre_stim, post_stim=post_stim, base_stim=base_stim, y_lim=[self.chn_min_bc, self.chn_max_bc], ) # Loop over each stimulus type and create ImageData objects passive_imgs = dict() for stim_type, aligned_img in stim_events.items(): xscale = (post_stim + pre_stim) / aligned_img.shape[1] yscale = (self.chn_max - self.chn_min) / aligned_img.shape[0] levels = np.array([-10, 10]) img = ImageData( img=aligned_img.T, scale=np.array([xscale, yscale]), levels=levels, default_levels=np.copy(levels), offset=np.array([-1 * pre_stim, self.chn_min]), cmap='bwr', xrange=np.array([-1 * pre_stim, post_stim]), xaxis='Time from Stim Onset (s)', title='Firing rate (z score)', ) passive_imgs.update({stim_type: img}) return passive_imgs
[docs] @skip_missing(['raw_snippets']) def image_raw_data(self) -> dict[str, Any]: """ Generate data for image plots of raw ephys data snippets. Returns ------- Dict A dict containing multiple ImageData objects with keys according to the time of the snippet during the recording. """ raw_imgs = dict() for i, (t, raw_img) in enumerate(self.data['raw_snippets']['images'].items()): x_range = np.array([0, raw_img.shape[0] - 1]) / self.data['raw_snippets']['fs'] * 1e3 xscale = (x_range[1] - x_range[0]) / raw_img.shape[0] yscale = (self.chn_max - self.chn_min) / raw_img.shape[1] levels = 10 ** (-90 / 20) * 4 * np.array([-1, 1]) img = ImageData( img=raw_img, scale=np.array([xscale, yscale]), levels=levels, default_levels=np.copy(levels), offset=np.array([0, self.chn_min]), cmap='bone', xrange=x_range, xaxis='Time (ms)', title=f'Power (uV) T={int(t)} s', ) raw_imgs[f'Raw ap snippet {i}'] = img return raw_imgs
# -------------------------------------------------------------------------------------------- # Line plots # --------------------------------------------------------------------------------------------
[docs] @skip_missing(['spikes']) def line_firing_rate(self) -> dict[str, Any]: """ Generate data for a line plot of depth vs firing rate averaged across time. Returns ------- Dict A dict containing a LineData object with key 'Firing Rate'. """ # Resample to 10um depth bins for smoother depth profile dbin = 10 factor = int(dbin / DBIN) bincount = group_bincount(self.fr, factor, axis=0) depths = self.depths[::factor] mean_fr = np.mean(bincount, axis=1) line = LineData( x=mean_fr, y=depths, xrange=np.array([0, np.max(mean_fr)]), levels=np.array([0, np.max(mean_fr)]), default_levels=np.array([0, np.max(mean_fr)]), xaxis='Firing Rate (Sp/s)', ) return {'Firing Rate': line}
[docs] @skip_missing(['spikes']) def line_amplitude(self) -> dict[str, Any]: """ Generate data for a line plot of depth vs amplitude averaged across time. Returns ------- Dict A dict containing a LineData object with key 'Amplitude'. """ # Resample to 10um depth bins for smoother depth profile dbin = 10 factor = int(dbin / DBIN) bincount = group_bincount(self.amp, factor, axis=0) depths = self.depths[::factor] mean_amp = np.mean(bincount, axis=1) * 1e6 line = LineData( x=mean_amp, y=depths, xrange=np.array([0, np.max(mean_amp)]), levels=np.array([0, np.max(mean_amp)]), default_levels=np.array([0, np.max(mean_amp)]), xaxis='Amplitude (uV)', ) return {'Amplitude': line}
[docs] @skip_missing(['raw_snippets']) def line_dead_channels(self) -> dict[str, Any]: """ Generate data for a line plot of dead channels across depth. Returns ------- Dict A dict containing a LineData object with key 'Dead Channels'. """ data = self.data['raw_snippets']['dead_channels'] min_level = np.min([np.min(data['lines']) * 1.1, np.nanmin(data['values'])]) max_level = np.max([np.max(data['lines']) * 1.1, np.nanmax(data['values'])]) levels = np.array([min_level, max_level]) line = LineData( x=data['values'], y=self.shank_sites['sites_y'], xrange=levels, levels=np.copy(levels), default_levels=np.copy(levels), xaxis='High coherence', vlines=data['lines'], mask=data['points'], mask_colour='k', mask_style='star', ) return {'Dead Channels': line}
[docs] @skip_missing(['raw_snippets']) def line_noisy_channels_coherence(self) -> dict[str, Any]: """ Generate data for a line plot of noisy channels across depth. Noisy channels in this plot are identified based on high coherence. Returns ------- Dict A dict containing a LineData object with key 'Noisy Channels Coherence'. """ data = self.data['raw_snippets']['noisy_channels_coherence'] min_level = np.min([np.min(data['lines']) * 1.1, np.nanmin(data['values'])]) max_level = np.max([np.max(data['lines']) * 1.1, np.nanmax(data['values'])]) levels = np.array([min_level, max_level]) line = LineData( x=data['values'], y=self.shank_sites['sites_y'], xrange=levels, levels=np.copy(levels), default_levels=np.copy(levels), xaxis='High coherence', vlines=data['lines'], mask=data['points'], mask_colour='r', mask_style='star', ) return {'Noisy Channels Coherence': line}
[docs] @skip_missing(['raw_snippets']) def line_noisy_channels_psd(self) -> dict[str, Any]: """ Generate data for a line plot of noisy channels across depth. Noisy channels in this plot are identified based on high PSD. Returns ------- Dict A dict containing a LineData object with key 'Noisy Channels PSD'. """ data = self.data['raw_snippets']['noisy_channels_psd'] min_level = np.min([np.min(data['lines']) * 1.1, np.nanmin(data['values'])]) max_level = np.max([np.max(data['lines']) * 1.1, np.nanmax(data['values'])]) levels = np.array([min_level, max_level]) line = LineData( x=data['values'], y=self.shank_sites['sites_y'], xrange=levels, levels=np.copy(levels), default_levels=np.copy(levels), xaxis='PSD', vlines=data['lines'], mask=data['points'], mask_colour='r', mask_style='star', ) return {'Noisy Channels PSD': line}
[docs] @skip_missing(['raw_snippets']) def line_outside_channels(self) -> dict[str, Any]: """ Generate data for a line plot of outide channels across depth. Returns ------- Dict A dict containing a LineData object with key 'Outside Channels'. """ data = self.data['raw_snippets']['outside_channels'] min_level = np.min([np.min(data['lines']) * 1.1, np.nanmin(data['values'])]) max_level = np.max([np.max(data['lines']) * 1.1, np.nanmax(data['values'])]) levels = np.array([min_level, max_level]) line = LineData( x=data['values'], y=self.shank_sites['sites_y'], xrange=levels, levels=np.copy(levels), default_levels=np.copy(levels), xaxis='Low coherence', vlines=data['lines'], mask=data['points'], mask_colour='y', mask_style='star', ) return {'Outside Channels': line}
# -------------------------------------------------------------------------------------------- # Probe plots # --------------------------------------------------------------------------------------------
[docs] @skip_missing(['rms_AP']) def probe_rms_ap(self) -> dict[str, Any]: """ Generate data for a probe plot of the RMS of the AP band averaged across time. Returns ------- Dict A dict containing a ProbeData object with key 'rms_AP'. """ return self._probe_rms('AP')
[docs] @skip_missing(['rms_LF']) def probe_rms_lf(self) -> dict[str, Any]: """ Generate data for a probe plot of the RMS of the LFP band averaged across time. Returns ------- Dict A dict containing a ProbeData object with key 'rms_LF'. """ return self._probe_rms('LF')
def _probe_rms(self, band: str) -> dict[str, Any]: """ Generate data for a probe plot of the RMS for the specified frequency band (AP or LF). Parameters ---------- band: str The frequency band to process (AP or LF). Returns ------- Dict A dict containing a ProbeData object with key 'rms_{band}'. """ # Average data across time rms_avg = np.mean(self.data[f'rms_{band}']['rms'], axis=0) * 1e6 levels = np.quantile(rms_avg, [0.1, 0.9]) # Split the data into banks of channels according to the probe geometry probe_img, probe_scale, probe_offset = arrange_channels_into_banks( self.shank_sites, rms_avg, bnk_width=BNK_SIZE ) cmap = 'plasma' if band == 'AP' else 'inferno' probe = ProbeData( img=probe_img, scale=probe_scale, offset=probe_offset, levels=levels, default_levels=np.copy(levels), cmap=cmap, xrange=np.array([0 * BNK_SIZE, (self.shank_sites['n_banks']) * BNK_SIZE]), title=band + ' RMS (uV)', data=rms_avg, ) return {f'rms {band}': probe}
[docs] @skip_missing(['psd_LF']) def probe_lfp_spectrum(self) -> dict[str, Any]: """ Generate data for probe plots of the LFP power averaged across different frequency bands. Returns ------- Dict A dict containing multiple ProbeData objects with keys according to frequency bands. """ # Define frequency bands freq_bands = np.vstack(([0, 4], [4, 10], [10, 30], [30, 80], [80, 200])) data_probe = dict() for freq in freq_bands: freq_idx = np.where( (self.data['psd_LF']['freqs'] >= freq[0]) & (self.data['psd_LF']['freqs'] < freq[1]) )[0] lfp_power = np.mean(self.data['psd_LF']['power'][freq_idx], axis=0) lfp_power = 10 * np.log10(lfp_power) probe_img, probe_scale, probe_offset = arrange_channels_into_banks( self.shank_sites, lfp_power, bnk_width=BNK_SIZE ) levels = np.quantile(lfp_power, [0.1, 0.9]) probe = ProbeData( img=probe_img, scale=probe_scale, offset=probe_offset, levels=levels, default_levels=np.copy(levels), cmap='viridis', xrange=np.array([0 * BNK_SIZE, (self.shank_sites['n_banks']) * BNK_SIZE]), title=f'{freq[0]}-{freq[1]} Hz (dB)', data=lfp_power, ) data_probe.update({f'{freq[0]} - {freq[1]} Hz': probe}) return data_probe
[docs] @skip_missing(['spikes', 'rf_map']) def probe_rfmap(self) -> dict[str, Any]: """ Generate data for probe plots of the Receptive Field map (on and off) across depth. Returns ------- Dict A dict containing ProbeData objects with for keys 'RF Map - on' and 'RF Map - off'. Notes ----- - Although this is a probe plot the data is not split into banks as for the case of other probe plots. """ # Extract stimulus times and positions rf_map_times, rf_map_pos, rf_stim_frames = passive.get_on_off_times_and_positions( self.data['rf_map'] ) # Compute receptive field map over depth rf_map, _ = passive.get_rf_map_over_depth( rf_map_times, rf_map_pos, rf_stim_frames, self.spike_times, self.spike_depths, d_bin=160, y_lim=[self.chn_min_bc, self.chn_max_bc], ) # Apply SVD decomposition to obtain ON and OFF maps rfs_svd = passive.get_svd_map(rf_map) img = {} img['on'] = np.vstack(rfs_svd['on']) img['off'] = np.vstack(rfs_svd['off']) # Scaling yscale = (self.chn_max - self.chn_min) / img['on'].shape[0] xscale = 1 depths = np.linspace(self.chn_min, self.chn_max, len(rfs_svd['on']) + 1) levels = np.quantile(np.c_[img['on'], img['off']], [0, 1]) data_img = dict() sub_type = ['on', 'off'] for sub in sub_type: sub_data = { f'RF Map - {sub}': ProbeData( img=img[sub].T, scale=np.array([xscale, yscale]), levels=levels, default_levels=np.copy(levels), offset=np.array([0, self.chn_min]), cmap='viridis', xrange=np.array([0, 15]), title='rfmap (dB)', boundaries=depths, data=None, ) } data_img.update(sub_data) return data_img
# -------------------------------------------------------------------------------------------- # Feature plots # --------------------------------------------------------------------------------------------
[docs] @skip_missing(['features']) def feature_ephys_atlas(self): """ Generate data for ephys atlas feature plots. Returns ------- Dict A dict containing multiple ProbeData objects with keys according to features. """ ignore_cols = [ 'pid', 'axial_um', 'lateral_um', 'x', 'y', 'z', 'acronym', 'atlas_id', 'x_target', 'y_target', 'z_target', 'outside', 'Allen_id', 'Cosmos_id', 'Beryl_id', ] feature_data = self.data['features']['df'] chn_coords = Bunch() chn_coords['localCoordinates'] = np.c_[ feature_data['lateral_um'].values, feature_data['axial_um'].values ] chn_coords['rawInd'] = np.arange(chn_coords['localCoordinates'].shape[0]) chn_geom = ChannelGeometry(chn_coords) chn_geom.split_sites_per_shank() sites = chn_geom._get_sites_for_shank(0) features = [k for k in feature_data if k not in ignore_cols] features.sort() data = Bunch() for i, feature in enumerate(features): vals = feature_data[feature].values min_val = np.nanmin(vals) max_val = np.nanmax(vals) feature_norm = (vals - min_val) / (max_val - min_val) img, scale, offset = arrange_channels_into_banks(sites, feature_norm) offset[0] += i * (10 * sites['n_banks']) feat = ProbeData( img=img, scale=scale, offset=offset, levels=np.array([0, 1]), default_levels=np.array([0, 1]), cmap='viridis', xrange=np.array([0, 10 * sites['n_banks']]), title=feature, data=None, ) data[feature] = feat return {'Ephys Atlas': data}