Source code for ibllib.tests.extractors.test_extractors_base
import json
import unittest
import tempfile
from pathlib import Path
from unittest.mock import patch, MagicMock
from ibllib.io.extractors import base
[docs]
class TestExtractorMaps(unittest.TestCase):
"""Tests for functions that return Bpod extractor classes."""
[docs]
def setUp(self):
# Store original __import__
self.orig_import = __import__
tmp = tempfile.TemporaryDirectory()
self.addCleanup(tmp.cleanup)
self.custom_extractors_path = Path(tmp.name).joinpath('task_extractor_map.json')
self.custom_extractors = {'fooChoiceWorld': 'Bar'}
self.projects = MagicMock()
self.projects.base.__file__ = str(self.custom_extractors_path.with_name('__init__.py'))
with open(self.custom_extractors_path, 'w') as fp:
json.dump(self.custom_extractors, fp)
[docs]
def import_mock(self, name, *args):
"""Return mock for project_extraction imports."""
if name == 'projects' or name == 'projects.base':
return self.projects
return self.orig_import(name, *args)
[docs]
def test_get_task_extractor_map(self):
"""Test ibllib.io.extractors.base._get_task_extractor_map function."""
# Check the custom map is loaded
with patch('builtins.__import__', side_effect=self.import_mock):
extractors = base._get_task_extractor_map()
self.assertTrue(self.custom_extractors.items() < extractors.items())
# Test handles case where module not installed
with patch('builtins.__import__', side_effect=ModuleNotFoundError):
extractors = base._get_task_extractor_map()
self.assertFalse(set(self.custom_extractors.items()).issubset(set(extractors.items())))
# Remove the file and check exception is caught
self.custom_extractors_path.unlink()
extractors = base._get_task_extractor_map()
self.assertFalse(set(self.custom_extractors.items()).issubset(set(extractors.items())))
[docs]
def test_get_bpod_extractor_class(self):
"""Test ibllib.io.extractors.base.get_bpod_extractor_class function."""
# installe
# alf_path = self.custom_extractors_path.parent.joinpath('subject', '2020-01-01', '001', 'raw_task_data_00')
# alf_path.mkdir(parents=True)
settings_file = Path(__file__).parent.joinpath(
'data', 'session_biased_ge5', 'raw_behavior_data', '_iblrig_taskSettings.raw.json'
)
# shutil.copy(settings_file, alf_path)
session_path = settings_file.parents[1]
self.assertEqual('BiasedTrials', base.get_bpod_extractor_class(session_path))
session_path = str(session_path).replace('session_biased_ge5', 'session_training_ge5')
self.assertEqual('TrainingTrials', base.get_bpod_extractor_class(session_path))
session_path = str(session_path).replace('session_training_ge5', 'foobar')
self.assertRaises(ValueError, base.get_bpod_extractor_class, session_path)
[docs]
def test_protocol2extractor(self):
"""Test ibllib.io.extractors.base.protocol2extractor function."""
# Test fuzzy match
(proc, expected), = self.custom_extractors.items()
with patch('builtins.__import__', side_effect=self.import_mock):
extractor = base.protocol2extractor('_mw_' + proc)
self.assertEqual(expected, extractor)
# Test unknown protocol
self.assertRaises(ValueError, base.protocol2extractor, proc)
if __name__ == '__main__':
unittest.main()