import numpy as np
from brainbox.metrics import electrode_drift
from brainbox.metrics.single_units import quick_unit_metrics, noise_cutoff
from iblutil.numerical import ismember
REC_LEN_SECS = 1000
fr = 200
[docs]
def multiple_spike_trains(firing_rates=None, rec_len_secs=1000, cluster_ids=None,
amplitude_noise=20 * 1e-6):
"""
:param firing_rates: list or np.array of firing rates (spikes per second)
:param rec_len_secs: recording length in seconds
:return: spike_times, spike_amps, spike_clusters
"""
if firing_rates is None:
firing_rates = np.random.randint(150, 600, 10)
if cluster_ids is None:
cluster_ids = np.arange(firing_rates.size)
ca = np.exp(np.random.normal(5.5, 0.5, firing_rates.size)) / 1e6 # output is in V
st = np.empty(0)
sc = np.empty(0)
for i, firing_rate in enumerate(firing_rates):
t = generate_spike_train(firing_rate=firing_rate, rec_len_secs=rec_len_secs)
st = np.r_[st, t]
sc = np.r_[sc, np.zeros(t.size, dtype=np.int32) + cluster_ids[i]]
ordre = st.argsort()
st = st[ordre]
sc = np.int32(sc[ordre])
_, isc = ismember(sc, cluster_ids) # clusters ids may be arbitrary: re-index
sa = np.maximum(ca[isc] + np.random.randn(st.size) * amplitude_noise, 25 * 1e-6)
return st, sa, sc
[docs]
def generate_spike_train(firing_rate=200, rec_len_secs=1000):
"""
Basic spike train generator following a poisson process for spike-times and
:param firing_rate:
:param rec_len_secs:
:return: spike_times (secs) , spike_amplitudes (V)
"""
# spike times: exponential decay prob
st = np.cumsum(- np.log(np.random.rand(int(rec_len_secs * firing_rate * 1.5))) / firing_rate)
st = st[:np.searchsorted(st, rec_len_secs)]
return st
[docs]
def test_clusters_metrics():
np.random.seed(54)
rec_length = 1000
frs = np.array([3, 100, 80, 40]) # firing rates
cid = [0, 1, 3, 4] # here we make sure one of the clusters has no spike
t, a, c = multiple_spike_trains(firing_rates=frs, rec_len_secs=rec_length, cluster_ids=cid)
d = np.sin(2 * np.pi * c / rec_length * t) * 100 # sinusoidal shift where cluster id drives f
def _assertions(dfm, idf, target_cid):
# dfm: qc dataframe, idf: indices of existing clusters in dfm, cid: cluster ids
assert np.allclose(dfm['amp_median'][idf] / np.exp(5.5) * 1e6, 1, rtol=1.1)
assert np.allclose(dfm['amp_std_dB'][idf] / 20 * np.log10(np.exp(0.5)), 1, rtol=1.1)
assert np.allclose(dfm['drift'][idf], np.array(cid) * 100 * 4 * 3.6, rtol=1.1)
assert np.allclose(dfm['firing_rate'][idf], frs, rtol=1.1)
assert np.allclose(dfm['cluster_id'], target_cid)
# test expected bitwise qc values:
expected_labels = 1 - np.sum(np.unpackbits(dfm['bitwise_fail']).reshape(-1, 8), axis=1) / 3
assert np.allclose(dfm['label'], expected_labels)
# check with missing clusters
dfm = quick_unit_metrics(c, t, a, d, cluster_ids=np.arange(5), tbounds=[100, 900])
idf, _ = ismember(np.arange(5), cid)
_assertions(dfm, idf, np.arange(5))
[docs]
def test_drift_estimate():
"""
From spike depths, xcorrelate drift maps to find a drift estimate
"""
np.random.seed(42)
ncells = 200
cells_depth = np.random.random(ncells) * 3800 + 50
frs = np.random.randn(ncells) * 50 + 200
t, a, c = multiple_spike_trains(firing_rates=frs, rec_len_secs=200)
# test negative times, no drift
drift, ts = electrode_drift.estimate_drift(t - 2, a, cells_depth[c])
assert np.all(np.abs(drift) < 0.01)
# test drift recovery - sinusoid 40 um peak amplitude
dcor = np.sin(2 * np.pi * t / np.max(t) * 2) * 50
drift, ts = electrode_drift.estimate_drift(t, a, cells_depth[c] + dcor, display=False)
drift_ = np.sin(2 * np.pi * ts / np.max(t) * 2) * 50
# import matplotlib.pyplot as plt
# plt.plot(ts, drift_)
# plt.plot(ts, drift)
assert np.all(np.abs(drift - drift_)[2:] < 4)
[docs]
def test_noise_cut_off():
np.random.seed(45)
amps = np.random.randn(400,) * 1.2 + 4
assert noise_cutoff(amps)