import numpy as np
import matplotlib.pylab as plt
from scipy.signal.windows import gaussian
from scipy.signal import convolve
from sklearn.decomposition import PCA
from iblutil.numerical import bincount2D
def _smooth(data, sd):
n_bins = data.shape[0]
w = n_bins - 1 if n_bins % 2 == 0 else n_bins
window = gaussian(w, std=sd)
for j in range(data.shape[1]):
data[:, j] = convolve(data[:, j], window, mode='same', method='auto')
return data
def _pca(data, n_pcs):
pca = PCA(n_components=n_pcs)
pca.fit(data)
data_pc = pca.transform(data)
return data_pc
[docs]
def preprocess(data, smoothing_sd=25, n_pcs=20):
"""
Preprocess neural data for cca analysis with smoothing and pca
:param data: array of shape (n_samples, n_features)
:type data: array-like
:param smoothing_sd: gaussian smoothing kernel standard deviation (ms)
:type smoothing_sd: float
:param n_pcs: number of pca dimensions to retain
:type n_pcs: int
:return: preprocessed neural data
:rtype: array-like, shape (n_samples, pca_dims)
"""
if smoothing_sd > 0:
data = _smooth(data, sd=smoothing_sd)
if n_pcs > 0:
data = _pca(data, n_pcs=n_pcs)
return data
[docs]
def split_trials(trial_ids, n_splits=5, rng_seed=0):
"""
Assign each trial to testing or training fold
:param trial_ids:
:type trial_ids: array-like
:param n_splits: one split used for testing; remaining splits used for training
:type n_splits: int
:param rng_seed: set random state for shuffling trials
:type rng_seed: int
:return: list of dicts of indices with keys `train` and `test`
"""
from sklearn.model_selection import KFold
shuffle = True if rng_seed is not None else False
kf = KFold(n_splits=n_splits, random_state=rng_seed, shuffle=shuffle)
kf.get_n_splits(trial_ids)
idxs = [None for _ in range(n_splits)]
for i, t0 in enumerate(kf.split(trial_ids)):
idxs[i] = {'train': t0[0], 'test': t0[1]}
return idxs
[docs]
def split_timepoints(trial_ids, idxs_trial):
"""
Assign each time point to testing or training fold
:param trial_ids: trial id for each timepoint
:type trial_ids: array-like
:param idxs_trial: list of dicts that define which trials are in `train` or `test` folds
:type idxs_trial: list
:return: list of dicts that define which time points are in `train` and `test` folds
"""
idxs_time = [None for _ in range(len(idxs_trial))]
for i, idxs in enumerate(idxs_trial):
idxs_time[i] = {
dtype: np.where(np.isin(trial_ids, idxs[dtype]))[0] for dtype in idxs.keys()}
return idxs_time
[docs]
def fit_cca(data_0, data_1, n_cca_dims=10):
"""
Initialize and fit CCA sklearn object
:param data_0: shape (n_samples, n_features_0)
:type data_0: array-like
:param data_1: shape (n_samples, n_features_1)
:type data_1: array-like
:param n_cca_dims: number of CCA dimensions to fit
:type n_cca_dims: int
:return: sklearn cca object
"""
from sklearn.cross_decomposition import CCA
cca = CCA(n_components=n_cca_dims, max_iter=1000)
cca.fit(data_0, data_1)
return cca
[docs]
def get_cca_projection(cca, data_0, data_1):
"""
Project data into CCA dimensions
:param cca:
:param data_0:
:param data_1:
:return: tuple; (data_0 projection, data_1 projection)
"""
x_scores, y_scores = cca.transform(data_0, data_1)
return x_scores, y_scores
[docs]
def get_correlations(cca, data_0, data_1):
"""
:param cca:
:param data_0:
:param data_1:
:return:
"""
x_scores, y_scores = get_cca_projection(cca, data_0, data_1)
corrs_tmp = np.corrcoef(x_scores.T, y_scores.T)
corrs = np.diagonal(corrs_tmp, offset=data_0.shape[1])
return corrs
[docs]
def shuffle_analysis(data_0, data_1, n_shuffles=100, **cca_kwargs):
"""
Perform CCA on shuffled data
:param data_0:
:param data_1:
:param n_shuffles:
:return:
"""
# TODO
pass
[docs]
def plot_correlations(corrs, errors=None, ax=None, **plot_kwargs):
"""
Correlation vs CCA dimension
:param corrs: correlation values for the CCA dimensions
:type corrs: 1-D vector
:param errors: error values
:type shuffled: 1-D array of size len(corrs)
:param ax: axis to plot on (default None)
:type ax: matplotlib axis object
:return: axis if specified, or plot if axis = None
"""
# evaluate if np.arrays are passed
assert type(corrs) is np.ndarray, "'corrs' is not a numpy array."
if errors is not None:
assert type(errors) is np.ndarray, "'errors' is not a numpy array."
# create axis if no axis is passed
if ax is None:
ax = plt.gca()
# get the data for the x and y axis
y_data = corrs
x_data = range(1, (len(corrs) + 1))
# create the plot object
ax.plot(x_data, y_data, **plot_kwargs)
if errors is not None:
ax.fill_between(x_data, y_data - errors, y_data + errors, **plot_kwargs, alpha=0.2)
# change y and x labels and ticks
ax.set_xticks(x_data)
ax.set_ylabel("Correlation")
ax.set_xlabel("CCA dimension")
return ax
[docs]
def plot_pairwise_correlations(means, stderrs=None, n_dims=None, region_strs=None, **kwargs):
"""
Plot CCA correlations for multiple pairs of regions
:param means: list of lists; means[i][j] contains the mean corrs between regions i, j
:param stderrs: list of lists; stderrs[i][j] contains std errors of corrs between regions i, j
:param n_dims: number of CCA dimensions to plot
:param region_strs: list of strings identifying each region
:param kwargs: keyword arguments for plot
:return: matplotlib figure handle
"""
n_regions = len(means)
fig, axes = plt.subplots(n_regions - 1, n_regions - 1, figsize=(12, 12))
for r in range(n_regions - 1):
for c in range(n_regions - 1):
axes[r, c].axis('off')
# get max correlation to standardize y axes
max_val = 0
for r in range(1, n_regions):
for c in range(r):
tmp = means[r][c]
if tmp is not None:
max_val = np.max([max_val, np.max(tmp)])
for r in range(1, n_regions):
for c in range(r):
ax = axes[r - 1, c]
ax.axis('on')
ax = plot_correlations(means[r][c][:n_dims], stderrs[r][c][:n_dims], ax=ax, **kwargs)
ax.axhline(y=0, xmin=0.05, xmax=0.95, linestyle='--', color='k')
if region_strs is not None:
ax.text(
x=0.95, y=0.95, s=str('%s-%s' % (region_strs[c], region_strs[r])),
horizontalalignment='right',
verticalalignment='top',
transform=ax.transAxes)
ax.set_ylim([-0.05, max_val + 0.05])
if not ax.is_first_col():
ax.set_ylabel('')
ax.set_yticks([])
if not ax.is_last_row():
ax.set_xlabel('')
ax.set_xticks([])
plt.tight_layout()
plt.show()
return fig
[docs]
def plot_pairwise_correlations_mult(
means, stderrs, colvec, n_dims=None, region_strs=None, **kwargs):
"""
Plot CCA correlations for multiple pairs of regions, for multiple behavioural events
:param means: list of lists; means[k][i][j] contains the mean corrs between regions i, j for
behavioral event k
:param stderrs: list of lists; stderrs[k][i][j] contains std errors of corrs between
regions i, j for behavioral event k
:param colvec: color vector [must be a better way for this]
:param n_dims: number of CCA dimensions to plot
:param region_strs: list of strings identifying each region
:param kwargs: keyword arguments for plot
:return: matplotlib figure handle
"""
n_regions = len(means[0])
fig, axes = plt.subplots(n_regions - 1, n_regions - 1, figsize=(12, 12))
for r in range(n_regions - 1):
for c in range(n_regions - 1):
axes[r, c].axis('off')
# get max correlation to standardize y axes
max_val = 0
for b in range(len(means)):
for r in range(1, n_regions):
for c in range(r):
tmp = means[b][r][c]
if tmp is not None:
max_val = np.max([max_val, np.max(tmp)])
for r in range(1, n_regions):
for c in range(r):
ax = axes[r - 1, c]
ax.axis('on')
for b in range(len(means)):
plot_correlations(means[b][r][c][:n_dims], stderrs[b][r][c][:n_dims],
ax=ax, color=colvec[b], **kwargs)
ax.axhline(y=0, xmin=0.05, xmax=0.95, linestyle='--', color='k')
if region_strs is not None:
ax.text(
x=0.95, y=0.95, s=str('%s-%s' % (region_strs[c], region_strs[r])),
horizontalalignment='right',
verticalalignment='top',
transform=ax.transAxes)
ax.set_ylim([-0.05, max_val + 0.05])
if not ax.is_first_col():
ax.set_ylabel('')
ax.set_yticks([])
if not ax.is_last_row():
ax.set_xlabel('')
ax.set_xticks([])
plt.tight_layout()
plt.show()
return fig
[docs]
def bin_spikes_trials(spikes, trials, bin_size=0.01):
"""
Binarizes the spike times into a raster and assigns a trial number to each bin
:param spikes: spikes object
:type spikes: Bunch
:param trials: trials object
:type trials: Bunch
:param bin_size: size, in s, of the bins
:type bin_size: float
:return: a matrix (bins, SpikeCounts), and a vector of bins size with trial ID,
and a vector bins size with the time that the bins start
"""
binned_spikes, bin_times, _ = bincount2D(spikes['times'], spikes['clusters'], bin_size)
trial_start_times = trials['intervals'][:, 0]
binned_trialIDs = np.digitize(bin_times, trial_start_times)
# correct, as index 0 is whatever happens before the first trial
binned_trialIDs_corrected = binned_trialIDs - 1
return binned_spikes.T, binned_trialIDs_corrected, bin_times
[docs]
def split_by_area(binned_spikes, cl_brainAcronyms, active_clusters, brain_areas):
"""
This function converts a matrix of binned spikes into a list of matrices, with the clusters
grouped by brain areas
:param binned_spikes: binned spike data of shape (n_bins, n_lusters)
:type binned_spikes: numpy.ndarray
:param cl_brainAcronyms: brain region for each cluster
:type cl_brainAcronyms: pandas.core.frame.DataFrame
:param brain_areas: list of brain areas to select
:type brain_areas: numpy.ndarray
:param active_clusters: list of clusterIDs
:type active_clusters: numpy.ndarray
:return: list of numpy.ndarrays of size brain_areas
"""
# TODO: check that this is doing what it is suppossed to!!!
# TODO: check that input is as expected
#
# initialize list
listof_bs = []
for b_area in brain_areas:
# get the ids of clusters in the area
cl_in_area = cl_brainAcronyms.loc[cl_brainAcronyms['brainAcronyms'] == b_area].index
# get the indexes of the clusters that are in that area
cl_idx_in_area = np.isin(active_clusters, cl_in_area)
bs_in_area = binned_spikes[:, cl_idx_in_area]
listof_bs.append(bs_in_area)
return listof_bs
[docs]
def get_event_bin_indexes(event_times, bin_times, window):
"""
Get the indexes of the bins corresponding to a specific behavioral event within a window
:param event_times: time series of an event
:type event_times: numpy.array
:param bin_times: time series pf starting point of bins
:type bin_times: numpy.array
:param window: list of size 2 specifying the window in seconds [-time before, time after]
:type window: numpy.array
:return: array of indexes
"""
# TODO: check that this is doing what it is supposed to (coded during codecamp in a rush)
# find bin size
bin_size = bin_times[1] - bin_times[0]
# find window size in bin units
bin_window = int(np.ceil((window[1] - window[0]) / bin_size))
# correct event_times to the start of the window
event_times_corrected = event_times - window[0]
# get the indexes of the bins that are containing each event and add the window after
idx_array = np.empty(shape=0)
for etc in event_times_corrected:
start_idx = (np.abs(bin_times - etc)).argmin()
# add the window
arr_to_append = np.array(range(start_idx, start_idx + bin_window))
idx_array = np.concatenate((idx_array, arr_to_append), axis=None)
# remove the non-existing bins if any
return idx_array.astype(int)
if __name__ == '__main__':
from pathlib import Path
from oneibl.one import ONE
import alf.io as ioalf
BIN_SIZE = 0.025 # seconds
SMOOTH_SIZE = 0.025 # seconds; standard deviation of gaussian kernel
PCA_DIMS = 20
CCA_DIMS = PCA_DIMS
N_SPLITS = 5
RNG_SEED = 0
# get the data from flatiron
subject = 'KS005'
date = '2019-08-30'
number = 1
one = ONE()
eid = one.search(subject=subject, date=date, number=number)
D = one.load(eid[0], download_only=True)
session_path = Path(D.local_path[0]).parent
spikes = ioalf.load_object(session_path, 'spikes')
clusters = ioalf.load_object(session_path, 'clusters')
# channels = ioalf.load_object(session_path, 'channels')
trials = ioalf.load_object(session_path, 'trials')
# bin spikes and get trial IDs associated with them
binned_spikes, binned_trialIDs, _ = bin_spikes_trials(spikes, trials, bin_size=0.01)
# define areas
brain_areas = np.unique(clusters.brainAcronyms)
brain_areas = brain_areas[1:4] # [take subset for testing]
# split data by brain area
# (bin_spikes_trials does not return info for innactive clusters)
active_clusters = np.unique(spikes['clusters'])
split_binned_spikes = split_by_area(
binned_spikes, clusters.brainAcronyms, active_clusters, brain_areas)
# preprocess data
for i, pop in enumerate(split_binned_spikes):
split_binned_spikes[i] = preprocess(pop, n_pcs=PCA_DIMS, smoothing_sd=SMOOTH_SIZE)
# split trials
idxs_trial = split_trials(np.unique(binned_trialIDs), n_splits=N_SPLITS, rng_seed=RNG_SEED)
# get train/test indices into spike arrays
idxs_time = split_timepoints(binned_trialIDs, idxs_trial)
# Create empty "matrix" to store cca objects
n_regions = len(brain_areas)
cca_mat = [[None for _ in range(n_regions)] for _ in range(n_regions)]
means_list = [[None for _ in range(n_regions)] for _ in range(n_regions)]
serrs_list = [[None for _ in range(n_regions)] for _ in range(n_regions)]
# For each pair of populations:
for i in range(len(brain_areas)):
pop_0 = split_binned_spikes[i]
for j in range(len(brain_areas)):
if j < i:
# print progress
print('Fitting CCA on regions {} / {}'.format(i, j))
pop_1 = split_binned_spikes[j]
ccas = [None for _ in range(N_SPLITS)]
corrs = [None for _ in range(N_SPLITS)]
# for each xv fold
for k, idxs in enumerate(idxs_time):
ccas[k] = fit_cca(
pop_0[idxs['train'], :], pop_1[idxs['train'], :], n_cca_dims=CCA_DIMS)
corrs[k] = get_correlations(
ccas[k], pop_0[idxs['test'], :], pop_1[idxs['test'], :])
cca_mat[i][j] = ccas[k]
vals = np.stack(corrs, axis=1)
means_list[i][j] = np.mean(vals, axis=1)
serrs_list[i][j] = np.std(vals, axis=1) / np.sqrt(N_SPLITS)
# plot matrix of correlations
fig = plot_pairwise_correlations(means_list, serrs_list, n_dims=10, region_strs=brain_areas)