Source code for brainbox.ephys_plots

import numpy as np
from matplotlib import cm
import matplotlib.pyplot as plt
from brainbox.plot_base import (ImagePlot, ScatterPlot, ProbePlot, LinePlot, plot_line,
                                plot_image, plot_probe, plot_scatter, arrange_channels2banks)
from brainbox.processing import compute_cluster_average
from iblutil.numerical import bincount2D
from iblatlas.regions import BrainRegions


[docs] def image_lfp_spectrum_plot(lfp_power, lfp_freq, chn_coords=None, chn_inds=None, freq_range=(0, 300), avg_across_depth=False, clim=None, cmap='viridis', display=False, title=None, **kwargs): """ Prepare data for 2D image plot of LFP power spectrum along depth of probe :param lfp_power: :param lfp_freq: :param chn_depths: :param chn_inds: :param freq_range: :param avg_across_depth: Whether to average across channels at same depth :param cmap: :param display: generate figure :return: ImagePlot object, if display=True also returns matplotlib fig and ax objects """ ylabel = 'Channel index' if chn_coords is None else 'Distance from probe tip (um)' title = title or 'LFP Power Spectrum' y = np.arange(lfp_power.shape[1]) if chn_coords is None else chn_coords[:, 1] chn_inds = np.arange(lfp_power.shape[1]) if chn_inds is None else chn_inds freq_idx = np.where((lfp_freq >= freq_range[0]) & (lfp_freq < freq_range[1]))[0] freqs = lfp_freq[freq_idx] lfp = np.take(lfp_power[freq_idx], chn_inds, axis=1) lfp_db = 10 * np.log10(lfp) lfp_db[np.isinf(lfp_db)] = np.nan x = freqs # Average across channels that are at the same depth if avg_across_depth: chn_depth, chn_idx, chn_count = np.unique(y, return_index=True, return_counts=True) chn_idx_eq = np.copy(chn_idx) chn_idx_eq[np.where(chn_count == 2)] += 1 lfp_db = np.apply_along_axis(lambda a: np.mean([a[chn_idx], a[chn_idx_eq]], axis=0), 1, lfp_db) x = freqs y = chn_depth data = ImagePlot(lfp_db, x=x, y=y, cmap=cmap) data.set_labels(title=title, xlabel='Frequency (Hz)', ylabel=ylabel, clabel='LFP Power (dB)') clim = clim or np.quantile(lfp_db, [0.1, 0.9]) data.set_clim(clim=clim) if display: ax, fig = plot_image(data.convert2dict(), **kwargs) return data.convert2dict(), fig, ax return data
[docs] def image_rms_plot(rms_amps, rms_times, chn_coords=None, chn_inds=None, avg_across_depth=False, median_subtract=True, clim=None, cmap='plasma', band='AP', display=False, title=None, **kwargs): """ Prepare data for 2D image plot of RMS data along depth of probe :param rms_amps: :param rms_times: :param chn_coords: :param chn_inds: :param avg_across_depth: Whether to average across channels at same depth :param median_subtract: Whether to apply median subtraction correction :param cmap: :param band: Frequency band of rms data, can be either 'LF' or 'AP' :param display: generate figure :return: ImagePlot object, if display=True also returns matplotlib fig and ax objects """ ylabel = 'Channel index' if chn_coords is None else 'Distance from probe tip (um)' title = title or f'{band} RMS' chn_inds = np.arange(rms_amps.shape[1]) if chn_inds is None else chn_inds y = np.arange(rms_amps.shape[1]) if chn_coords is None else chn_coords[:, 1] rms = rms_amps[:, chn_inds] rms = 10 * np.log10(rms) x = rms_times if avg_across_depth: chn_depth, chn_idx, chn_count = np.unique(y, return_index=True, return_counts=True) chn_idx_eq = np.copy(chn_idx) chn_idx_eq[np.where(chn_count == 2)] += 1 rms = np.apply_along_axis(lambda a: np.mean([a[chn_idx], a[chn_idx_eq]], axis=0), 1, rms) y = chn_depth if median_subtract: median = np.mean(np.apply_along_axis(lambda a: np.median(a), 1, rms)) rms = np.apply_along_axis(lambda a: a - np.median(a), 1, rms) + median data = ImagePlot(rms, x=x, y=y, cmap=cmap) data.set_labels(title=title, xlabel='Time (s)', ylabel=ylabel, clabel=f'{band} RMS (dB)') clim = clim or np.quantile(rms, [0.1, 0.9]) data.set_clim(clim=clim) if display: ax, fig = plot_image(data.convert2dict(), **kwargs) return data.convert2dict(), fig, ax return data
[docs] def scatter_raster_plot(spike_amps, spike_depths, spike_times, n_amp_bins=10, cmap='BuPu', subsample_factor=100, display=False, title=None, **kwargs): """ Prepare data for 2D raster plot of spikes with colour and size indicative of spike amplitude :param spike_amps: :param spike_depths: :param spike_times: :param n_amp_bins: no. of colour and size bins into which to split amplitude data :param cmap: :param subsample_factor: factor by which to subsample data when too many points for efficient display :param display: generate figure :return: ScatterPlot object, if display=True also returns matplotlib fig and ax objects """ title = title or 'Spike times vs Spike depths' amp_range = np.quantile(spike_amps, [0, 0.9]) amp_bins = np.linspace(amp_range[0], amp_range[1], n_amp_bins) color_bin = np.linspace(0.0, 1.0, n_amp_bins + 1) colors = (cm.get_cmap(cmap)(color_bin)[np.newaxis, :, :3][0]) spike_amps = spike_amps[0:-1:subsample_factor] spike_colors = np.zeros((spike_amps.size, 3)) spike_size = np.zeros(spike_amps.size) for iA in range(amp_bins.size): if iA == (amp_bins.size - 1): idx = np.where(spike_amps > amp_bins[iA])[0] # Make saturated spikes the darkest colour spike_colors[idx] = colors[-1] else: idx = np.where((spike_amps > amp_bins[iA]) & (spike_amps <= amp_bins[iA + 1]))[0] spike_colors[idx] = [*colors[iA]] spike_size[idx] = iA / (n_amp_bins / 8) data = ScatterPlot(x=spike_times[0:-1:subsample_factor], y=spike_depths[0:-1:subsample_factor], c=spike_amps * 1e6, cmap='BuPu') data.set_ylim((0, 3840)) data.set_color(color=spike_colors) data.set_clim(clim=amp_range * 1e6) data.set_marker_size(marker_size=spike_size) data.set_labels(title=title, xlabel='Time (s)', ylabel='Distance from probe tip (um)', clabel='Spike amplitude (uV)') if display: ax, fig = plot_scatter(data.convert2dict(), **kwargs) return data.convert2dict(), fig, ax return data
[docs] def image_fr_plot(spike_depths, spike_times, chn_coords, t_bin=0.05, d_bin=5, cmap='binary', display=False, title=None, **kwargs): """ Prepare data 2D raster plot of firing rate across recording :param spike_depths: :param spike_times: :param chn_coords: :param t_bin: time bin to average across (see also brainbox.processing.bincount2D) :param d_bin: depth bin to average across (see also brainbox.processing.bincount2D) :param cmap: :param display: generate figure :return: ImagePlot object, if display=True also returns matplotlib fig and ax objects """ title = title or 'Firing Rate' n, x, y = bincount2D(spike_times, spike_depths, t_bin, d_bin, ylim=[0, np.max(chn_coords[:, 1])]) fr = n.T / t_bin data = ImagePlot(fr, x=x, y=y, cmap=cmap) data.set_labels(title=title, xlabel='Time (s)', ylabel='Distance from probe tip (um)', clabel='Firing Rate (Hz)') data.set_clim(clim=(np.min(np.mean(fr, axis=0)), np.max(np.mean(fr, axis=0)))) if display: ax, fig = plot_image(data.convert2dict(), **kwargs) return data.convert2dict(), fig, ax return data
[docs] def image_crosscorr_plot(spike_depths, spike_times, chn_coords, t_bin=0.05, d_bin=40, cmap='viridis', display=False, title=None, **kwargs): """ Prepare data for 2D cross correlation plot of data across depth :param spike_depths: :param spike_times: :param chn_coords: :param t_bin: t_bin: time bin to average across (see also brainbox.processing.bincount2D) :param d_bin: depth bin to average across (see also brainbox.processing.bincount2D) :param cmap: :param display: generate figure :return: ImagePlot object, if display=True also returns matploltlib fig and ax objects """ title = title or 'Correlation' n, x, y = bincount2D(spike_times, spike_depths, t_bin, d_bin, ylim=[0, np.max(chn_coords[:, 1])]) corr = np.corrcoef(n) corr[np.isnan(corr)] = 0 data = ImagePlot(corr, x=y, y=y, cmap=cmap) data.set_labels(title=title, xlabel='Distance from probe tip (um)', ylabel='Distance from probe tip (um)', clabel='Correlation') if display: ax, fig = plot_image(data.convert2dict(), **kwargs) return data.convert2dict(), fig, ax return data
[docs] def scatter_amp_depth_fr_plot(spike_amps, spike_clusters, spike_depths, spike_times, cmap='hot', display=False, title=None, **kwargs): """ Prepare data for 2D scatter plot of cluster depth vs cluster amp with colour indicating cluster firing rate :param spike_amps: :param spike_clusters: :param spike_depths: :param spike_times: :param cmap: :param display: generate figure :return: ScatterPlot object, if display=True also returns matplotlib fig and ax objects """ title = title or 'Cluster depth vs amp vs firing rate' # TODO use pandas here instead, much quicker cluster, cluster_depth, n_cluster = compute_cluster_average(spike_clusters, spike_depths) _, cluster_amp, _ = compute_cluster_average(spike_clusters, spike_amps) cluster_amp = cluster_amp * 1e6 cluster_fr = n_cluster / np.max(spike_times) data = ScatterPlot(x=cluster_amp, y=cluster_depth, c=cluster_fr, cmap=cmap) data.set_xlim((0.9 * np.min(cluster_amp), 1.1 * np.max(cluster_amp))) data.set_labels(title=title, xlabel='Cluster Amplitude (uV)', ylabel='Distance from probe tip (um)', clabel='Firing rate (Hz)') if display: ax, fig = plot_scatter(data.convert2dict(), **kwargs) return data.convert2dict(), fig, ax return data
[docs] def probe_lfp_spectrum_plot(lfp_power, lfp_freq, chn_coords, chn_inds, freq_range=(0, 4), display=False, pad=True, x_offset=1, **kwargs): """ Prepare data for 2D probe plot of LFP power spectrum along depth of probe :param lfp_power: :param lfp_freq: :param chn_coords: :param chn_inds: :param freq_range: :param display: :param pad: whether to add nans around the individual image plots. For matplotlib use pad=True, for pyqtgraph use pad=False :param x_offset: Distance between the channel banks in x direction :return: ProbePlot object, if display=True also returns matplotlib fig and ax objects """ freq_idx = np.where((lfp_freq >= freq_range[0]) & (lfp_freq < freq_range[1]))[0] lfp = np.take(lfp_power[freq_idx], chn_inds, axis=1) lfp_db = 10 * np.log10(lfp) lfp_db[np.isinf(lfp_db)] = np.nan lfp_db = np.mean(lfp_db, axis=0) data_bank, x_bank, y_bank = arrange_channels2banks(lfp_db, chn_coords, depth=None, pad=pad, x_offset=x_offset) data = ProbePlot(data_bank, x=x_bank, y=y_bank) data.set_labels(ylabel='Distance from probe tip (um)', clabel='PSD 0-4 Hz (dB)') clim = np.nanquantile(np.concatenate([np.squeeze(np.ravel(d)) for d in data_bank]).ravel(), [0.1, 0.9]) data.set_clim(clim) if display: ax, fig = plot_probe(data.convert2dict(), **kwargs) return data.convert2dict(), fig, ax return data
[docs] def probe_rms_plot(rms_amps, chn_coords, chn_inds, cmap='plasma', band='AP', display=False, pad=True, x_offset=1, **kwargs): """ Prepare data for 2D probe plot of RMS along depth of probe :param rms_amps: :param chn_coords: :param chn_inds: :param cmap: :param band: :param display: :param pad: whether to add nans around the individual image plots. For matplotlib use pad=True, for pyqtgraph use pad=False :param x_offset: Distance between the channel banks in x direction :return: ProbePlot object, if display=True also returns matplotlib fig and ax objects """ rms = (np.mean(rms_amps, axis=0)[chn_inds]) * 1e6 data_bank, x_bank, y_bank = arrange_channels2banks(rms, chn_coords, depth=None, pad=pad, x_offset=x_offset) data = ProbePlot(data_bank, x=x_bank, y=y_bank, cmap=cmap) data.set_labels(ylabel='Distance from probe tip (um)', clabel=f'{band} RMS (uV)') clim = np.nanquantile(np.concatenate([np.squeeze(np.ravel(d)) for d in data_bank]).ravel(), [0.1, 0.9]) data.set_clim(clim) if display: ax, fig = plot_probe(data.convert2dict(), **kwargs) return data.convert2dict(), fig, ax return data
[docs] def line_fr_plot(spike_depths, spike_times, chn_coords, d_bin=10, display=False, title=None, **kwargs): """ Prepare data for 1D line plot of average firing rate across depth :param spike_depths: :param spike_times: :param chn_coords: :param d_bin: depth bin to average across (see also brainbox.processing.bincount2D) :param display: :return: """ title = title or 'Avg Firing Rate' t_bin = np.max(spike_times) n, x, y = bincount2D(spike_times, spike_depths, t_bin, d_bin, ylim=[0, np.max(chn_coords[:, 1])]) mean_fr = n[:, 0] / t_bin data = LinePlot(x=mean_fr, y=y) data.set_xlim((0, np.max(mean_fr))) data.set_labels(title=title, xlabel='Firing Rate (Hz)', ylabel='Distance from probe tip (um)') if display: ax, fig = plot_line(data.convert2dict(), **kwargs) return data.convert2dict(), fig, ax return data
[docs] def line_amp_plot(spike_amps, spike_depths, spike_times, chn_coords, d_bin=10, display=False, title=None, **kwargs): """ Prepare data for 1D line plot of average firing rate across depth :param spike_amps: :param spike_depths: :param spike_times: :param chn_coords: :param d_bin: depth bin to average across (see also brainbox.processing.bincount2D) :param display: :return: """ title = title or 'Avg Amplitude' t_bin = np.max(spike_times) n, _, _ = bincount2D(spike_times, spike_depths, t_bin, d_bin, ylim=[0, np.max(chn_coords[:, 1])]) amp, x, y = bincount2D(spike_times, spike_depths, t_bin, d_bin, ylim=[0, np.max(chn_coords[:, 1])], weights=spike_amps) mean_amp = np.divide(amp[:, 0], n[:, 0]) * 1e6 mean_amp[np.isnan(mean_amp)] = 0 remove_bins = np.where(n[:, 0] < 50)[0] mean_amp[remove_bins] = 0 data = LinePlot(x=mean_amp, y=y) data.set_xlim((0, np.max(mean_amp))) data.set_labels(title=title, xlabel='Amplitude (uV)', ylabel='Distance from probe tip (um)') if display: ax, fig = plot_line(data.convert2dict(), **kwargs) return data.convert2dict(), fig, ax return data
[docs] def plot_brain_regions(channel_ids, channel_depths=None, brain_regions=None, display=True, ax=None, title=None, label='left', **kwargs): """ Plot brain regions along probe, if channel depths is provided will plot along depth otherwise along channel idx :param channel_ids: atlas ids for each channel :param channel_depths: depth along probe for each channel :param brain_regions: BrainRegions object :param display: whether to output plot :param ax: axis to plot on :param title: title for plot :param kwargs: additional keyword arguments for bar plot :return: """ if channel_depths is not None: assert channel_ids.shape[0] == channel_depths.shape[0] else: channel_depths = np.arange(channel_ids.shape[0]) br = brain_regions or BrainRegions() region_info = br.get(channel_ids) boundaries = np.where(np.diff(region_info.id) != 0)[0] boundaries = np.r_[0, boundaries, region_info.id.shape[0] - 1] regions = np.c_[boundaries[0:-1], boundaries[1:]] if channel_depths is not None: regions = channel_depths[regions] region_labels = np.c_[np.mean(regions, axis=1), region_info.acronym[boundaries[1:]]] region_colours = region_info.rgb[boundaries[1:]] if display: if ax is None: fig, ax = plt.subplots() else: fig = ax.get_figure() for reg, col in zip(regions, region_colours): height = np.abs(reg[1] - reg[0]) bar_kwargs = dict(edgecolor='w', width=1) bar_kwargs.update(**kwargs) color = col / 255 ax.bar(x=0.5, height=height, color=color, bottom=reg[0], **kwargs) if label == 'right': ax.yaxis.tick_right() ax.set_yticks(region_labels[:, 0].astype(int)) ax.yaxis.set_tick_params(labelsize=8) ax.set_ylim(np.nanmin(channel_depths), np.nanmax(channel_depths)) ax.get_xaxis().set_visible(False) ax.set_yticklabels(region_labels[:, 1]) if label == 'right': ax.yaxis.tick_right() ax.spines['left'].set_visible(False) else: ax.spines['right'].set_visible(False) ax.spines['top'].set_visible(False) ax.spines['bottom'].set_visible(False) if title: ax.set_title(title) return fig, ax else: return regions, region_labels, region_colours
[docs] def plot_cdf(spike_amps, spike_depths, spike_times, n_amp_bins=10, d_bin=40, amp_range=None, d_range=None, display=False, cmap='hot', ax=None): """ Plot cumulative amplitude of spikes across depth :param spike_amps: :param spike_depths: :param spike_times: :param n_amp_bins: number of amplitude bins to use :param d_bin: the value of the depth bins in um (default is 40 um) :param amp_range: amp range to use [amp_min, amp_max], if not given automatically computed from spike_amps :param d_range: depth range to use, by default [0, 3840] :param display: whether or not to display plot :param cmap: :return: """ amp_range = amp_range or np.quantile(spike_amps, (0, 0.9)) amp_bins = np.linspace(amp_range[0], amp_range[1], n_amp_bins) d_range = d_range or [0, 3840] depth_bins = np.arange(d_range[0], d_range[1] + d_bin, d_bin) t_bin = np.max(spike_times) def histc(x, bins): map_to_bins = np.digitize(x, bins) # Get indices of the bins to which each value in input array belongs. res = np.zeros(bins.shape) for el in map_to_bins: res[el - 1] += 1 # Increment appropriate bin. return res cdfs = np.empty((len(depth_bins) - 1, n_amp_bins)) for d in range(len(depth_bins) - 1): spikes = np.bitwise_and(spike_depths > depth_bins[d], spike_depths <= depth_bins[d + 1]) h = histc(spike_amps[spikes], amp_bins) / t_bin hcsum = np.cumsum(h[::-1]) cdfs[d, :] = hcsum[::-1] cdfs[cdfs == 0] = np.nan data = ImagePlot(cdfs.T, x=amp_bins * 1e6, y=depth_bins[:-1], cmap=cmap) data.set_labels(title='Cumulative Amplitude', xlabel='Spike amplitude (uV)', ylabel='Distance from probe tip (um)', clabel='Firing Rate (Hz)') if display: ax, fig = plot_image(data.convert2dict(), fig_kwargs={'figsize': [3, 7]}, ax=ax) return data.convert2dict(), fig, ax return data
[docs] def image_raw_data(raw, fs, chn_coords=None, cmap='bone', title=None, display=False, gain=-90, **kwargs): def gain2level(gain): return 10 ** (gain / 20) * 4 * np.array([-1, 1]) ylabel = 'Channel index' if chn_coords is None else 'Distance from probe tip (um)' title = title or 'Raw data' y = np.arange(raw.shape[1]) if chn_coords is None else chn_coords[:, 1] x = np.array([0, raw.shape[0] - 1]) / fs * 1e3 data = ImagePlot(raw, y=y, cmap=cmap) data.set_labels(title=title, xlabel='Time (ms)', ylabel=ylabel, clabel='Power (uV)') clim = gain2level(gain) data.set_clim(clim=clim) data.set_xlim(xlim=x) data.set_ylim() if display: ax, fig = plot_image(data.convert2dict(), **kwargs) return data.convert2dict(), fig, ax return data