[docs]defsetUp(self):# Test data is a dictionary of spike times and clusters and event times and groupspickle_file=Path(__file__).parent.joinpath('fixtures','ephys_test.p')ifnotpickle_file.exists():self.test_data=Noneelse:withopen(pickle_file,'rb')asf:self.test_data=pickle.load(f)
[docs]deftest_regress(self):ifself.test_dataisNone:returnspike_times=self.test_data['spike_times']spike_clusters=self.test_data['spike_clusters']event_times=self.test_data['event_times']event_groups=self.test_data['event_groups']cv=KFold(n_splits=2)times=np.column_stack(((event_times-0.5),(event_times+0.5)))counts,cluster_ids=get_spike_counts_in_bins(spike_times,spike_clusters,times)counts=counts.T# Test all regularization methods WITHOUT cross-validationpred=regress(counts,event_groups,cross_validation=None,return_training=False,regularization=None)self.assertEqual(pred.shape,event_groups.shape)pred=regress(counts,event_groups,cross_validation=None,return_training=False,regularization='L1')self.assertEqual(pred.shape,event_groups.shape)pred=regress(counts,event_groups,cross_validation=None,return_training=False,regularization='L2')self.assertEqual(pred.shape,event_groups.shape)# Test all regularization methods WITH cross-validationpred,pred_training=regress(counts,event_groups,cross_validation=cv,return_training=True,regularization=None)self.assertEqual(pred.shape,event_groups.shape)self.assertEqual(pred_training.shape,event_groups.shape)pred,pred_training=regress(counts,event_groups,cross_validation=cv,return_training=True,regularization='L1')self.assertEqual(pred.shape,event_groups.shape)self.assertEqual(pred_training.shape,event_groups.shape)pred,pred_training=regress(counts,event_groups,cross_validation=cv,return_training=True,regularization='L2')self.assertEqual(pred.shape,event_groups.shape)self.assertEqual(pred_training.shape,event_groups.shape)