'''
Computes properties of single-cells, e.g. the autocorrelation and firing rate.
'''
import numpy as np
from scipy.signal import convolve
from scipy.signal.windows import gaussian
from iblutil.util import Bunch
from brainbox.population.decode import xcorr
[docs]
def acorr(spike_times, bin_size=None, window_size=None):
"""Compute the auto-correlogram of a neuron.
Parameters
----------
:param spike_times: Spike times in seconds.
:type spike_times: array-like
:param bin_size: Size of the bin, in seconds.
:type bin_size: float
:param window_size: Size of the window, in seconds.
:type window_size: float
Returns an `(winsize_samples,)` array with the auto-correlogram.
"""
xc = xcorr(spike_times, np.zeros_like(spike_times, dtype=np.int32),
bin_size=bin_size, window_size=window_size)
return xc[0, 0, :]
[docs]
def bin_spikes(times, align_times, pre_time=0.4, post_time=1, bin_size=0.01, weights=None):
"""
Event aligned raster for single cluster
:param times:
:param align_times:
:param pre_time:
:param post_time:
:param bin_size:
:param weights:
:return:
"""
n_bins_pre = int(np.ceil(pre_time / bin_size))
n_bins_post = int(np.ceil(post_time / bin_size))
n_bins = n_bins_pre + n_bins_post
tscale = np.arange(-n_bins_pre, n_bins_post + 1) * bin_size
ts = np.repeat(align_times[:, np.newaxis], tscale.size, axis=1) + tscale
epoch_idxs = np.searchsorted(times, np.c_[ts[:, 0], ts[:, -1]])
bins = np.zeros(shape=(align_times.shape[0], n_bins))
for i, (ep, t) in enumerate(zip(epoch_idxs, ts)):
xind = (np.floor((times[ep[0]:ep[1]] - t[0]) / bin_size)).astype(np.int64)
w = weights[ep[0]:ep[1]] if weights is not None else None
r = np.bincount(xind, minlength=tscale.shape[0], weights=w)
bins[i, :] = r[:-1]
tscale = (tscale[:-1] + tscale[1:]) / 2
return bins, tscale
[docs]
def bin_spikes2D(spike_times, spike_clusters, cluster_ids, align_times, pre_time=0.4, post_time=1, bin_size=0.01, weights=None):
"""
Event aligned raster for mutliple clusters
:param spike_times:
:param spike_clusters:
:param cluster_ids:
:param align_times:
:param pre_time:
:param post_time:
:param bin_size:
:param weights:
:return:
"""
n_bins_pre = int(np.ceil(pre_time / bin_size))
n_bins_post = int(np.ceil(post_time / bin_size))
n_bins = n_bins_pre + n_bins_post
tscale = np.arange(-n_bins_pre, n_bins_post + 1) * bin_size
ts = np.repeat(align_times[:, np.newaxis], tscale.size, axis=1) + tscale
epoch_idxs = np.searchsorted(spike_times, np.c_[ts[:, 0], ts[:, -1]])
bins = np.zeros(shape=(align_times.shape[0], cluster_ids.shape[0], n_bins))
for i, (ep, t) in enumerate(zip(epoch_idxs, ts)):
xind = (np.floor((spike_times[ep[0]:ep[1]] - t[0]) / bin_size)).astype(np.int64)
w = weights[ep[0]:ep[1]] if weights is not None else None
yscale, yind = np.unique(spike_clusters[ep[0]:ep[1]], return_inverse=True)
nx, ny = [tscale.size, yscale.size]
ind2d = np.ravel_multi_index(np.c_[yind, xind].transpose(), dims=(ny, nx))
r = np.bincount(ind2d, minlength=nx * ny, weights=w).reshape(ny, nx)
bs_idxs = np.isin(cluster_ids, yscale)
bins[i, bs_idxs, :] = r[:, :-1]
tscale = (tscale[:-1] + tscale[1:]) / 2
return bins, tscale
[docs]
def calculate_peths(
spike_times, spike_clusters, cluster_ids, align_times, pre_time=0.2,
post_time=0.5, bin_size=0.025, smoothing=0.025, return_fr=True):
"""
Calcluate peri-event time histograms; return means and standard deviations
for each time point across specified clusters
:param spike_times: spike times (in seconds)
:type spike_times: array-like
:param spike_clusters: cluster ids corresponding to each event in `spikes`
:type spike_clusters: array-like
:param cluster_ids: subset of cluster ids for calculating peths
:type cluster_ids: array-like
:param align_times: times (in seconds) to align peths to
:type align_times: array-like
:param pre_time: time (in seconds) to precede align times in peth
:type pre_time: float
:param post_time: time (in seconds) to follow align times in peth
:type post_time: float
:param bin_size: width of time windows (in seconds) to bin spikes
:type bin_size: float
:param smoothing: standard deviation (in seconds) of Gaussian kernel for
smoothing peths; use `smoothing=0` to skip smoothing
:type smoothing: float
:param return_fr: `True` to return (estimated) firing rate, `False` to return spike counts
:type return_fr: bool
:return: peths, binned_spikes
:rtype: peths: Bunch({'mean': peth_means, 'std': peth_stds, 'tscale': ts, 'cscale': ids})
:rtype: binned_spikes: np.array (n_align_times, n_clusters, n_bins)
"""
# initialize containers
n_offset = 5 * int(np.ceil(smoothing / bin_size)) # get rid of boundary effects for smoothing
n_bins_pre = int(np.ceil(pre_time / bin_size)) + n_offset
n_bins_post = int(np.ceil(post_time / bin_size)) + n_offset
n_bins = n_bins_pre + n_bins_post
binned_spikes = np.zeros(shape=(len(align_times), len(cluster_ids), n_bins))
# build gaussian kernel if requested
if smoothing > 0:
w = n_bins - 1 if n_bins % 2 == 0 else n_bins
window = gaussian(w, std=smoothing / bin_size)
# half (causal) gaussian filter
# window[int(np.ceil(w/2)):] = 0
window /= np.sum(window)
binned_spikes_conv = np.copy(binned_spikes)
ids = np.unique(cluster_ids)
# filter spikes outside of the loop
idxs = np.bitwise_and(spike_times >= np.min(align_times) - (n_bins_pre + 1) * bin_size,
spike_times <= np.max(align_times) + (n_bins_post + 1) * bin_size)
idxs = np.bitwise_and(idxs, np.isin(spike_clusters, cluster_ids))
spike_times = spike_times[idxs]
spike_clusters = spike_clusters[idxs]
# compute floating tscale
tscale = np.arange(-n_bins_pre, n_bins_post + 1) * bin_size
# bin spikes
for i, t_0 in enumerate(align_times):
# define bin edges
ts = tscale + t_0
# filter spikes
idxs = np.bitwise_and(spike_times >= ts[0], spike_times <= ts[-1])
i_spikes = spike_times[idxs]
i_clusters = spike_clusters[idxs]
# bin spikes similar to bincount2D: x = spike times, y = spike clusters
xscale = ts
xind = (np.floor((i_spikes - np.min(ts)) / bin_size)).astype(np.int64)
yscale, yind = np.unique(i_clusters, return_inverse=True)
nx, ny = [xscale.size, yscale.size]
ind2d = np.ravel_multi_index(np.c_[yind, xind].transpose(), dims=(ny, nx))
r = np.bincount(ind2d, minlength=nx * ny, weights=None).reshape(ny, nx)
# store (ts represent bin edges, so there are one fewer bins)
bs_idxs = np.isin(ids, yscale)
binned_spikes[i, bs_idxs, :] = r[:, :-1]
# smooth
if smoothing > 0:
idxs = np.where(bs_idxs)[0]
for j in range(r.shape[0]):
binned_spikes_conv[i, idxs[j], :] = convolve(
r[j, :], window, mode='same', method='auto')[:-1]
# average
if smoothing > 0:
binned_spikes_ = np.copy(binned_spikes_conv)
else:
binned_spikes_ = np.copy(binned_spikes)
if return_fr:
binned_spikes_ /= bin_size
peth_means = np.mean(binned_spikes_, axis=0)
peth_stds = np.std(binned_spikes_, axis=0)
if smoothing > 0:
peth_means = peth_means[:, n_offset:-n_offset]
peth_stds = peth_stds[:, n_offset:-n_offset]
binned_spikes = binned_spikes[:, :, n_offset:-n_offset]
tscale = tscale[n_offset:-n_offset]
# package output
tscale = (tscale[:-1] + tscale[1:]) / 2
peths = Bunch({'means': peth_means, 'stds': peth_stds, 'tscale': tscale, 'cscale': ids})
return peths, binned_spikes
[docs]
def firing_rate(ts, hist_win=0.01, fr_win=0.5):
'''
Computes the instantaneous firing rate of a unit over time by computing a histogram of spike
counts over a specified window of time, and summing this histogram over a sliding window of
specified time over a specified period of total time.
Parameters
----------
ts : ndarray
The spike timestamps from which to compute the firing rate..
hist_win : float
The time window (in s) to use for computing spike counts.
fr_win : float
The time window (in s) to use as a moving slider to compute the instantaneous firing rate.
Returns
-------
fr : ndarray
The instantaneous firing rate over time (in hz).
See Also
--------
metrics.firing_rate_cv
metrics.firing_rate_fano_factor
plot.firing_rate
Examples
--------
1) Compute the firing rate for unit 1 from the time of its first to last spike.
>>> import brainbox as bb
>>> import alf.io as aio
>>> import ibllib.ephys.spikes as e_spks
(*Note, if there is no 'alf' directory, make 'alf' directory from 'ks2' output directory):
>>> e_spks.ks2_to_alf(path_to_ks_out, path_to_alf_out)
# Load a spikes bunch and get the timestamps for unit 1, and calculate the instantaneous
# firing rate.
>>> spks_b = aio.load_object(path_to_alf_out, 'spikes')
>>> unit_idxs = np.where(spks_b['clusters'] == 1)[0]
>>> ts = spks_b['times'][unit_idxs]
>>> fr = bb.singlecell.firing_rate(ts)
'''
# Compute histogram of spike counts.
t_tot = ts[-1] - ts[0]
n_bins_hist = int(t_tot / hist_win)
counts = np.histogram(ts, n_bins_hist)[0]
# Compute moving average of spike counts to get instantaneous firing rate in s.
n_bins_fr = int(t_tot / fr_win)
step_sz = int(len(counts) / n_bins_fr)
fr = np.convolve(counts, np.ones(step_sz)) / fr_win
fr = fr[step_sz - 1:- step_sz]
return fr