Source code for ibllib.tests.extractors.test_extractors_base

import json
import unittest
import tempfile
from pathlib import Path
from unittest.mock import patch, MagicMock

from ibllib.io.extractors import base


[docs] class TestExtractorMaps(unittest.TestCase): """Tests for functions that return Bpod extractor classes."""
[docs] def setUp(self): # Store original __import__ self.orig_import = __import__ tmp = tempfile.TemporaryDirectory() self.addCleanup(tmp.cleanup) self.custom_extractors_path = Path(tmp.name).joinpath('task_extractor_map.json') self.custom_extractors = {'fooChoiceWorld': 'Bar'} self.projects = MagicMock() self.projects.base.__file__ = str(self.custom_extractors_path.with_name('__init__.py')) with open(self.custom_extractors_path, 'w') as fp: json.dump(self.custom_extractors, fp)
[docs] def import_mock(self, name, *args): """Return mock for project_extraction imports.""" if name == 'projects' or name == 'projects.base': return self.projects return self.orig_import(name, *args)
[docs] def test_get_task_extractor_map(self): """Test ibllib.io.extractors.base._get_task_extractor_map function.""" # Check the custom map is loaded with patch('builtins.__import__', side_effect=self.import_mock): extractors = base._get_task_extractor_map() self.assertTrue(self.custom_extractors.items() < extractors.items()) # Test handles case where module not installed with patch('builtins.__import__', side_effect=ModuleNotFoundError): extractors = base._get_task_extractor_map() self.assertFalse(set(self.custom_extractors.items()).issubset(set(extractors.items()))) # Remove the file and check exception is caught self.custom_extractors_path.unlink() extractors = base._get_task_extractor_map() self.assertFalse(set(self.custom_extractors.items()).issubset(set(extractors.items())))
[docs] def test_get_bpod_extractor_class(self): """Test ibllib.io.extractors.base.get_bpod_extractor_class function.""" # installe # alf_path = self.custom_extractors_path.parent.joinpath('subject', '2020-01-01', '001', 'raw_task_data_00') # alf_path.mkdir(parents=True) settings_file = Path(__file__).parent.joinpath( 'data', 'session_biased_ge5', 'raw_behavior_data', '_iblrig_taskSettings.raw.json' ) # shutil.copy(settings_file, alf_path) session_path = settings_file.parents[1] self.assertEqual('BiasedTrials', base.get_bpod_extractor_class(session_path)) session_path = str(session_path).replace('session_biased_ge5', 'session_training_ge5') self.assertEqual('TrainingTrials', base.get_bpod_extractor_class(session_path)) session_path = str(session_path).replace('session_training_ge5', 'foobar') self.assertRaises(ValueError, base.get_bpod_extractor_class, session_path)
[docs] def test_protocol2extractor(self): """Test ibllib.io.extractors.base.protocol2extractor function.""" # Test fuzzy match (proc, expected), = self.custom_extractors.items() with patch('builtins.__import__', side_effect=self.import_mock): extractor = base.protocol2extractor('_mw_' + proc) self.assertEqual(expected, extractor) # Test unknown protocol self.assertRaises(ValueError, base.protocol2extractor, proc)
if __name__ == '__main__': unittest.main()