Source code for brainbox.population.cca

import numpy as np
import matplotlib.pylab as plt
from scipy.signal.windows import gaussian
from scipy.signal import convolve
from sklearn.decomposition import PCA

from iblutil.numerical import bincount2D


def _smooth(data, sd):

    n_bins = data.shape[0]
    w = n_bins - 1 if n_bins % 2 == 0 else n_bins
    window = gaussian(w, std=sd)
    for j in range(data.shape[1]):
        data[:, j] = convolve(data[:, j], window, mode='same', method='auto')
    return data


def _pca(data, n_pcs):
    pca = PCA(n_components=n_pcs)
    pca.fit(data)
    data_pc = pca.transform(data)
    return data_pc


[docs] def preprocess(data, smoothing_sd=25, n_pcs=20): """ Preprocess neural data for cca analysis with smoothing and pca :param data: array of shape (n_samples, n_features) :type data: array-like :param smoothing_sd: gaussian smoothing kernel standard deviation (ms) :type smoothing_sd: float :param n_pcs: number of pca dimensions to retain :type n_pcs: int :return: preprocessed neural data :rtype: array-like, shape (n_samples, pca_dims) """ if smoothing_sd > 0: data = _smooth(data, sd=smoothing_sd) if n_pcs > 0: data = _pca(data, n_pcs=n_pcs) return data
[docs] def split_trials(trial_ids, n_splits=5, rng_seed=0): """ Assign each trial to testing or training fold :param trial_ids: :type trial_ids: array-like :param n_splits: one split used for testing; remaining splits used for training :type n_splits: int :param rng_seed: set random state for shuffling trials :type rng_seed: int :return: list of dicts of indices with keys `train` and `test` """ from sklearn.model_selection import KFold shuffle = True if rng_seed is not None else False kf = KFold(n_splits=n_splits, random_state=rng_seed, shuffle=shuffle) kf.get_n_splits(trial_ids) idxs = [None for _ in range(n_splits)] for i, t0 in enumerate(kf.split(trial_ids)): idxs[i] = {'train': t0[0], 'test': t0[1]} return idxs
[docs] def split_timepoints(trial_ids, idxs_trial): """ Assign each time point to testing or training fold :param trial_ids: trial id for each timepoint :type trial_ids: array-like :param idxs_trial: list of dicts that define which trials are in `train` or `test` folds :type idxs_trial: list :return: list of dicts that define which time points are in `train` and `test` folds """ idxs_time = [None for _ in range(len(idxs_trial))] for i, idxs in enumerate(idxs_trial): idxs_time[i] = { dtype: np.where(np.isin(trial_ids, idxs[dtype]))[0] for dtype in idxs.keys()} return idxs_time
[docs] def fit_cca(data_0, data_1, n_cca_dims=10): """ Initialize and fit CCA sklearn object :param data_0: shape (n_samples, n_features_0) :type data_0: array-like :param data_1: shape (n_samples, n_features_1) :type data_1: array-like :param n_cca_dims: number of CCA dimensions to fit :type n_cca_dims: int :return: sklearn cca object """ from sklearn.cross_decomposition import CCA cca = CCA(n_components=n_cca_dims, max_iter=1000) cca.fit(data_0, data_1) return cca
[docs] def get_cca_projection(cca, data_0, data_1): """ Project data into CCA dimensions :param cca: :param data_0: :param data_1: :return: tuple; (data_0 projection, data_1 projection) """ x_scores, y_scores = cca.transform(data_0, data_1) return x_scores, y_scores
[docs] def get_correlations(cca, data_0, data_1): """ :param cca: :param data_0: :param data_1: :return: """ x_scores, y_scores = get_cca_projection(cca, data_0, data_1) corrs_tmp = np.corrcoef(x_scores.T, y_scores.T) corrs = np.diagonal(corrs_tmp, offset=data_0.shape[1]) return corrs
[docs] def shuffle_analysis(data_0, data_1, n_shuffles=100, **cca_kwargs): """ Perform CCA on shuffled data :param data_0: :param data_1: :param n_shuffles: :return: """ # TODO pass
[docs] def plot_correlations(corrs, errors=None, ax=None, **plot_kwargs): """ Correlation vs CCA dimension :param corrs: correlation values for the CCA dimensions :type corrs: 1-D vector :param errors: error values :type shuffled: 1-D array of size len(corrs) :param ax: axis to plot on (default None) :type ax: matplotlib axis object :return: axis if specified, or plot if axis = None """ # evaluate if np.arrays are passed assert type(corrs) is np.ndarray, "'corrs' is not a numpy array." if errors is not None: assert type(errors) is np.ndarray, "'errors' is not a numpy array." # create axis if no axis is passed if ax is None: ax = plt.gca() # get the data for the x and y axis y_data = corrs x_data = range(1, (len(corrs) + 1)) # create the plot object ax.plot(x_data, y_data, **plot_kwargs) if errors is not None: ax.fill_between(x_data, y_data - errors, y_data + errors, **plot_kwargs, alpha=0.2) # change y and x labels and ticks ax.set_xticks(x_data) ax.set_ylabel("Correlation") ax.set_xlabel("CCA dimension") return ax
[docs] def plot_pairwise_correlations(means, stderrs=None, n_dims=None, region_strs=None, **kwargs): """ Plot CCA correlations for multiple pairs of regions :param means: list of lists; means[i][j] contains the mean corrs between regions i, j :param stderrs: list of lists; stderrs[i][j] contains std errors of corrs between regions i, j :param n_dims: number of CCA dimensions to plot :param region_strs: list of strings identifying each region :param kwargs: keyword arguments for plot :return: matplotlib figure handle """ n_regions = len(means) fig, axes = plt.subplots(n_regions - 1, n_regions - 1, figsize=(12, 12)) for r in range(n_regions - 1): for c in range(n_regions - 1): axes[r, c].axis('off') # get max correlation to standardize y axes max_val = 0 for r in range(1, n_regions): for c in range(r): tmp = means[r][c] if tmp is not None: max_val = np.max([max_val, np.max(tmp)]) for r in range(1, n_regions): for c in range(r): ax = axes[r - 1, c] ax.axis('on') ax = plot_correlations(means[r][c][:n_dims], stderrs[r][c][:n_dims], ax=ax, **kwargs) ax.axhline(y=0, xmin=0.05, xmax=0.95, linestyle='--', color='k') if region_strs is not None: ax.text( x=0.95, y=0.95, s=str('%s-%s' % (region_strs[c], region_strs[r])), horizontalalignment='right', verticalalignment='top', transform=ax.transAxes) ax.set_ylim([-0.05, max_val + 0.05]) if not ax.is_first_col(): ax.set_ylabel('') ax.set_yticks([]) if not ax.is_last_row(): ax.set_xlabel('') ax.set_xticks([]) plt.tight_layout() plt.show() return fig
[docs] def plot_pairwise_correlations_mult( means, stderrs, colvec, n_dims=None, region_strs=None, **kwargs): """ Plot CCA correlations for multiple pairs of regions, for multiple behavioural events :param means: list of lists; means[k][i][j] contains the mean corrs between regions i, j for behavioral event k :param stderrs: list of lists; stderrs[k][i][j] contains std errors of corrs between regions i, j for behavioral event k :param colvec: color vector [must be a better way for this] :param n_dims: number of CCA dimensions to plot :param region_strs: list of strings identifying each region :param kwargs: keyword arguments for plot :return: matplotlib figure handle """ n_regions = len(means[0]) fig, axes = plt.subplots(n_regions - 1, n_regions - 1, figsize=(12, 12)) for r in range(n_regions - 1): for c in range(n_regions - 1): axes[r, c].axis('off') # get max correlation to standardize y axes max_val = 0 for b in range(len(means)): for r in range(1, n_regions): for c in range(r): tmp = means[b][r][c] if tmp is not None: max_val = np.max([max_val, np.max(tmp)]) for r in range(1, n_regions): for c in range(r): ax = axes[r - 1, c] ax.axis('on') for b in range(len(means)): plot_correlations(means[b][r][c][:n_dims], stderrs[b][r][c][:n_dims], ax=ax, color=colvec[b], **kwargs) ax.axhline(y=0, xmin=0.05, xmax=0.95, linestyle='--', color='k') if region_strs is not None: ax.text( x=0.95, y=0.95, s=str('%s-%s' % (region_strs[c], region_strs[r])), horizontalalignment='right', verticalalignment='top', transform=ax.transAxes) ax.set_ylim([-0.05, max_val + 0.05]) if not ax.is_first_col(): ax.set_ylabel('') ax.set_yticks([]) if not ax.is_last_row(): ax.set_xlabel('') ax.set_xticks([]) plt.tight_layout() plt.show() return fig
[docs] def bin_spikes_trials(spikes, trials, bin_size=0.01): """ Binarizes the spike times into a raster and assigns a trial number to each bin :param spikes: spikes object :type spikes: Bunch :param trials: trials object :type trials: Bunch :param bin_size: size, in s, of the bins :type bin_size: float :return: a matrix (bins, SpikeCounts), and a vector of bins size with trial ID, and a vector bins size with the time that the bins start """ binned_spikes, bin_times, _ = bincount2D(spikes['times'], spikes['clusters'], bin_size) trial_start_times = trials['intervals'][:, 0] binned_trialIDs = np.digitize(bin_times, trial_start_times) # correct, as index 0 is whatever happens before the first trial binned_trialIDs_corrected = binned_trialIDs - 1 return binned_spikes.T, binned_trialIDs_corrected, bin_times
[docs] def split_by_area(binned_spikes, cl_brainAcronyms, active_clusters, brain_areas): """ This function converts a matrix of binned spikes into a list of matrices, with the clusters grouped by brain areas :param binned_spikes: binned spike data of shape (n_bins, n_lusters) :type binned_spikes: numpy.ndarray :param cl_brainAcronyms: brain region for each cluster :type cl_brainAcronyms: pandas.core.frame.DataFrame :param brain_areas: list of brain areas to select :type brain_areas: numpy.ndarray :param active_clusters: list of clusterIDs :type active_clusters: numpy.ndarray :return: list of numpy.ndarrays of size brain_areas """ # TODO: check that this is doing what it is suppossed to!!! # TODO: check that input is as expected # # initialize list listof_bs = [] for b_area in brain_areas: # get the ids of clusters in the area cl_in_area = cl_brainAcronyms.loc[cl_brainAcronyms['brainAcronyms'] == b_area].index # get the indexes of the clusters that are in that area cl_idx_in_area = np.isin(active_clusters, cl_in_area) bs_in_area = binned_spikes[:, cl_idx_in_area] listof_bs.append(bs_in_area) return listof_bs
[docs] def get_event_bin_indexes(event_times, bin_times, window): """ Get the indexes of the bins corresponding to a specific behavioral event within a window :param event_times: time series of an event :type event_times: numpy.array :param bin_times: time series pf starting point of bins :type bin_times: numpy.array :param window: list of size 2 specifying the window in seconds [-time before, time after] :type window: numpy.array :return: array of indexes """ # TODO: check that this is doing what it is supposed to (coded during codecamp in a rush) # find bin size bin_size = bin_times[1] - bin_times[0] # find window size in bin units bin_window = int(np.ceil((window[1] - window[0]) / bin_size)) # correct event_times to the start of the window event_times_corrected = event_times - window[0] # get the indexes of the bins that are containing each event and add the window after idx_array = np.empty(shape=0) for etc in event_times_corrected: start_idx = (np.abs(bin_times - etc)).argmin() # add the window arr_to_append = np.array(range(start_idx, start_idx + bin_window)) idx_array = np.concatenate((idx_array, arr_to_append), axis=None) # remove the non-existing bins if any return idx_array.astype(int)
if __name__ == '__main__': from pathlib import Path from oneibl.one import ONE import alf.io as ioalf BIN_SIZE = 0.025 # seconds SMOOTH_SIZE = 0.025 # seconds; standard deviation of gaussian kernel PCA_DIMS = 20 CCA_DIMS = PCA_DIMS N_SPLITS = 5 RNG_SEED = 0 # get the data from flatiron subject = 'KS005' date = '2019-08-30' number = 1 one = ONE() eid = one.search(subject=subject, date=date, number=number) D = one.load(eid[0], download_only=True) session_path = Path(D.local_path[0]).parent spikes = ioalf.load_object(session_path, 'spikes') clusters = ioalf.load_object(session_path, 'clusters') # channels = ioalf.load_object(session_path, 'channels') trials = ioalf.load_object(session_path, 'trials') # bin spikes and get trial IDs associated with them binned_spikes, binned_trialIDs, _ = bin_spikes_trials(spikes, trials, bin_size=0.01) # define areas brain_areas = np.unique(clusters.brainAcronyms) brain_areas = brain_areas[1:4] # [take subset for testing] # split data by brain area # (bin_spikes_trials does not return info for innactive clusters) active_clusters = np.unique(spikes['clusters']) split_binned_spikes = split_by_area( binned_spikes, clusters.brainAcronyms, active_clusters, brain_areas) # preprocess data for i, pop in enumerate(split_binned_spikes): split_binned_spikes[i] = preprocess(pop, n_pcs=PCA_DIMS, smoothing_sd=SMOOTH_SIZE) # split trials idxs_trial = split_trials(np.unique(binned_trialIDs), n_splits=N_SPLITS, rng_seed=RNG_SEED) # get train/test indices into spike arrays idxs_time = split_timepoints(binned_trialIDs, idxs_trial) # Create empty "matrix" to store cca objects n_regions = len(brain_areas) cca_mat = [[None for _ in range(n_regions)] for _ in range(n_regions)] means_list = [[None for _ in range(n_regions)] for _ in range(n_regions)] serrs_list = [[None for _ in range(n_regions)] for _ in range(n_regions)] # For each pair of populations: for i in range(len(brain_areas)): pop_0 = split_binned_spikes[i] for j in range(len(brain_areas)): if j < i: # print progress print('Fitting CCA on regions {} / {}'.format(i, j)) pop_1 = split_binned_spikes[j] ccas = [None for _ in range(N_SPLITS)] corrs = [None for _ in range(N_SPLITS)] # for each xv fold for k, idxs in enumerate(idxs_time): ccas[k] = fit_cca( pop_0[idxs['train'], :], pop_1[idxs['train'], :], n_cca_dims=CCA_DIMS) corrs[k] = get_correlations( ccas[k], pop_0[idxs['test'], :], pop_1[idxs['test'], :]) cca_mat[i][j] = ccas[k] vals = np.stack(corrs, axis=1) means_list[i][j] = np.mean(vals, axis=1) serrs_list[i][j] = np.std(vals, axis=1) / np.sqrt(N_SPLITS) # plot matrix of correlations fig = plot_pairwise_correlations(means_list, serrs_list, n_dims=10, region_strs=brain_areas)