import numpy as np
from ibldsp import smooth, utils, fourier
from iblutil.numerical import bincount2D
[docs]
def estimate_drift(spike_times, spike_amps, spike_depths, display=False):
"""
Electrode drift for spike sorted data.
:param spike_times:
:param spike_amps:
:param spike_depths:
:param display:
:return: drift (ntimes vector) in input units (usually um)
:return: ts (ntimes vector) time scale in seconds
"""
# binning parameters
DT_SECS = 1 # output sampling rate of the depth estimation (seconds)
DEPTH_BIN_UM = 2 # binning parameter for depth
AMP_BIN_LOG10 = [1.25, 3.25] # binning parameter for amplitudes (log10 in uV)
N_AMP = 1 # number of amplitude bins
NXCORR = 50 # positive and negative lag in depth samples to look for depth
NT_SMOOTH = 9 # length of the Gaussian smoothing window in samples (DT_SECS rate)
# experimental: try the amp with a log scale
nd = int(np.ceil(np.nanmax(spike_depths) / DEPTH_BIN_UM))
tmin, tmax = (np.min(spike_times), np.max(spike_times))
nt = int((np.ceil(tmax) - np.floor(tmin)) / DT_SECS)
# 3d histogram of spikes along amplitude, depths and time
atd_hist = np.zeros((N_AMP, nt, nd), dtype=np.single)
abins = (np.log10(spike_amps * 1e6) - AMP_BIN_LOG10[0]) / np.diff(AMP_BIN_LOG10) * N_AMP
abins = np.minimum(np.maximum(0, np.floor(abins)), N_AMP - 1)
for i, abin in enumerate(np.unique(abins)):
inds = np.where(np.logical_and(abins == abin, ~np.isnan(spike_depths)))[0]
a, _, _ = bincount2D(spike_depths[inds], spike_times[inds], DEPTH_BIN_UM, DT_SECS,
[0, nd * DEPTH_BIN_UM], [np.floor(tmin), np.ceil(tmax)])
atd_hist[i] = a[:-1, :-1]
fdscale = np.abs(np.fft.fftfreq(nd, d=DEPTH_BIN_UM))
# k-filter along the depth direction
lp = fourier._freq_vector(fdscale, np.array([1 / 16, 1 / 8]), typ='lp')
# compute the depth lag by xcorr
# to experiment: LP the fft for a better tracking ?
atd_ = np.fft.fft(atd_hist, axis=-1)
# xcorrelation against reference
xcorr = np.real(np.fft.ifft(lp * atd_ * np.conj(np.median(atd_, axis=1))[:, np.newaxis, :]))
xcorr = np.sum(xcorr, axis=0)
xcorr = np.c_[xcorr[:, -NXCORR:], xcorr[:, :NXCORR + 1]]
xcorr = xcorr - np.mean(xcorr, 1)[:, np.newaxis]
# import easyqc
# easyqc.viewdata(xcorr - np.mean(xcorr, 1)[:, np.newaxis], DEPTH_BIN_UM, title='xcor')
# to experiment: parabolic fit to get max values
raw_drift = (utils.parabolic_max(xcorr)[0] - NXCORR) * DEPTH_BIN_UM
drift = smooth.rolling_window(raw_drift, window_len=NT_SMOOTH, window='hanning')
drift = drift - np.mean(drift)
ts = DT_SECS * np.arange(drift.size)
if display: # pragma: no cover
import matplotlib.pyplot as plt
from brainbox.plot import driftmap
fig1, axs = plt.subplots(2, 1, gridspec_kw={'height_ratios': [.15, .85]},
sharex=True, figsize=(20, 10))
axs[0].plot(ts, drift)
driftmap(spike_times, spike_depths, t_bin=0.1, d_bin=5, ax=axs[1])
axs[1].set_ylim([- NXCORR * 2, 3840 + NXCORR * 2])
fig2, axs = plt.subplots(2, 1, gridspec_kw={'height_ratios': [.15, .85]},
sharex=True, figsize=(20, 10))
axs[0].plot(ts, drift)
dd = np.interp(spike_times, ts, drift)
driftmap(spike_times, spike_depths - dd, t_bin=0.1, d_bin=5, ax=axs[1])
axs[1].set_ylim([- NXCORR * 2, 3840 + NXCORR * 2])
return drift, ts, [fig1, fig2]
return drift, ts