Source code for brainbox.tests.test_population

from pathlib import Path
import pickle
from sklearn.naive_bayes import MultinomialNB
from sklearn.model_selection import KFold
from brainbox.population.decode import (xcorr, classify, regress, get_spike_counts_in_bins,
                                        sigtest_pseudosessions, sigtest_linshift)
import unittest
import numpy as np


def _random_data(max_cluster):
    nspikes = 10000
    spike_times = np.cumsum(np.random.exponential(scale=.025, size=nspikes))
    spike_clusters = np.random.randint(0, max_cluster, nspikes)
    return spike_times, spike_clusters


[docs] class TestPopulation(unittest.TestCase):
[docs] def setUp(self): # Test data is a dictionary of spike times and clusters and event times and groups pickle_file = Path(__file__).parent.joinpath('fixtures', 'ephys_test.p') if not pickle_file.exists(): self.test_data = None else: with open(pickle_file, 'rb') as f: self.test_data = pickle.load(f)
[docs] def test_get_spike_counts_in_bins(self): if self.test_data is None: return spike_times = self.test_data['spike_times'] spike_clusters = self.test_data['spike_clusters'] event_times = self.test_data['event_times'] times = np.column_stack(((event_times - 0.5), (event_times + 0.5))) counts, cluster_ids = get_spike_counts_in_bins(spike_times, spike_clusters, times) num_clusters = np.size(np.unique(spike_clusters)) self.assertEqual(counts.shape, (num_clusters, np.size(event_times))) self.assertTrue(np.size(cluster_ids) == num_clusters)
[docs] def test_classify(self): if self.test_data is None: return spike_times = self.test_data['spike_times'] spike_clusters = self.test_data['spike_clusters'] event_times = self.test_data['event_times'] event_groups = self.test_data['event_groups'] clf = MultinomialNB() cv = KFold(n_splits=2) times = np.column_stack(((event_times - 0.5), (event_times + 0.5))) counts, cluster_ids = get_spike_counts_in_bins(spike_times, spike_clusters, times) counts = counts.T accuracy, pred, prob, acc_training = classify(counts, event_groups, clf, cross_validation=cv, return_training=True) self.assertTrue(accuracy == 0.2222222222222222) self.assertTrue(acc_training == 0.9444444444444444) self.assertEqual(pred.shape, event_groups.shape) self.assertEqual(prob.shape, event_groups.shape)
[docs] def test_regress(self): if self.test_data is None: return spike_times = self.test_data['spike_times'] spike_clusters = self.test_data['spike_clusters'] event_times = self.test_data['event_times'] event_groups = self.test_data['event_groups'] cv = KFold(n_splits=2) times = np.column_stack(((event_times - 0.5), (event_times + 0.5))) counts, cluster_ids = get_spike_counts_in_bins(spike_times, spike_clusters, times) counts = counts.T # Test all regularization methods WITHOUT cross-validation pred = regress(counts, event_groups, cross_validation=None, return_training=False, regularization=None) self.assertEqual(pred.shape, event_groups.shape) pred = regress(counts, event_groups, cross_validation=None, return_training=False, regularization='L1') self.assertEqual(pred.shape, event_groups.shape) pred = regress(counts, event_groups, cross_validation=None, return_training=False, regularization='L2') self.assertEqual(pred.shape, event_groups.shape) # Test all regularization methods WITH cross-validation pred, pred_training = regress(counts, event_groups, cross_validation=cv, return_training=True, regularization=None) self.assertEqual(pred.shape, event_groups.shape) self.assertEqual(pred_training.shape, event_groups.shape) pred, pred_training = regress(counts, event_groups, cross_validation=cv, return_training=True, regularization='L1') self.assertEqual(pred.shape, event_groups.shape) self.assertEqual(pred_training.shape, event_groups.shape) pred, pred_training = regress(counts, event_groups, cross_validation=cv, return_training=True, regularization='L2') self.assertEqual(pred.shape, event_groups.shape) self.assertEqual(pred_training.shape, event_groups.shape)
[docs] def test_xcorr_0(self): # 0: 0, 10 # 1: 10, 20 spike_times = np.array([0, 10, 10, 20]) spike_clusters = np.array([0, 1, 0, 1]) bin_size = 1 winsize_bins = 2 * 3 + 1 c_expected = np.zeros((2, 2, 7), dtype=np.int32) c_expected[1, 0, 3] = 1 c_expected[0, 1, 3] = 1 c = xcorr(spike_times, spike_clusters, bin_size=bin_size, window_size=winsize_bins) self.assertTrue(np.allclose(c, c_expected))
[docs] def test_xcorr_1(self): # 0: 2, 10, 12, 30 # 1: 3, 24 # 2: 20 spike_times = np.array([2, 3, 10, 12, 20, 24, 30, 40], dtype=np.uint64) spike_clusters = np.array([0, 1, 0, 0, 2, 1, 0, 2]) bin_size = 1 winsize_bins = 2 * 3 + 1 c_expected = np.zeros((3, 3, 7), dtype=np.int32) c_expected[0, 1, 4] = 1 c_expected[1, 0, 2] = 1 c_expected[0, 0, 1] = 1 c_expected[0, 0, 5] = 1 c = xcorr(spike_times, spike_clusters, bin_size=bin_size, window_size=winsize_bins) self.assertTrue(np.allclose(c, c_expected))
[docs] def test_xcorr_2(self): max_cluster = 10 spike_times, spike_clusters = _random_data(max_cluster) bin_size, winsize_bins = .001, .05 c = xcorr(spike_times, spike_clusters, bin_size=bin_size, window_size=winsize_bins) self.assertEqual(c.shape, (max_cluster, max_cluster, 51))
[docs] def test_sigtest_pseudosessions(self): X = np.zeros((200, 700)) y = np.zeros(700) def fStatMeas(X, y): return np.random.rand() def genPseudo(): return np.zeros(700) acount = 0 for i in range(100): if sigtest_pseudosessions(X, y, fStatMeas, genPseudo, npseuds=100)[0] < .1: acount += 1 self.assertTrue(acount <= 50)
[docs] def test_sigtest_linshift(self): X = np.zeros((200, 700)) y = np.zeros(700) def fStatMeas(X, y): return np.random.rand() with self.assertRaises(AssertionError): sigtest_linshift(X, y, fStatMeas, D=699) acount = 0 for i in range(100): if sigtest_linshift(X, y, fStatMeas, D=600)[0] < .1: acount += 1 self.assertTrue(acount <= 50)
if __name__ == "__main__": np.random.seed(0) unittest.main(exit=False)