[docs]defsetUp(self):# Test ephys data is a dictionary of spike times and clusters and event times and groupspickle_file=Path(__file__).parent.joinpath('fixtures','ephys_test.p')assertpickle_file.exists()withopen(pickle_file,'rb')asf:self.test_data=pickle.load(f)# Test trials data is pandas dataframe with trialscsv_file=Path(__file__).parent.joinpath('fixtures','trials_df_test.csv')self.test_trials=pd.read_csv(csv_file)
[docs]deftest_get_impostor_target(self):# labels between 3 and 14labels=np.array([str(np.random.randint(12)+3)foriinrange(1000)])# targets with the same label are equaltargets=[np.ones((2,3,int(labels[i])))*int(labels[i])foriinrange(len(labels))]impostor_target=task.get_impostor_target(targets,labels,'3')self.assertTrue(impostor_target.shape[-1]==3)self.assertTrue(impostor_target.shape[0]==2)self.assertTrue(impostor_target.shape[1]==3)impostor_target=task.get_impostor_target(targets,labels,'14')self.assertTrue(impostor_target.shape[-1]==14)self.assertTrue(impostor_target.shape[0]==2)self.assertTrue(impostor_target.shape[1]==3)try:# assertion should be thrown because '2' is not a valid labelimpostor_target=task.get_impostor_target(targets,labels,'2')# code shouldn't make it hereself.assertTrue(False)exceptAssertionError:self.assertTrue(True)# seed should make output deterministicforiinrange(10):impostor_target1=task.get_impostor_target(targets,labels,seed_idx=i)impostor_target2=task.get_impostor_target(targets,labels,seed_idx=i)self.assertTrue(np.all(impostor_target1==impostor_target2))