Source code for one.tests.alf.test_cache

"""Unit tests for the one.alf.cache module."""
import unittest
import tempfile
from pathlib import Path
import shutil
import datetime
from uuid import uuid4

import pandas as pd
import numpy as np
from pandas.testing import assert_frame_equal

from iblutil.io import parquet
import one.alf.cache as apt
from one.tests.util import revisions_datasets_table


[docs] class TestONEParquet(unittest.TestCase): """Tests for the make_parquet_db function and its helpers.""" rel_ses_path = 'mylab/Subjects/mysub/2021-02-28/001/' ses_info = { 'id': 'mylab/mysub/2021-02-28/001', 'lab': 'mylab', 'subject': 'mysub', 'date': datetime.date.fromisoformat('2021-02-28'), 'number': int('001'), 'projects': '', 'task_protocol': '', } rel_ses_files = [Path('alf/spikes.clusters.npy'), Path('alf/spikes.times.npy')]
[docs] def setUp(self) -> None: pd.set_option('display.max_columns', 12) # root path: self.tmpdir = Path(tempfile.gettempdir()) / 'pqttest' self.tmpdir.mkdir(exist_ok=True) # full session path: self.full_ses_path = self.tmpdir / self.rel_ses_path (self.full_ses_path / 'alf').mkdir(exist_ok=True, parents=True) self.file_path = self.full_ses_path / 'alf/spikes.times.npy' self.file_path.write_text('mock') sc = self.full_ses_path / 'alf/spikes.clusters.npy' sc.write_text('mock2') # Create a second session containing an invalid dataset second_session = self.tmpdir.joinpath(self.rel_ses_path).parent.joinpath('002') second_session.mkdir() second_session.joinpath('trials.intervals.npy').touch() second_session.joinpath('.invalid').touch()
[docs] def test_parse(self): self.assertEqual(apt._get_session_info(self.rel_ses_path), tuple(self.ses_info.values())) self.assertTrue( self.full_ses_path.as_posix().endswith(self.rel_ses_path[:-1]))
[docs] def test_parquet(self): # Test data columns = ('colA', 'colB') rows = [('a1', 'b1'), ('a2', 'b2')] metadata = apt._metadata('dbname') filename = self.tmpdir.resolve() / 'mypqt.pqt' # Save parquet file. df = pd.DataFrame(rows, columns=columns) parquet.save(filename, df, metadata=metadata) # Load parquet file df2, metadata2 = parquet.load(filename) assert_frame_equal(df, df2) self.assertTrue(metadata == metadata2)
[docs] def test_sessions_df(self): df = apt._make_sessions_df(self.tmpdir) print('Sessions dataframe') print(df) self.assertEqual(df.loc[0].to_dict(), self.ses_info)
[docs] def test_datasets_df(self): df = apt._make_datasets_df(self.tmpdir) print('Datasets dataframe') print(df) dset_info = df.loc[0].to_dict() self.assertEqual(dset_info['rel_path'], self.rel_ses_files[0].as_posix()) self.assertTrue(dset_info['file_size'] > 0) self.assertFalse(df.rel_path.str.contains('invalid').any())
[docs] def tests_db(self): fn_ses, fn_dsets = apt.make_parquet_db(self.tmpdir, hash_ids=False) metadata_exp = apt._metadata(self.tmpdir.resolve()) df_ses, metadata = parquet.load(fn_ses) # Check sessions dataframe. self.assertEqual(metadata, metadata_exp) self.assertEqual(df_ses.loc[0].to_dict(), self.ses_info) # Check datasets dataframe. df_dsets, metadata2 = parquet.load(fn_dsets) self.assertEqual(metadata2, metadata_exp) dset_info = df_dsets.loc[0].to_dict() self.assertEqual(dset_info['rel_path'], self.rel_ses_files[0].as_posix()) # Check behaviour when no files found with tempfile.TemporaryDirectory() as tdir: with self.assertWarns(RuntimeWarning): fn_ses, fn_dsets = apt.make_parquet_db(tdir, hash_ids=False) self.assertTrue(parquet.load(fn_ses)[0].empty) self.assertTrue(parquet.load(fn_dsets)[0].empty) # Check labname arg with self.assertRaises(AssertionError): apt.make_parquet_db(self.tmpdir, hash_ids=False, lab='another') # Create some more datasets in a session folder outside of a lab directory with tempfile.TemporaryDirectory() as tdir: session_path = Path(tdir).joinpath('subject', '1900-01-01', '001') _ = revisions_datasets_table(touch_path=session_path) # create some files fn_ses, _ = apt.make_parquet_db(tdir, hash_ids=False, lab='another') df_ses, _ = parquet.load(fn_ses) self.assertTrue((df_ses['lab'] == 'another').all())
[docs] def test_hash_ids(self): # Build and load caches with int UUIDs (ses, _), (dsets, _) = map(parquet.load, apt.make_parquet_db(self.tmpdir, hash_ids=True)) # Check ID fields in both dataframes self.assertTrue(ses.index.nlevels == 1 and ses.index.name == 'id') self.assertTrue(dsets.index.nlevels == 2 and tuple(dsets.index.names) == ('eid', 'id'))
[docs] def test_remove_missing_datasets(self): # Add a session that will only contains missing datasets ghost_session = self.tmpdir.joinpath('lab', 'Subjects', 'sub', '2021-01-30', '001') ghost_session.mkdir(parents=True) tables = { 'sessions': apt._make_sessions_df(self.tmpdir), 'datasets': apt._make_datasets_df(self.tmpdir) } # Touch some files and folders for deletion empty_missing_session = self.tmpdir.joinpath(self.rel_ses_path).parent.joinpath('003') empty_missing_session.mkdir() missing_dataset = self.tmpdir.joinpath(self.rel_ses_path).joinpath('foo.bar.npy') missing_dataset.touch() ghost_dataset = ghost_session.joinpath('foo.bar.npy') ghost_dataset.touch() # Test dry to_remove = apt.remove_missing_datasets( self.tmpdir, tables=tables, dry=True, remove_empty_sessions=False ) self.assertTrue(all(map(Path.exists, to_remove)), 'Removed files during dry run!') self.assertTrue(all(map(Path.is_file, to_remove)), 'Failed to ignore empty folders') self.assertNotIn(empty_missing_session, to_remove, 'Failed to ignore empty folders') self.assertNotIn(next(self.tmpdir.rglob('.invalid')), to_remove, 'Removed non-ALF file') # Test removal of files and folders removed = apt.remove_missing_datasets( self.tmpdir, tables=tables, dry=False, remove_empty_sessions=True ) self.assertTrue(sum(map(Path.exists, to_remove)) == 0, 'Failed to remove all files') self.assertIn(empty_missing_session, removed, 'Failed to remove empty session folder') self.assertIn(missing_dataset, removed, 'Failed to remove missing dataset') self.assertIn(ghost_dataset, removed, 'Failed to remove missing dataset') self.assertNotIn(ghost_session, removed, 'Removed empty session that was in session table') # Check without tables input apt.make_parquet_db(self.tmpdir, hash_ids=False) removed = apt.remove_missing_datasets(self.tmpdir, dry=False) self.assertTrue(len(removed) == 0)
[docs] def tearDown(self) -> None: shutil.rmtree(self.tmpdir)
[docs] class TestONETables(unittest.TestCase): """Tests for the cache table functions."""
[docs] def test_merge_cache_tables(self): """Test merge_cache_tables function.""" fixture = Path(__file__).parents[1].joinpath('fixtures') caches = apt.load_tables(fixture) sessions_types = caches.sessions.reset_index().dtypes.to_dict() datasets_types = caches.datasets.reset_index().dtypes.to_dict() # Update with single record (pandas.Series), one exists, one doesn't session = caches.sessions.iloc[0].squeeze() session.name = uuid4() # New record dataset = caches.datasets.iloc[0].squeeze() dataset['exists'] = not dataset['exists'] apt.merge_tables(caches, sessions=session, datasets=dataset) self.assertTrue(session.name in caches.sessions.index) updated, = dataset['exists'] == caches.datasets.loc[dataset.name, 'exists'] self.assertTrue(updated) # Check that the updated data frame has kept its original dtypes types = caches.sessions.reset_index().dtypes.to_dict() self.assertDictEqual(sessions_types, types) types = caches.datasets.reset_index().dtypes.to_dict() self.assertDictEqual(datasets_types, types) # Update a number of records datasets = caches.datasets.iloc[:3].copy() datasets.loc[:, 'exists'] = ~datasets.loc[:, 'exists'] # Make one of the datasets a new record idx = datasets.index.values idx[-1] = (idx[-1][0], uuid4()) datasets.index = pd.MultiIndex.from_tuples(idx, names=('eid', 'id')) apt.merge_tables(caches, datasets=datasets) self.assertTrue(idx[-1] in caches.datasets.index) verifiable = caches.datasets.loc[datasets.index.values, 'exists'] self.assertTrue(np.all(verifiable == datasets.loc[:, 'exists'])) # Check that the updated data frame has kept its original dtypes types = caches.datasets.reset_index().dtypes.to_dict() self.assertDictEqual(datasets_types, types) # Check behaviour when columns don't match datasets.loc[:, 'exists'] = ~datasets.loc[:, 'exists'] datasets['extra_column'] = True caches.datasets['foo_bar'] = 12 # this column is missing in our new records caches.datasets['new_column'] = False expected_datasets_types = caches.datasets.reset_index().dtypes.to_dict() # An exception is exists_* as the Alyx cache contains exists_aws and exists_flatiron # These should simply be filled with the values of exists as Alyx won't return datasets # that don't exist on FlatIron and if they don't exist on AWS it falls back to this. caches.datasets['exists_aws'] = False with self.assertRaises(AssertionError): apt.merge_tables(caches, datasets=datasets, strict=True) apt.merge_tables(caches, datasets=datasets) verifiable = caches.datasets.loc[datasets.index.values, 'exists'] self.assertTrue(np.all(verifiable == datasets.loc[:, 'exists'])) apt.merge_tables(caches, datasets=datasets) verifiable = caches.datasets.loc[datasets.index.values, 'exists_aws'] self.assertTrue(np.all(verifiable == datasets.loc[:, 'exists'])) # If the extra column does not start with 'exists' it should be set to NaN verifiable = caches.datasets.loc[datasets.index.values, 'foo_bar'] self.assertTrue(np.isnan(verifiable).all()) # Check that the missing columns were updated to nullable fields expected_datasets_types.update( foo_bar=pd.Int64Dtype(), exists_aws=pd.BooleanDtype(), new_column=pd.BooleanDtype()) types = caches.datasets.reset_index().dtypes.to_dict() self.assertDictEqual(expected_datasets_types, types) # Check fringe cases with self.assertRaises(KeyError): apt.merge_tables(caches, unknown=datasets) self.assertIsNone(apt.merge_tables(caches, datasets=None)) # Absent cache table caches = apt.load_tables('/foo') sessions_types = caches.sessions.reset_index().dtypes.to_dict() datasets_types = caches.datasets.reset_index().dtypes.to_dict() apt.merge_tables(caches, sessions=session, datasets=dataset) self.assertTrue(all(caches.sessions == pd.DataFrame([session]))) self.assertEqual(1, len(caches.datasets)) self.assertEqual(caches.datasets.squeeze().name, dataset.name) self.assertCountEqual(caches.datasets.squeeze().to_dict(), dataset.to_dict()) types = caches.datasets.reset_index().dtypes.to_dict() self.assertDictEqual(datasets_types, types)
if __name__ == '__main__': unittest.main(exit=False)