Source code for brainbox.metrics.single_units

"""
Computes metrics for assessing quality of single units.

Run the following to set-up the workspace to run the docstring examples:
>>> import brainbox as bb
>>> import one.alf.io as aio
>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> import ibllib.ephys.spikes as e_spks
# (*Note, if there is no 'alf' directory, make 'alf' directory from 'ks2' output directory):
>>> e_spks.ks2_to_alf(path_to_ks_out, path_to_alf_out)
# Load the alf spikes bunch and clusters bunch, and get a units bunch.
>>> spks_b = aio.load_object(path_to_alf_out, 'spikes')
>>> clstrs_b = aio.load_object(path_to_alf_out, 'clusters')
>>> units_b = bb.processing.get_units_bunch(spks_b)  # may take a few mins to compute
"""

import time
import logging

import numpy as np
from scipy.ndimage import gaussian_filter1d
import scipy.stats as stats
import pandas as pd

import spikeglx
from phylib.stats import correlograms
from iblutil.util import Bunch
from iblutil.numerical import ismember, between_sorted, bincount2D
from slidingRP import metrics

from brainbox import singlecell
from brainbox.io.spikeglx import extract_waveforms
from brainbox.metrics import electrode_drift


_logger = logging.getLogger('ibllib')

# Parameters to be used in `quick_unit_metrics`
METRICS_PARAMS = {
    'noise_cutoff': dict(quantile_length=.25, n_bins=100, nc_threshold=5, percent_threshold=0.10),
    'missed_spikes_est': dict(spks_per_bin=10, sigma=4, min_num_bins=50),
    'acceptable_contamination': 0.1,
    'bin_size': 0.25,
    'med_amp_thresh_uv': 50,  # units below this threshold are considered noise
    'min_isi': 0.0001,
    'presence_window': 10,
    'refractory_period': 0.0015,
    'RPslide_thresh': 0.1,
    'RPmax_confidence': 90,  # a unit needs to pass with at least this confidence percentage (0 - 100)
}


[docs] def unit_stability(units_b, units=None, feat_names=['amps'], dist='norm', test='ks'): """ Computes the probability that the empirical spike feature distribution(s), for specified feature(s), for all units, comes from a specific theoretical distribution, based on a specified statistical test. Also computes the coefficients of variation of the spike feature(s) for all units. Parameters ---------- units_b : bunch A units bunch containing fields with spike information (e.g. cluster IDs, times, features, etc.) for all units. units : array-like (optional) A subset of all units for which to create the bar plot. (If `None`, all units are used) feat_names : list of strings (optional) A list of names of spike features that can be found in `spks` to specify which features to use for calculating unit stability. dist : string (optional) The type of hypothetical null distribution for which the empirical spike feature distributions are presumed to belong to. test : string (optional) The statistical test used to compute the probability that the empirical spike feature distributions come from `dist`. Returns ------- p_vals_b : bunch A bunch with `feat_names` as keys, containing a ndarray with p-values (the probabilities that the empirical spike feature distribution for each unit comes from `dist` based on `test`) for each unit for all `feat_names`. cv_b : bunch A bunch with `feat_names` as keys, containing a ndarray with the coefficients of variation of each unit's empirical spike feature distribution for all features. See Also -------- plot.feat_vars Examples -------- 1) Compute 1) the p-values obtained from running a one-sample ks test on the spike amplitudes for each unit, and 2) the variances of the empirical spike amplitudes distribution for each unit. Create a histogram of the variances of the spike amplitudes for each unit, color-coded by depth of channel of max amplitudes. Get cluster IDs of those units which have variances greater than 50. >>> p_vals_b, variances_b = bb.metrics.unit_stability(units_b) # Plot histograms of variances color-coded by depth of channel of max amplitudes >>> fig = bb.plot.feat_vars(units_b, feat_name='amps') # Get all unit IDs which have amps variance > 50 >>> var_vals = np.array(tuple(variances_b['amps'].values())) >>> bad_units = np.where(var_vals > 50) """ # Get units. if not (units is None): # we're using a subset of all units unit_list = list(units_b[feat_names[0]].keys()) # for each `feat` and unit in `unit_list`, remove unit from `units_b` if not in `units` for feat in feat_names: [units_b[feat].pop(unit) for unit in unit_list if not (int(unit) in units)] unit_list = list(units_b[feat_names[0]].keys()) # get new `unit_list` after removing units # Initialize `p_vals` and `variances`. p_vals_b = Bunch() cv_b = Bunch() # Set the test as a lambda function (in future, more tests can be added to this dict) tests = \ { 'ks': lambda x, y: stats.kstest(x, y) } test_fun = tests[test] # Compute the statistical tests and variances. For each feature, iteratively get each unit's # p-values and variances, and add them as keys to the respective bunches `p_vals_feat` and # `variances_feat`. After iterating through all units, add these bunches as keys to their # respective parent bunches, `p_vals` and `variances`. for feat in feat_names: p_vals_feat = Bunch((unit, 0) for unit in unit_list) cv_feat = Bunch((unit, 0) for unit in unit_list) for unit in unit_list: # If we're missing units/features, create a NaN placeholder and skip them: if len(units_b['times'][str(unit)]) == 0: p_val = np.nan cv = np.nan else: # compute p_val and var for current feature _, p_val = test_fun(units_b[feat][unit], dist) cv = np.var(units_b[feat][unit]) / np.mean(units_b[feat][unit]) # Append current unit's values to list of units' values for current feature: p_vals_feat[str(unit)] = p_val cv_feat[str(unit)] = cv p_vals_b[feat] = p_vals_feat cv_b[feat] = cv_feat return p_vals_b, cv_b
[docs] def missed_spikes_est(feat, spks_per_bin=20, sigma=5, min_num_bins=50): """ Computes the approximate fraction of spikes missing from a spike feature distribution for a given unit, assuming the distribution is symmetric. Inspired by metric described in Hill et al. (2011) J Neurosci 31: 8699-8705. Parameters ---------- feat : ndarray The spikes' feature values (e.g. amplitudes) spks_per_bin : int (optional) The number of spikes per bin from which to compute the spike feature histogram. sigma : int (optional) The standard deviation for the gaussian kernel used to compute the pdf from the spike feature histogram. min_num_bins : int (optional) The minimum number of bins used to compute the spike feature histogram. Returns ------- fraction_missing : float The fraction of missing spikes (0-0.5). *Note: If more than 50% of spikes are missing, an accurate estimate isn't possible. pdf : ndarray The computed pdf of the spike feature histogram. cutoff_idx : int The index for `pdf` at which point `pdf` is no longer symmetrical around the peak. (This is returned for plotting purposes). See Also -------- plot.feat_cutoff Examples -------- 1) Determine the fraction of spikes missing from unit 1 based on the recorded unit's spike amplitudes, assuming the distribution of the unit's spike amplitudes is symmetric. # Get unit 1 amplitudes from a unit bunch, and compute fraction spikes missing. >>> feat = units_b['amps']['1'] >>> fraction_missing = bb.plot.feat_cutoff(feat) """ # Ensure minimum number of spikes requirement is met, return Nan otherwise if feat.size <= (spks_per_bin * min_num_bins): return np.nan, None, None # compute the spike feature histogram and pdf: num_bins = int(feat.size / spks_per_bin) hist, bins = np.histogram(feat, num_bins, density=True) pdf = gaussian_filter1d(hist, sigma) # Find where the distribution stops being symmetric around the peak: peak_idx = np.argmax(pdf) max_idx_sym_around_peak = np.argmin(np.abs(pdf[peak_idx:] - pdf[0])) cutoff_idx = peak_idx + max_idx_sym_around_peak # compute fraction missing from the tail of the pdf (the area where pdf stops being # symmetric around peak). fraction_missing = np.sum(pdf[cutoff_idx:]) / np.sum(pdf) fraction_missing = 0.5 if (fraction_missing > 0.5) else fraction_missing return fraction_missing, pdf, cutoff_idx
[docs] def wf_similarity(wf1, wf2): """ Computes a unit normalized spatiotemporal similarity score between two sets of waveforms. This score is based on how waveform shape correlates for each pair of spikes between the two sets of waveforms across space and time. The shapes of the arrays of the two sets of waveforms must be equal. Parameters ---------- wf1 : ndarray An array of shape (#spikes, #samples, #channels). wf2 : ndarray An array of shape (#spikes, #samples, #channels). Returns ------- s: float The unit normalized spatiotemporal similarity score. See Also -------- io.extract_waveforms plot.single_unit_wf_comp Examples -------- 1) Compute the similarity between the first and last 100 waveforms for unit1, across the 20 channels around the channel of max amplitude. # Get the channels around the max amp channel for the unit, two sets of timestamps for the # unit, and the two corresponding sets of waveforms for those two sets of timestamps. # Then compute `s`. >>> max_ch = clstrs_b['channels'][1] >>> if max_ch < 10: # take only channels greater than `max_ch`. >>> ch = np.arange(max_ch, max_ch + 20) >>> elif (max_ch + 10) > 385: # take only channels less than `max_ch`. >>> ch = np.arange(max_ch - 20, max_ch) >>> else: # take `n_c_ch` around `max_ch`. >>> ch = np.arange(max_ch - 10, max_ch + 10) >>> ts1 = units_b['times']['1'][:100] >>> ts2 = units_b['times']['1'][-100:] >>> wf1 = bb.io.extract_waveforms(path_to_ephys_file, ts1, ch) >>> wf2 = bb.io.extract_waveforms(path_to_ephys_file, ts2, ch) >>> s = bb.metrics.wf_similarity(wf1, wf2) TODO check `s` calculation: take median of waveforms xcorr all waveforms with median, and divide by autocorr of all waveforms profile for two sets of units: xcorr(cl1, cl2) / (sqrt autocorr(cl1) * autocorr(cl2)) """ # Remove warning for dividing by 0 when calculating `s` (this is resolved by using # `np.nan_to_num`) import warnings warnings.filterwarnings('ignore', r'invalid value encountered in true_divide') assert wf1.shape == wf2.shape, ('The shapes of the sets of waveforms are inconsistent ({})' '({})'.format(wf1.shape, wf2.shape)) # Get number of spikes, samples, and channels of waveforms. n_spks = wf1.shape[0] n_samples = wf1.shape[1] n_ch = wf1.shape[2] # Create a matrix that will hold the similarity values of each spike in `wf1` to `wf2`. # Iterate over both sets of spikes, computing `s` for each pair. similarity_matrix = np.zeros((n_spks, n_spks)) for spk1 in range(n_spks): for spk2 in range(n_spks): s_spk = \ np.sum(np.nan_to_num( wf1[spk1, :, :] * wf2[spk2, :, :] / np.sqrt(wf1[spk1, :, :] ** 2 * wf2[spk2, :, :] ** 2))) / (n_samples * n_ch) similarity_matrix[spk1, spk2] = s_spk # Return mean of similarity matrix s = np.mean(similarity_matrix) return s
[docs] def firing_rate_coeff_var(ts, hist_win=0.01, fr_win=0.5, n_bins=10): ''' Computes the coefficient of variation of the firing rate: the ratio of the standard deviation to the mean. Parameters ---------- ts : ndarray The spike timestamps from which to compute the firing rate. hist_win : float (optional) The time window (in s) to use for computing spike counts. fr_win : float (optional) The time window (in s) to use as a moving slider to compute the instantaneous firing rate. n_bins : int (optional) The number of bins in which to compute a coefficient of variation of the firing rate. Returns ------- cv : float The mean coefficient of variation of the firing rate of the `n_bins` number of coefficients computed. cvs : ndarray The coefficients of variation of the firing for each bin of `n_bins`. fr : ndarray The instantaneous firing rate over time (in hz). See Also -------- singlecell.firing_rate plot.firing_rate Examples -------- 1) Compute the coefficient of variation of the firing rate for unit 1 from the time of its first to last spike, and compute the coefficient of variation of the firing rate for unit 2 from the first to second minute. >>> ts_1 = units_b['times']['1'] >>> ts_2 = units_b['times']['2'] >>> ts_2 = np.intersect1d(np.where(ts_2 > 60)[0], np.where(ts_2 < 120)[0]) >>> cv, cvs, fr = bb.metrics.firing_rate_coeff_var(ts_1) >>> cv_2, cvs_2, fr_2 = bb.metrics.firing_rate_coeff_var(ts_2) ''' # Compute overall instantaneous firing rate and firing rate for each bin. fr = singlecell.firing_rate(ts, hist_win=hist_win, fr_win=fr_win) bin_sz = int(fr.size / n_bins) fr_binned = np.array([fr[(b * bin_sz):(b * bin_sz + bin_sz)] for b in range(n_bins)]) # Compute coefficient of variations of firing rate for each bin, and the mean c.v. cvs = np.std(fr_binned, axis=1) / np.mean(fr_binned, axis=1) # NaNs from zero spikes are turned into 0's # cvs[np.isnan(cvs)] = 0 nan's can happen if neuron doesn't spike in a bin cv = np.mean(cvs) return cv, cvs, fr
[docs] def firing_rate_fano_factor(ts, hist_win=0.01, fr_win=0.5, n_bins=10): ''' Computes the fano factor of the firing rate: the ratio of the variance to the mean. (Almost identical to coeff. of variation) Parameters ---------- ts : ndarray The spike timestamps from which to compute the firing rate. hist_win : float The time window (in s) to use for computing spike counts. fr_win : float The time window (in s) to use as a moving slider to compute the instantaneous firing rate. n_bins : int (optional) The number of bins in which to compute a fano factor of the firing rate. Returns ------- ff : float The mean fano factor of the firing rate of the `n_bins` number of factors computed. ffs : ndarray The fano factors of the firing for each bin of `n_bins`. fr : ndarray The instantaneous firing rate over time (in hz). See Also -------- singlecell.firing_rate plot.firing_rate Examples -------- 1) Compute the fano factor of the firing rate for unit 1 from the time of its first to last spike, and compute the fano factor of the firing rate for unit 2 from the first to second minute. >>> ts_1 = units_b['times']['1'] >>> ts_2 = units_b['times']['2'] >>> ts_2 = np.intersect1d(np.where(ts_2 > 60)[0], np.where(ts_2 < 120)[0]) >>> ff, ffs, fr = bb.metrics.firing_rate_fano_factor(ts_1) >>> ff_2, ffs_2, fr_2 = bb.metrics.firing_rate_fano_factor(ts_2) ''' # Compute overall instantaneous firing rate and firing rate for each bin. fr = singlecell.firing_rate(ts, hist_win=hist_win, fr_win=fr_win) # this procedure can cut off data at the end, up to n_bins last timesteps bin_sz = int(fr.size / n_bins) fr_binned = np.array([fr[(b * bin_sz):(b * bin_sz + bin_sz)] for b in range(n_bins)]) # Compute fano factor of firing rate for each bin, and the mean fano factor ffs = np.var(fr_binned, axis=1) / np.mean(fr_binned, axis=1) # ffs[np.isnan(ffs)] = 0 nan's can happen if neuron doesn't spike in a bin ff = np.mean(ffs) return ff, ffs, fr
[docs] def average_drift(feat, times): """ Computes the cumulative drift (normalized by the total number of spikes) of a spike feature array. Parameters ---------- feat : ndarray The spike feature values from which to compute the maximum drift. Usually amplitudes Returns ------- cd : float The cumulative drift of the unit. See Also -------- max_drift Examples -------- 1) Get the cumulative depth drift for unit 1. >>> unit_idxs = np.where(spks_b['clusters'] == 1)[0] >>> depths = spks_b['depths'][unit_idxs] >>> amps = spks_b['amps'][unit_idxs] >>> depth_cd = bb.metrics.cum_drift(depths) >>> amp_cd = bb.metrics.cum_drift(amps) """ cd = np.sum(np.abs(np.diff(feat) / np.diff(times))) / len(feat) return cd
[docs] def pres_ratio(ts, hist_win=10): """ Computes the presence ratio of spike counts: the number of bins where there is at least one spike, over the total number of bins, given a specified bin width. Parameters ---------- ts : ndarray The spike timestamps from which to compute the presence ratio. hist_win : float (optional) The time window (in s) to use for computing the presence ratio. Returns ------- pr : float The presence ratio. spks_bins : ndarray The number of spks in each bin. See Also -------- plot.pres_ratio Examples -------- 1) Compute the presence ratio for unit 1, given a window of 10 s. >>> ts = units_b['times']['1'] >>> pr, pr_bins = bb.metrics.pres_ratio(ts) """ bins = np.arange(0, ts[-1] + hist_win, hist_win) spks_bins, _ = np.histogram(ts, bins) pr = len(np.where(spks_bins)[0]) / len(spks_bins) return pr, spks_bins
[docs] def ptp_over_noise(ephys_file, ts, ch, t=2.0, sr=30000, n_ch_probe=385, car=True): """ For specified channels, for specified timestamps, computes the mean (peak-to-peak amplitudes / the MADs of the background noise). Parameters ---------- ephys_file : string The file path to the binary ephys data. ts : ndarray_like The timestamps (in s) of the spikes. ch : ndarray_like The channels on which to extract the waveforms. t : numeric (optional) The time (in ms) of the waveforms to extract to compute the ptp. sr : int (optional) The sampling rate (in hz) that the ephys data was acquired at. n_ch_probe : int (optional) The number of channels of the recording. car: bool (optional) A flag to perform common-average-referencing before extracting waveforms. Returns ------- ptp_sigma : ndarray An array containing the mean ptp_over_noise values for the specified `ts` and `ch`. Examples -------- 1) Compute ptp_over_noise for all spikes on 20 channels around the channel of max amplitude for unit 1. >>> ts = units_b['times']['1'] >>> max_ch = max_ch = clstrs_b['channels'][1] >>> if max_ch < 10: # take only channels greater than `max_ch`. >>> ch = np.arange(max_ch, max_ch + 20) >>> elif (max_ch + 10) > 385: # take only channels less than `max_ch`. >>> ch = np.arange(max_ch - 20, max_ch) >>> else: # take `n_c_ch` around `max_ch`. >>> ch = np.arange(max_ch - 10, max_ch + 10) >>> p = bb.metrics.ptp_over_noise(ephys_file, ts, ch) """ # Ensure `ch` is ndarray ch = np.asarray(ch) ch = ch.reshape((ch.size, 1)) if ch.size == 1 else ch # Get waveforms. wf = extract_waveforms(ephys_file, ts, ch, t=t, sr=sr, n_ch_probe=n_ch_probe, car=car) # Initialize `mean_ptp` based on `ch`, and compute mean ptp of all spikes for each ch. mean_ptp = np.zeros((ch.size,)) for cur_ch in range(ch.size, ): mean_ptp[cur_ch] = np.mean(np.max(wf[:, :, cur_ch], axis=1) - np.min(wf[:, :, cur_ch], axis=1)) # Compute MAD for `ch` in chunks. with spikeglx.Reader(ephys_file) as s_reader: file_m = s_reader.data # the memmapped array n_chunk_samples = 5e6 # number of samples per chunk n_chunks = np.ceil(file_m.shape[0] / n_chunk_samples).astype('int') # Get samples that make up each chunk. e.g. `chunk_sample[1] - chunk_sample[0]` are the # samples that make up the first chunk. chunk_sample = np.arange(0, file_m.shape[0], n_chunk_samples, dtype=int) chunk_sample = np.append(chunk_sample, file_m.shape[0]) # Give time estimate for computing MAD. t0 = time.perf_counter() stats.median_absolute_deviation(file_m[chunk_sample[0]:chunk_sample[1], ch], axis=0) dt = time.perf_counter() - t0 print('Performing MAD computation. Estimated time is {:.2f} mins.' ' ({})'.format(dt * n_chunks / 60, time.ctime())) # Compute MAD for each chunk, then take the median MAD of all chunks. mad_chunks = np.zeros((n_chunks, ch.size), dtype=np.int16) for chunk in range(n_chunks): mad_chunks[chunk, :] = stats.median_absolute_deviation( file_m[chunk_sample[chunk]:chunk_sample[chunk + 1], ch], axis=0, scale=1) print('Done. ({})'.format(time.ctime())) # Return `mean_ptp` over `mad` mad = np.median(mad_chunks, axis=0) ptp_sigma = mean_ptp / mad return ptp_sigma
[docs] def contamination_alt(ts, rp=0.002): """ An estimate of the contamination of the unit (i.e. a pseudo false positive measure) based on the number of spikes, number of isi violations, and time between the first and last spike. (see Hill et al. (2011) J Neurosci 31: 8699-8705). Parameters ---------- ts : ndarray_like The timestamps (in s) of the spikes. rp : float (optional) The refractory period (in s). Returns ------- ce : float An estimate of the fraction of contamination. See Also -------- contamination_alt Examples -------- 1) Compute contamination estimate for unit 1. >>> ts = units_b['times']['1'] >>> ce = bb.metrics.contamination(ts) """ # Get number of spikes, number of isi violations, and time from first to final spike. n_spks = ts.size n_isi_viol = np.size(np.where(np.diff(ts) < rp)[0]) t = ts[-1] - ts[0] # `ce` is min of roots of solved quadratic equation. c = (t * n_isi_viol) / (2 * rp * n_spks ** 2) # 3rd term in quadratic ce = np.min(np.abs(np.roots([-1, 1, c]))) # solve quadratic return ce
[docs] def contamination(ts, min_time, max_time, rp=0.002, min_isi=0.0001): """ An estimate of the contamination of the unit (i.e. a pseudo false positive measure) based on the number of spikes, number of isi violations, and time between the first and last spike. (see Hill et al. (2011) J Neurosci 31: 8699-8705). Modified by Dan Denman from cortex-lab/sortingQuality GitHub by Nick Steinmetz. Parameters ---------- ts : ndarray_like The timestamps (in s) of the spikes. min_time : float The minimum time (in s) that a potential spike occurred. max_time : float The maximum time (in s) that a potential spike occurred. rp : float (optional) The refractory period (in s). min_isi : float (optional) The minimum interspike-interval (in s) for counting duplicate spikes. Returns ------- ce : float An estimate of the contamination. A perfect unit has a ce = 0 A unit with some contamination has a ce < 0.5 A unit with lots of contamination has a ce > 1.0 num_violations : int The total number of isi violations. See Also -------- contamination Examples -------- 1) Compute contamination estimate for unit 1, with a minimum isi for counting duplicate spikes of 0.1 ms. >>> ts = units_b['times']['1'] >>> ce = bb.metrics.contamination_alt(ts, min_isi=0.0001) """ duplicate_spikes = np.where(np.diff(ts) <= min_isi)[0] ts = np.delete(ts, duplicate_spikes + 1) isis = np.diff(ts) num_spikes = ts.size num_violations = np.sum(isis < rp) violation_time = 2 * num_spikes * (rp - min_isi) total_rate = ts.size / (max_time - min_time) violation_rate = num_violations / violation_time ce = violation_rate / total_rate return ce, num_violations
def _max_acceptable_cont(FR, RP, rec_duration, acceptableCont, thresh): """ Function to compute the maximum acceptable refractory period contamination called during slidingRP_viol """ time_for_viol = RP * 2 * FR * rec_duration expected_count_for_acceptable_limit = acceptableCont * time_for_viol max_acceptable = stats.poisson.ppf(thresh, expected_count_for_acceptable_limit) if max_acceptable == 0 and stats.poisson.pmf(0, expected_count_for_acceptable_limit) > 0: max_acceptable = -1 return max_acceptable
[docs] def slidingRP_viol(ts, bin_size=0.25, thresh=0.1, acceptThresh=0.1): """ A binary metric which determines whether there is an acceptable level of refractory period violations by using a sliding refractory period: This takes into account the firing rate of the neuron and computes a maximum acceptable level of contamination at different possible values of the refractory period. If the unit has less than the maximum contamination at any of the possible values of the refractory period, the unit passes. A neuron will always fail this metric for very low firing rates, and thus this metric takes into account both firing rate and refractory period violations. Parameters ---------- ts : ndarray_like The timestamps (in s) of the spikes. bin_size : float The size of binning for the autocorrelogram. thresh : float Spike rate used to generate poisson distribution (to compute maximum acceptable contamination, see _max_acceptable_cont) acceptThresh : float The fraction of contamination we are willing to accept (default value set to 0.1, or 10% contamination) Returns ------- didpass : int 0 if unit didn't pass 1 if unit did pass See Also -------- contamination Examples -------- 1) Compute whether a unit has too much refractory period contamination at any possible value of a refractory period, for a 0.25 ms bin, with a threshold of 10% acceptable contamination >>> ts = units_b['times']['1'] >>> didpass = bb.metrics.slidingRP_viol(ts, bin_size=0.25, thresh=0.1, acceptThresh=0.1) """ b = np.arange(0, 10.25, bin_size) / 1000 + 1e-6 # bins in seconds bTestIdx = [5, 6, 7, 8, 10, 12, 14, 16, 18, 20, 24, 28, 32, 36, 40] bTest = [b[i] for i in bTestIdx] if len(ts) > 0 and ts[-1] > ts[0]: # only do this for units with samples recDur = (ts[-1] - ts[0]) # compute acg c0 = correlograms(ts, np.zeros(len(ts), dtype='int8'), cluster_ids=[0], bin_size=bin_size / 1000, sample_rate=20000, window_size=2, symmetrize=False) # cumulative sum of acg, i.e. number of total spikes occuring from 0 # to end of that bin cumsumc0 = np.cumsum(c0[0, 0, :]) # cumulative sum at each of the testing bins res = cumsumc0[bTestIdx] total_spike_count = len(ts) # divide each bin's count by the total spike count and the bin size bin_count_normalized = c0[0, 0] / total_spike_count / bin_size * 1000 num_bins_2s = len(c0[0, 0]) # number of total bins that equal 2 secs num_bins_1s = int(num_bins_2s / 2) # number of bins that equal 1 sec # compute fr based on the mean of bin_count_normalized from 1 to 2 s # instead of as before (len(ts)/recDur) for a better estimate fr = np.sum(bin_count_normalized[num_bins_1s:num_bins_2s]) / num_bins_1s mfunc = np.vectorize(_max_acceptable_cont) # compute the maximum allowed number of spikes per testing bin m = mfunc(fr, bTest, recDur, fr * acceptThresh, thresh) # did the unit pass (resulting number of spikes less than maximum # allowed spikes) at any of the testing bins? didpass = int(np.any(np.less_equal(res, m))) else: didpass = 0 return didpass
[docs] def noise_cutoff(amps, quantile_length=.25, n_bins=100, nc_threshold=5, percent_threshold=0.10): """ A new metric to determine whether a unit's amplitude distribution is cut off (at floor), without assuming a Gaussian distribution. This metric takes the amplitude distribution, computes the mean and std of an upper quartile of the distribution, and determines how many standard deviations away from that mean a lower quartile lies. Parameters ---------- amps : ndarray_like The amplitudes (in uV) of the spikes. quantile_length : float The size of the upper quartile of the amplitude distribution. n_bins : int The number of bins used to compute a histogram of the amplitude distribution. n_low_bins : int The number of bins used in the lower part of the distribution (where cutoff is determined). nc_threshold: float the noise cutoff result has to be lower than this for a neuron to fail percent_threshold: float the first bin has to be greater than percent_threshold for neuron the to fail Returns ------- cutoff : float Number of standard deviations that the lower mean is outside of the mean of the upper quartile. See Also -------- missed_spikes_est Examples -------- 1) Compute whether a unit's amplitude distribution is cut off >>> amps = spks_b['amps'][unit_idxs] >>> cutoff = bb.metrics.noise_cutoff(amps, quantile_length=.25, n_bins=100) """ cutoff = np.float64(np.nan) first_low_quantile = np.float64(np.nan) fail_criteria = np.ones(1).astype(bool)[0] if amps.size > 1: # ensure there are amplitudes available to analyze bins_list = np.linspace(0, np.max(amps), n_bins) # list of bins to compute the amplitude histogram n, bins = np.histogram(amps, bins=bins_list) # construct amplitude histogram idx_peak = np.argmax(n) # peak of amplitude distribution # don't count zeros #len(n) - idx_peak, compute the length of the top half of the distribution -- ignoring zero bins length_top_half = len(np.where(n[idx_peak:-1] > 0)[0]) # the remaining part of the distribution, which we will compare the low quantile to high_quantile = 2 * quantile_length # the first bin (index) of the high quantile part of the distribution high_quantile_start_ind = int(np.ceil(high_quantile * length_top_half + idx_peak)) # bins to consider in the high quantile (of all non-zero bins) indices_bins_high_quantile = np.arange(high_quantile_start_ind, len(n)) idx_use = np.where(n[indices_bins_high_quantile] >= 1)[0] if len(n[indices_bins_high_quantile]) > 0: # ensure there are amplitudes in these bins # mean of all amp values in high quantile bins mean_high_quantile = np.mean(n[indices_bins_high_quantile][idx_use]) std_high_quantile = np.std(n[indices_bins_high_quantile][idx_use]) if std_high_quantile > 0: first_low_quantile = n[(n != 0)][1] # take the second bin cutoff = (first_low_quantile - mean_high_quantile) / std_high_quantile peak_bin_height = np.max(n) percent_of_peak = percent_threshold * peak_bin_height fail_criteria = (cutoff > nc_threshold) & (first_low_quantile > percent_of_peak) nc_pass = ~fail_criteria return nc_pass, cutoff, first_low_quantile
[docs] def spike_sorting_metrics(times, clusters, amps, depths, cluster_ids=None, params=METRICS_PARAMS): """ Computes: - cell level metrics (cf quick_unit_metrics) - label the metrics according to quality thresholds - estimates drift as a function of time :param times: vector of spike times :param clusters: :param amplitudes: :param depths: :param cluster_ids (optional): set of clusters (if None the output datgrame will match the unique set of clusters represented in spike clusters) :param params: dict (optional) parameters for qc computation ( see constant at the top of the module for default values and keys) :return: data_frame of metrics (cluster records, columns are qc attributes)| :return: dictionary of recording qc (keys 'time_scale' and 'drift_um') """ # compute metrics and convert to `DataFrame` df_units = quick_unit_metrics( clusters, times, amps, depths, cluster_ids=cluster_ids, params=params) df_units = pd.DataFrame(df_units) # compute drift as a function of time and put in a dictionary drift, ts = electrode_drift.estimate_drift(times, amps, depths) rec_qc = {'time_scale': ts, 'drift_um': drift} return df_units, rec_qc
[docs] def quick_unit_metrics(spike_clusters, spike_times, spike_amps, spike_depths, params=METRICS_PARAMS, cluster_ids=None, tbounds=None): """ Computes single unit metrics from only the spike times, amplitudes, and depths for a set of units. Metrics computed: 'amp_max', 'amp_min', 'amp_median', 'amp_std_dB', 'contamination', 'contamination_alt', 'drift', 'missed_spikes_est', 'noise_cutoff', 'presence_ratio', 'presence_ratio_std', 'slidingRP_viol', 'spike_count' Parameters (see the METRICS_PARAMS constant) ---------- spike_clusters : ndarray_like A vector of the unit ids for a set of spikes. spike_times : ndarray_like A vector of the timestamps for a set of spikes. spike_amps : ndarray_like A vector of the amplitudes for a set of spikes. spike_depths : ndarray_like A vector of the depths for a set of spikes. clusters_id: (optional) lists of cluster ids. If not all clusters are represented in the spikes_clusters (ie. cluster has no spike), this will ensure the output size is consistent with the input arrays. tbounds: (optional) list or 2 elements array containing a time-selection to perform the metrics computation on. params : dict (optional) Parameters used for computing some of the metrics in the function: 'presence_window': float The time window (in s) used to look for spikes when computing the presence ratio. 'refractory_period': float The refractory period used when computing isi violations and the contamination estimate. 'min_isi': float The minimum interspike-interval (in s) for counting duplicate spikes when computing the contamination estimate. 'spks_per_bin_for_missed_spks_est': int The number of spikes per bin used to compute the spike amplitude pdf for a unit, when computing the missed spikes estimate. 'std_smoothing_kernel_for_missed_spks_est': float The standard deviation for the gaussian kernel used to compute the spike amplitude pdf for a unit, when computing the missed spikes estimate. 'min_num_bins_for_missed_spks_est': int The minimum number of bins used to compute the spike amplitude pdf for a unit, when computing the missed spikes estimate. Returns ------- r : bunch A bunch whose keys are the computed spike metrics. Notes ----- This function is called by `ephysqc.unit_metrics_ks2` which is called by `spikes.ks2_to_alf` during alf extraction of an ephys dataset in the ibl ephys extraction pipeline. Examples -------- 1) Compute quick metrics from a ks2 output directory: >>> from ibllib.ephys.ephysqc import phy_model_from_ks2_path >>> m = phy_model_from_ks2_path(path_to_ks2_out) >>> cluster_ids = m.spike_clusters >>> ts = m.spike_times >>> amps = m.amplitudes >>> depths = m.depths >>> r = bb.metrics.quick_unit_metrics(cluster_ids, ts, amps, depths) """ metrics_list = [ 'cluster_id', 'amp_max', 'amp_min', 'amp_median', 'amp_std_dB', 'contamination', 'contamination_alt', 'drift', 'missed_spikes_est', 'noise_cutoff', 'presence_ratio', 'presence_ratio_std', 'slidingRP_viol', 'spike_count', 'slidingRP_viol_forced', 'max_confidence', 'min_contamination', 'n_spikes_below2' ] if tbounds: ispi = between_sorted(spike_times, tbounds) spike_times = spike_times[ispi] spike_clusters = spike_clusters[ispi] spike_amps = spike_amps[ispi] spike_depths = spike_depths[ispi] if cluster_ids is None: cluster_ids = np.unique(spike_clusters) nclust = cluster_ids.size r = Bunch({k: np.full((nclust,), np.nan) for k in metrics_list}) r['cluster_id'] = cluster_ids # vectorized computation of basic metrics such as presence ratio and firing rate tmin = spike_times[0] tmax = spike_times[-1] presence_ratio = bincount2D(spike_times, spike_clusters, xbin=params['presence_window'], ybin=cluster_ids, xlim=[tmin, tmax])[0] r.presence_ratio = np.sum(presence_ratio > 0, axis=1) / presence_ratio.shape[1] r.presence_ratio_std = np.std(presence_ratio, axis=1) r.spike_count = np.sum(presence_ratio, axis=1) r.firing_rate = r.spike_count / (tmax - tmin) # computing amplitude statistical indicators by aggregating over cluster id camp = pd.DataFrame(np.c_[spike_amps, 20 * np.log10(spike_amps), spike_clusters], columns=['amps', 'log_amps', 'clusters']) camp = camp.groupby('clusters') ir, ib = ismember(r.cluster_id, camp.clusters.unique()) r.amp_min[ir] = np.array(camp['amps'].min()) r.amp_max[ir] = np.array(camp['amps'].max()) # this is the geometric median r.amp_median[ir] = np.array(10 ** (camp['log_amps'].median() / 20)) r.amp_std_dB[ir] = np.array(camp['log_amps'].std()) srp = metrics.slidingRP_all(spikeTimes=spike_times, spikeClusters=spike_clusters, sampleRate=30000, binSizeCorr=1 / 30000) r.slidingRP_viol[ir] = srp['value'] r.slidingRP_viol_forced[ir] = srp['value_forced'] r.max_confidence[ir] = srp['max_confidence'] r.min_contamination[ir] = srp['min_contamination'] r.n_spikes_below2[ir] = srp['n_spikes_below2'] # loop over each cluster to compute the rest of the metrics for ic in np.arange(nclust): # slice the spike_times array ispikes = spike_clusters == cluster_ids[ic] if np.all(~ispikes): # if this cluster has no spikes, continue continue ts = spike_times[ispikes] amps = spike_amps[ispikes] depths = spike_depths[ispikes] # compute metrics r.contamination_alt[ic] = contamination_alt(ts, rp=params['refractory_period']) r.contamination[ic], _ = contamination( ts, tmin, tmax, rp=params['refractory_period'], min_isi=params['min_isi']) _, r.noise_cutoff[ic], _ = noise_cutoff(amps, **params['noise_cutoff']) r.missed_spikes_est[ic], _, _ = missed_spikes_est(amps, **params['missed_spikes_est']) # wonder if there is a need to low-cut this r.drift[ic] = np.sum(np.abs(np.diff(depths))) / (tmax - tmin) * 3600 r.label, r.bitwise_fail = compute_labels(r, return_bitwise=True) return r
[docs] def compute_labels(r, params=METRICS_PARAMS, return_bitwise=False): """ From a dataframe or a dictionary of unit metrics, compute a label :param r: dictionary or pandas dataframe containing unit qcs :param return_bitwise: True (returns a full dictionary of metrics) :return: vector of proportion of qcs passed between 0 and 1, where 1 denotes an all pass """ # right now the score is a value between 0 and 1 denoting the proportion of passing qcs, # where 1 means passing and 0 means failing labels = np.c_[ r['max_confidence'] >= params['RPmax_confidence'], # this is the least significant bit r.noise_cutoff < params['noise_cutoff']['nc_threshold'], r.amp_median > params['med_amp_thresh_uv'] / 1e6, # add a new metric here on higher significant bits ] # The first column takes binary values 001 or 000 to represent fail or pass, # the second, 010 or 000, the third, 100 or 000 etc. # The bitwise or "sum" produces 111 if all metrics fail, or 000 if all metrics pass # All other permutations are also captured, i.e. 110 == 000 || 010 || 100 means # the second and third metrics failed and the first metric was a pass score = np.mean(labels, axis=1) if return_bitwise: # note the cast to uint8 casts nan to 0 # a nan implies no metrics was computed which we mark as a failure here n_criteria = labels.shape[1] bitwise = np.bitwise_or.reduce(2 ** np.arange(n_criteria) * (~ labels.astype(bool)).astype(np.uint8), axis=1) return score, bitwise.astype(np.uint8) else: return score