import numpy as np
from matplotlib import cm
import matplotlib.pyplot as plt
from brainbox.plot_base import (ImagePlot, ScatterPlot, ProbePlot, LinePlot, plot_line,
plot_image, plot_probe, plot_scatter, arrange_channels2banks)
from brainbox.processing import compute_cluster_average
from iblutil.numerical import bincount2D
from iblatlas.regions import BrainRegions
[docs]
def image_lfp_spectrum_plot(lfp_power, lfp_freq, chn_coords=None, chn_inds=None, freq_range=(0, 300),
avg_across_depth=False, clim=None, cmap='viridis', display=False, title=None, **kwargs):
"""
Prepare data for 2D image plot of LFP power spectrum along depth of probe
:param lfp_power:
:param lfp_freq:
:param chn_depths:
:param chn_inds:
:param freq_range:
:param avg_across_depth: Whether to average across channels at same depth
:param cmap:
:param display: generate figure
:return: ImagePlot object, if display=True also returns matplotlib fig and ax objects
"""
ylabel = 'Channel index' if chn_coords is None else 'Distance from probe tip (um)'
title = title or 'LFP Power Spectrum'
y = np.arange(lfp_power.shape[1]) if chn_coords is None else chn_coords[:, 1]
chn_inds = np.arange(lfp_power.shape[1]) if chn_inds is None else chn_inds
freq_idx = np.where((lfp_freq >= freq_range[0]) & (lfp_freq < freq_range[1]))[0]
freqs = lfp_freq[freq_idx]
lfp = np.take(lfp_power[freq_idx], chn_inds, axis=1)
lfp_db = 10 * np.log10(lfp)
lfp_db[np.isinf(lfp_db)] = np.nan
x = freqs
# Average across channels that are at the same depth
if avg_across_depth:
chn_depth, chn_idx, chn_count = np.unique(y, return_index=True,
return_counts=True)
chn_idx_eq = np.copy(chn_idx)
chn_idx_eq[np.where(chn_count == 2)] += 1
lfp_db = np.apply_along_axis(lambda a: np.mean([a[chn_idx], a[chn_idx_eq]], axis=0), 1,
lfp_db)
x = freqs
y = chn_depth
data = ImagePlot(lfp_db, x=x, y=y, cmap=cmap)
data.set_labels(title=title, xlabel='Frequency (Hz)',
ylabel=ylabel, clabel='LFP Power (dB)')
clim = clim or np.quantile(lfp_db, [0.1, 0.9])
data.set_clim(clim=clim)
if display:
ax, fig = plot_image(data.convert2dict(), **kwargs)
return data.convert2dict(), fig, ax
return data
[docs]
def image_rms_plot(rms_amps, rms_times, chn_coords=None, chn_inds=None, avg_across_depth=False,
median_subtract=True, clim=None, cmap='plasma', band='AP', display=False, title=None, **kwargs):
"""
Prepare data for 2D image plot of RMS data along depth of probe
:param rms_amps:
:param rms_times:
:param chn_coords:
:param chn_inds:
:param avg_across_depth: Whether to average across channels at same depth
:param median_subtract: Whether to apply median subtraction correction
:param cmap:
:param band: Frequency band of rms data, can be either 'LF' or 'AP'
:param display: generate figure
:return: ImagePlot object, if display=True also returns matplotlib fig and ax objects
"""
ylabel = 'Channel index' if chn_coords is None else 'Distance from probe tip (um)'
title = title or f'{band} RMS'
chn_inds = np.arange(rms_amps.shape[1]) if chn_inds is None else chn_inds
y = np.arange(rms_amps.shape[1]) if chn_coords is None else chn_coords[:, 1]
rms = rms_amps[:, chn_inds]
rms = 10 * np.log10(rms)
x = rms_times
if avg_across_depth:
chn_depth, chn_idx, chn_count = np.unique(y, return_index=True, return_counts=True)
chn_idx_eq = np.copy(chn_idx)
chn_idx_eq[np.where(chn_count == 2)] += 1
rms = np.apply_along_axis(lambda a: np.mean([a[chn_idx], a[chn_idx_eq]], axis=0), 1, rms)
y = chn_depth
if median_subtract:
median = np.mean(np.apply_along_axis(lambda a: np.median(a), 1, rms))
rms = np.apply_along_axis(lambda a: a - np.median(a), 1, rms) + median
data = ImagePlot(rms, x=x, y=y, cmap=cmap)
data.set_labels(title=title, xlabel='Time (s)', ylabel=ylabel, clabel=f'{band} RMS (dB)')
clim = clim or np.quantile(rms, [0.1, 0.9])
data.set_clim(clim=clim)
if display:
ax, fig = plot_image(data.convert2dict(), **kwargs)
return data.convert2dict(), fig, ax
return data
[docs]
def scatter_raster_plot(spike_amps, spike_depths, spike_times, n_amp_bins=10, cmap='BuPu',
subsample_factor=100, display=False, title=None, **kwargs):
"""
Prepare data for 2D raster plot of spikes with colour and size indicative of spike amplitude
:param spike_amps:
:param spike_depths:
:param spike_times:
:param n_amp_bins: no. of colour and size bins into which to split amplitude data
:param cmap:
:param subsample_factor: factor by which to subsample data when too many points for efficient
display
:param display: generate figure
:return: ScatterPlot object, if display=True also returns matplotlib fig and ax objects
"""
title = title or 'Spike times vs Spike depths'
amp_range = np.quantile(spike_amps, [0, 0.9])
amp_bins = np.linspace(amp_range[0], amp_range[1], n_amp_bins)
color_bin = np.linspace(0.0, 1.0, n_amp_bins + 1)
colors = (cm.get_cmap(cmap)(color_bin)[np.newaxis, :, :3][0])
spike_amps = spike_amps[0:-1:subsample_factor]
spike_colors = np.zeros((spike_amps.size, 3))
spike_size = np.zeros(spike_amps.size)
for iA in range(amp_bins.size):
if iA == (amp_bins.size - 1):
idx = np.where(spike_amps > amp_bins[iA])[0]
# Make saturated spikes the darkest colour
spike_colors[idx] = colors[-1]
else:
idx = np.where((spike_amps > amp_bins[iA]) & (spike_amps <= amp_bins[iA + 1]))[0]
spike_colors[idx] = [*colors[iA]]
spike_size[idx] = iA / (n_amp_bins / 8)
data = ScatterPlot(x=spike_times[0:-1:subsample_factor], y=spike_depths[0:-1:subsample_factor],
c=spike_amps * 1e6, cmap='BuPu')
data.set_ylim((0, 3840))
data.set_color(color=spike_colors)
data.set_clim(clim=amp_range * 1e6)
data.set_marker_size(marker_size=spike_size)
data.set_labels(title=title, xlabel='Time (s)',
ylabel='Distance from probe tip (um)', clabel='Spike amplitude (uV)')
if display:
ax, fig = plot_scatter(data.convert2dict(), **kwargs)
return data.convert2dict(), fig, ax
return data
[docs]
def image_fr_plot(spike_depths, spike_times, chn_coords, t_bin=0.05, d_bin=5, cmap='binary',
display=False, title=None, **kwargs):
"""
Prepare data 2D raster plot of firing rate across recording
:param spike_depths:
:param spike_times:
:param chn_coords:
:param t_bin: time bin to average across (see also brainbox.processing.bincount2D)
:param d_bin: depth bin to average across (see also brainbox.processing.bincount2D)
:param cmap:
:param display: generate figure
:return: ImagePlot object, if display=True also returns matplotlib fig and ax objects
"""
title = title or 'Firing Rate'
n, x, y = bincount2D(spike_times, spike_depths, t_bin, d_bin,
ylim=[0, np.max(chn_coords[:, 1])])
fr = n.T / t_bin
data = ImagePlot(fr, x=x, y=y, cmap=cmap)
data.set_labels(title=title, xlabel='Time (s)',
ylabel='Distance from probe tip (um)', clabel='Firing Rate (Hz)')
data.set_clim(clim=(np.min(np.mean(fr, axis=0)), np.max(np.mean(fr, axis=0))))
if display:
ax, fig = plot_image(data.convert2dict(), **kwargs)
return data.convert2dict(), fig, ax
return data
[docs]
def image_crosscorr_plot(spike_depths, spike_times, chn_coords, t_bin=0.05, d_bin=40,
cmap='viridis', display=False, title=None, **kwargs):
"""
Prepare data for 2D cross correlation plot of data across depth
:param spike_depths:
:param spike_times:
:param chn_coords:
:param t_bin: t_bin: time bin to average across (see also brainbox.processing.bincount2D)
:param d_bin: depth bin to average across (see also brainbox.processing.bincount2D)
:param cmap:
:param display: generate figure
:return: ImagePlot object, if display=True also returns matploltlib fig and ax objects
"""
title = title or 'Correlation'
n, x, y = bincount2D(spike_times, spike_depths, t_bin, d_bin,
ylim=[0, np.max(chn_coords[:, 1])])
corr = np.corrcoef(n)
corr[np.isnan(corr)] = 0
data = ImagePlot(corr, x=y, y=y, cmap=cmap)
data.set_labels(title=title, xlabel='Distance from probe tip (um)',
ylabel='Distance from probe tip (um)', clabel='Correlation')
if display:
ax, fig = plot_image(data.convert2dict(), **kwargs)
return data.convert2dict(), fig, ax
return data
[docs]
def scatter_amp_depth_fr_plot(spike_amps, spike_clusters, spike_depths, spike_times, cmap='hot',
display=False, title=None, **kwargs):
"""
Prepare data for 2D scatter plot of cluster depth vs cluster amp with colour indicating cluster
firing rate
:param spike_amps:
:param spike_clusters:
:param spike_depths:
:param spike_times:
:param cmap:
:param display: generate figure
:return: ScatterPlot object, if display=True also returns matplotlib fig and ax objects
"""
title = title or 'Cluster depth vs amp vs firing rate'
# TODO use pandas here instead, much quicker
cluster, cluster_depth, n_cluster = compute_cluster_average(spike_clusters, spike_depths)
_, cluster_amp, _ = compute_cluster_average(spike_clusters, spike_amps)
cluster_amp = cluster_amp * 1e6
cluster_fr = n_cluster / np.max(spike_times)
data = ScatterPlot(x=cluster_amp, y=cluster_depth, c=cluster_fr, cmap=cmap)
data.set_xlim((0.9 * np.min(cluster_amp), 1.1 * np.max(cluster_amp)))
data.set_labels(title=title, xlabel='Cluster Amplitude (uV)', ylabel='Distance from probe tip (um)',
clabel='Firing rate (Hz)')
if display:
ax, fig = plot_scatter(data.convert2dict(), **kwargs)
return data.convert2dict(), fig, ax
return data
[docs]
def probe_lfp_spectrum_plot(lfp_power, lfp_freq, chn_coords, chn_inds, freq_range=(0, 4),
display=False, pad=True, x_offset=1, **kwargs):
"""
Prepare data for 2D probe plot of LFP power spectrum along depth of probe
:param lfp_power:
:param lfp_freq:
:param chn_coords:
:param chn_inds:
:param freq_range:
:param display:
:param pad: whether to add nans around the individual image plots. For matplotlib use pad=True,
for pyqtgraph use pad=False
:param x_offset: Distance between the channel banks in x direction
:return: ProbePlot object, if display=True also returns matplotlib fig and ax objects
"""
freq_idx = np.where((lfp_freq >= freq_range[0]) & (lfp_freq < freq_range[1]))[0]
lfp = np.take(lfp_power[freq_idx], chn_inds, axis=1)
lfp_db = 10 * np.log10(lfp)
lfp_db[np.isinf(lfp_db)] = np.nan
lfp_db = np.mean(lfp_db, axis=0)
data_bank, x_bank, y_bank = arrange_channels2banks(lfp_db, chn_coords, depth=None,
pad=pad, x_offset=x_offset)
data = ProbePlot(data_bank, x=x_bank, y=y_bank)
data.set_labels(ylabel='Distance from probe tip (um)', clabel='PSD 0-4 Hz (dB)')
clim = np.nanquantile(np.concatenate([np.squeeze(np.ravel(d)) for d in data_bank]).ravel(),
[0.1, 0.9])
data.set_clim(clim)
if display:
ax, fig = plot_probe(data.convert2dict(), **kwargs)
return data.convert2dict(), fig, ax
return data
[docs]
def probe_rms_plot(rms_amps, chn_coords, chn_inds, cmap='plasma', band='AP',
display=False, pad=True, x_offset=1, **kwargs):
"""
Prepare data for 2D probe plot of RMS along depth of probe
:param rms_amps:
:param chn_coords:
:param chn_inds:
:param cmap:
:param band:
:param display:
:param pad: whether to add nans around the individual image plots. For matplotlib use pad=True,
for pyqtgraph use pad=False
:param x_offset: Distance between the channel banks in x direction
:return: ProbePlot object, if display=True also returns matplotlib fig and ax objects
"""
rms = (np.mean(rms_amps, axis=0)[chn_inds]) * 1e6
data_bank, x_bank, y_bank = arrange_channels2banks(rms, chn_coords, depth=None,
pad=pad, x_offset=x_offset)
data = ProbePlot(data_bank, x=x_bank, y=y_bank, cmap=cmap)
data.set_labels(ylabel='Distance from probe tip (um)', clabel=f'{band} RMS (uV)')
clim = np.nanquantile(np.concatenate([np.squeeze(np.ravel(d)) for d in data_bank]).ravel(),
[0.1, 0.9])
data.set_clim(clim)
if display:
ax, fig = plot_probe(data.convert2dict(), **kwargs)
return data.convert2dict(), fig, ax
return data
[docs]
def line_fr_plot(spike_depths, spike_times, chn_coords, d_bin=10, display=False, title=None, **kwargs):
"""
Prepare data for 1D line plot of average firing rate across depth
:param spike_depths:
:param spike_times:
:param chn_coords:
:param d_bin: depth bin to average across (see also brainbox.processing.bincount2D)
:param display:
:return:
"""
title = title or 'Avg Firing Rate'
t_bin = np.max(spike_times)
n, x, y = bincount2D(spike_times, spike_depths, t_bin, d_bin,
ylim=[0, np.max(chn_coords[:, 1])])
mean_fr = n[:, 0] / t_bin
data = LinePlot(x=mean_fr, y=y)
data.set_xlim((0, np.max(mean_fr)))
data.set_labels(title=title, xlabel='Firing Rate (Hz)',
ylabel='Distance from probe tip (um)')
if display:
ax, fig = plot_line(data.convert2dict(), **kwargs)
return data.convert2dict(), fig, ax
return data
[docs]
def line_amp_plot(spike_amps, spike_depths, spike_times, chn_coords, d_bin=10, display=False, title=None, **kwargs):
"""
Prepare data for 1D line plot of average firing rate across depth
:param spike_amps:
:param spike_depths:
:param spike_times:
:param chn_coords:
:param d_bin: depth bin to average across (see also brainbox.processing.bincount2D)
:param display:
:return:
"""
title = title or 'Avg Amplitude'
t_bin = np.max(spike_times)
n, _, _ = bincount2D(spike_times, spike_depths, t_bin, d_bin,
ylim=[0, np.max(chn_coords[:, 1])])
amp, x, y = bincount2D(spike_times, spike_depths, t_bin, d_bin,
ylim=[0, np.max(chn_coords[:, 1])], weights=spike_amps)
mean_amp = np.divide(amp[:, 0], n[:, 0]) * 1e6
mean_amp[np.isnan(mean_amp)] = 0
remove_bins = np.where(n[:, 0] < 50)[0]
mean_amp[remove_bins] = 0
data = LinePlot(x=mean_amp, y=y)
data.set_xlim((0, np.max(mean_amp)))
data.set_labels(title=title, xlabel='Amplitude (uV)',
ylabel='Distance from probe tip (um)')
if display:
ax, fig = plot_line(data.convert2dict(), **kwargs)
return data.convert2dict(), fig, ax
return data
[docs]
def plot_brain_regions(channel_ids, channel_depths=None, brain_regions=None, display=True, ax=None,
title=None, label='left', **kwargs):
"""
Plot brain regions along probe, if channel depths is provided will plot along depth otherwise along channel idx
:param channel_ids: atlas ids for each channel
:param channel_depths: depth along probe for each channel
:param brain_regions: BrainRegions object
:param display: whether to output plot
:param ax: axis to plot on
:param title: title for plot
:param kwargs: additional keyword arguments for bar plot
:return:
"""
if channel_depths is not None:
assert channel_ids.shape[0] == channel_depths.shape[0]
else:
channel_depths = np.arange(channel_ids.shape[0])
br = brain_regions or BrainRegions()
region_info = br.get(channel_ids)
boundaries = np.where(np.diff(region_info.id) != 0)[0]
boundaries = np.r_[0, boundaries, region_info.id.shape[0] - 1]
regions = np.c_[boundaries[0:-1], boundaries[1:]]
if channel_depths is not None:
regions = channel_depths[regions]
region_labels = np.c_[np.mean(regions, axis=1), region_info.acronym[boundaries[1:]]]
region_colours = region_info.rgb[boundaries[1:]]
if display:
if ax is None:
fig, ax = plt.subplots()
else:
fig = ax.get_figure()
for reg, col in zip(regions, region_colours):
height = np.abs(reg[1] - reg[0])
bar_kwargs = dict(edgecolor='w', width=1)
bar_kwargs.update(**kwargs)
color = col / 255
ax.bar(x=0.5, height=height, color=color, bottom=reg[0], **kwargs)
if label == 'right':
ax.yaxis.tick_right()
ax.set_yticks(region_labels[:, 0].astype(int))
ax.yaxis.set_tick_params(labelsize=8)
ax.set_ylim(np.nanmin(channel_depths), np.nanmax(channel_depths))
ax.get_xaxis().set_visible(False)
ax.set_yticklabels(region_labels[:, 1])
if label == 'right':
ax.yaxis.tick_right()
ax.spines['left'].set_visible(False)
else:
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.spines['bottom'].set_visible(False)
if title:
ax.set_title(title)
return fig, ax
else:
return regions, region_labels, region_colours
[docs]
def plot_cdf(spike_amps, spike_depths, spike_times, n_amp_bins=10, d_bin=40, amp_range=None, d_range=None,
display=False, cmap='hot', ax=None):
"""
Plot cumulative amplitude of spikes across depth
:param spike_amps:
:param spike_depths:
:param spike_times:
:param n_amp_bins: number of amplitude bins to use
:param d_bin: the value of the depth bins in um (default is 40 um)
:param amp_range: amp range to use [amp_min, amp_max], if not given automatically computed from spike_amps
:param d_range: depth range to use, by default [0, 3840]
:param display: whether or not to display plot
:param cmap:
:return:
"""
amp_range = amp_range or np.quantile(spike_amps, (0, 0.9))
amp_bins = np.linspace(amp_range[0], amp_range[1], n_amp_bins)
d_range = d_range or [0, 3840]
depth_bins = np.arange(d_range[0], d_range[1] + d_bin, d_bin)
t_bin = np.max(spike_times)
def histc(x, bins):
map_to_bins = np.digitize(x, bins) # Get indices of the bins to which each value in input array belongs.
res = np.zeros(bins.shape)
for el in map_to_bins:
res[el - 1] += 1 # Increment appropriate bin.
return res
cdfs = np.empty((len(depth_bins) - 1, n_amp_bins))
for d in range(len(depth_bins) - 1):
spikes = np.bitwise_and(spike_depths > depth_bins[d], spike_depths <= depth_bins[d + 1])
h = histc(spike_amps[spikes], amp_bins) / t_bin
hcsum = np.cumsum(h[::-1])
cdfs[d, :] = hcsum[::-1]
cdfs[cdfs == 0] = np.nan
data = ImagePlot(cdfs.T, x=amp_bins * 1e6, y=depth_bins[:-1], cmap=cmap)
data.set_labels(title='Cumulative Amplitude', xlabel='Spike amplitude (uV)',
ylabel='Distance from probe tip (um)', clabel='Firing Rate (Hz)')
if display:
ax, fig = plot_image(data.convert2dict(), fig_kwargs={'figsize': [3, 7]}, ax=ax)
return data.convert2dict(), fig, ax
return data
[docs]
def image_raw_data(raw, fs, chn_coords=None, cmap='bone', title=None, display=False, gain=-90, **kwargs):
def gain2level(gain):
return 10 ** (gain / 20) * 4 * np.array([-1, 1])
ylabel = 'Channel index' if chn_coords is None else 'Distance from probe tip (um)'
title = title or 'Raw data'
y = np.arange(raw.shape[1]) if chn_coords is None else chn_coords[:, 1]
x = np.array([0, raw.shape[0] - 1]) / fs * 1e3
data = ImagePlot(raw, y=y, cmap=cmap)
data.set_labels(title=title, xlabel='Time (ms)',
ylabel=ylabel, clabel='Power (uV)')
clim = gain2level(gain)
data.set_clim(clim=clim)
data.set_xlim(xlim=x)
data.set_ylim()
if display:
ax, fig = plot_image(data.convert2dict(), **kwargs)
return data.convert2dict(), fig, ax
return data