import tempfile
from pathlib import Path
import unittest
from unittest import mock
from itertools import chain
import yaml
import ibllib.tests
import ibllib.pipes.dynamic_pipeline as dyn
import ibllib.pipes.behavior_tasks as btasks
from ibllib.tests.fixtures.utils import populate_task_settings
[docs]
class TestCreateLegacyAcqusitionDescriptions(unittest.TestCase):
[docs]
def test_legacy_biased(self):
session_path = Path(ibllib.tests.__file__).parent.joinpath('extractors', 'data', 'session_biased_ge5')
ad = dyn.acquisition_description_legacy_session(session_path)
protocols = list(chain(*map(dict.keys, ad.get('tasks', []))))
self.assertCountEqual(['biasedChoiceWorld'], protocols)
self.assertEqual(1, len(ad['devices']['cameras']))
[docs]
def test_legacy_ephys(self):
session_path = Path(ibllib.tests.__file__).parent.joinpath('extractors', 'data', 'session_ephys')
ad_ephys = dyn.acquisition_description_legacy_session(session_path)
self.assertEqual(2, len(ad_ephys['devices']['neuropixel']))
self.assertEqual(3, len(ad_ephys['devices']['cameras']))
protocols = list(chain(*map(dict.keys, ad_ephys.get('tasks', []))))
self.assertEqual(protocols, ['ephysChoiceWorld', 'passiveChoiceWorld'])
[docs]
def test_legacy_training(self):
session_path = Path(ibllib.tests.__file__).parent.joinpath('extractors', 'data', 'session_training_ge5')
ad = dyn.acquisition_description_legacy_session(session_path)
protocols = list(chain(*map(dict.keys, ad.get('tasks', []))))
self.assertCountEqual(['trainingChoiceWorld'], protocols)
self.assertEqual(1, len(ad['devices']['cameras']))
[docs]
class TestGetTrialsTasks(unittest.TestCase):
"""Test pipes.dynamic_pipeline.get_trials_tasks function."""
[docs]
def setUp(self):
tmpdir = tempfile.TemporaryDirectory()
self.addCleanup(tmpdir.cleanup)
# The github CI root dir contains an alias/symlink so we must resolve it
self.tempdir = Path(tmpdir.name).resolve()
self.session_path_dynamic = self.tempdir / 'subject' / '2023-01-01' / '001'
self.session_path_dynamic.mkdir(parents=True)
description = {'version': '1.0.0',
'sync': {'nidq': {'collection': 'raw_ephys_data', 'extension': 'bin', 'acquisition_software': 'spikeglx'}},
'tasks': [
{'ephysChoiceWorld': {'task_collection': 'raw_task_data_00'}},
{'passiveChoiceWorld': {'task_collection': 'raw_task_data_01'}},
]}
self.description = description
with open(self.session_path_dynamic / '_ibl_experiment.description.yaml', 'w') as fp:
yaml.safe_dump(description, fp)
self.session_path_legacy = self.session_path_dynamic.with_name('002')
(collection := self.session_path_legacy.joinpath('raw_behavior_data')).mkdir(parents=True)
self.settings = {'IBLRIG_VERSION': '7.2.2', 'PYBPOD_PROTOCOL': '_iblrig_tasks_ephysChoiceWorld'}
self.settings_path = populate_task_settings(collection, self.settings)
[docs]
def test_get_trials_tasks(self):
"""Test pipes.dynamic_pipeline.get_trials_tasks function."""
# A dynamic pipeline session
tasks = dyn.get_trials_tasks(self.session_path_dynamic)
self.assertEqual(2, len(tasks))
self.assertEqual('raw_task_data_00', tasks[0].collection)
# Check behaviour with ONE
one = mock.MagicMock()
one.offline = False
one.alyx = mock.MagicMock()
one.alyx.cache_mode = None # sneaky hack as this is checked by the pipeline somewhere
tasks = dyn.get_trials_tasks(self.session_path_dynamic, one)
self.assertEqual(2, len(tasks))
self.assertIsInstance(tasks[0], btasks.ChoiceWorldTrialsNidq)
one.load_datasets.assert_called() # check that description file is checked on disk
# A session with timeline acquisition
self.description['sync']['nidq']['acquisition_software'] = 'timeline'
with open(self.session_path_dynamic / '_ibl_experiment.description.yaml', 'w') as fp:
yaml.safe_dump(self.description, fp)
tasks = dyn.get_trials_tasks(self.session_path_dynamic, one)
self.assertIsInstance(tasks[0], btasks.ChoiceWorldTrialsTimeline)
# A session with an unknown sync namespace
self.description['sync']['nidq']['acquisition_software'] = 'notepad'
with open(self.session_path_dynamic / '_ibl_experiment.description.yaml', 'w') as fp:
yaml.safe_dump(self.description, fp)
with self.assertLogs(dyn.__name__, 'WARNING') as cm:
self.assertEqual([], dyn.get_trials_tasks(self.session_path_dynamic))
log_message = cm.records[0].getMessage()
self.assertIn('sync namespace "notepad"', log_message)
# An ephys session
tasks = dyn.get_trials_tasks(self.session_path_legacy)
self.assertEqual(2, len(tasks))
self.assertIsInstance(tasks[0], btasks.ChoiceWorldTrialsNidq)
# An ephys session with bpod_only flag
description = self.description
description['tasks'].pop() # remove passiveChoiceWorld task for this test; no Bpod extractor for this protocol
with mock.patch('ibllib.pipes.dynamic_pipeline._load_acquisition_description', return_value=description):
tasks = dyn.get_trials_tasks(self.session_path_dynamic, bpod_only=True)
self.assertEqual(1, len(tasks))
self.assertIsInstance(tasks[0], btasks.ChoiceWorldTrialsBpod)
# A training session
self.settings['PYBPOD_PROTOCOL'] = '_iblrig_tasks_trainingChoiceWorld'
populate_task_settings(self.settings_path, self.settings)
tasks = dyn.get_trials_tasks(self.session_path_legacy, one=one)
self.assertEqual(1, len(tasks))
self.assertIsInstance(tasks[0], btasks.ChoiceWorldTrialsBpod)
self.assertIs(tasks[0].one, one, 'failed to assign ONE instance to task')
# A personal project
self.settings['PYBPOD_PROTOCOL'] = '_misc_foobarChoiceWorld'
populate_task_settings(self.settings_path, self.settings)
m = mock.MagicMock() # Mock the project_extractors repo
m.base.__file__ = str(self.tempdir / 'base.py')
# Create the personal project extractor types map
task_type_map = {'_misc_foobarChoiceWorld': 'foobar'}
extractor_types_path = Path(m.base.__file__).parent.joinpath('extractor_types.json')
populate_task_settings(extractor_types_path, task_type_map)
self.assertRaises(ValueError, dyn.get_trials_tasks, self.session_path_legacy)
[docs]
class TestMisc(unittest.TestCase):
"""Test miscellaneous functions in pipes.dynamic_pipeline."""
[docs]
def test_sync_label(self):
"""Test pipes.dynamic_pipeline._sync_label function."""
self.assertEqual('nidq', dyn._sync_label('nidq'))
self.assertEqual('timeline', dyn._sync_label('nidq', acquisition_software='timeline'))
self.assertEqual('nidq', dyn._sync_label('nidq', acquisition_software='spikeglx'))
self.assertEqual('tdms', dyn._sync_label('tdms'))