Source code for brainbox.metrics.electrode_drift

import numpy as np

from ibldsp import smooth, utils, fourier
from iblutil.numerical import bincount2D


[docs] def estimate_drift(spike_times, spike_amps, spike_depths, display=False): """ Electrode drift for spike sorted data. :param spike_times: :param spike_amps: :param spike_depths: :param display: :return: drift (ntimes vector) in input units (usually um) :return: ts (ntimes vector) time scale in seconds """ # binning parameters DT_SECS = 1 # output sampling rate of the depth estimation (seconds) DEPTH_BIN_UM = 2 # binning parameter for depth AMP_BIN_LOG10 = [1.25, 3.25] # binning parameter for amplitudes (log10 in uV) N_AMP = 1 # number of amplitude bins NXCORR = 50 # positive and negative lag in depth samples to look for depth NT_SMOOTH = 9 # length of the Gaussian smoothing window in samples (DT_SECS rate) # experimental: try the amp with a log scale nd = int(np.ceil(np.nanmax(spike_depths) / DEPTH_BIN_UM)) tmin, tmax = (np.min(spike_times), np.max(spike_times)) nt = int((np.ceil(tmax) - np.floor(tmin)) / DT_SECS) # 3d histogram of spikes along amplitude, depths and time atd_hist = np.zeros((N_AMP, nt, nd), dtype=np.single) abins = (np.log10(spike_amps * 1e6) - AMP_BIN_LOG10[0]) / np.diff(AMP_BIN_LOG10) * N_AMP abins = np.minimum(np.maximum(0, np.floor(abins)), N_AMP - 1) for i, abin in enumerate(np.unique(abins)): inds = np.where(np.logical_and(abins == abin, ~np.isnan(spike_depths)))[0] a, _, _ = bincount2D(spike_depths[inds], spike_times[inds], DEPTH_BIN_UM, DT_SECS, [0, nd * DEPTH_BIN_UM], [np.floor(tmin), np.ceil(tmax)]) atd_hist[i] = a[:-1, :-1] fdscale = np.abs(np.fft.fftfreq(nd, d=DEPTH_BIN_UM)) # k-filter along the depth direction lp = fourier._freq_vector(fdscale, np.array([1 / 16, 1 / 8]), typ='lp') # compute the depth lag by xcorr # to experiment: LP the fft for a better tracking ? atd_ = np.fft.fft(atd_hist, axis=-1) # xcorrelation against reference xcorr = np.real(np.fft.ifft(lp * atd_ * np.conj(np.median(atd_, axis=1))[:, np.newaxis, :])) xcorr = np.sum(xcorr, axis=0) xcorr = np.c_[xcorr[:, -NXCORR:], xcorr[:, :NXCORR + 1]] xcorr = xcorr - np.mean(xcorr, 1)[:, np.newaxis] # import easyqc # easyqc.viewdata(xcorr - np.mean(xcorr, 1)[:, np.newaxis], DEPTH_BIN_UM, title='xcor') # to experiment: parabolic fit to get max values raw_drift = (utils.parabolic_max(xcorr)[0] - NXCORR) * DEPTH_BIN_UM drift = smooth.rolling_window(raw_drift, window_len=NT_SMOOTH, window='hanning') drift = drift - np.mean(drift) ts = DT_SECS * np.arange(drift.size) if display: # pragma: no cover import matplotlib.pyplot as plt from brainbox.plot import driftmap fig1, axs = plt.subplots(2, 1, gridspec_kw={'height_ratios': [.15, .85]}, sharex=True, figsize=(20, 10)) axs[0].plot(ts, drift) driftmap(spike_times, spike_depths, t_bin=0.1, d_bin=5, ax=axs[1]) axs[1].set_ylim([- NXCORR * 2, 3840 + NXCORR * 2]) fig2, axs = plt.subplots(2, 1, gridspec_kw={'height_ratios': [.15, .85]}, sharex=True, figsize=(20, 10)) axs[0].plot(ts, drift) dd = np.interp(spike_times, ts, drift) driftmap(spike_times, spike_depths - dd, t_bin=0.1, d_bin=5, ax=axs[1]) axs[1].set_ylim([- NXCORR * 2, 3840 + NXCORR * 2]) return drift, ts, [fig1, fig2] return drift, ts