Source code for brainbox.plot

"""
Plots metrics that assess quality of single units. Some functions here generate plots for the
output of functions in the brainbox `single_units.py` module.

Run the following to set-up the workspace to run the docstring examples:
>>> from brainbox import processing
>>> import one.alf.io as alfio
>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> import ibllib.ephys.spikes as e_spks
# (*Note, if there is no 'alf' directory, make 'alf' directory from 'ks2' output directory):
>>> e_spks.ks2_to_alf(path_to_ks_out, path_to_alf_out)
# Load the alf spikes bunch and clusters bunch, and get a units bunch.
>>> spks_b = alfio.load_object(path_to_alf_out, 'spikes')
>>> clstrs_b = alfio.load_object(path_to_alf_out, 'clusters')
>>> units_b = processing.get_units_bunch(spks_b)  # may take a few mins to compute
"""

import time
from warnings import warn

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# from matplotlib.ticker import StrMethodFormatter
from brainbox import singlecell
from brainbox.metrics import single_units
from brainbox.io.spikeglx import extract_waveforms
from iblutil.numerical import bincount2D
import spikeglx


[docs] def feat_vars(units_b, units=None, feat_name='amps', dist='norm', test='ks', cmap_name='coolwarm', ax=None): ''' Plots the coefficients of variation of a particular spike feature for all units as a bar plot, where each bar is color-coded corresponding to the depth of the max amplitude channel of the respective unit. Parameters ---------- units_b : bunch A units bunch containing fields with spike information (e.g. cluster IDs, times, features, etc.) for all units. units : array-like (optional) A subset of all units for which to create the bar plot. (If `None`, all units are used) feat_name : string (optional) The spike feature to plot. dist : string (optional) The type of hypothetical null distribution from which the empirical spike feature distributions are presumed to belong to. test : string (optional) The statistical test used to calculate the probability that the empirical spike feature distributions come from `dist`. cmap_name : string (optional) The name of the colormap associated with the plot. ax : axessubplot (optional) The axis handle to plot the histogram on. (if `None`, a new figure and axis is created) Returns ------- cv_vals : ndarray The coefficients of variation of `feat_name` for each unit. p_vals : ndarray The probabilites that the distribution for `feat_name` for each unit comes from a `dist` distribution based on the `test` statistical test. See Also -------- metrics.unit_stability Examples -------- 1) Create a bar plot of the coefficients of variation of the spike amplitudes for all units. >>> fig, var_vals, p_vals = bb.plot.feat_vars(units_b) ''' # Get units. if not (units is None): # we're using a subset of all units unit_list = list(units_b['depths'].keys()) # For each unit in `unit_list`, remove unit from `units_b` if not in `units`. [units_b['depths'].pop(unit) for unit in unit_list if not (int(unit) in units)] unit_list = list(units_b['depths'].keys()) # get new `unit_list` after removing unit # Calculate coefficients of variation for all units p_vals_b, cv_b = single_units.unit_stability( units_b, units=units, feat_names=[feat_name], dist=dist, test=test) cv_vals = np.array(tuple(cv_b[feat_name].values())) cv_vals = cv_vals * 1e6 if feat_name == 'amps' else cv_vals # convert to uV if amps p_vals = np.array(tuple(p_vals_b[feat_name].values())) # Remove any empty units. This must be done AFTER the above calculations for ALL units so that # we can keep direct indexing. empty_unit_idxs = np.where([len(units_b['times'][unit]) == 0 for unit in unit_list])[0] good_units = [unit for unit in unit_list if unit not in empty_unit_idxs.astype(str)] # Get mean depths of spikes for good units depths = np.asarray([np.mean(units_b['depths'][str(unit)]) for unit in good_units]) # Create unit normalized colormap based on `depths`, sorted by depth. cmap = plt.cm.get_cmap(cmap_name) depths_norm = depths / np.max(depths) rgba = np.asarray([cmap(depth) for depth in np.sort(np.flip(depths_norm))]) # Plot depth-color-coded h bar plot of CVs for `feature` for each unit, where units are # sorted descendingly by depth along y-axis. if ax is None: fig, ax = plt.subplots() ax.barh(y=[int(unit) for unit in good_units], width=cv_vals[np.argsort(depths)], color=rgba) fig = ax.figure cbar = fig.colorbar(plt.cm.ScalarMappable(cmap=cmap), ax=ax) max_d = np.max(depths) tick_labels = [int(max_d * tick) for tick in (0, 0.2, 0.4, 0.6, 0.8, 1.0)] cbar.set_ticks(cbar.get_ticks()) # must call `set_ticks` to call `set_ticklabels` cbar.set_ticklabels(tick_labels) ax.set_title('CV of {feat}'.format(feat=feat_name)) ax.set_ylabel('Unit Number (sorted by depth)') ax.set_xlabel('CV') cbar.set_label('Depth', rotation=-90) return cv_vals, p_vals
[docs] def missed_spikes_est(feat, feat_name, spks_per_bin=20, sigma=5, min_num_bins=50, ax=None): ''' Plots the pdf of an estimated symmetric spike feature distribution, with a vertical cutoff line that indicates the approximate fraction of spikes missing from the distribution, assuming the true distribution is symmetric. Parameters ---------- feat : ndarray The spikes' feature values. feat_name : string The spike feature to plot. spks_per_bin : int (optional) The number of spikes per bin from which to compute the spike feature histogram. sigma : int (optional) The standard deviation for the gaussian kernel used to compute the pdf from the spike feature histogram. min_num_bins : int (optional) The minimum number of bins used to compute the spike feature histogram. ax : axessubplot (optional) The axis handle to plot the histogram on. (if `None`, a new figure and axis is created) Returns ------- fraction_missing : float The fraction of missing spikes (0-0.5). *Note: If more than 50% of spikes are missing, an accurate estimate isn't possible. See Also -------- single_units.feature_cutoff Examples -------- 1) Plot cutoff line indicating the fraction of spikes missing from a unit based on the recorded unit's spike amplitudes, assuming the distribution of the unit's spike amplitudes is symmetric. >>> feat = units_b['amps']['1'] >>> fraction_missing = bb.plot.missed_spikes_est(feat, feat_name='amps', unit=1) ''' # Calculate the feature distribution histogram and fraction of spikes missing. fraction_missing, pdf, cutoff_idx = \ single_units.missed_spikes_est(feat, spks_per_bin, sigma, min_num_bins) # Plot. if ax is None: # create two axes fig, ax = plt.subplots(nrows=1, ncols=2) if ax is None or len(ax) == 2: # plot histogram and pdf on two separate axes num_bins = int(feat.size / spks_per_bin) ax[0].hist(feat, bins=num_bins) ax[0].set_xlabel('{0}'.format(feat_name)) ax[0].set_ylabel('Count') ax[0].set_title('Histogram of {0}'.format(feat_name)) ax[1].plot(pdf) ax[1].vlines(cutoff_idx, 0, np.max(pdf), colors='r') ax[1].set_xlabel('Bin Number') ax[1].set_ylabel('Density') ax[1].set_title('PDF Symmetry Cutoff\n' '(estimated {:.2f}% missing spikes)'.format(fraction_missing * 100)) else: # just plot pdf ax = ax[0] ax.plot(pdf) ax.vlines(cutoff_idx, 0, np.max(pdf), colors='r') ax.set_xlabel('Bin Number') ax.set_ylabel('Density') ax.set_title('PDF Symmetry Cutoff\n' '(estimated {:.2f}% missing spikes)'.format(fraction_missing * 100)) return fraction_missing
[docs] def wf_comp(ephys_file, ts1, ts2, ch, sr=30000, n_ch_probe=385, dtype='int16', car=True, col=['b', 'r'], ax=None): ''' Plots two different sets of waveforms across specified channels after (optionally) common-average-referencing. In this way, waveforms can be compared to see if there is, e.g. drift during the recording, or if two units should be merged, or one unit should be split. Parameters ---------- ephys_file : string The file path to the binary ephys data. ts1 : array_like A set of timestamps for which to compare waveforms with `ts2`. ts2: array_like A set of timestamps for which to compare waveforms with `ts1`. ch : array-like The channels to use for extracting and plotting the waveforms. sr : int (optional) The sampling rate (in hz) that the ephys data was acquired at. n_ch_probe : int (optional) The number of channels of the recording. dtype: str (optional) The datatype represented by the bytes in `ephys_file`. car: bool (optional) A flag for whether or not to perform common-average-referencing before extracting waveforms col: list of strings or float arrays (optional) Two elements in the list, where each specifies the color the `ts1` and `ts2` waveforms will be plotted in, respectively. ax : axessubplot (optional) The axis handle to plot the histogram on. (if `None`, a new figure and axis is created) Returns ------- wf1 : ndarray The waveforms for the spikes in `ts1`: an array of shape (#spikes, #samples, #channels). wf2 : ndarray The waveforms for the spikes in `ts2`: an array of shape (#spikes, #samples, #channels). s : float The similarity score between the two sets of waveforms, calculated by `single_units.wf_similarity` See Also -------- io.extract_waveforms single_units.wf_similarity Examples -------- 1) Compare first and last 100 spike waveforms for unit1, across 20 channels around the channel of max amplitude, and compare the waveforms in the first minute to the waveforms in the fourth minutes for unit2, across 10 channels around the mean. # Get first and last 100 spikes, and 20 channels around channel of max amp for unit 1: >>> ts1 = units_b['times']['1'][:100] >>> ts2 = units_b['times']['1'][-100:] >>> max_ch = clstrs_b['channels'][1] >>> if max_ch < n_c_ch: # take only channels greater than `max_ch`. >>> ch = np.arange(max_ch, max_ch + 20) >>> elif (max_ch + n_c_ch) > n_ch_probe: # take only channels less than `max_ch`. >>> ch = np.arange(max_ch - 20, max_ch) >>> else: # take `n_c_ch` around `max_ch`. >>> ch = np.arange(max_ch - 10, max_ch + 10) >>> wf1, wf2, s = bb.plot.wf_comp(path_to_ephys_file, ts1, ts2, ch) # Plot waveforms for unit2 from the first and fourth minutes across 10 channels. >>> ts = units_b['times']['2'] >>> ts1_2 = ts[np.where(ts<60)[0]] >>> ts2_2 = ts[np.where(ts>180)[0][:len(ts1)]] >>> max_ch = clstrs_b['channels'][2] >>> if max_ch < n_c_ch: # take only channels greater than `max_ch`. >>> ch = np.arange(max_ch, max_ch + 10) >>> elif (max_ch + n_c_ch) > n_ch_probe: # take only channels less than `max_ch`. >>> ch = np.arange(max_ch - 10, max_ch) >>> else: # take `n_c_ch` around `max_ch`. >>> ch = np.arange(max_ch - 5, max_ch + 5) >>> wf1_2, wf2_2, s_2 = bb.plot.wf_comp(path_to_ephys_file, ts1_2, ts2_2, ch) ''' # Ensure `ch` is ndarray ch = np.asarray(ch) ch = ch.reshape((ch.size, 1)) if ch.size == 1 else ch # Extract the waveforms for these timestamps and compute similarity score. wf1 = extract_waveforms(ephys_file, ts1, ch, sr=sr, n_ch_probe=n_ch_probe, dtype=dtype, car=car) wf2 = extract_waveforms(ephys_file, ts2, ch, sr=sr, n_ch_probe=n_ch_probe, dtype=dtype, car=car) s = single_units.wf_similarity(wf1, wf2) # Plot these waveforms against each other. n_ch = ch.size if ax is None: fig, ax = plt.subplots(nrows=n_ch, ncols=2) # left col is all waveforms, right col is mean for cur_ax, cur_ch in enumerate(ch): ax[cur_ax][0].plot(wf1[:, :, cur_ax].T, c=col[0]) ax[cur_ax][0].plot(wf2[:, :, cur_ax].T, c=col[1]) ax[cur_ax][1].plot(np.mean(wf1[:, :, cur_ax], axis=0), c=col[0]) ax[cur_ax][1].plot(np.mean(wf2[:, :, cur_ax], axis=0), c=col[1]) ax[cur_ax][0].set_ylabel('Ch {0}'.format(cur_ch)) ax[0][0].set_title('All Waveforms. S = {:.2f}'.format(s)) ax[0][1].set_title('Mean Waveforms') plt.legend(['1st spike set', '2nd spike set']) return wf1, wf2, s
[docs] def amp_heatmap(ephys_file, ts, ch, sr=30000, n_ch_probe=385, dtype='int16', cmap_name='RdBu', car=True, ax=None): ''' Plots a heatmap of the normalized voltage values over time and space for given timestamps and channels, after (optionally) common-average-referencing. Parameters ---------- ephys_file : string The file path to the binary ephys data. ts: array_like A set of timestamps for which to get the voltage values. ch : array-like The channels to use for extracting the voltage values. sr : int (optional) The sampling rate (in hz) that the ephys data was acquired at. n_ch_probe : int (optional) The number of channels of the recording. dtype: str (optional) The datatype represented by the bytes in `ephys_file`. cmap_name : string (optional) The name of the colormap associated with the plot. car: bool (optional) A flag for whether or not to perform common-average-referencing before extracting waveforms ax : axessubplot (optional) The axis handle to plot the histogram on. (if `None`, a new figure and axis is created) Returns ------- v_vals : ndarray The voltage values. Examples -------- 1) Plot a heatmap of the spike amplitudes across 20 channels around the channel of max amplitude for all spikes in unit 1. >>> ts = units_b['times']['1'] >>> max_ch = clstrs_b['channels'][1] >>> if max_ch < n_c_ch: # take only channels greater than `max_ch`. >>> ch = np.arange(max_ch, max_ch + 20) >>> elif (max_ch + n_c_ch) > n_ch_probe: # take only channels less than `max_ch`. >>> ch = np.arange(max_ch - 20, max_ch) >>> else: # take `n_c_ch` around `max_ch`. >>> ch = np.arange(max_ch - 10, max_ch + 10) >>> bb.plot.amp_heatmap(path_to_ephys_file, ts, ch) ''' # Ensure `ch` is ndarray ch = np.asarray(ch) ch = ch.reshape((ch.size, 1)) if ch.size == 1 else ch # Get memmapped array of `ephys_file` s_reader = spikeglx.Reader(ephys_file, open=True) file_m = s_reader.data # Get voltage values for each peak amplitude sample for `ch`. max_amp_samples = (ts * sr).astype(int) # Currently this is an annoying way to calculate `v_vals` b/c indexing with multiple values # is currently unsupported. v_vals = np.zeros((max_amp_samples.size, ch.size)) for sample in range(max_amp_samples.size): v_vals[sample] = file_m[max_amp_samples[sample]:max_amp_samples[sample] + 1, ch] if car: # compute spatial noise in chunks, and subtract from `v_vals`. # Get subset of time (from first to last max amp sample) n_chunk_samples = 5e6 # number of samples per chunk n_chunks = np.ceil((max_amp_samples[-1] - max_amp_samples[0]) / n_chunk_samples).astype('int') # Get samples that make up each chunk. e.g. `chunk_sample[1] - chunk_sample[0]` are the # samples that make up the first chunk. chunk_sample = np.arange(max_amp_samples[0], max_amp_samples[-1], n_chunk_samples, dtype=int) chunk_sample = np.append(chunk_sample, max_amp_samples[-1]) noise_s_chunks = np.zeros((n_chunks, ch.size), dtype=np.int16) # spatial noise array # Give time estimate for computing `noise_s_chunks`. t0 = time.perf_counter() np.median(file_m[chunk_sample[0]:chunk_sample[1], ch], axis=0) dt = time.perf_counter() - t0 print('Performing spatial CAR before waveform extraction. Estimated time is {:.2f} mins.' ' ({})'.format(dt * n_chunks / 60, time.ctime())) # Compute noise for each chunk, then take the median noise of all chunks. for chunk in range(n_chunks): noise_s_chunks[chunk, :] = np.median( file_m[chunk_sample[chunk]:chunk_sample[chunk + 1], ch], axis=0) noise_s = np.median(noise_s_chunks, axis=0) v_vals -= noise_s[None, :] print('Done. ({})'.format(time.ctime())) s_reader.close() # Plot heatmap. if ax is None: fig, ax = plt.subplots() v_vals_norm = (v_vals / np.max(abs(v_vals))).T cbar_map = ax.imshow(v_vals_norm, cmap=cmap_name, aspect='auto', extent=[ts[0], ts[-1], ch[0], ch[-1]], origin='lower') ax.set_yticks(np.arange(ch[0], ch[-1], 5)) ax.set_ylabel('Channel Numbers') ax.set_xlabel('Time (s)') ax.set_title('Voltage Heatmap') fig = ax.figure cbar = fig.colorbar(cbar_map, ax=ax) cbar.set_label('V', rotation=-90) return v_vals
[docs] def firing_rate(ts, hist_win=0.01, fr_win=0.5, n_bins=10, show_fr_cv=True, ax=None): ''' Plots the instantaneous firing rate of for given spike timestamps over time, and optionally overlays the value of the coefficient of variation of the firing rate for a specified number of bins. Parameters ---------- ts : ndarray The spike timestamps from which to compute the firing rate. hist_win : float (optional) The time window (in s) to use for computing spike counts. fr_win : float (optional) The time window (in s) to use as a moving slider to compute the instantaneous firing rate. n_bins : int (optional) The number of bins in which to compute coefficients of variation of the firing rate. show_fr_cv : bool (optional) A flag for whether or not to compute and show the coefficients of variation of the firing rate for `n_bins`. ax : axessubplot (optional) The axis handle to plot the histogram on. (if `None`, a new figure and axis is created) Returns ------- fr: ndarray The instantaneous firing rate over time (in hz). cv: float The mean coefficient of variation of the firing rate of the `n_bins` number of coefficients computed. Can only be returned if `show_fr_cv` is True. cvs: ndarray The coefficients of variation of the firing for each bin of `n_bins`. Can only be returned if `show_fr_cv` is True. See Also -------- single_units.firing_rate_cv singecell.firing_rate Examples -------- 1) Plot the firing rate for unit 1 from the time of its first to last spike, showing the cv of the firing rate for 10 evenly spaced bins. >>> ts = units_b['times']['1'] >>> fr, cv, cvs = bb.plot.firing_rate(ts) ''' if ax is None: fig, ax = plt.subplots() if not (show_fr_cv): # compute just the firing rate fr = singlecell.firing_rate(ts, hist_win=hist_win, fr_win=fr_win) else: # compute firing rate and coefficients of variation cv, cvs, fr = single_units.firing_rate_coeff_var(ts, hist_win=hist_win, fr_win=fr_win, n_bins=n_bins) x = np.arange(fr.size) * hist_win ax.plot(x, fr) ax.set_title('Firing Rate') ax.set_xlabel('Time (s)') ax.set_ylabel('Rate (s$^-1$)') if not (show_fr_cv): return fr else: # show coefficients of variation y_max = np.max(fr) * 1.05 x_l = x[int(x.size / n_bins)] # Plot vertical lines separating plots into `n_bins`. [ax.vlines((x_l * i), 0, y_max, linestyles='dashed', linewidth=2) for i in range(1, n_bins)] # Plot text with cv of firing rate for each bin. [ax.text(x_l * (i + 1), y_max, 'cv={0:.2f}'.format(cvs[i]), fontsize=9, ha='right') for i in range(n_bins)] return fr, cv, cvs
[docs] def peri_event_time_histogram( spike_times, spike_clusters, events, cluster_id, # Everything you need for a basic plot t_before=0.2, t_after=0.5, bin_size=0.025, smoothing=0.025, as_rate=True, include_raster=False, n_rasters=None, error_bars='std', ax=None, pethline_kwargs={'color': 'blue', 'lw': 2}, errbar_kwargs={'color': 'blue', 'alpha': 0.5}, eventline_kwargs={'color': 'black', 'alpha': 0.5}, raster_kwargs={'color': 'black', 'lw': 0.5}, **kwargs): """ Plot peri-event time histograms, with the meaning firing rate of units centered on a given series of events. Can optionally add a raster underneath the PETH plot of individual spike trains about the events. Parameters ---------- spike_times : array_like Spike times (in seconds) spike_clusters : array-like Cluster identities for each element of spikes events : array-like Times to align the histogram(s) to cluster_id : int Identity of the cluster for which to plot a PETH t_before : float, optional Time before event to plot (default: 0.2s) t_after : float, optional Time after event to plot (default: 0.5s) bin_size :float, optional Width of bin for histograms (default: 0.025s) smoothing : float, optional Sigma of gaussian smoothing to use in histograms. (default: 0.025s) as_rate : bool, optional Whether to use spike counts or rates in the plot (default: `True`, uses rates) include_raster : bool, optional Whether to put a raster below the PETH of individual spike trains (default: `False`) n_rasters : int, optional If include_raster is True, the number of rasters to include. If `None` will default to plotting rasters around all provided events. (default: `None`) error_bars : {'std', 'sem', 'none'}, optional Defines which type of error bars to plot. Options are: -- `'std'` for 1 standard deviation -- `'sem'` for standard error of the mean -- `'none'` for only plotting the mean value (default: `'std'`) ax : matplotlib axes, optional If passed, the function will plot on the passed axes. Note: current behavior causes whatever was on the axes to be cleared before plotting! (default: `None`) pethline_kwargs : dict, optional Dict containing line properties to define PETH plot line. Default is a blue line with weight of 2. Needs to have color. See matplotlib plot documentation for more options. (default: `{'color': 'blue', 'lw': 2}`) errbar_kwargs : dict, optional Dict containing fill-between properties to define PETH error bars. Default is a blue fill with 50 percent opacity.. Needs to have color. See matplotlib fill_between documentation for more options. (default: `{'color': 'blue', 'alpha': 0.5}`) eventline_kwargs : dict, optional Dict containing fill-between properties to define line at event. Default is a black line with 50 percent opacity.. Needs to have color. See matplotlib vlines documentation for more options. (default: `{'color': 'black', 'alpha': 0.5}`) raster_kwargs : dict, optional Dict containing properties defining lines in the raster plot. Default is black lines with line width of 0.5. See matplotlib vlines for more options. (default: `{'color': 'black', 'lw': 0.5}`) Returns ------- ax : matplotlib axes Axes with all of the plots requested. """ # Check to make sure if we fail, we fail in an informative way if not len(spike_times) == len(spike_clusters): raise ValueError('Spike times and clusters are not of the same shape') if len(events) == 1: raise ValueError('Cannot make a PETH with only one event.') if error_bars not in ('std', 'sem', 'none'): raise ValueError('Invalid error bar type was passed.') if not all(np.isfinite(events)): raise ValueError('There are NaN or inf values in the list of events passed. ' ' Please remove non-finite data points and try again.') # Compute peths peths, binned_spikes = singlecell.calculate_peths(spike_times, spike_clusters, [cluster_id], events, t_before, t_after, bin_size, smoothing, as_rate) # Construct an axis object if none passed if ax is None: plt.figure() ax = plt.gca() # Plot the curve and add error bars mean = peths.means[0, :] ax.plot(peths.tscale, mean, **pethline_kwargs) if error_bars == 'std': bars = peths.stds[0, :] elif error_bars == 'sem': bars = peths.stds[0, :] / np.sqrt(len(events)) else: bars = np.zeros_like(mean) if error_bars != 'none': ax.fill_between(peths.tscale, mean - bars, mean + bars, **errbar_kwargs) # Plot the event marker line. Extends to 5% higher than max value of means plus any error bar. plot_edge = (mean.max() + bars[mean.argmax()]) * 1.05 ax.vlines(0., 0., plot_edge, **eventline_kwargs) # Set the limits on the axes to t_before and t_after. Either set the ylim to the 0 and max # values of the PETH, or if we want to plot a spike raster below, create an equal amount of # blank space below the zero where the raster will go. ax.set_xlim([-t_before, t_after]) ax.set_ylim([-plot_edge if include_raster else 0., plot_edge]) # Put y ticks only at min, max, and zero if mean.min() != 0: ax.set_yticks([0, mean.min(), mean.max()]) else: ax.set_yticks([0., mean.max()]) # Move the x axis line from the bottom of the plotting space to zero if including a raster, # Then plot the raster if include_raster: if n_rasters is None: n_rasters = len(events) if n_rasters > 60: warn("Number of raster traces is greater than 60. This might look bad on the plot.") ax.axhline(0., color='black') tickheight = plot_edge / len(events[:n_rasters]) # How much space per trace tickedges = np.arange(0., -plot_edge - 1e-5, -tickheight) clu_spks = spike_times[spike_clusters == cluster_id] for i, t in enumerate(events[:n_rasters]): idx = np.bitwise_and(clu_spks >= t - t_before, clu_spks <= t + t_after) event_spks = clu_spks[idx] ax.vlines(event_spks - t, tickedges[i + 1], tickedges[i], **raster_kwargs) ax.set_ylabel('Firing Rate' if as_rate else 'Number of spikes', y=0.75) else: ax.set_ylabel('Firing Rate' if as_rate else 'Number of spikes') ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) ax.set_xlabel('Time (s) after event') return ax
[docs] def driftmap(ts, feat, ax=None, plot_style='bincount', t_bin=0.01, d_bin=20, weights=None, vmax=None, **kwargs): """ Plots the values of a spike feature array (y-axis) over time (x-axis). Two arguments can be given for the plot_style of the drift map: - 'scatter' : whereby each value is plotted as a marker (up to 100'000 data point) - 'bincount' : whereby the values are binned (optimised to represent spike raster) Parameters ---------- feat : ndarray The spikes' feature values. ts : ndarray The spike timestamps from which to compute the firing rate. ax : axessubplot (optional) The axis handle to plot the histogram on. (if `None`, a new figure and axis is created) t_bin: time bin used when plot_style='bincount' d_bin: depth bin used when plot_style='bincount' plot_style: 'scatter', 'bincount' **kwargs: matplotlib.imshow arguments Returns ------- cd: float The cumulative drift of `feat`. md: float The maximum drift of `feat`. See Also -------- metrics.cum_drift metrics.max_drift Examples -------- 1) Plot the amplitude driftmap for unit 1. >>> ts = units_b['times']['1'] >>> amps = units_b['amps']['1'] >>> ax = bb.plot.driftmap(ts, amps) 2) Plot the depth driftmap for unit 1. >>> ts = units_b['times']['1'] >>> depths = units_b['depths']['1'] >>> ax = bb.plot.driftmap(ts, depths) """ iok = ~np.isnan(feat) if ax is None: fig, ax = plt.subplots() if plot_style == 'scatter' and len(ts) < 100000: print('here todo') if 'color' not in kwargs.keys(): kwargs['color'] = 'k' ax.plot(ts, feat, **kwargs) else: # compute raster map as a function of site depth R, times, depths = bincount2D( ts[iok], feat[iok], t_bin, d_bin, weights=weights) # plot raster map ax.imshow(R, aspect='auto', cmap='binary', vmin=0, vmax=vmax or np.std(R) * 4, extent=np.r_[times[[0, -1]], depths[[0, -1]]], origin='lower', **kwargs) ax.set_xlabel('time (secs)') ax.set_ylabel('depth (um)') return ax
[docs] def pres_ratio(ts, hist_win=10, ax=None): ''' Plots the presence ratio of spike counts: the number of bins where there is at least one spike, over the total number of bins, given a specified bin width. Parameters ---------- ts : ndarray The spike timestamps from which to compute the presence ratio. hist_win : float The time window (in s) to use for computing the presence ratio. ax : axessubplot (optional) The axis handle to plot the histogram on. (if `None`, a new figure and axis is created) Returns ------- pr : float The presence ratio. spks_bins : ndarray The number of spks in each bin. See Also -------- metrics.pres_ratio Examples -------- 1) Plot the presence ratio for unit 1, given a window of 10 s. >>> ts = units_b['times']['1'] >>> pr, pr_bins = bb.plot.pres_ratio(ts) ''' pr, spks_bins = single_units.pres_ratio(ts, hist_win) pr_bins = np.where(spks_bins > 0, 1, 0) if ax is None: fig, ax = plt.subplots() ax.plot(pr_bins) ax.set_xlabel('Bin Number (width={:.1f}s)'.format(hist_win)) ax.set_ylabel('Presence') ax.set_title('Presence Ratio') return pr, spks_bins
[docs] def driftmap_color( clusters_depths, spikes_times, spikes_amps, spikes_depths, spikes_clusters, ax=None, axesoff=False, return_lims=False): ''' Plots the driftmap of a session or a trial The plot shows the spike times vs spike depths. Each dot is a spike, whose color indicates the cluster and opacity indicates the spike amplitude. Parameters ------------- clusters_depths: ndarray depths of all clusters spikes_times: ndarray spike times of all clusters spikes_amps: ndarray amplitude of each spike spikes_depths: ndarray depth of each spike spikes_clusters: ndarray cluster idx of each spike ax: matplotlib.axes.Axes object (optional) The axis object to plot the driftmap on (if `None`, a new figure and axis is created) Return --- ax: matplotlib.axes.Axes object The axis object with driftmap plotted x_lim: list of two elements range of x axis y_lim: list of two elements range of y axis ''' color_bins = sns.color_palette("hls", 500) new_color_bins = np.vstack( np.transpose(np.reshape(color_bins, [5, 100, 3]), [1, 0, 2])) # get the sorted idx of each depth, and create colors based on the idx sorted_idx = np.argsort(np.argsort(clusters_depths)) colors = np.vstack( [np.repeat( new_color_bins[np.mod(idx, 500), :][np.newaxis, ...], n_spikes, axis=0) for (idx, n_spikes) in zip(sorted_idx, np.unique(spikes_clusters, return_counts=True)[1])]) max_amp = np.percentile(spikes_amps, 90) min_amp = np.percentile(spikes_amps, 10) opacity = np.divide(spikes_amps - min_amp, max_amp - min_amp) opacity[opacity > 1] = 1 opacity[opacity < 0] = 0 colorvec = np.zeros([len(opacity), 4], dtype='float16') colorvec[:, 3] = opacity.astype('float16') colorvec[:, 0:3] = colors.astype('float16') x = spikes_times.astype('float32') y = spikes_depths.astype('float32') args = dict(color=colorvec, edgecolors='none') if ax is None: fig = plt.Figure(dpi=200, frameon=False, figsize=[10, 10]) ax = plt.Axes(fig, [0.1, 0.1, 0.9, 0.9]) ax.set_xlabel('Time (sec)') ax.set_ylabel('Distance from the probe tip (um)') savefig = True args.update(s=0.1) ax.scatter(x, y, **args) x_edge = (max(x) - min(x)) * 0.05 x_lim = [min(x) - x_edge, max(x) + x_edge] y_lim = [min(y) - 50, max(y) + 100] ax.set_xlim(x_lim[0], x_lim[1]) ax.set_ylim(y_lim[0], y_lim[1]) if axesoff: ax.axis('off') if savefig: fig.add_axes(ax) fig.savefig('driftmap.png') if return_lims: return ax, x_lim, y_lim else: return ax