[docs]defsetUp(self):""" Test data contains training data from 10 consecutive sessions from subject SWC_054. It is a dict of trials objects with each key indication a session date. By using data combinations from different dates can test each of the different training criterion a subject goes through in the IBL training pipeline """pickle_file=Path(__file__).parent.joinpath('fixtures','trials_test.pickle')ifnotpickle_file.exists():trial_data=Noneelse:withopen(pickle_file,'rb')asf:trial_data=pickle.load(f)self.trials=trial_data['2020-08-26']self.trials['goCue_times']=self.trials['stimOn_times']
[docs]deftest_find_trial_ids(self):# Test that default returns all trialsids,div=find_trial_ids(self.trials)expected_ids=np.arange(len(self.trials['probabilityLeft']))self.assertTrue(np.all(ids==expected_ids))# Test filtering by correctids,div=find_trial_ids(self.trials,choice='correct')expected_ids=np.where(self.trials['feedbackType']==1)[0]self.assertTrue(np.all(ids==expected_ids))# Test filtering by incorrectids,div=find_trial_ids(self.trials,choice='incorrect')expected_ids=np.where(self.trials['feedbackType']==-1)[0]self.assertTrue(np.all(ids==expected_ids))# Test filtering by leftids,div=find_trial_ids(self.trials,side='left')expected_ids=np.where(~np.isnan(self.trials['contrastLeft']))[0]self.assertTrue(np.all(ids==expected_ids))# Test filtering by rightids,div=find_trial_ids(self.trials,side='right')expected_ids=np.where(~np.isnan(self.trials['contrastRight']))[0]self.assertTrue(np.all(ids==expected_ids))# Test filtering by choice and side# right correctids,div=find_trial_ids(self.trials,side='right',choice='correct')righ_corr_expected_ids=np.where(np.bitwise_and(~np.isnan(self.trials['contrastRight']),self.trials['feedbackType']==1))[0]self.assertTrue(np.all(ids==righ_corr_expected_ids))# right incorrectids,div=find_trial_ids(self.trials,side='right',choice='incorrect')righ_incor_expected_ids=np.where(np.bitwise_and(~np.isnan(self.trials['contrastRight']),self.trials['feedbackType']==-1))[0]self.assertTrue(np.all(ids==righ_incor_expected_ids))# left correctids,div=find_trial_ids(self.trials,side='left',choice='correct')left_corr_expected_ids=np.where(np.bitwise_and(~np.isnan(self.trials['contrastLeft']),self.trials['feedbackType']==1))[0]self.assertTrue(np.all(ids==left_corr_expected_ids))# left incorrectids,div=find_trial_ids(self.trials,side='left',choice='incorrect')left_incorr_expected_ids=np.where(np.bitwise_and(~np.isnan(self.trials['contrastLeft']),self.trials['feedbackType']==-1))[0]self.assertTrue(np.all(ids==left_incorr_expected_ids))# Test sortingids,div=find_trial_ids(self.trials,sort='choice and side')expected_ids=np.r_[left_corr_expected_ids,left_incorr_expected_ids,righ_corr_expected_ids,righ_incor_expected_ids]self.assertTrue(np.all(ids==expected_ids))ids,div=find_trial_ids(self.trials,side='left',sort='choice')expected_ids=np.r_[left_corr_expected_ids,left_incorr_expected_ids]self.assertTrue(np.all(ids==expected_ids))ids,div=find_trial_ids(self.trials,side='left',sort='choice and side')self.assertTrue(np.all(ids==expected_ids))ids,div=find_trial_ids(self.trials,side='left',sort='side')expected_ids=np.where(~np.isnan(self.trials['contrastLeft']))[0]self.assertTrue(np.all(ids==expected_ids))# Test ordering by reaction timereaction_time=self.trials['response_times']-self.trials['goCue_times']expected_ids=np.argsort(reaction_time)ids,div=find_trial_ids(self.trials,order='reaction time')self.assertTrue(np.all(ids==expected_ids))ids,div=find_trial_ids(self.trials,side='left',choice='correct',order='reaction time')expected_ids=left_corr_expected_ids[np.argsort(reaction_time[left_corr_expected_ids])]self.assertTrue(np.all(ids==expected_ids))# Test contrastsids,div=find_trial_ids(self.trials,contrast=[1])expected_ids=np.sort(np.r_[np.where(self.trials['contrastLeft']==1)[0],np.where(self.trials['contrastRight']==1)[0]])self.assertTrue(np.all(ids==expected_ids))ids,div=find_trial_ids(self.trials,contrast=[0.0625,0],side='left')expected_ids=np.where(self.trials['contrastLeft']<=0.0625)[0]self.assertTrue(np.all(ids==expected_ids))
[docs]deftest_get_event_aligned_rasters(self):ts=1/3000spikes=np.arange(0,100,ts)use_trials=self.trials['stimOn_times'][self.trials['stimOn_times']<100]# Test for normal case where trials are within spike timesraster,t=get_event_aligned_raster(spikes,use_trials)self.assertEqual(raster.shape[0],len(use_trials))self.assertTrue(np.sum(np.isnan(raster))==0)# Test for the case where first trial/s is before first spike timespikes=np.arange(int(use_trials[0]+1),100,ts)raster,t=get_event_aligned_raster(spikes,use_trials)self.assertEqual(raster.shape[0],len(use_trials))self.assertTrue(np.all(np.isnan(raster[0,:])))self.assertTrue(np.all(~np.isnan(raster[1,:]).ravel()))spikes=np.arange(int(use_trials[4]+1),100,ts)raster,t=get_event_aligned_raster(spikes,use_trials)self.assertEqual(raster.shape[0],len(use_trials))self.assertTrue(np.all(np.isnan(raster[0:5,:]).ravel()))self.assertTrue(np.all(~np.isnan(raster[6,:]).ravel()))# Test for case where last trial/s is after last spike timespikes=np.arange(0,int(use_trials[-1]-1),ts)raster,t=get_event_aligned_raster(spikes,use_trials)self.assertEqual(raster.shape[0],len(use_trials))self.assertTrue(np.all(np.isnan(raster[-1,:])))self.assertTrue(np.all(~np.isnan(raster[-2,:])))spikes=np.arange(0,int(use_trials[-5]-1),ts)raster,t=get_event_aligned_raster(spikes,use_trials)self.assertEqual(raster.shape[0],len(use_trials))self.assertTrue(np.all(np.isnan(raster[-5:,:]).ravel()))self.assertTrue(np.all(~np.isnan(raster[-6,:]).ravel()))# Test for both before and afterspikes=np.arange(int(use_trials[4]+1),int(use_trials[-5]-1),ts)raster,t=get_event_aligned_raster(spikes,use_trials)self.assertEqual(raster.shape[0],len(use_trials))self.assertTrue(np.all(np.isnan(raster[0:5,:]).ravel()))self.assertTrue(np.all(np.isnan(raster[-5:,:]).ravel()))# Test when nans have trials - these are removed from the rasteruse_trials[10:12]=np.nanraster,t=get_event_aligned_raster(spikes,use_trials)self.assertEqual(raster.shape[0],len(use_trials))self.assertTrue(np.all(np.isnan(raster[10:12,:]).ravel()))self.assertTrue(np.all(~np.isnan(raster[12:15,:]).ravel()))use_trials[0:2]=np.nanraster,t=get_event_aligned_raster(spikes,use_trials)self.assertEqual(raster.shape[0],len(use_trials))self.assertTrue(np.all(np.isnan(raster[0:2,:]).ravel()))self.assertTrue(np.all(np.isnan(raster[-5:,:]).ravel()))