Source code for brainbox.tests.test_trials

import unittest
import pickle
from pathlib import Path
from brainbox.task.trials import find_trial_ids, get_event_aligned_raster
import numpy as np


[docs] class TestTrials(unittest.TestCase):
[docs] def setUp(self): """ Test data contains training data from 10 consecutive sessions from subject SWC_054. It is a dict of trials objects with each key indication a session date. By using data combinations from different dates can test each of the different training criterion a subject goes through in the IBL training pipeline """ pickle_file = Path(__file__).parent.joinpath('fixtures', 'trials_test.pickle') if not pickle_file.exists(): trial_data = None else: with open(pickle_file, 'rb') as f: trial_data = pickle.load(f) self.trials = trial_data['2020-08-26'] self.trials['goCue_times'] = self.trials['stimOn_times']
[docs] def test_find_trial_ids(self): # Test that default returns all trials ids, div = find_trial_ids(self.trials) expected_ids = np.arange(len(self.trials['probabilityLeft'])) self.assertTrue(np.all(ids == expected_ids)) # Test filtering by correct ids, div = find_trial_ids(self.trials, choice='correct') expected_ids = np.where(self.trials['feedbackType'] == 1)[0] self.assertTrue(np.all(ids == expected_ids)) # Test filtering by incorrect ids, div = find_trial_ids(self.trials, choice='incorrect') expected_ids = np.where(self.trials['feedbackType'] == -1)[0] self.assertTrue(np.all(ids == expected_ids)) # Test filtering by left ids, div = find_trial_ids(self.trials, side='left') expected_ids = np.where(~np.isnan(self.trials['contrastLeft']))[0] self.assertTrue(np.all(ids == expected_ids)) # Test filtering by right ids, div = find_trial_ids(self.trials, side='right') expected_ids = np.where(~np.isnan(self.trials['contrastRight']))[0] self.assertTrue(np.all(ids == expected_ids)) # Test filtering by choice and side # right correct ids, div = find_trial_ids(self.trials, side='right', choice='correct') righ_corr_expected_ids = np.where(np.bitwise_and(~np.isnan(self.trials['contrastRight']), self.trials['feedbackType'] == 1))[0] self.assertTrue(np.all(ids == righ_corr_expected_ids)) # right incorrect ids, div = find_trial_ids(self.trials, side='right', choice='incorrect') righ_incor_expected_ids = np.where(np.bitwise_and(~np.isnan(self.trials['contrastRight']), self.trials['feedbackType'] == -1))[0] self.assertTrue(np.all(ids == righ_incor_expected_ids)) # left correct ids, div = find_trial_ids(self.trials, side='left', choice='correct') left_corr_expected_ids = np.where(np.bitwise_and(~np.isnan(self.trials['contrastLeft']), self.trials['feedbackType'] == 1))[0] self.assertTrue(np.all(ids == left_corr_expected_ids)) # left incorrect ids, div = find_trial_ids(self.trials, side='left', choice='incorrect') left_incorr_expected_ids = np.where(np.bitwise_and(~np.isnan(self.trials['contrastLeft']), self.trials['feedbackType'] == -1))[0] self.assertTrue(np.all(ids == left_incorr_expected_ids)) # Test sorting ids, div = find_trial_ids(self.trials, sort='choice and side') expected_ids = np.r_[left_corr_expected_ids, left_incorr_expected_ids, righ_corr_expected_ids, righ_incor_expected_ids] self.assertTrue(np.all(ids == expected_ids)) ids, div = find_trial_ids(self.trials, side='left', sort='choice') expected_ids = np.r_[left_corr_expected_ids, left_incorr_expected_ids] self.assertTrue(np.all(ids == expected_ids)) ids, div = find_trial_ids(self.trials, side='left', sort='choice and side') self.assertTrue(np.all(ids == expected_ids)) ids, div = find_trial_ids(self.trials, side='left', sort='side') expected_ids = np.where(~np.isnan(self.trials['contrastLeft']))[0] self.assertTrue(np.all(ids == expected_ids)) # Test ordering by reaction time reaction_time = self.trials['response_times'] - self.trials['goCue_times'] expected_ids = np.argsort(reaction_time) ids, div = find_trial_ids(self.trials, order='reaction time') self.assertTrue(np.all(ids == expected_ids)) ids, div = find_trial_ids(self.trials, side='left', choice='correct', order='reaction time') expected_ids = left_corr_expected_ids[np.argsort(reaction_time[left_corr_expected_ids])] self.assertTrue(np.all(ids == expected_ids)) # Test contrasts ids, div = find_trial_ids(self.trials, contrast=[1]) expected_ids = np.sort(np.r_[np.where(self.trials['contrastLeft'] == 1)[0], np.where(self.trials['contrastRight'] == 1)[0]]) self.assertTrue(np.all(ids == expected_ids)) ids, div = find_trial_ids(self.trials, contrast=[0.0625, 0], side='left') expected_ids = np.where(self.trials['contrastLeft'] <= 0.0625)[0] self.assertTrue(np.all(ids == expected_ids))
[docs] def test_get_event_aligned_rasters(self): ts = 1 / 3000 spikes = np.arange(0, 100, ts) use_trials = self.trials['stimOn_times'][self.trials['stimOn_times'] < 100] # Test for normal case where trials are within spike times raster, t = get_event_aligned_raster(spikes, use_trials) self.assertEqual(raster.shape[0], len(use_trials)) self.assertTrue(np.sum(np.isnan(raster)) == 0) # Test for the case where first trial/s is before first spike time spikes = np.arange(int(use_trials[0] + 1), 100, ts) raster, t = get_event_aligned_raster(spikes, use_trials) self.assertEqual(raster.shape[0], len(use_trials)) self.assertTrue(np.all(np.isnan(raster[0, :]))) self.assertTrue(np.all(~np.isnan(raster[1, :]).ravel())) spikes = np.arange(int(use_trials[4] + 1), 100, ts) raster, t = get_event_aligned_raster(spikes, use_trials) self.assertEqual(raster.shape[0], len(use_trials)) self.assertTrue(np.all(np.isnan(raster[0:5, :]).ravel())) self.assertTrue(np.all(~np.isnan(raster[6, :]).ravel())) # Test for case where last trial/s is after last spike time spikes = np.arange(0, int(use_trials[-1] - 1), ts) raster, t = get_event_aligned_raster(spikes, use_trials) self.assertEqual(raster.shape[0], len(use_trials)) self.assertTrue(np.all(np.isnan(raster[-1, :]))) self.assertTrue(np.all(~np.isnan(raster[-2, :]))) spikes = np.arange(0, int(use_trials[-5] - 1), ts) raster, t = get_event_aligned_raster(spikes, use_trials) self.assertEqual(raster.shape[0], len(use_trials)) self.assertTrue(np.all(np.isnan(raster[-5:, :]).ravel())) self.assertTrue(np.all(~np.isnan(raster[-6, :]).ravel())) # Test for both before and after spikes = np.arange(int(use_trials[4] + 1), int(use_trials[-5] - 1), ts) raster, t = get_event_aligned_raster(spikes, use_trials) self.assertEqual(raster.shape[0], len(use_trials)) self.assertTrue(np.all(np.isnan(raster[0:5, :]).ravel())) self.assertTrue(np.all(np.isnan(raster[-5:, :]).ravel())) # Test when nans have trials - these are removed from the raster use_trials[10:12] = np.nan raster, t = get_event_aligned_raster(spikes, use_trials) self.assertEqual(raster.shape[0], len(use_trials)) self.assertTrue(np.all(np.isnan(raster[10:12, :]).ravel())) self.assertTrue(np.all(~np.isnan(raster[12:15, :]).ravel())) use_trials[0:2] = np.nan raster, t = get_event_aligned_raster(spikes, use_trials) self.assertEqual(raster.shape[0], len(use_trials)) self.assertTrue(np.all(np.isnan(raster[0:2, :]).ravel())) self.assertTrue(np.all(np.isnan(raster[-5:, :]).ravel()))