"""
Plots metrics that assess quality of single units. Some functions here generate plots for the
output of functions in the brainbox `single_units.py` module.
Run the following to set-up the workspace to run the docstring examples:
>>> from brainbox import processing
>>> import one.alf.io as alfio
>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> import ibllib.ephys.spikes as e_spks
# (*Note, if there is no 'alf' directory, make 'alf' directory from 'ks2' output directory):
>>> e_spks.ks2_to_alf(path_to_ks_out, path_to_alf_out)
# Load the alf spikes bunch and clusters bunch, and get a units bunch.
>>> spks_b = alfio.load_object(path_to_alf_out, 'spikes')
>>> clstrs_b = alfio.load_object(path_to_alf_out, 'clusters')
>>> units_b = processing.get_units_bunch(spks_b) # may take a few mins to compute
"""
import time
from warnings import warn
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
# from matplotlib.ticker import StrMethodFormatter
from brainbox import singlecell
from brainbox.metrics import single_units
from brainbox.io.spikeglx import extract_waveforms
from iblutil.numerical import bincount2D
import spikeglx
[docs]
def feat_vars(units_b, units=None, feat_name='amps', dist='norm', test='ks', cmap_name='coolwarm',
ax=None):
'''
Plots the coefficients of variation of a particular spike feature for all units as a bar plot,
where each bar is color-coded corresponding to the depth of the max amplitude channel of the
respective unit.
Parameters
----------
units_b : bunch
A units bunch containing fields with spike information (e.g. cluster IDs, times, features,
etc.) for all units.
units : array-like (optional)
A subset of all units for which to create the bar plot. (If `None`, all units are used)
feat_name : string (optional)
The spike feature to plot.
dist : string (optional)
The type of hypothetical null distribution from which the empirical spike feature
distributions are presumed to belong to.
test : string (optional)
The statistical test used to calculate the probability that the empirical spike feature
distributions come from `dist`.
cmap_name : string (optional)
The name of the colormap associated with the plot.
ax : axessubplot (optional)
The axis handle to plot the histogram on. (if `None`, a new figure and axis is created)
Returns
-------
cv_vals : ndarray
The coefficients of variation of `feat_name` for each unit.
p_vals : ndarray
The probabilites that the distribution for `feat_name` for each unit comes from a
`dist` distribution based on the `test` statistical test.
See Also
--------
metrics.unit_stability
Examples
--------
1) Create a bar plot of the coefficients of variation of the spike amplitudes for all units.
>>> fig, var_vals, p_vals = bb.plot.feat_vars(units_b)
'''
# Get units.
if not (units is None): # we're using a subset of all units
unit_list = list(units_b['depths'].keys())
# For each unit in `unit_list`, remove unit from `units_b` if not in `units`.
[units_b['depths'].pop(unit) for unit in unit_list if not (int(unit) in units)]
unit_list = list(units_b['depths'].keys()) # get new `unit_list` after removing unit
# Calculate coefficients of variation for all units
p_vals_b, cv_b = single_units.unit_stability(
units_b, units=units, feat_names=[feat_name], dist=dist, test=test)
cv_vals = np.array(tuple(cv_b[feat_name].values()))
cv_vals = cv_vals * 1e6 if feat_name == 'amps' else cv_vals # convert to uV if amps
p_vals = np.array(tuple(p_vals_b[feat_name].values()))
# Remove any empty units. This must be done AFTER the above calculations for ALL units so that
# we can keep direct indexing.
empty_unit_idxs = np.where([len(units_b['times'][unit]) == 0 for unit in unit_list])[0]
good_units = [unit for unit in unit_list if unit not in empty_unit_idxs.astype(str)]
# Get mean depths of spikes for good units
depths = np.asarray([np.mean(units_b['depths'][str(unit)]) for unit in good_units])
# Create unit normalized colormap based on `depths`, sorted by depth.
cmap = plt.cm.get_cmap(cmap_name)
depths_norm = depths / np.max(depths)
rgba = np.asarray([cmap(depth) for depth in np.sort(np.flip(depths_norm))])
# Plot depth-color-coded h bar plot of CVs for `feature` for each unit, where units are
# sorted descendingly by depth along y-axis.
if ax is None:
fig, ax = plt.subplots()
ax.barh(y=[int(unit) for unit in good_units], width=cv_vals[np.argsort(depths)], color=rgba)
fig = ax.figure
cbar = fig.colorbar(plt.cm.ScalarMappable(cmap=cmap), ax=ax)
max_d = np.max(depths)
tick_labels = [int(max_d * tick) for tick in (0, 0.2, 0.4, 0.6, 0.8, 1.0)]
cbar.set_ticks(cbar.get_ticks()) # must call `set_ticks` to call `set_ticklabels`
cbar.set_ticklabels(tick_labels)
ax.set_title('CV of {feat}'.format(feat=feat_name))
ax.set_ylabel('Unit Number (sorted by depth)')
ax.set_xlabel('CV')
cbar.set_label('Depth', rotation=-90)
return cv_vals, p_vals
[docs]
def missed_spikes_est(feat, feat_name, spks_per_bin=20, sigma=5, min_num_bins=50, ax=None):
'''
Plots the pdf of an estimated symmetric spike feature distribution, with a vertical cutoff line
that indicates the approximate fraction of spikes missing from the distribution, assuming the
true distribution is symmetric.
Parameters
----------
feat : ndarray
The spikes' feature values.
feat_name : string
The spike feature to plot.
spks_per_bin : int (optional)
The number of spikes per bin from which to compute the spike feature histogram.
sigma : int (optional)
The standard deviation for the gaussian kernel used to compute the pdf from the spike
feature histogram.
min_num_bins : int (optional)
The minimum number of bins used to compute the spike feature histogram.
ax : axessubplot (optional)
The axis handle to plot the histogram on. (if `None`, a new figure and axis is created)
Returns
-------
fraction_missing : float
The fraction of missing spikes (0-0.5). *Note: If more than 50% of spikes are missing, an
accurate estimate isn't possible.
See Also
--------
single_units.feature_cutoff
Examples
--------
1) Plot cutoff line indicating the fraction of spikes missing from a unit based on the recorded
unit's spike amplitudes, assuming the distribution of the unit's spike amplitudes is symmetric.
>>> feat = units_b['amps']['1']
>>> fraction_missing = bb.plot.missed_spikes_est(feat, feat_name='amps', unit=1)
'''
# Calculate the feature distribution histogram and fraction of spikes missing.
fraction_missing, pdf, cutoff_idx = \
single_units.missed_spikes_est(feat, spks_per_bin, sigma, min_num_bins)
# Plot.
if ax is None: # create two axes
fig, ax = plt.subplots(nrows=1, ncols=2)
if ax is None or len(ax) == 2: # plot histogram and pdf on two separate axes
num_bins = int(feat.size / spks_per_bin)
ax[0].hist(feat, bins=num_bins)
ax[0].set_xlabel('{0}'.format(feat_name))
ax[0].set_ylabel('Count')
ax[0].set_title('Histogram of {0}'.format(feat_name))
ax[1].plot(pdf)
ax[1].vlines(cutoff_idx, 0, np.max(pdf), colors='r')
ax[1].set_xlabel('Bin Number')
ax[1].set_ylabel('Density')
ax[1].set_title('PDF Symmetry Cutoff\n'
'(estimated {:.2f}% missing spikes)'.format(fraction_missing * 100))
else: # just plot pdf
ax = ax[0]
ax.plot(pdf)
ax.vlines(cutoff_idx, 0, np.max(pdf), colors='r')
ax.set_xlabel('Bin Number')
ax.set_ylabel('Density')
ax.set_title('PDF Symmetry Cutoff\n'
'(estimated {:.2f}% missing spikes)'.format(fraction_missing * 100))
return fraction_missing
[docs]
def wf_comp(ephys_file, ts1, ts2, ch, sr=30000, n_ch_probe=385, dtype='int16', car=True,
col=['b', 'r'], ax=None):
'''
Plots two different sets of waveforms across specified channels after (optionally)
common-average-referencing. In this way, waveforms can be compared to see if there is,
e.g. drift during the recording, or if two units should be merged, or one unit should be split.
Parameters
----------
ephys_file : string
The file path to the binary ephys data.
ts1 : array_like
A set of timestamps for which to compare waveforms with `ts2`.
ts2: array_like
A set of timestamps for which to compare waveforms with `ts1`.
ch : array-like
The channels to use for extracting and plotting the waveforms.
sr : int (optional)
The sampling rate (in hz) that the ephys data was acquired at.
n_ch_probe : int (optional)
The number of channels of the recording.
dtype: str (optional)
The datatype represented by the bytes in `ephys_file`.
car: bool (optional)
A flag for whether or not to perform common-average-referencing before extracting waveforms
col: list of strings or float arrays (optional)
Two elements in the list, where each specifies the color the `ts1` and `ts2` waveforms
will be plotted in, respectively.
ax : axessubplot (optional)
The axis handle to plot the histogram on. (if `None`, a new figure and axis is created)
Returns
-------
wf1 : ndarray
The waveforms for the spikes in `ts1`: an array of shape (#spikes, #samples, #channels).
wf2 : ndarray
The waveforms for the spikes in `ts2`: an array of shape (#spikes, #samples, #channels).
s : float
The similarity score between the two sets of waveforms, calculated by
`single_units.wf_similarity`
See Also
--------
io.extract_waveforms
single_units.wf_similarity
Examples
--------
1) Compare first and last 100 spike waveforms for unit1, across 20 channels around the channel
of max amplitude, and compare the waveforms in the first minute to the waveforms in the fourth
minutes for unit2, across 10 channels around the mean.
# Get first and last 100 spikes, and 20 channels around channel of max amp for unit 1:
>>> ts1 = units_b['times']['1'][:100]
>>> ts2 = units_b['times']['1'][-100:]
>>> max_ch = clstrs_b['channels'][1]
>>> if max_ch < n_c_ch: # take only channels greater than `max_ch`.
>>> ch = np.arange(max_ch, max_ch + 20)
>>> elif (max_ch + n_c_ch) > n_ch_probe: # take only channels less than `max_ch`.
>>> ch = np.arange(max_ch - 20, max_ch)
>>> else: # take `n_c_ch` around `max_ch`.
>>> ch = np.arange(max_ch - 10, max_ch + 10)
>>> wf1, wf2, s = bb.plot.wf_comp(path_to_ephys_file, ts1, ts2, ch)
# Plot waveforms for unit2 from the first and fourth minutes across 10 channels.
>>> ts = units_b['times']['2']
>>> ts1_2 = ts[np.where(ts<60)[0]]
>>> ts2_2 = ts[np.where(ts>180)[0][:len(ts1)]]
>>> max_ch = clstrs_b['channels'][2]
>>> if max_ch < n_c_ch: # take only channels greater than `max_ch`.
>>> ch = np.arange(max_ch, max_ch + 10)
>>> elif (max_ch + n_c_ch) > n_ch_probe: # take only channels less than `max_ch`.
>>> ch = np.arange(max_ch - 10, max_ch)
>>> else: # take `n_c_ch` around `max_ch`.
>>> ch = np.arange(max_ch - 5, max_ch + 5)
>>> wf1_2, wf2_2, s_2 = bb.plot.wf_comp(path_to_ephys_file, ts1_2, ts2_2, ch)
'''
# Ensure `ch` is ndarray
ch = np.asarray(ch)
ch = ch.reshape((ch.size, 1)) if ch.size == 1 else ch
# Extract the waveforms for these timestamps and compute similarity score.
wf1 = extract_waveforms(ephys_file, ts1, ch, sr=sr, n_ch_probe=n_ch_probe, dtype=dtype,
car=car)
wf2 = extract_waveforms(ephys_file, ts2, ch, sr=sr, n_ch_probe=n_ch_probe, dtype=dtype,
car=car)
s = single_units.wf_similarity(wf1, wf2)
# Plot these waveforms against each other.
n_ch = ch.size
if ax is None:
fig, ax = plt.subplots(nrows=n_ch, ncols=2) # left col is all waveforms, right col is mean
for cur_ax, cur_ch in enumerate(ch):
ax[cur_ax][0].plot(wf1[:, :, cur_ax].T, c=col[0])
ax[cur_ax][0].plot(wf2[:, :, cur_ax].T, c=col[1])
ax[cur_ax][1].plot(np.mean(wf1[:, :, cur_ax], axis=0), c=col[0])
ax[cur_ax][1].plot(np.mean(wf2[:, :, cur_ax], axis=0), c=col[1])
ax[cur_ax][0].set_ylabel('Ch {0}'.format(cur_ch))
ax[0][0].set_title('All Waveforms. S = {:.2f}'.format(s))
ax[0][1].set_title('Mean Waveforms')
plt.legend(['1st spike set', '2nd spike set'])
return wf1, wf2, s
[docs]
def amp_heatmap(ephys_file, ts, ch, sr=30000, n_ch_probe=385, dtype='int16', cmap_name='RdBu',
car=True, ax=None):
'''
Plots a heatmap of the normalized voltage values over time and space for given timestamps and
channels, after (optionally) common-average-referencing.
Parameters
----------
ephys_file : string
The file path to the binary ephys data.
ts: array_like
A set of timestamps for which to get the voltage values.
ch : array-like
The channels to use for extracting the voltage values.
sr : int (optional)
The sampling rate (in hz) that the ephys data was acquired at.
n_ch_probe : int (optional)
The number of channels of the recording.
dtype: str (optional)
The datatype represented by the bytes in `ephys_file`.
cmap_name : string (optional)
The name of the colormap associated with the plot.
car: bool (optional)
A flag for whether or not to perform common-average-referencing before extracting waveforms
ax : axessubplot (optional)
The axis handle to plot the histogram on. (if `None`, a new figure and axis is created)
Returns
-------
v_vals : ndarray
The voltage values.
Examples
--------
1) Plot a heatmap of the spike amplitudes across 20 channels around the channel of max
amplitude for all spikes in unit 1.
>>> ts = units_b['times']['1']
>>> max_ch = clstrs_b['channels'][1]
>>> if max_ch < n_c_ch: # take only channels greater than `max_ch`.
>>> ch = np.arange(max_ch, max_ch + 20)
>>> elif (max_ch + n_c_ch) > n_ch_probe: # take only channels less than `max_ch`.
>>> ch = np.arange(max_ch - 20, max_ch)
>>> else: # take `n_c_ch` around `max_ch`.
>>> ch = np.arange(max_ch - 10, max_ch + 10)
>>> bb.plot.amp_heatmap(path_to_ephys_file, ts, ch)
'''
# Ensure `ch` is ndarray
ch = np.asarray(ch)
ch = ch.reshape((ch.size, 1)) if ch.size == 1 else ch
# Get memmapped array of `ephys_file`
s_reader = spikeglx.Reader(ephys_file, open=True)
file_m = s_reader.data
# Get voltage values for each peak amplitude sample for `ch`.
max_amp_samples = (ts * sr).astype(int)
# Currently this is an annoying way to calculate `v_vals` b/c indexing with multiple values
# is currently unsupported.
v_vals = np.zeros((max_amp_samples.size, ch.size))
for sample in range(max_amp_samples.size):
v_vals[sample] = file_m[max_amp_samples[sample]:max_amp_samples[sample] + 1, ch]
if car: # compute spatial noise in chunks, and subtract from `v_vals`.
# Get subset of time (from first to last max amp sample)
n_chunk_samples = 5e6 # number of samples per chunk
n_chunks = np.ceil((max_amp_samples[-1] - max_amp_samples[0]) /
n_chunk_samples).astype('int')
# Get samples that make up each chunk. e.g. `chunk_sample[1] - chunk_sample[0]` are the
# samples that make up the first chunk.
chunk_sample = np.arange(max_amp_samples[0], max_amp_samples[-1], n_chunk_samples,
dtype=int)
chunk_sample = np.append(chunk_sample, max_amp_samples[-1])
noise_s_chunks = np.zeros((n_chunks, ch.size), dtype=np.int16) # spatial noise array
# Give time estimate for computing `noise_s_chunks`.
t0 = time.perf_counter()
np.median(file_m[chunk_sample[0]:chunk_sample[1], ch], axis=0)
dt = time.perf_counter() - t0
print('Performing spatial CAR before waveform extraction. Estimated time is {:.2f} mins.'
' ({})'.format(dt * n_chunks / 60, time.ctime()))
# Compute noise for each chunk, then take the median noise of all chunks.
for chunk in range(n_chunks):
noise_s_chunks[chunk, :] = np.median(
file_m[chunk_sample[chunk]:chunk_sample[chunk + 1], ch], axis=0)
noise_s = np.median(noise_s_chunks, axis=0)
v_vals -= noise_s[None, :]
print('Done. ({})'.format(time.ctime()))
s_reader.close()
# Plot heatmap.
if ax is None:
fig, ax = plt.subplots()
v_vals_norm = (v_vals / np.max(abs(v_vals))).T
cbar_map = ax.imshow(v_vals_norm, cmap=cmap_name, aspect='auto',
extent=[ts[0], ts[-1], ch[0], ch[-1]], origin='lower')
ax.set_yticks(np.arange(ch[0], ch[-1], 5))
ax.set_ylabel('Channel Numbers')
ax.set_xlabel('Time (s)')
ax.set_title('Voltage Heatmap')
fig = ax.figure
cbar = fig.colorbar(cbar_map, ax=ax)
cbar.set_label('V', rotation=-90)
return v_vals
[docs]
def firing_rate(ts, hist_win=0.01, fr_win=0.5, n_bins=10, show_fr_cv=True, ax=None):
'''
Plots the instantaneous firing rate of for given spike timestamps over time, and optionally
overlays the value of the coefficient of variation of the firing rate for a specified number
of bins.
Parameters
----------
ts : ndarray
The spike timestamps from which to compute the firing rate.
hist_win : float (optional)
The time window (in s) to use for computing spike counts.
fr_win : float (optional)
The time window (in s) to use as a moving slider to compute the instantaneous firing rate.
n_bins : int (optional)
The number of bins in which to compute coefficients of variation of the firing rate.
show_fr_cv : bool (optional)
A flag for whether or not to compute and show the coefficients of variation of the firing
rate for `n_bins`.
ax : axessubplot (optional)
The axis handle to plot the histogram on. (if `None`, a new figure and axis is created)
Returns
-------
fr: ndarray
The instantaneous firing rate over time (in hz).
cv: float
The mean coefficient of variation of the firing rate of the `n_bins` number of coefficients
computed. Can only be returned if `show_fr_cv` is True.
cvs: ndarray
The coefficients of variation of the firing for each bin of `n_bins`. Can only be returned
if `show_fr_cv` is True.
See Also
--------
single_units.firing_rate_cv
singecell.firing_rate
Examples
--------
1) Plot the firing rate for unit 1 from the time of its first to last spike, showing the cv
of the firing rate for 10 evenly spaced bins.
>>> ts = units_b['times']['1']
>>> fr, cv, cvs = bb.plot.firing_rate(ts)
'''
if ax is None:
fig, ax = plt.subplots()
if not (show_fr_cv): # compute just the firing rate
fr = singlecell.firing_rate(ts, hist_win=hist_win, fr_win=fr_win)
else: # compute firing rate and coefficients of variation
cv, cvs, fr = single_units.firing_rate_coeff_var(ts, hist_win=hist_win, fr_win=fr_win,
n_bins=n_bins)
x = np.arange(fr.size) * hist_win
ax.plot(x, fr)
ax.set_title('Firing Rate')
ax.set_xlabel('Time (s)')
ax.set_ylabel('Rate (s$^-1$)')
if not (show_fr_cv):
return fr
else: # show coefficients of variation
y_max = np.max(fr) * 1.05
x_l = x[int(x.size / n_bins)]
# Plot vertical lines separating plots into `n_bins`.
[ax.vlines((x_l * i), 0, y_max, linestyles='dashed', linewidth=2)
for i in range(1, n_bins)]
# Plot text with cv of firing rate for each bin.
[ax.text(x_l * (i + 1), y_max, 'cv={0:.2f}'.format(cvs[i]), fontsize=9, ha='right')
for i in range(n_bins)]
return fr, cv, cvs
[docs]
def peri_event_time_histogram(
spike_times, spike_clusters, events, cluster_id, # Everything you need for a basic plot
t_before=0.2, t_after=0.5, bin_size=0.025, smoothing=0.025, as_rate=True,
include_raster=False, n_rasters=None, error_bars='std', ax=None,
pethline_kwargs={'color': 'blue', 'lw': 2},
errbar_kwargs={'color': 'blue', 'alpha': 0.5},
eventline_kwargs={'color': 'black', 'alpha': 0.5},
raster_kwargs={'color': 'black', 'lw': 0.5}, **kwargs):
"""
Plot peri-event time histograms, with the meaning firing rate of units centered on a given
series of events. Can optionally add a raster underneath the PETH plot of individual spike
trains about the events.
Parameters
----------
spike_times : array_like
Spike times (in seconds)
spike_clusters : array-like
Cluster identities for each element of spikes
events : array-like
Times to align the histogram(s) to
cluster_id : int
Identity of the cluster for which to plot a PETH
t_before : float, optional
Time before event to plot (default: 0.2s)
t_after : float, optional
Time after event to plot (default: 0.5s)
bin_size :float, optional
Width of bin for histograms (default: 0.025s)
smoothing : float, optional
Sigma of gaussian smoothing to use in histograms. (default: 0.025s)
as_rate : bool, optional
Whether to use spike counts or rates in the plot (default: `True`, uses rates)
include_raster : bool, optional
Whether to put a raster below the PETH of individual spike trains (default: `False`)
n_rasters : int, optional
If include_raster is True, the number of rasters to include. If `None`
will default to plotting rasters around all provided events. (default: `None`)
error_bars : {'std', 'sem', 'none'}, optional
Defines which type of error bars to plot. Options are:
-- `'std'` for 1 standard deviation
-- `'sem'` for standard error of the mean
-- `'none'` for only plotting the mean value
(default: `'std'`)
ax : matplotlib axes, optional
If passed, the function will plot on the passed axes. Note: current
behavior causes whatever was on the axes to be cleared before plotting!
(default: `None`)
pethline_kwargs : dict, optional
Dict containing line properties to define PETH plot line. Default
is a blue line with weight of 2. Needs to have color. See matplotlib plot documentation
for more options.
(default: `{'color': 'blue', 'lw': 2}`)
errbar_kwargs : dict, optional
Dict containing fill-between properties to define PETH error bars.
Default is a blue fill with 50 percent opacity.. Needs to have color. See matplotlib
fill_between documentation for more options.
(default: `{'color': 'blue', 'alpha': 0.5}`)
eventline_kwargs : dict, optional
Dict containing fill-between properties to define line at event.
Default is a black line with 50 percent opacity.. Needs to have color. See matplotlib
vlines documentation for more options.
(default: `{'color': 'black', 'alpha': 0.5}`)
raster_kwargs : dict, optional
Dict containing properties defining lines in the raster plot.
Default is black lines with line width of 0.5. See matplotlib vlines for more options.
(default: `{'color': 'black', 'lw': 0.5}`)
Returns
-------
ax : matplotlib axes
Axes with all of the plots requested.
"""
# Check to make sure if we fail, we fail in an informative way
if not len(spike_times) == len(spike_clusters):
raise ValueError('Spike times and clusters are not of the same shape')
if len(events) == 1:
raise ValueError('Cannot make a PETH with only one event.')
if error_bars not in ('std', 'sem', 'none'):
raise ValueError('Invalid error bar type was passed.')
if not all(np.isfinite(events)):
raise ValueError('There are NaN or inf values in the list of events passed. '
' Please remove non-finite data points and try again.')
# Compute peths
peths, binned_spikes = singlecell.calculate_peths(spike_times, spike_clusters, [cluster_id],
events, t_before, t_after, bin_size,
smoothing, as_rate)
# Construct an axis object if none passed
if ax is None:
plt.figure()
ax = plt.gca()
# Plot the curve and add error bars
mean = peths.means[0, :]
ax.plot(peths.tscale, mean, **pethline_kwargs)
if error_bars == 'std':
bars = peths.stds[0, :]
elif error_bars == 'sem':
bars = peths.stds[0, :] / np.sqrt(len(events))
else:
bars = np.zeros_like(mean)
if error_bars != 'none':
ax.fill_between(peths.tscale, mean - bars, mean + bars, **errbar_kwargs)
# Plot the event marker line. Extends to 5% higher than max value of means plus any error bar.
plot_edge = (mean.max() + bars[mean.argmax()]) * 1.05
ax.vlines(0., 0., plot_edge, **eventline_kwargs)
# Set the limits on the axes to t_before and t_after. Either set the ylim to the 0 and max
# values of the PETH, or if we want to plot a spike raster below, create an equal amount of
# blank space below the zero where the raster will go.
ax.set_xlim([-t_before, t_after])
ax.set_ylim([-plot_edge if include_raster else 0., plot_edge])
# Put y ticks only at min, max, and zero
if mean.min() != 0:
ax.set_yticks([0, mean.min(), mean.max()])
else:
ax.set_yticks([0., mean.max()])
# Move the x axis line from the bottom of the plotting space to zero if including a raster,
# Then plot the raster
if include_raster:
if n_rasters is None:
n_rasters = len(events)
if n_rasters > 60:
warn("Number of raster traces is greater than 60. This might look bad on the plot.")
ax.axhline(0., color='black')
tickheight = plot_edge / len(events[:n_rasters]) # How much space per trace
tickedges = np.arange(0., -plot_edge - 1e-5, -tickheight)
clu_spks = spike_times[spike_clusters == cluster_id]
for i, t in enumerate(events[:n_rasters]):
idx = np.bitwise_and(clu_spks >= t - t_before, clu_spks <= t + t_after)
event_spks = clu_spks[idx]
ax.vlines(event_spks - t, tickedges[i + 1], tickedges[i], **raster_kwargs)
ax.set_ylabel('Firing Rate' if as_rate else 'Number of spikes', y=0.75)
else:
ax.set_ylabel('Firing Rate' if as_rate else 'Number of spikes')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_xlabel('Time (s) after event')
return ax
[docs]
def driftmap(ts, feat, ax=None, plot_style='bincount',
t_bin=0.01, d_bin=20, weights=None, vmax=None, **kwargs):
"""
Plots the values of a spike feature array (y-axis) over time (x-axis).
Two arguments can be given for the plot_style of the drift map:
- 'scatter' : whereby each value is plotted as a marker (up to 100'000 data point)
- 'bincount' : whereby the values are binned (optimised to represent spike raster)
Parameters
----------
feat : ndarray
The spikes' feature values.
ts : ndarray
The spike timestamps from which to compute the firing rate.
ax : axessubplot (optional)
The axis handle to plot the histogram on. (if `None`, a new figure and axis is created)
t_bin: time bin used when plot_style='bincount'
d_bin: depth bin used when plot_style='bincount'
plot_style: 'scatter', 'bincount'
**kwargs: matplotlib.imshow arguments
Returns
-------
cd: float
The cumulative drift of `feat`.
md: float
The maximum drift of `feat`.
See Also
--------
metrics.cum_drift
metrics.max_drift
Examples
--------
1) Plot the amplitude driftmap for unit 1.
>>> ts = units_b['times']['1']
>>> amps = units_b['amps']['1']
>>> ax = bb.plot.driftmap(ts, amps)
2) Plot the depth driftmap for unit 1.
>>> ts = units_b['times']['1']
>>> depths = units_b['depths']['1']
>>> ax = bb.plot.driftmap(ts, depths)
"""
iok = ~np.isnan(feat)
if ax is None:
fig, ax = plt.subplots()
if plot_style == 'scatter' and len(ts) < 100000:
print('here todo')
if 'color' not in kwargs.keys():
kwargs['color'] = 'k'
ax.plot(ts, feat, **kwargs)
else:
# compute raster map as a function of site depth
R, times, depths = bincount2D(
ts[iok], feat[iok], t_bin, d_bin, weights=weights[iok] if weights is not None else None)
# plot raster map
ax.imshow(R, aspect='auto', cmap='binary', vmin=0, vmax=vmax or np.std(R) * 4,
extent=np.r_[times[[0, -1]], depths[[0, -1]]], origin='lower', **kwargs)
ax.set_xlabel('time (secs)')
ax.set_ylabel('depth (um)')
return ax
[docs]
def pres_ratio(ts, hist_win=10, ax=None):
'''
Plots the presence ratio of spike counts: the number of bins where there is at least one
spike, over the total number of bins, given a specified bin width.
Parameters
----------
ts : ndarray
The spike timestamps from which to compute the presence ratio.
hist_win : float
The time window (in s) to use for computing the presence ratio.
ax : axessubplot (optional)
The axis handle to plot the histogram on. (if `None`, a new figure and axis is created)
Returns
-------
pr : float
The presence ratio.
spks_bins : ndarray
The number of spks in each bin.
See Also
--------
metrics.pres_ratio
Examples
--------
1) Plot the presence ratio for unit 1, given a window of 10 s.
>>> ts = units_b['times']['1']
>>> pr, pr_bins = bb.plot.pres_ratio(ts)
'''
pr, spks_bins = single_units.pres_ratio(ts, hist_win)
pr_bins = np.where(spks_bins > 0, 1, 0)
if ax is None:
fig, ax = plt.subplots()
ax.plot(pr_bins)
ax.set_xlabel('Bin Number (width={:.1f}s)'.format(hist_win))
ax.set_ylabel('Presence')
ax.set_title('Presence Ratio')
return pr, spks_bins
[docs]
def driftmap_color(
clusters_depths, spikes_times,
spikes_amps, spikes_depths, spikes_clusters,
ax=None, axesoff=False, return_lims=False):
'''
Plots the driftmap of a session or a trial
The plot shows the spike times vs spike depths.
Each dot is a spike, whose color indicates the cluster
and opacity indicates the spike amplitude.
Parameters
-------------
clusters_depths: ndarray
depths of all clusters
spikes_times: ndarray
spike times of all clusters
spikes_amps: ndarray
amplitude of each spike
spikes_depths: ndarray
depth of each spike
spikes_clusters: ndarray
cluster idx of each spike
ax: matplotlib.axes.Axes object (optional)
The axis object to plot the driftmap on
(if `None`, a new figure and axis is created)
Return
---
ax: matplotlib.axes.Axes object
The axis object with driftmap plotted
x_lim: list of two elements
range of x axis
y_lim: list of two elements
range of y axis
'''
color_bins = sns.color_palette("hls", 500)
new_color_bins = np.vstack(
np.transpose(np.reshape(color_bins, [5, 100, 3]), [1, 0, 2]))
# get the sorted idx of each depth, and create colors based on the idx
sorted_idx = np.argsort(np.argsort(clusters_depths))
colors = np.vstack(
[np.repeat(
new_color_bins[np.mod(idx, 500), :][np.newaxis, ...],
n_spikes, axis=0)
for (idx, n_spikes) in
zip(sorted_idx, np.unique(spikes_clusters,
return_counts=True)[1])])
max_amp = np.percentile(spikes_amps, 90)
min_amp = np.percentile(spikes_amps, 10)
opacity = np.divide(spikes_amps - min_amp, max_amp - min_amp)
opacity[opacity > 1] = 1
opacity[opacity < 0] = 0
colorvec = np.zeros([len(opacity), 4], dtype='float16')
colorvec[:, 3] = opacity.astype('float16')
colorvec[:, 0:3] = colors.astype('float16')
x = spikes_times.astype('float32')
y = spikes_depths.astype('float32')
args = dict(color=colorvec, edgecolors='none')
if ax is None:
fig = plt.Figure(dpi=200, frameon=False, figsize=[10, 10])
ax = plt.Axes(fig, [0.1, 0.1, 0.9, 0.9])
ax.set_xlabel('Time (sec)')
ax.set_ylabel('Distance from the probe tip (um)')
savefig = True
args.update(s=0.1)
ax.scatter(x, y, **args)
x_edge = (max(x) - min(x)) * 0.05
x_lim = [min(x) - x_edge, max(x) + x_edge]
y_lim = [min(y) - 50, max(y) + 100]
ax.set_xlim(x_lim[0], x_lim[1])
ax.set_ylim(y_lim[0], y_lim[1])
if axesoff:
ax.axis('off')
if savefig:
fig.add_axes(ax)
fig.savefig('driftmap.png')
if return_lims:
return ax, x_lim, y_lim
else:
return ax