"""Decorators and small standalone functions for api module."""
import logging
import urllib.parse
from functools import wraps
from typing import Sequence, Union, Iterable, Optional, List
from collections.abc import Mapping
import fnmatch
from datetime import datetime
import pandas as pd
from iblutil.io import parquet
import numpy as np
from packaging import version
import one.alf.exceptions as alferr
from one.alf.files import rel_path_parts, get_session_path, get_alf_path, remove_uuid_string
from one.alf.spec import QC, FILE_SPEC, regex as alf_regex
logger = logging.getLogger(__name__)
QC_TYPE = pd.CategoricalDtype(categories=[e.name for e in sorted(QC)], ordered=True)
"""pandas.api.types.CategoricalDtype: The cache table QC column data type."""
[docs]
def Listable(t):
"""Return a typing.Union if the input and sequence of input."""
return Union[t, Sequence[t]]
[docs]
def ses2records(ses: dict):
"""Extract session cache record and datasets cache from a remote session data record.
Parameters
----------
ses : dict
Session dictionary from Alyx REST endpoint.
Returns
-------
pd.Series
Session record.
pd.DataFrame
Datasets frame.
"""
# Extract session record
eid = ses['url'][-36:]
session_keys = ('subject', 'start_time', 'lab', 'number', 'task_protocol', 'projects')
session_data = {k: v for k, v in ses.items() if k in session_keys}
session = (
pd.Series(data=session_data, name=eid).rename({'start_time': 'date'})
)
session['projects'] = ','.join(session.pop('projects'))
session['date'] = datetime.fromisoformat(session['date']).date()
# Extract datasets table
def _to_record(d):
rec = dict(file_size=d['file_size'], hash=d['hash'], exists=True, id=d['id'])
rec['eid'] = session.name
file_path = urllib.parse.urlsplit(d['data_url'], allow_fragments=False).path.strip('/')
file_path = get_alf_path(remove_uuid_string(file_path))
rec['session_path'] = get_session_path(file_path).as_posix()
rec['rel_path'] = file_path[len(rec['session_path']):].strip('/')
rec['default_revision'] = d['default_revision'] == 'True'
rec['qc'] = d.get('qc', 'NOT_SET')
return rec
if not ses.get('data_dataset_session_related'):
return session, pd.DataFrame()
records = map(_to_record, ses['data_dataset_session_related'])
index = ['eid', 'id']
datasets = pd.DataFrame(records).set_index(index).sort_index().astype({'qc': QC_TYPE})
return session, datasets
[docs]
def datasets2records(datasets, additional=None) -> pd.DataFrame:
"""Extract datasets DataFrame from one or more Alyx dataset records.
Parameters
----------
datasets : dict, list
One or more records from the Alyx 'datasets' endpoint.
additional : list of str
A set of optional fields to extract from dataset records.
Returns
-------
pd.DataFrame
Datasets frame.
Examples
--------
>>> datasets = ONE().alyx.rest('datasets', 'list', subject='foobar')
>>> df = datasets2records(datasets)
"""
records = []
for d in ensure_list(datasets):
file_record = next((x for x in d['file_records'] if x['data_url'] and x['exists']), None)
if not file_record:
continue # Ignore files that are not accessible
rec = dict(file_size=d['file_size'], hash=d['hash'], exists=True)
rec['id'] = d['url'][-36:]
rec['eid'] = (d['session'] or '')[-36:]
data_url = urllib.parse.urlsplit(file_record['data_url'], allow_fragments=False)
file_path = get_alf_path(data_url.path.strip('/'))
file_path = remove_uuid_string(file_path).as_posix()
rec['session_path'] = get_session_path(file_path) or ''
if rec['session_path']:
rec['session_path'] = rec['session_path'].as_posix()
rec['rel_path'] = file_path[len(rec['session_path']):].strip('/')
rec['default_revision'] = d['default_dataset']
rec['qc'] = d.get('qc')
for field in additional or []:
rec[field] = d.get(field)
records.append(rec)
index = ['eid', 'id']
if not records:
keys = (*index, 'file_size', 'hash', 'session_path', 'rel_path', 'default_revision', 'qc')
return pd.DataFrame(columns=keys).set_index(index)
return pd.DataFrame(records).set_index(index).sort_index().astype({'qc': QC_TYPE})
[docs]
def parse_id(method):
"""
Ensures the input experiment identifier is an experiment UUID string.
Parameters
----------
method : function
An ONE method whose second arg is an experiment ID.
Returns
-------
function
A wrapper function that parses the ID to the expected string.
Raises
------
ValueError
Unable to convert input to a valid experiment ID.
"""
@wraps(method)
def wrapper(self, id, *args, **kwargs):
eid = self.to_eid(id)
if eid is None:
raise ValueError(f'Cannot parse session ID "{id}" (session may not exist)')
return method(self, eid, *args, **kwargs)
return wrapper
[docs]
def refresh(method):
"""Refresh cache depending on query_type kwarg."""
@wraps(method)
def wrapper(self, *args, **kwargs):
mode = kwargs.get('query_type', None)
if not mode or mode == 'auto':
mode = self.mode
self.refresh_cache(mode=mode)
return method(self, *args, **kwargs)
return wrapper
[docs]
def validate_date_range(date_range) -> (pd.Timestamp, pd.Timestamp):
"""
Validates and arrange date range in a 2 elements list.
Parameters
----------
date_range : str, datetime.date, datetime.datetime, pd.Timestamp, np.datetime64, list, None
A single date or tuple/list of two dates. None represents no bound.
Returns
-------
tuple of pd.Timestamp
The start and end timestamps.
Examples
--------
>>> validate_date_range('2020-01-01') # On this day
>>> validate_date_range(datetime.date(2020, 1, 1))
>>> validate_date_range(np.array(['2022-01-30', '2022-01-30'], dtype='datetime64[D]'))
>>> validate_date_range(pd.Timestamp(2020, 1, 1))
>>> validate_date_range(np.datetime64(2021, 3, 11))
>>> validate_date_range(['2020-01-01']) # from date
>>> validate_date_range(['2020-01-01', None]) # from date
>>> validate_date_range([None, '2020-01-01']) # up to date
Raises
------
ValueError
Size of date range tuple must be 1 or 2.
"""
if date_range is None:
return
# Ensure we have exactly two values
if isinstance(date_range, str) or not isinstance(date_range, Iterable):
# date_range = (date_range, pd.Timestamp(date_range) + pd.Timedelta(days=1))
dt = pd.Timedelta(days=1) - pd.Timedelta(milliseconds=1)
date_range = (date_range, pd.Timestamp(date_range) + dt)
elif len(date_range) == 1:
date_range = [date_range[0], pd.Timestamp.max]
elif len(date_range) != 2:
raise ValueError
# For comparisons, ensure both values are pd.Timestamp (datetime, date and datetime64
# objects will be converted)
start, end = date_range
start = start or pd.Timestamp.min # Convert None to lowest possible date
end = end or pd.Timestamp.max # Convert None to highest possible date
# Convert to timestamp
if not isinstance(start, pd.Timestamp):
start = pd.Timestamp(start)
if not isinstance(end, pd.Timestamp):
end = pd.Timestamp(end)
return start, end
def _collection_spec(collection=None, revision=None) -> str:
"""
Return a template string for a collection/revision regular expression. Because both are
optional in the ALF spec, None will match any (including absent), while an empty string will
match absent.
Parameters
----------
collection : None, str
An optional collection regular expression.
revision : None, str
An optional revision regular expression.
Returns
-------
str
A string format for matching the collection/revision.
"""
spec = ''
for value, default in zip((collection, revision), ('{collection}/', '#{revision}#/')):
if not value:
default = f'({default})?' if value is None else ''
spec += default
return spec
def _file_spec(**kwargs):
"""
Return a template string for a ALF dataset regular expression. Because 'namespace',
'timescale', and 'extra' are optional None will match any (including absent). This function
removes the regex flags from the file spec string that make certain parts optional.
TODO an empty string should only match absent; this could be achieved by removing parts from
spec string
Parameters
----------
namespace : None, str
If namespace is not None, the namespace section of the returned file spec will not be
optional.
timescale : None, str
If timescale is not None, the namespace section of the returned file spec will not be
optional.
extra : None, str
If extra is not None, the namespace section of the returned file spec will not be
optional.
Returns
-------
str
A string format for matching an ALF dataset.
"""
OPTIONAL = {'namespace': '?', 'timescale': '?', 'extra': '*'}
filespec = FILE_SPEC
for k, v in kwargs.items():
if k in OPTIONAL and v is not None:
i = filespec.find(k) + len(k)
i += filespec[i:].find(OPTIONAL[k])
filespec = filespec[:i] + filespec[i:].replace(OPTIONAL[k], '', 1)
return filespec
[docs]
def filter_datasets(
all_datasets, filename=None, collection=None, revision=None, revision_last_before=True,
qc=QC.FAIL, ignore_qc_not_set=False, assert_unique=True, wildcards=False):
"""
Filter the datasets cache table by the relative path (dataset name, collection and revision).
When None is passed, all values will match. To match on empty parts, use an empty string.
When revision_last_before is true, None means return latest revision.
Parameters
----------
all_datasets : pandas.DataFrame
A datasets cache table.
filename : str, dict, None
A filename str or a dict of alf parts. Regular expressions permitted.
collection : str, None
A collection string. Regular expressions permitted.
revision : str, None
A revision string to match. If revision_last_before is true, regular expressions are
not permitted.
revision_last_before : bool
When true and no exact match exists, the (lexicographically) previous revision is used
instead. When false the revision string is matched like collection and filename,
with regular expressions permitted.
qc : str, int, one.alf.spec.QC
Returns datasets at or below this QC level. Integer values should correspond to the QC
enumeration NOT the qc category column codes in the pandas table.
ignore_qc_not_set : bool
When true, do not return datasets for which QC is NOT_SET.
assert_unique : bool
When true an error is raised if multiple collections or datasets are found.
wildcards : bool
If true, use unix shell style matching instead of regular expressions.
Returns
-------
pd.DataFrame
A slice of all_datasets that match the filters.
Examples
--------
Filter by dataset name and collection
>>> datasets = filter_datasets(all_datasets, '.*spikes.times.*', 'alf/probe00')
Filter datasets not in a collection
>>> datasets = filter_datasets(all_datasets, collection='')
Filter by matching revision
>>> datasets = filter_datasets(all_datasets, 'spikes.times.npy',
... revision='2020-01-12', revision_last_before=False)
Filter by filename parts
>>> datasets = filter_datasets(all_datasets, dict(object='spikes', attribute='times'))
Filter by QC outcome - datasets with WARNING or better
>>> datasets filter_datasets(all_datasets, qc='WARNING')
Filter by QC outcome and ignore datasets with unset QC - datasets with PASS only
>>> datasets filter_datasets(all_datasets, qc='PASS', ignore_qc_not_set=True)
Notes
-----
- It is not possible to match datasets that are in a given collection OR NOT in ANY collection.
e.g. filter_datasets(dsets, collection=['alf', '']) will not match the latter. For this you
must use two separate queries.
"""
# Create a regular expression string to match relative path against
filename = filename or {}
regex_args = {'collection': collection}
spec_str = _collection_spec(collection, None if revision_last_before else revision)
if isinstance(filename, dict):
spec_str += _file_spec(**filename)
regex_args.update(**filename)
else:
# Convert to regex is necessary and assert end of string
filename = [fnmatch.translate(x) if wildcards else x + '$' for x in ensure_list(filename)]
spec_str += '|'.join(filename)
# If matching revision name, add to regex string
if not revision_last_before:
regex_args.update(revision=revision)
for k, v in regex_args.items():
if v is None:
continue
if wildcards:
# Convert to regex, remove \\Z which asserts end of string
v = (fnmatch.translate(x).replace('\\Z', '') for x in ensure_list(v))
if not isinstance(v, str):
regex_args[k] = '|'.join(v) # logical OR
# Build regex string
pattern = alf_regex('^' + spec_str, **regex_args)
path_match = all_datasets['rel_path'].str.match(pattern)
# Test on QC outcome
qc = QC.validate(qc)
qc_match = all_datasets['qc'].le(qc.name)
if ignore_qc_not_set:
qc_match &= all_datasets['qc'].ne('NOT_SET')
# Filter datasets on path and QC
match = all_datasets[path_match & qc_match]
if len(match) == 0 or not (revision_last_before or assert_unique):
return match
revisions = [rel_path_parts(x)[1] or '' for x in match.rel_path.values]
if assert_unique:
collections = set(rel_path_parts(x)[0] or '' for x in match.rel_path.values)
if len(collections) > 1:
_list = '"' + '", "'.join(collections) + '"'
raise alferr.ALFMultipleCollectionsFound(_list)
if not revision_last_before:
if filename and len(match) > 1:
_list = '"' + '", "'.join(match['rel_path']) + '"'
raise alferr.ALFMultipleObjectsFound(_list)
if len(set(revisions)) > 1:
_list = '"' + '", "'.join(set(revisions)) + '"'
raise alferr.ALFMultipleRevisionsFound(_list)
else:
return match
elif filename and len(set(revisions)) != len(revisions):
_list = '"' + '", "'.join(match['rel_path']) + '"'
raise alferr.ALFMultipleObjectsFound(_list)
return filter_revision_last_before(match, revision, assert_unique=assert_unique)
[docs]
def filter_revision_last_before(datasets, revision=None, assert_unique=True):
"""
Filter datasets by revision, returning previous revision in ordered list if revision
doesn't exactly match.
Parameters
----------
datasets : pandas.DataFrame
A datasets cache table.
revision : str
A revision string to match (regular expressions not permitted).
assert_unique : bool
When true an alferr.ALFMultipleRevisionsFound exception is raised when multiple
default revisions are found; an alferr.ALFError when no default revision is found.
Returns
-------
pd.DataFrame
A datasets DataFrame with 0 or 1 row per unique dataset.
"""
def _last_before(df):
"""Takes a DataFrame with only one dataset and multiple revisions, returns matching row"""
if revision is None and 'default_revision' in df.columns:
if assert_unique and sum(df.default_revision) > 1:
revisions = df['revision'][df.default_revision.values]
rev_list = '"' + '", "'.join(revisions) + '"'
raise alferr.ALFMultipleRevisionsFound(rev_list)
if sum(df.default_revision) == 1:
return df[df.default_revision]
if len(df) == 1: # This may be the case when called from load_datasets
return df # It's not the default be there's only one available revision
# default_revision column all False; default isn't copied to remote repository
dset_name = df['rel_path'].iloc[0]
if assert_unique:
raise alferr.ALFError(f'No default revision for dataset {dset_name}')
else:
logger.warning(f'No default revision for dataset {dset_name}; using most recent')
# Compare revisions lexicographically
if assert_unique and len(df['revision'].unique()) > 1:
rev_list = '"' + '", "'.join(df['revision'].unique()) + '"'
raise alferr.ALFMultipleRevisionsFound(rev_list)
# Square brackets forces 1 row DataFrame returned instead of Series
idx = index_last_before(df['revision'].tolist(), revision)
# return df.iloc[slice(0, 0) if idx is None else [idx], :]
return df.iloc[slice(0, 0) if idx is None else [idx], :]
with pd.option_context('mode.chained_assignment', None): # FIXME Explicitly copy?
datasets['revision'] = [rel_path_parts(x)[1] or '' for x in datasets.rel_path]
groups = datasets.rel_path.str.replace('#.*#/', '', regex=True).values
grouped = datasets.groupby(groups, group_keys=False)
return grouped.apply(_last_before)
[docs]
def index_last_before(revisions: List[str], revision: Optional[str]) -> Optional[int]:
"""
Returns the index of string that occurs directly before the provided revision string when
lexicographic sorted. If revision is None, the index of the most recent revision is returned.
Parameters
----------
revisions : list of strings
A list of revision strings.
revision : None, str
The revision string to match on.
Returns
-------
int, None
Index of revision before matching string in sorted list or None.
Examples
--------
>>> idx = index_last_before([], '2020-08-01')
"""
if len(revisions) == 0:
return # No revisions, just return
revisions_sorted = sorted(revisions, reverse=True)
if revision is None: # Return most recent revision
return revisions.index(revisions_sorted[0])
lt = np.array(revisions_sorted) <= revision
return revisions.index(revisions_sorted[lt.argmax()]) if any(lt) else None
[docs]
def autocomplete(term, search_terms) -> str:
"""
Validate search term and return complete name, e.g. autocomplete('subj') == 'subject'.
"""
term = term.lower()
# Check if term already complete
if term in search_terms:
return term
full_key = (x for x in search_terms if x.lower().startswith(term))
key_ = next(full_key, None)
if not key_:
raise ValueError(f'Invalid search term "{term}", see `one.search_terms()`')
elif next(full_key, None):
raise ValueError(f'Ambiguous search term "{term}"')
return key_
[docs]
def ensure_list(value):
"""Ensure input is a list."""
return [value] if isinstance(value, (str, dict)) or not isinstance(value, Iterable) else value
[docs]
class LazyId(Mapping):
"""
Using a paginated response object or list of session records, extracts eid string when required
"""
def __init__(self, pg, func=None):
self._pg = pg
self.func = func or self.ses2eid
def __getitem__(self, item):
return self.func(self._pg.__getitem__(item))
def __len__(self):
return self._pg.__len__()
def __iter__(self):
return map(self.func, self._pg.__iter__())
[docs]
@staticmethod
def ses2eid(ses):
"""Given one or more session dictionaries, extract and return the session UUID.
Parameters
----------
ses : one.webclient._PaginatedResponse, dict, list
A collection of Alyx REST sessions endpoint records.
Returns
-------
str, list
One or more experiment ID strings.
"""
if isinstance(ses, list):
return [LazyId.ses2eid(x) for x in ses]
else:
return ses.get('id', None) or ses['url'].split('/').pop()
[docs]
def cache_int2str(table: pd.DataFrame) -> pd.DataFrame:
"""Convert int ids to str ids for cache table.
Parameters
----------
table : pd.DataFrame
A cache table (from One._cache).
"""
# Convert integer uuids to str uuids
if table.index.nlevels < 2 or not any(x.endswith('_0') for x in table.index.names):
return table
table = table.reset_index()
int_cols = table.filter(regex=r'_\d{1}$').columns.sort_values()
assert not len(int_cols) % 2, 'expected even number of columns ending in _0 or _1'
names = sorted(set(c.rsplit('_', 1)[0] for c in int_cols.values))
for i, name in zip(range(0, len(int_cols), 2), names):
table[name] = parquet.np2str(table[int_cols[i:i + 2]])
table = table.drop(int_cols, axis=1).set_index(names)
return table
[docs]
def patch_cache(table: pd.DataFrame, min_api_version=None, name=None) -> pd.DataFrame:
"""Reformat older cache tables to comply with this version of ONE.
Currently this function will 1. convert integer UUIDs to string UUIDs; 2. rename the 'project'
column to 'projects'.
Parameters
----------
table : pd.DataFrame
A cache table (from One._cache).
min_api_version : str
The minimum API version supported by this cache table.
name : {'dataset', 'session'} str
The name of the table.
"""
min_version = version.parse(min_api_version or '0.0.0')
table = cache_int2str(table)
# Rename project column
if min_version < version.Version('1.13.0') and 'project' in table.columns:
table.rename(columns={'project': 'projects'}, inplace=True)
if name == 'datasets' and min_version < version.Version('2.7.0') and 'qc' not in table.columns:
qc = pd.Categorical.from_codes(np.zeros(len(table.index), dtype=int), dtype=QC_TYPE)
table = table.assign(qc=qc)
return table