Source code for brainbox.population.decode

"""
Population functions.

Code from https://github.com/cortex-lab/phylib/blob/master/phylib/stats/ccg.py by C. Rossant.
Code for decoding by G. Meijer
Code from sigtest_pseudosessions and sigtest_linshift by B. Benson
"""

import numpy as np
import scipy as sp
import scipy.stats
import types
from itertools import groupby
from sklearn.linear_model import LinearRegression, Lasso, Ridge
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.model_selection import KFold, LeaveOneOut, LeaveOneGroupOut
from sklearn.metrics import accuracy_score


[docs] def get_spike_counts_in_bins(spike_times, spike_clusters, intervals): """ Return the number of spikes in a sequence of time intervals, for each neuron. Parameters ---------- spike_times : 1D array spike times (in seconds) spike_clusters : 1D array cluster ids corresponding to each event in `spikes` intervals : 2D array of shape (n_events, 2) the start and end times of the events Returns --------- counts : 2D array of shape (n_neurons, n_events) the spike counts of all neurons ffrom scipy.stats import sem, tor all events value (i, j) is the number of spikes of neuron `neurons[i]` in interval #j cluster_ids : 1D array list of cluster ids """ # Check input assert intervals.ndim == 2 assert intervals.shape[1] == 2 assert np.all(np.diff(spike_times) >= 0), "Spike times need to be sorted" intervals_idx = np.searchsorted(spike_times, intervals) # For each neuron and each interval, the number of spikes in the interval. cluster_ids = np.unique(spike_clusters) n_neurons = len(cluster_ids) n_intervals = intervals.shape[0] counts = np.zeros((n_neurons, n_intervals), dtype=np.uint32) for j in range(n_intervals): i0, i1 = intervals_idx[j, :] # Count the number of spikes in the window, for each neuron. x = np.bincount(spike_clusters[i0:i1], minlength=cluster_ids.max() + 1) counts[:, j] = x[cluster_ids] return counts, cluster_ids
def _index_of(arr, lookup): """Replace scalars in an array by their indices in a lookup table. Implicitly assume that: * All elements of arr and lookup are non-negative integers. * All elements or arr belong to lookup. This is not checked for performance reasons. """ # Equivalent of np.digitize(arr, lookup) - 1, but much faster. # TODO: assertions to disable in production for performance reasons. # TODO: np.searchsorted(lookup, arr) is faster on small arrays with large # values lookup = np.asarray(lookup, dtype=np.int32) m = (lookup.max() if len(lookup) else 0) + 1 tmp = np.zeros(m + 1, dtype=int) # Ensure that -1 values are kept. tmp[-1] = -1 if len(lookup): tmp[lookup] = np.arange(len(lookup)) return tmp[arr] def _increment(arr, indices): """Increment some indices in a 1D vector of non-negative integers. Repeated indices are taken into account.""" bbins = np.bincount(indices) arr[:len(bbins)] += bbins return arr def _diff_shifted(arr, steps=1): return arr[steps:] - arr[:len(arr) - steps] def _create_correlograms_array(n_clusters, winsize_bins): return np.zeros((n_clusters, n_clusters, winsize_bins // 2 + 1), dtype=np.int32) def _symmetrize_correlograms(correlograms): """Return the symmetrized version of the CCG arrays.""" n_clusters, _, n_bins = correlograms.shape assert n_clusters == _ # We symmetrize c[i, j, 0]. # This is necessary because the algorithm in correlograms() # is sensitive to the order of identical spikes. correlograms[..., 0] = np.maximum( correlograms[..., 0], correlograms[..., 0].T) sym = correlograms[..., 1:][..., ::-1] sym = np.transpose(sym, (1, 0, 2)) return np.dstack((sym, correlograms))
[docs] def xcorr(spike_times, spike_clusters, bin_size=None, window_size=None): """Compute all pairwise cross-correlograms among the clusters appearing in `spike_clusters`. Parameters ---------- :param spike_times: Spike times in seconds. :type spike_times: array-like :param spike_clusters: Spike-cluster mapping. :type spike_clusters: array-like :param bin_size: Size of the bin, in seconds. :type bin_size: float :param window_size: Size of the window, in seconds. :type window_size: float Returns an `(n_clusters, n_clusters, winsize_samples)` array with all pairwise cross-correlograms. """ assert np.all(np.diff(spike_times) >= 0), "The spike times must be increasing." assert spike_times.ndim == 1 assert spike_times.shape == spike_clusters.shape # Find `binsize`. bin_size = np.clip(bin_size, 1e-5, 1e5) # in seconds # Find `winsize_bins`. window_size = np.clip(window_size, 1e-5, 1e5) # in seconds winsize_bins = 2 * int(.5 * window_size / bin_size) + 1 # Take the cluster order into account. clusters = np.unique(spike_clusters) n_clusters = len(clusters) # Like spike_clusters, but with 0..n_clusters-1 indices. spike_clusters_i = _index_of(spike_clusters, clusters) # Shift between the two copies of the spike trains. shift = 1 # At a given shift, the mask precises which spikes have matching spikes # within the correlogram time window. mask = np.ones_like(spike_times, dtype=bool) correlograms = _create_correlograms_array(n_clusters, winsize_bins) # The loop continues as long as there is at least one spike with # a matching spike. while mask[:-shift].any(): # Interval between spike i and spike i+shift. spike_diff = _diff_shifted(spike_times, shift) # Binarize the delays between spike i and spike i+shift. spike_diff_b = np.round(spike_diff / bin_size).astype(np.int64) # Spikes with no matching spikes are masked. mask[:-shift][spike_diff_b > (winsize_bins / 2)] = False # Cache the masked spike delays. m = mask[:-shift].copy() d = spike_diff_b[m] # Find the indices in the raveled correlograms array that need # to be incremented, taking into account the spike clusters. indices = np.ravel_multi_index( (spike_clusters_i[:-shift][m], spike_clusters_i[+shift:][m], d), correlograms.shape) # Increment the matching spikes in the correlograms array. _increment(correlograms.ravel(), indices) shift += 1 return _symmetrize_correlograms(correlograms)
[docs] def classify(population_activity, trial_labels, classifier, cross_validation=None, return_training=False): """ Classify trial identity (e.g. stim left/right) from neural population activity. Parameters ---------- population_activity : 2D array (trials x neurons) population activity of all neurons in the population for each trial. trial_labels : 1D or 2D array identities of the trials, can be any number of groups, accepts integers and strings classifier : scikit-learn object which decoder to use, for example Gaussian with Multinomial likelihood: from sklearn.naive_bayes import MultinomialNB classifier = MultinomialNB() cross_validation : None or scikit-learn object which cross-validation method to use, for example 5-fold: from sklearn.model_selection import KFold cross_validation = KFold(n_splits=5) return_training : bool if set to True the classifier will also return the performance on the training set Returns ------- accuracy : float accuracy of the classifier pred : 1D array predictions of the classifier prob : 1D array probablity of classification training_accuracy : float accuracy of the classifier on the training set (only if return_training is True) """ # Check input if (cross_validation is None) and (return_training is True): raise RuntimeError('cannot return training accuracy without cross-validation') if population_activity.shape[0] != trial_labels.shape[0]: raise ValueError('trial_labels is not the same length as the first dimension of ' 'population_activity') if cross_validation is None: # Fit the model on all the data classifier.fit(population_activity, trial_labels) pred = classifier.predict(population_activity) prob = classifier.predict_proba(population_activity) prob = prob[:, 1] else: pred = np.empty(trial_labels.shape[0]) prob = np.empty(trial_labels.shape[0]) if return_training: pred_training = np.empty(trial_labels.shape[0]) for train_index, test_index in cross_validation.split(population_activity): # Fit the model to the training data classifier.fit(population_activity[train_index], trial_labels[train_index]) # Predict the held-out test data pred[test_index] = classifier.predict(population_activity[test_index]) proba = classifier.predict_proba(population_activity[test_index]) prob[test_index] = proba[:, 1] # Predict the training data if return_training: pred_training[train_index] = classifier.predict(population_activity[train_index]) # Calculate accuracy accuracy = accuracy_score(trial_labels, pred) if return_training: training_accuracy = accuracy_score(trial_labels, pred_training) return accuracy, pred, prob, training_accuracy else: return accuracy, pred, prob
[docs] def regress(population_activity, trial_targets, regularization=None, cross_validation=None, return_training=False): """ Perform linear regression to predict a continuous variable from neural data Parameters ---------- population_activity : 2D array (trials x neurons) population activity of all neurons in the population for each trial. trial_targets : 1D or 2D array the decoding target per trial as a continuous variable regularization : None or string None = no regularization using ordinary least squares linear regression 'L1' = L1 regularization using Lasso 'L2' = L2 regularization using Ridge regression cross_validation : None or scikit-learn object which cross-validation method to use, for example 5-fold: from sklearn.model_selection import KFold cross_validation = KFold(n_splits=5) return_training : bool if set to True the classifier will also return the performance on the training set Returns ------- pred : 1D array array with predictions pred_training : 1D array array with predictions for the training set (only if return_training is True) """ # Check input if (cross_validation is None) and (return_training is True): raise RuntimeError('cannot return training accuracy without cross-validation') if population_activity.shape[0] != trial_targets.shape[0]: raise ValueError('trial_targets is not the same length as the first dimension of ' 'population_activity') # Initialize regression if regularization is None: reg = LinearRegression() elif regularization == 'L1': reg = Lasso() elif regularization == 'L2': reg = Ridge() if cross_validation is None: # Fit the model on all the data reg.fit(population_activity, trial_targets) pred = reg.predict(population_activity) else: pred = np.empty(trial_targets.shape[0]) if return_training: pred_training = np.empty(trial_targets.shape[0]) for train_index, test_index in cross_validation.split(population_activity): # Fit the model to the training data reg.fit(population_activity[train_index], trial_targets[train_index]) # Predict the held-out test data pred[test_index] = reg.predict(population_activity[test_index]) # Predict the training data if return_training: pred_training[train_index] = reg.predict(population_activity[train_index]) if return_training: return pred, pred_training else: return pred
[docs] def lda_project(spike_times, spike_clusters, event_times, event_groups, pre_time=0, post_time=0.5, cross_validation='kfold', num_splits=5, prob_left=None, custom_validation=None): """ Use linear discriminant analysis to project population vectors to the line that best separates the two groups. When cross-validation is used, the LDA projection is fitted on the training data after which the test data is projected to this projection. spike_times : 1D array spike times (in seconds) spike_clusters : 1D array cluster ids corresponding to each event in `spikes` event_times : 1D array times (in seconds) of the events from the two groups event_groups : 1D array group identities of the events, can be any number of groups, accepts integers and strings cross_validation : string which cross-validation method to use, options are: 'none' No cross-validation 'kfold' K-fold cross-validation 'leave-one-out' Leave out the trial that is being decoded 'block' Leave out the block the to-be-decoded trial is in 'custom' Any custom cross-validation provided by the user num_splits : integer ** only for 'kfold' cross-validation ** Number of splits to use for k-fold cross validation, a value of 5 means that the decoder will be trained on 4/5th of the data and used to predict the remaining 1/5th. This process is repeated five times so that all data has been used as both training and test set. prob_left : 1D array ** only for 'block' cross-validation ** the probability of the stimulus appearing on the left for each trial in event_times custom_validation : generator ** only for 'custom' cross-validation ** a generator object with the splits to be used for cross validation using this format: ( (split1_train_idxs, split1_test_idxs), (split2_train_idxs, split2_test_idxs), (split3_train_idxs, split3_test_idxs), ...) n_neurons : int Group size of number of neurons to be sub-selected Returns ------- lda_projection : 1D array the position along the LDA projection axis for the population vector of each trial """ # Check input assert cross_validation in ['none', 'kfold', 'leave-one-out', 'block', 'custom'] assert event_times.shape[0] == event_groups.shape[0] if cross_validation == 'block': assert event_times.shape[0] == prob_left.shape[0] if cross_validation == 'custom': assert isinstance(custom_validation, types.GeneratorType) # Get matrix of all neuronal responses times = np.column_stack(((event_times - pre_time), (event_times + post_time))) pop_vector, cluster_ids = get_spike_counts_in_bins(spike_times, spike_clusters, times) pop_vector = pop_vector.T # Initialize lda = LinearDiscriminantAnalysis() lda_projection = np.zeros(event_groups.shape) if cross_validation == 'none': # Find the best LDA projection on all data and transform those data lda_projection = lda.fit_transform(pop_vector, event_groups) else: # Perform cross-validation if cross_validation == 'leave-one-out': cv = LeaveOneOut().split(pop_vector) elif cross_validation == 'kfold': cv = KFold(n_splits=num_splits).split(pop_vector) elif cross_validation == 'block': block_lengths = [sum(1 for i in g) for k, g in groupby(prob_left)] blocks = np.repeat(np.arange(len(block_lengths)), block_lengths) cv = LeaveOneGroupOut().split(pop_vector, groups=blocks) elif cross_validation == 'custom': cv = custom_validation # Loop over the splits into train and test for train_index, test_index in cv: # Find LDA projection on the training data lda.fit(pop_vector[train_index], [event_groups[j] for j in train_index]) # Project the held-out test data to projection lda_projection[test_index] = lda.transform(pop_vector[test_index]).T[0] return lda_projection
[docs] def sigtest_pseudosessions(X, y, fStatMeas, genPseudo, npseuds=200): """ Estimates significance level of any statistical measure following Harris, Arxiv, 2021 (https://www.biorxiv.org/content/10.1101/2020.11.29.402719v2). fStatMeas computes a scalar statistical measure (e.g. R^2) between the data, X, and the decoded variable, y. pseudosessions are generated npseuds times to create a null distribution of statistical measures. Significance level is reported relative to this null distribution. X : 2-d array Data of size (elements, timetrials) y : 1-d array predicted variable of size (timetrials) fStatMeas : function takes arguments (X, y) and returns a statistical measure relating how well X decodes y genPseudo : function takes no arguments () and returns a pseudosession (same shape as y) drawn from the experimentally known null-distribution of y npseuds : int the number of pseudosessions used to estimate the significance level Returns ------- alpha : p-value e.g. at a significance level of b, if alpha <= b then reject the null hypothesis. statms_real : the value of the statistical measure evaluated on X and y statms_pseuds : array of statistical measures evaluated on pseudosessions """ statms_real = fStatMeas(X, y) statms_pseuds = np.zeros(npseuds) for i in range(npseuds): statms_pseuds[i] = fStatMeas(X, genPseudo()) alpha = 1 - (0.01 * sp.stats.percentileofscore(statms_pseuds, statms_real, kind='weak')) return alpha, statms_real, statms_pseuds
[docs] def sigtest_linshift(X, y, fStatMeas, D=300): """ Uses a provably conservative Linear Shift technique (Harris, Kenneth Arxiv 2021, https://arxiv.org/ftp/arxiv/papers/2012/2012.06862.pdf) to estimate significance level of a statistical measure. fStatMeas computes a scalar statistical measure (e.g. R^2) from the data matrix, X, and the variable, y. A central window of X and y of size, D, is linearly shifted to generate a null distribution of statistical measures. Significance level is reported relative to this null distribution. X : 2-d array Data of size (elements, timetrials) y : 1-d array predicted variable of size (timetrials) fStatMeas : function takes arguments (X, y) and returns a scalar statistical measure of how well X decodes y D : int the window length along the center of y used to compute the statistical measure. must have room to shift both right and left: len(y) >= D+2 Returns ------- alpha : conservative p-value e.g. at a significance level of b, if alpha <= b then reject the null hypothesis. statms_real : the value of the statistical measure evaluated on X and y statms_pseuds : a 1-d array of statistical measures evaluated on shifted versions of y """ assert len(y) >= D + 2 T = len(y) N = int((T - D) / 2) shifts = np.arange(-N, N + 1) # compute all statms statms_real = fStatMeas(X[:, N:T - N], y[N:T - N]) statms_pseuds = np.zeros(len(shifts)) for si in range(len(shifts)): s = shifts[si] statms_pseuds[si] = fStatMeas(np.copy(X[:, N:T - N]), np.copy(y[s + N:s + T - N])) M = np.sum(statms_pseuds >= statms_real) alpha = M / (N + 1) return alpha, statms_real, statms_pseuds