import unittest
from unittest.mock import patch
import tempfile
from pathlib import Path
from functools import partial
import matplotlib.pyplot as plt
import numpy as np
from one.api import ONE
from one.registration import RegistrationClient
from ibllib.oneibl.data_handlers import ExpectedDataset
from ibllib.pipes import base_tasks
from ibllib.pipes.behavior_tasks import ChoiceWorldTrialsBpod
from ibllib.tests import TEST_DB
[docs]
class TestRegisterRawDataTask(unittest.TestCase):
tmpdir = None
one = None
session_path = None
eid = None
[docs]
@classmethod
def setUpClass(cls) -> None:
cls.tmpdir = tempfile.TemporaryDirectory()
cls.one = ONE(**TEST_DB, cache_rest=None)
ses_dict = {
'subject': 'algernon',
'start_time': RegistrationClient.ensure_ISO8601(None),
'number': 1,
'users': ['test_user']}
ses = cls.one.alyx.rest('sessions', 'create', data=ses_dict)
cls.session_path = Path(cls.tmpdir.name).joinpath(
ses['subject'], ses['start_time'][:10], str(ses['number']).zfill(3))
cls.eid = ses['url'][-36:]
# Add a couple of images
cls.session_path.joinpath('snapshots').mkdir(parents=True)
for i, ext in enumerate(('.PNG', '.tif')):
plt.imshow(np.random.random((7, 7)))
plt.savefig(cls.session_path.joinpath('snapshots', f'foo_{i}').with_suffix(ext))
plt.close()
[docs]
def test_register_snapshots(self):
"""Test ibllib.pipes.base_tasks.RegisterRawDataTask.register_snapshots.
A more thorough test for this exists in ibllib.tests.test_pipes.TestRegisterRawDataTask.
This test does not mock REST (and therefore requires a test database), while the other does.
This test also works on actual image data, testing the conversion from tif to png.
"""
task = base_tasks.RegisterRawDataTask(self.session_path, one=self.one)
notes = task.register_snapshots()
self.assertEqual(2, len(notes))
self.assertTrue(self.session_path.joinpath('snapshots').exists())
task.register_snapshots(unlink=True)
self.assertFalse(self.session_path.joinpath('snapshots').exists())
self.assertTrue(all(n['image'].lower().endswith('.png') for n in notes), 'failed to convert tif to png')
[docs]
def test_rename_files(self):
collection = 'raw_sync_data'
task = base_tasks.RegisterRawDataTask(self.session_path, one=self.one)
task.input_files = task.output_files = []
task.rename_files() # Returns without raising
I = ExpectedDataset.input # noqa
task.input_files = [I('foo.*', collection, True), ]
task.output_files = [I('_ns_DAQdata.raw.bar', collection, True), ]
self.session_path.joinpath(collection).mkdir()
self.session_path.joinpath(collection, 'foo.bar').touch()
task.rename_files()
self.assertTrue(self.session_path.joinpath(collection, '_ns_DAQdata.raw.bar').exists())
self.assertFalse(self.session_path.joinpath(collection, 'foo.bar').exists())
with self.assertRaises(FileNotFoundError):
task.rename_files()
# Check asserts number of inputs == number of outputs
task.output_files.append(I('_ns_DAQdata.baz.bar', collection, True),)
with self.assertRaises(AssertionError):
task.rename_files()
[docs]
@classmethod
def tearDownClass(cls) -> None:
if cls.tmpdir:
cls.tmpdir.cleanup()
if cls.one and cls.eid:
cls.one.alyx.rest('sessions', 'delete', id=cls.eid)
[docs]
class TestBehaviourTask(unittest.TestCase):
[docs]
def test_spacer_support(self) -> None:
"""Test for BehaviourTask._spacer_support method."""
to_test = [('100.0.0', False), ('8.0.0', False), ('7.1.0', True), ('8.0.1', True), ('7.2.0', True)]
settings = {}
spacer_support = partial(base_tasks.BehaviourTask._spacer_support, settings)
for version, expected in to_test:
settings['IBLRIG_VERSION'] = version
with self.subTest(version):
self.assertIs(spacer_support(), expected)
[docs]
def test_get_task_collection(self) -> None:
"""Test for BehaviourTask.get_task_collection method."""
params = {'tasks': [{'fooChoiceWorld': {'collection': 'raw_task_data_00'}}]}
task = ChoiceWorldTrialsBpod('')
self.assertIsNone(task.get_task_collection())
task.session_params = params
self.assertEqual('raw_task_data_00', task.get_task_collection())
params['tasks'].append({'barChoiceWorld': {'collection': 'raw_task_data_01'}})
self.assertRaises(AssertionError, task.get_task_collection)
self.assertEqual('raw_task_data_02', task.get_task_collection('raw_task_data_02'))
[docs]
def test_get_protocol(self) -> None:
"""Test for BehaviourTask.get_protocol method."""
task = ChoiceWorldTrialsBpod('')
self.assertIsNone(task.get_protocol())
self.assertEqual('foobar', task.get_protocol(protocol='foobar'))
task.session_params = {'tasks': [{'fooChoiceWorld': {'collection': 'raw_task_data_00'}}]}
self.assertEqual('fooChoiceWorld', task.get_protocol())
task.session_params['tasks'].append({'barChoiceWorld': {'collection': 'raw_task_data_01'}})
self.assertRaises(ValueError, task.get_protocol)
self.assertEqual('barChoiceWorld', task.get_protocol(task_collection='raw_task_data_01'))
self.assertIsNone(task.get_protocol(task_collection='raw_behavior_data'))
[docs]
def test_get_protocol_number(self) -> None:
"""Test for BehaviourTask.get_protocol_number method."""
params = {'tasks': [
{'fooChoiceWorld': {'collection': 'raw_task_data_00', 'protocol_number': 0}},
{'barChoiceWorld': {'collection': 'raw_task_data_01', 'protocol_number': 1}}
]}
task = ChoiceWorldTrialsBpod('')
self.assertIsNone(task.get_protocol_number())
self.assertRaises(ValueError, task.get_protocol_number, number='foo')
self.assertEqual(1, task.get_protocol_number(number=1))
task.session_params = params
self.assertRaises(AssertionError, task.get_protocol_number)
for i, proc in enumerate(('fooChoiceWorld', 'barChoiceWorld')):
self.assertEqual(i, task.get_protocol_number(task_protocol=proc))
[docs]
def test_assert_trials_data(self):
"""Test for BehaviourTask._assert_trials_data method."""
task = ChoiceWorldTrialsBpod('')
trials_data = {'foo': [1, 2, 3]}
def _set(**_):
task.extractor = True # set extractor attribute
return trials_data, None
with patch.object(task, 'extract_behaviour', side_effect=_set) as mock:
# Trials data but no extractor
self.assertEqual(trials_data, task._assert_trials_data(trials_data))
mock.assert_called_with(save=False)
with patch.object(task, 'extract_behaviour', return_value=(trials_data, None)) as mock:
# Extractor but no trials data
self.assertEqual(trials_data, task._assert_trials_data(None))
mock.assert_called_with(save=False)
# Returns no trials
mock.return_value = (None, None)
self.assertRaises(ValueError, task._assert_trials_data)
with patch.object(task, 'extract_behaviour', return_value=(trials_data, None)) as mock:
# Both extractor and trials data
self.assertEqual(trials_data, task._assert_trials_data(trials_data))
mock.assert_not_called()
if __name__ == '__main__':
unittest.main()