"""Unit tests for the one.alf.cache module."""
import unittest
import tempfile
from pathlib import Path
import shutil
import datetime
from uuid import uuid4
import pandas as pd
import numpy as np
from pandas.testing import assert_frame_equal
from iblutil.io import parquet
import one.alf.cache as apt
from one.tests.util import revisions_datasets_table
[docs]
class TestONEParquet(unittest.TestCase):
    """Tests for the make_parquet_db function and its helpers."""
    rel_ses_path = 'mylab/Subjects/mysub/2021-02-28/001/'
    ses_info = {
        'id': 'mylab/mysub/2021-02-28/001',
        'lab': 'mylab',
        'subject': 'mysub',
        'date': datetime.date.fromisoformat('2021-02-28'),
        'number': int('001'),
        'projects': '',
        'task_protocol': '',
    }
    rel_ses_files = [Path('alf/spikes.clusters.npy'), Path('alf/spikes.times.npy')]
[docs]
    def setUp(self) -> None:
        pd.set_option('display.max_columns', 12)
        # root path:
        self.tmpdir = Path(tempfile.gettempdir()) / 'pqttest'
        self.tmpdir.mkdir(exist_ok=True)
        # full session path:
        self.full_ses_path = self.tmpdir / self.rel_ses_path
        (self.full_ses_path / 'alf').mkdir(exist_ok=True, parents=True)
        self.file_path = self.full_ses_path / 'alf/spikes.times.npy'
        self.file_path.write_text('mock')
        sc = self.full_ses_path / 'alf/spikes.clusters.npy'
        sc.write_text('mock2')
        # Create a second session containing an invalid dataset
        second_session = self.tmpdir.joinpath(self.rel_ses_path).parent.joinpath('002')
        second_session.mkdir()
        second_session.joinpath('trials.intervals.npy').touch()
        second_session.joinpath('.invalid').touch() 
[docs]
    def test_parse(self):
        self.assertEqual(apt._get_session_info(self.rel_ses_path), tuple(self.ses_info.values()))
        self.assertTrue(
            self.full_ses_path.as_posix().endswith(self.rel_ses_path[:-1])) 
[docs]
    def test_parquet(self):
        # Test data
        columns = ('colA', 'colB')
        rows = [('a1', 'b1'), ('a2', 'b2')]
        metadata = apt._metadata('dbname')
        filename = self.tmpdir.resolve() / 'mypqt.pqt'
        # Save parquet file.
        df = pd.DataFrame(rows, columns=columns)
        parquet.save(filename, df, metadata=metadata)
        # Load parquet file
        df2, metadata2 = parquet.load(filename)
        assert_frame_equal(df, df2)
        self.assertTrue(metadata == metadata2) 
[docs]
    def test_sessions_df(self):
        df = apt._make_sessions_df(self.tmpdir)
        print('Sessions dataframe')
        print(df)
        self.assertEqual(df.loc[0].to_dict(), self.ses_info) 
[docs]
    def test_datasets_df(self):
        df = apt._make_datasets_df(self.tmpdir)
        print('Datasets dataframe')
        print(df)
        dset_info = df.loc[0].to_dict()
        self.assertEqual(dset_info['rel_path'], self.rel_ses_files[0].as_posix())
        self.assertTrue(dset_info['file_size'] > 0)
        self.assertFalse(df.rel_path.str.contains('invalid').any()) 
[docs]
    def tests_db(self):
        fn_ses, fn_dsets = apt.make_parquet_db(self.tmpdir, hash_ids=False)
        metadata_exp = apt._metadata(self.tmpdir.resolve())
        df_ses, metadata = parquet.load(fn_ses)
        # Check sessions dataframe.
        self.assertEqual(metadata, metadata_exp)
        self.assertEqual(df_ses.loc[0].to_dict(), self.ses_info)
        # Check datasets dataframe.
        df_dsets, metadata2 = parquet.load(fn_dsets)
        self.assertEqual(metadata2, metadata_exp)
        dset_info = df_dsets.loc[0].to_dict()
        self.assertEqual(dset_info['rel_path'], self.rel_ses_files[0].as_posix())
        # Check behaviour when no files found
        with tempfile.TemporaryDirectory() as tdir:
            with self.assertWarns(RuntimeWarning):
                fn_ses, fn_dsets = apt.make_parquet_db(tdir, hash_ids=False)
            self.assertTrue(parquet.load(fn_ses)[0].empty)
            self.assertTrue(parquet.load(fn_dsets)[0].empty)
        # Check labname arg
        with self.assertRaises(AssertionError):
            apt.make_parquet_db(self.tmpdir, hash_ids=False, lab='another')
        # Create some more datasets in a session folder outside of a lab directory
        with tempfile.TemporaryDirectory() as tdir:
            session_path = Path(tdir).joinpath('subject', '1900-01-01', '001')
            _ = revisions_datasets_table(touch_path=session_path)  # create some files
            fn_ses, _ = apt.make_parquet_db(tdir, hash_ids=False, lab='another')
            df_ses, _ = parquet.load(fn_ses)
            self.assertTrue((df_ses['lab'] == 'another').all()) 
[docs]
    def test_hash_ids(self):
        # Build and load caches with int UUIDs
        (ses, _), (dsets, _) = map(parquet.load, apt.make_parquet_db(self.tmpdir, hash_ids=True))
        # Check ID fields in both dataframes
        self.assertTrue(ses.index.nlevels == 1 and ses.index.name == 'id')
        self.assertTrue(dsets.index.nlevels == 2 and tuple(dsets.index.names) == ('eid', 'id')) 
[docs]
    def test_remove_missing_datasets(self):
        # Add a session that will only contains missing datasets
        ghost_session = self.tmpdir.joinpath('lab', 'Subjects', 'sub', '2021-01-30', '001')
        ghost_session.mkdir(parents=True)
        tables = {
            'sessions': apt._make_sessions_df(self.tmpdir),
            'datasets': apt._make_datasets_df(self.tmpdir)
        }
        # Touch some files and folders for deletion
        empty_missing_session = self.tmpdir.joinpath(self.rel_ses_path).parent.joinpath('003')
        empty_missing_session.mkdir()
        missing_dataset = self.tmpdir.joinpath(self.rel_ses_path).joinpath('foo.bar.npy')
        missing_dataset.touch()
        ghost_dataset = ghost_session.joinpath('foo.bar.npy')
        ghost_dataset.touch()
        # Test dry
        to_remove = apt.remove_missing_datasets(
            self.tmpdir, tables=tables, dry=True, remove_empty_sessions=False
        )
        self.assertTrue(all(map(Path.exists, to_remove)), 'Removed files during dry run!')
        self.assertTrue(all(map(Path.is_file, to_remove)), 'Failed to ignore empty folders')
        self.assertNotIn(empty_missing_session, to_remove, 'Failed to ignore empty folders')
        self.assertNotIn(next(self.tmpdir.rglob('.invalid')), to_remove, 'Removed non-ALF file')
        # Test removal of files and folders
        removed = apt.remove_missing_datasets(
            self.tmpdir, tables=tables, dry=False, remove_empty_sessions=True
        )
        self.assertTrue(sum(map(Path.exists, to_remove)) == 0, 'Failed to remove all files')
        self.assertIn(empty_missing_session, removed, 'Failed to remove empty session folder')
        self.assertIn(missing_dataset, removed, 'Failed to remove missing dataset')
        self.assertIn(ghost_dataset, removed, 'Failed to remove missing dataset')
        self.assertNotIn(ghost_session, removed, 'Removed empty session that was in session table')
        # Check without tables input
        apt.make_parquet_db(self.tmpdir, hash_ids=False)
        removed = apt.remove_missing_datasets(self.tmpdir, dry=False)
        self.assertTrue(len(removed) == 0) 
[docs]
    def tearDown(self) -> None:
        shutil.rmtree(self.tmpdir) 
 
[docs]
class TestONETables(unittest.TestCase):
    """Tests for the cache table functions."""
[docs]
    def test_merge_cache_tables(self):
        """Test merge_cache_tables function."""
        fixture = Path(__file__).parents[1].joinpath('fixtures')
        caches = apt.load_tables(fixture)
        sessions_types = caches.sessions.reset_index().dtypes.to_dict()
        datasets_types = caches.datasets.reset_index().dtypes.to_dict()
        # Update with single record (pandas.Series), one exists, one doesn't
        session = caches.sessions.iloc[0].squeeze()
        session.name = uuid4()  # New record
        dataset = caches.datasets.iloc[0].squeeze()
        dataset['exists'] = not dataset['exists']
        apt.merge_tables(caches, sessions=session, datasets=dataset)
        self.assertTrue(session.name in caches.sessions.index)
        updated, = dataset['exists'] == caches.datasets.loc[dataset.name, 'exists']
        self.assertTrue(updated)
        # Check that the updated data frame has kept its original dtypes
        types = caches.sessions.reset_index().dtypes.to_dict()
        self.assertDictEqual(sessions_types, types)
        types = caches.datasets.reset_index().dtypes.to_dict()
        self.assertDictEqual(datasets_types, types)
        # Update a number of records
        datasets = caches.datasets.iloc[:3].copy()
        datasets.loc[:, 'exists'] = ~datasets.loc[:, 'exists']
        # Make one of the datasets a new record
        idx = datasets.index.values
        idx[-1] = (idx[-1][0], uuid4())
        datasets.index = pd.MultiIndex.from_tuples(idx, names=('eid', 'id'))
        apt.merge_tables(caches, datasets=datasets)
        self.assertTrue(idx[-1] in caches.datasets.index)
        verifiable = caches.datasets.loc[datasets.index.values, 'exists']
        self.assertTrue(np.all(verifiable == datasets.loc[:, 'exists']))
        # Check that the updated data frame has kept its original dtypes
        types = caches.datasets.reset_index().dtypes.to_dict()
        self.assertDictEqual(datasets_types, types)
        # Check behaviour when columns don't match
        datasets.loc[:, 'exists'] = ~datasets.loc[:, 'exists']
        datasets['extra_column'] = True
        caches.datasets['foo_bar'] = 12  # this column is missing in our new records
        caches.datasets['new_column'] = False
        expected_datasets_types = caches.datasets.reset_index().dtypes.to_dict()
        # An exception is exists_* as the Alyx cache contains exists_aws and exists_flatiron
        # These should simply be filled with the values of exists as Alyx won't return datasets
        # that don't exist on FlatIron and if they don't exist on AWS it falls back to this.
        caches.datasets['exists_aws'] = False
        with self.assertRaises(AssertionError):
            apt.merge_tables(caches, datasets=datasets, strict=True)
        apt.merge_tables(caches, datasets=datasets)
        verifiable = caches.datasets.loc[datasets.index.values, 'exists']
        self.assertTrue(np.all(verifiable == datasets.loc[:, 'exists']))
        apt.merge_tables(caches, datasets=datasets)
        verifiable = caches.datasets.loc[datasets.index.values, 'exists_aws']
        self.assertTrue(np.all(verifiable == datasets.loc[:, 'exists']))
        # If the extra column does not start with 'exists' it should be set to NaN
        verifiable = caches.datasets.loc[datasets.index.values, 'foo_bar']
        self.assertTrue(np.isnan(verifiable).all())
        # Check that the missing columns were updated to nullable fields
        expected_datasets_types.update(
            foo_bar=pd.Int64Dtype(), exists_aws=pd.BooleanDtype(), new_column=pd.BooleanDtype())
        types = caches.datasets.reset_index().dtypes.to_dict()
        self.assertDictEqual(expected_datasets_types, types)
        # Check fringe cases
        with self.assertRaises(KeyError):
            apt.merge_tables(caches, unknown=datasets)
        self.assertIsNone(apt.merge_tables(caches, datasets=None))
        # Absent cache table
        caches = apt.load_tables('/foo')
        sessions_types = caches.sessions.reset_index().dtypes.to_dict()
        datasets_types = caches.datasets.reset_index().dtypes.to_dict()
        apt.merge_tables(caches, sessions=session, datasets=dataset)
        self.assertTrue(all(caches.sessions == pd.DataFrame([session])))
        self.assertEqual(1, len(caches.datasets))
        self.assertEqual(caches.datasets.squeeze().name, dataset.name)
        self.assertCountEqual(caches.datasets.squeeze().to_dict(), dataset.to_dict())
        types = caches.datasets.reset_index().dtypes.to_dict()
        self.assertDictEqual(datasets_types, types) 
 
if __name__ == '__main__':
    unittest.main(exit=False)