import unittest
import pickle
from pathlib import Path
from brainbox.task.trials import find_trial_ids, get_event_aligned_raster
import numpy as np
[docs]
class TestTrials(unittest.TestCase):
[docs]
def setUp(self):
"""
Test data contains training data from 10 consecutive sessions from subject SWC_054. It is
a dict of trials objects with each key indication a session date. By using data
combinations from different dates can test each of the different training criterion a
subject goes through in the IBL training pipeline
"""
pickle_file = Path(__file__).parent.joinpath('fixtures', 'trials_test.pickle')
if not pickle_file.exists():
trial_data = None
else:
with open(pickle_file, 'rb') as f:
trial_data = pickle.load(f)
self.trials = trial_data['2020-08-26']
self.trials['goCue_times'] = self.trials['stimOn_times']
[docs]
def test_find_trial_ids(self):
# Test that default returns all trials
ids, div = find_trial_ids(self.trials)
expected_ids = np.arange(len(self.trials['probabilityLeft']))
self.assertTrue(np.all(ids == expected_ids))
# Test filtering by correct
ids, div = find_trial_ids(self.trials, choice='correct')
expected_ids = np.where(self.trials['feedbackType'] == 1)[0]
self.assertTrue(np.all(ids == expected_ids))
# Test filtering by incorrect
ids, div = find_trial_ids(self.trials, choice='incorrect')
expected_ids = np.where(self.trials['feedbackType'] == -1)[0]
self.assertTrue(np.all(ids == expected_ids))
# Test filtering by left
ids, div = find_trial_ids(self.trials, side='left')
expected_ids = np.where(~np.isnan(self.trials['contrastLeft']))[0]
self.assertTrue(np.all(ids == expected_ids))
# Test filtering by right
ids, div = find_trial_ids(self.trials, side='right')
expected_ids = np.where(~np.isnan(self.trials['contrastRight']))[0]
self.assertTrue(np.all(ids == expected_ids))
# Test filtering by choice and side
# right correct
ids, div = find_trial_ids(self.trials, side='right', choice='correct')
righ_corr_expected_ids = np.where(np.bitwise_and(~np.isnan(self.trials['contrastRight']),
self.trials['feedbackType'] == 1))[0]
self.assertTrue(np.all(ids == righ_corr_expected_ids))
# right incorrect
ids, div = find_trial_ids(self.trials, side='right', choice='incorrect')
righ_incor_expected_ids = np.where(np.bitwise_and(~np.isnan(self.trials['contrastRight']),
self.trials['feedbackType'] == -1))[0]
self.assertTrue(np.all(ids == righ_incor_expected_ids))
# left correct
ids, div = find_trial_ids(self.trials, side='left', choice='correct')
left_corr_expected_ids = np.where(np.bitwise_and(~np.isnan(self.trials['contrastLeft']),
self.trials['feedbackType'] == 1))[0]
self.assertTrue(np.all(ids == left_corr_expected_ids))
# left incorrect
ids, div = find_trial_ids(self.trials, side='left', choice='incorrect')
left_incorr_expected_ids = np.where(np.bitwise_and(~np.isnan(self.trials['contrastLeft']),
self.trials['feedbackType'] == -1))[0]
self.assertTrue(np.all(ids == left_incorr_expected_ids))
# Test sorting
ids, div = find_trial_ids(self.trials, sort='choice and side')
expected_ids = np.r_[left_corr_expected_ids, left_incorr_expected_ids,
righ_corr_expected_ids, righ_incor_expected_ids]
self.assertTrue(np.all(ids == expected_ids))
ids, div = find_trial_ids(self.trials, side='left', sort='choice')
expected_ids = np.r_[left_corr_expected_ids, left_incorr_expected_ids]
self.assertTrue(np.all(ids == expected_ids))
ids, div = find_trial_ids(self.trials, side='left', sort='choice and side')
self.assertTrue(np.all(ids == expected_ids))
ids, div = find_trial_ids(self.trials, side='left', sort='side')
expected_ids = np.where(~np.isnan(self.trials['contrastLeft']))[0]
self.assertTrue(np.all(ids == expected_ids))
# Test ordering by reaction time
reaction_time = self.trials['response_times'] - self.trials['goCue_times']
expected_ids = np.argsort(reaction_time)
ids, div = find_trial_ids(self.trials, order='reaction time')
self.assertTrue(np.all(ids == expected_ids))
ids, div = find_trial_ids(self.trials, side='left', choice='correct',
order='reaction time')
expected_ids = left_corr_expected_ids[np.argsort(reaction_time[left_corr_expected_ids])]
self.assertTrue(np.all(ids == expected_ids))
# Test contrasts
ids, div = find_trial_ids(self.trials, contrast=[1])
expected_ids = np.sort(np.r_[np.where(self.trials['contrastLeft'] == 1)[0],
np.where(self.trials['contrastRight'] == 1)[0]])
self.assertTrue(np.all(ids == expected_ids))
ids, div = find_trial_ids(self.trials, contrast=[0.0625, 0], side='left')
expected_ids = np.where(self.trials['contrastLeft'] <= 0.0625)[0]
self.assertTrue(np.all(ids == expected_ids))
[docs]
def test_get_event_aligned_rasters(self):
ts = 1 / 3000
spikes = np.arange(0, 100, ts)
use_trials = self.trials['stimOn_times'][self.trials['stimOn_times'] < 100]
# Test for normal case where trials are within spike times
raster, t = get_event_aligned_raster(spikes, use_trials)
self.assertEqual(raster.shape[0], len(use_trials))
self.assertTrue(np.sum(np.isnan(raster)) == 0)
# Test for the case where first trial/s is before first spike time
spikes = np.arange(int(use_trials[0] + 1), 100, ts)
raster, t = get_event_aligned_raster(spikes, use_trials)
self.assertEqual(raster.shape[0], len(use_trials))
self.assertTrue(np.all(np.isnan(raster[0, :])))
self.assertTrue(np.all(~np.isnan(raster[1, :]).ravel()))
spikes = np.arange(int(use_trials[4] + 1), 100, ts)
raster, t = get_event_aligned_raster(spikes, use_trials)
self.assertEqual(raster.shape[0], len(use_trials))
self.assertTrue(np.all(np.isnan(raster[0:5, :]).ravel()))
self.assertTrue(np.all(~np.isnan(raster[6, :]).ravel()))
# Test for case where last trial/s is after last spike time
spikes = np.arange(0, int(use_trials[-1] - 1), ts)
raster, t = get_event_aligned_raster(spikes, use_trials)
self.assertEqual(raster.shape[0], len(use_trials))
self.assertTrue(np.all(np.isnan(raster[-1, :])))
self.assertTrue(np.all(~np.isnan(raster[-2, :])))
spikes = np.arange(0, int(use_trials[-5] - 1), ts)
raster, t = get_event_aligned_raster(spikes, use_trials)
self.assertEqual(raster.shape[0], len(use_trials))
self.assertTrue(np.all(np.isnan(raster[-5:, :]).ravel()))
self.assertTrue(np.all(~np.isnan(raster[-6, :]).ravel()))
# Test for both before and after
spikes = np.arange(int(use_trials[4] + 1), int(use_trials[-5] - 1), ts)
raster, t = get_event_aligned_raster(spikes, use_trials)
self.assertEqual(raster.shape[0], len(use_trials))
self.assertTrue(np.all(np.isnan(raster[0:5, :]).ravel()))
self.assertTrue(np.all(np.isnan(raster[-5:, :]).ravel()))
# Test when nans have trials - these are removed from the raster
use_trials[10:12] = np.nan
raster, t = get_event_aligned_raster(spikes, use_trials)
self.assertEqual(raster.shape[0], len(use_trials))
self.assertTrue(np.all(np.isnan(raster[10:12, :]).ravel()))
self.assertTrue(np.all(~np.isnan(raster[12:15, :]).ravel()))
use_trials[0:2] = np.nan
raster, t = get_event_aligned_raster(spikes, use_trials)
self.assertEqual(raster.shape[0], len(use_trials))
self.assertTrue(np.all(np.isnan(raster[0:2, :]).ravel()))
self.assertTrue(np.all(np.isnan(raster[-5:, :]).ravel()))