Source code for brainbox.io.one

"""Functions for loading IBL ephys and trial data using the Open Neurophysiology Environment."""
from dataclasses import dataclass, field
import gc
import logging
import re
import os
from pathlib import Path

import numpy as np
import pandas as pd
from scipy.interpolate import interp1d
import matplotlib.pyplot as plt

from one.api import ONE, One
from one.alf.path import get_alf_path, full_path_parts
from one.alf.exceptions import ALFObjectNotFound, ALFMultipleCollectionsFound
from one.alf import cache
import one.alf.io as alfio
from neuropixel import TIP_SIZE_UM, trace_header
import spikeglx

import ibldsp.voltage
from ibldsp.waveform_extraction import WaveformsLoader
from iblutil.util import Bunch
from iblatlas.atlas import AllenAtlas, BrainRegions
from iblatlas import atlas
from ibllib.io.extractors.training_wheel import extract_wheel_moves, extract_first_movement_times
from ibllib.pipes import histology
from ibllib.pipes.ephys_alignment import EphysAlignment
from ibllib.plots import vertical_lines, Density

import brainbox.plot
from brainbox.io.spikeglx import Streamer
from brainbox.ephys_plots import plot_brain_regions
from brainbox.metrics.single_units import quick_unit_metrics
from brainbox.behavior.wheel import interpolate_position, velocity_filtered
from brainbox.behavior.dlc import likelihood_threshold, get_pupil_diameter, get_smooth_pupil_diameter

_logger = logging.getLogger('ibllib')


SPIKES_ATTRIBUTES = ['clusters', 'times', 'amps', 'depths']
CLUSTERS_ATTRIBUTES = ['channels', 'depths', 'metrics', 'uuids']
WAVEFORMS_ATTRIBUTES = ['templates']


[docs] def load_lfp(eid, one=None, dataset_types=None, **kwargs): """ TODO Verify works From an eid, hits the Alyx database and downloads the standard set of datasets needed for LFP :param eid: :param dataset_types: additional dataset types to add to the list :param open: if True, spikeglx readers are opened :return: spikeglx.Reader """ if dataset_types is None: dataset_types = [] dtypes = dataset_types + ['*ephysData.raw.lf*', '*ephysData.raw.meta*', '*ephysData.raw.ch*'] [one.load_dataset(eid, dset, download_only=True) for dset in dtypes] session_path = one.eid2path(eid) efiles = [ef for ef in spikeglx.glob_ephys_files(session_path, bin_exists=False) if ef.get('lf', None)] return [spikeglx.Reader(ef['lf'], **kwargs) for ef in efiles]
def _collection_filter_from_args(probe, spike_sorter=None): collection = f'alf/{probe}/{spike_sorter}' collection = collection.replace('None', '*') collection = collection.replace('/*', '*') collection = collection[:-1] if collection.endswith('/') else collection return collection def _get_spike_sorting_collection(collections, pname): """ Filters a list or array of collections to get the relevant spike sorting dataset if there is a pykilosort, load it """ # collection = next(filter(lambda c: c == f'alf/{pname}/pykilosort', collections), None) # otherwise, prefers the shortest collection = collection or next(iter(sorted(filter(lambda c: f'alf/{pname}' in c, collections), key=len)), None) _logger.debug(f"selecting: {collection} to load amongst candidates: {collections}") return collection def _channels_alyx2bunch(chans): channels = Bunch({ 'atlas_id': np.array([ch['brain_region'] for ch in chans]), 'x': np.array([ch['x'] for ch in chans]) / 1e6, 'y': np.array([ch['y'] for ch in chans]) / 1e6, 'z': np.array([ch['z'] for ch in chans]) / 1e6, 'axial_um': np.array([ch['axial'] for ch in chans]), 'lateral_um': np.array([ch['lateral'] for ch in chans]) }) return channels def _channels_traj2bunch(xyz_chans, brain_atlas): brain_regions = brain_atlas.regions.get(brain_atlas.get_labels(xyz_chans)) channels = { 'x': xyz_chans[:, 0], 'y': xyz_chans[:, 1], 'z': xyz_chans[:, 2], 'acronym': brain_regions['acronym'], 'atlas_id': brain_regions['id'] } return channels def _channels_bunch2alf(channels): channels_ = { 'mlapdv': np.c_[channels['x'], channels['y'], channels['z']] * 1e6, 'brainLocationIds_ccf_2017': channels['atlas_id'], 'localCoordinates': np.c_[channels['lateral_um'], channels['axial_um']]} return channels_ def _channels_alf2bunch(channels, brain_regions=None): # reformat the dictionary according to the standard that comes out of Alyx channels_ = { 'x': channels['mlapdv'][:, 0].astype(np.float64) / 1e6, 'y': channels['mlapdv'][:, 1].astype(np.float64) / 1e6, 'z': channels['mlapdv'][:, 2].astype(np.float64) / 1e6, 'acronym': None, 'atlas_id': channels['brainLocationIds_ccf_2017'], 'axial_um': channels['localCoordinates'][:, 1], 'lateral_um': channels['localCoordinates'][:, 0], } # here if we have some extra keys, they will carry over to the next dictionary for k in channels: if k not in list(channels_.keys()) + ['mlapdv', 'brainLocationIds_ccf_2017', 'localCoordinates']: channels_[k] = channels[k] if brain_regions: channels_['acronym'] = brain_regions.get(channels_['atlas_id'])['acronym'] return channels_ def _load_spike_sorting(eid, one=None, collection=None, revision=None, return_channels=True, dataset_types=None, brain_regions=None): """ Generic function to load spike sorting according data using ONE. Will try to load one spike sorting for any probe present for the eid matching the collection For each probe it will load a spike sorting: - if there is one version: loads this one - if there are several versions: loads pykilosort, if not found the shortest collection (alf/probeXX) Parameters ---------- eid : [str, UUID, Path, dict] Experiment session identifier; may be a UUID, URL, experiment reference string details dict or Path one : one.api.OneAlyx An instance of ONE (may be in 'local' mode) collection : str collection filter word - accepts wildcards - can be a combination of spike sorter and probe. See `ALF documentation`_ for details. revision : str A particular revision return (defaults to latest revision). See `ALF documentation`_ for details. return_channels : bool Defaults to False otherwise loads channels from disk .. _ALF documentation: https://one.internationalbrainlab.org/alf_intro.html#optional-components Returns ------- spikes : dict of one.alf.io.AlfBunch A dict with probe labels as keys, contains bunch(es) of spike data for the provided session and spike sorter, with keys ('clusters', 'times') clusters : dict of one.alf.io.AlfBunch A dict with probe labels as keys, contains bunch(es) of cluster data, with keys ('channels', 'depths', 'metrics') channels : dict of one.alf.io.AlfBunch A dict with probe labels as keys, contains channel locations with keys ('acronym', 'atlas_id', 'x', 'y', 'z'). Only returned when return_channels is True. Atlas IDs non-lateralized. """ one = one or ONE() # enumerate probes and load according to the name collections = one.list_collections(eid, filename='spikes*', collection=collection, revision=revision) if len(collections) == 0: _logger.warning(f"eid {eid}: no collection found with collection filter: {collection}, revision: {revision}") pnames = list(set(c.split('/')[1] for c in collections)) spikes, clusters, channels = ({} for _ in range(3)) spike_attributes, cluster_attributes = _get_attributes(dataset_types) for pname in pnames: probe_collection = _get_spike_sorting_collection(collections, pname) spikes[pname] = one.load_object(eid, collection=probe_collection, obj='spikes', attribute=spike_attributes) clusters[pname] = one.load_object(eid, collection=probe_collection, obj='clusters', attribute=cluster_attributes) if return_channels: channels = _load_channels_locations_from_disk( eid, collection=collection, one=one, revision=revision, brain_regions=brain_regions) return spikes, clusters, channels else: return spikes, clusters def _get_attributes(dataset_types): if dataset_types is None: return SPIKES_ATTRIBUTES, CLUSTERS_ATTRIBUTES else: spike_attributes = [sp.split('.')[1] for sp in dataset_types if 'spikes.' in sp] cluster_attributes = [cl.split('.')[1] for cl in dataset_types if 'clusters.' in cl] spike_attributes = list(set(SPIKES_ATTRIBUTES + spike_attributes)) cluster_attributes = list(set(CLUSTERS_ATTRIBUTES + cluster_attributes)) return spike_attributes, cluster_attributes def _load_channels_locations_from_disk(eid, collection=None, one=None, revision=None, brain_regions=None): _logger.debug('loading spike sorting from disk') channels = Bunch({}) collections = one.list_collections(eid, filename='channels*', collection=collection, revision=revision) if len(collections) == 0: _logger.warning(f"eid {eid}: no collection found with collection filter: {collection}, revision: {revision}") probes = list(set([c.split('/')[1] for c in collections])) for probe in probes: probe_collection = _get_spike_sorting_collection(collections, probe) channels[probe] = one.load_object(eid, collection=probe_collection, obj='channels') # if the spike sorter has not aligned data, try and get the alignment available if 'brainLocationIds_ccf_2017' not in channels[probe].keys(): aligned_channel_collections = one.list_collections( eid, filename='channels.brainLocationIds_ccf_2017*', collection=probe_collection, revision=revision) if len(aligned_channel_collections) == 0: _logger.debug(f"no resolved alignment dataset found for {eid}/{probe}") continue _logger.debug(f"looking for a resolved alignment dataset in {aligned_channel_collections}") ac_collection = _get_spike_sorting_collection(aligned_channel_collections, probe) channels_aligned = one.load_object(eid, 'channels', collection=ac_collection) channels[probe] = channel_locations_interpolation(channels_aligned, channels[probe]) # only have to reformat channels if we were able to load coordinates from disk channels[probe] = _channels_alf2bunch(channels[probe], brain_regions=brain_regions) return channels
[docs] def channel_locations_interpolation(channels_aligned, channels=None, brain_regions=None): """ oftentimes the channel map for different spike sorters may be different so interpolate the alignment onto if there is no spike sorting in the base folder, the alignment doesn't have the localCoordinates field so we reconstruct from the Neuropixel map. This only happens for early pykilosort sorts :param channels_aligned: Bunch or dictionary of aligned channels containing at least keys 'localCoordinates', 'mlapdv' and 'brainLocationIds_ccf_2017' OR 'x', 'y', 'z', 'acronym', 'axial_um' those are the guide for the interpolation :param channels: Bunch or dictionary of aligned channels containing at least keys 'localCoordinates' :param brain_regions: None (default) or iblatlas.regions.BrainRegions object if None will return a dict with keys 'localCoordinates', 'mlapdv', 'brainLocationIds_ccf_2017 if a brain region object is provided, outputts a dict with keys 'x', 'y', 'z', 'acronym', 'atlas_id', 'axial_um', 'lateral_um' :return: Bunch or dictionary of channels with brain coordinates keys """ NEUROPIXEL_VERSION = 1 h = trace_header(version=NEUROPIXEL_VERSION) if channels is None: channels = {'localCoordinates': np.c_[h['x'], h['y']]} nch = channels['localCoordinates'].shape[0] if {'x', 'y', 'z'}.issubset(set(channels_aligned.keys())): channels_aligned = _channels_bunch2alf(channels_aligned) if 'localCoordinates' in channels_aligned.keys(): aligned_depths = channels_aligned['localCoordinates'][:, 1] else: # this is a edge case for a few spike sorting sessions assert channels_aligned['mlapdv'].shape[0] == 384 aligned_depths = h['y'] depth_aligned, ind_aligned = np.unique(aligned_depths, return_index=True) depths, ind, iinv = np.unique(channels['localCoordinates'][:, 1], return_index=True, return_inverse=True) channels['mlapdv'] = np.zeros((nch, 3)) for i in np.arange(3): channels['mlapdv'][:, i] = np.interp( depths, depth_aligned, channels_aligned['mlapdv'][ind_aligned, i])[iinv] # the brain locations have to be interpolated by nearest neighbour fcn_interp = interp1d(depth_aligned, channels_aligned['brainLocationIds_ccf_2017'][ind_aligned], kind='nearest') channels['brainLocationIds_ccf_2017'] = fcn_interp(depths)[iinv].astype(np.int32) if brain_regions is not None: return _channels_alf2bunch(channels, brain_regions=brain_regions) else: return channels
def _load_channel_locations_traj(eid, probe=None, one=None, revision=None, aligned=False, brain_atlas=None, return_source=False): if not hasattr(one, 'alyx'): return {}, None _logger.debug(f"trying to load from traj {probe}") channels = Bunch() brain_atlas = brain_atlas or AllenAtlas # need to find the collection bruh insertion = one.alyx.rest('insertions', 'list', session=eid, name=probe)[0] collection = _collection_filter_from_args(probe=probe) collections = one.list_collections(eid, filename='channels*', collection=collection, revision=revision) probe_collection = _get_spike_sorting_collection(collections, probe) chn_coords = one.load_dataset(eid, 'channels.localCoordinates', collection=probe_collection) depths = chn_coords[:, 1] tracing = insertion.get('json', {'temp': 0}).get('extended_qc', {'temp': 0}). \ get('tracing_exists', False) resolved = insertion.get('json', {'temp': 0}).get('extended_qc', {'temp': 0}). \ get('alignment_resolved', False) counts = insertion.get('json', {'temp': 0}).get('extended_qc', {'temp': 0}). \ get('alignment_count', 0) if tracing: xyz = np.array(insertion['json']['xyz_picks']) / 1e6 if resolved: _logger.debug(f'Channel locations for {eid}/{probe} have been resolved. ' f'Channel and cluster locations obtained from ephys aligned histology ' f'track.') traj = one.alyx.rest('trajectories', 'list', session=eid, probe=probe, provenance='Ephys aligned histology track')[0] align_key = insertion['json']['extended_qc']['alignment_stored'] feature = traj['json'][align_key][0] track = traj['json'][align_key][1] ephysalign = EphysAlignment(xyz, depths, track_prev=track, feature_prev=feature, brain_atlas=brain_atlas, speedy=True) chans = ephysalign.get_channel_locations(feature, track) channels[probe] = _channels_traj2bunch(chans, brain_atlas) source = 'resolved' elif counts > 0 and aligned: _logger.debug(f'Channel locations for {eid}/{probe} have not been ' f'resolved. However, alignment flag set to True so channel and cluster' f' locations will be obtained from latest available ephys aligned ' f'histology track.') # get the latest user aligned channels traj = one.alyx.rest('trajectories', 'list', session=eid, probe=probe, provenance='Ephys aligned histology track')[0] align_key = insertion['json']['extended_qc']['alignment_stored'] feature = traj['json'][align_key][0] track = traj['json'][align_key][1] ephysalign = EphysAlignment(xyz, depths, track_prev=track, feature_prev=feature, brain_atlas=brain_atlas, speedy=True) chans = ephysalign.get_channel_locations(feature, track) channels[probe] = _channels_traj2bunch(chans, brain_atlas) source = 'aligned' else: _logger.debug(f'Channel locations for {eid}/{probe} have not been resolved. ' f'Channel and cluster locations obtained from histology track.') # get the channels from histology tracing xyz = xyz[np.argsort(xyz[:, 2]), :] chans = histology.interpolate_along_track(xyz, (depths + TIP_SIZE_UM) / 1e6) channels[probe] = _channels_traj2bunch(chans, brain_atlas) source = 'traced' channels[probe]['axial_um'] = chn_coords[:, 1] channels[probe]['lateral_um'] = chn_coords[:, 0] else: _logger.warning(f'Histology tracing for {probe} does not exist. No channels for {probe}') source = '' channels = None if return_source: return channels, source else: return channels
[docs] def load_channel_locations(eid, probe=None, one=None, aligned=False, brain_atlas=None): """ Load the brain locations of each channel for a given session/probe Parameters ---------- eid : [str, UUID, Path, dict] Experiment session identifier; may be a UUID, URL, experiment reference string details dict or Path probe : [str, list of str] The probe label(s), e.g. 'probe01' one : one.api.OneAlyx An instance of ONE (shouldn't be in 'local' mode) aligned : bool Whether to get the latest user aligned channel when not resolved or use histology track brain_atlas : iblatlas.BrainAtlas Brain atlas object (default: Allen atlas) Returns ------- dict of one.alf.io.AlfBunch A dict with probe labels as keys, contains channel locations with keys ('acronym', 'atlas_id', 'x', 'y', 'z'). Atlas IDs non-lateralized. optional: string 'resolved', 'aligned', 'traced' or '' """ one = one or ONE() brain_atlas = brain_atlas or AllenAtlas() if isinstance(eid, dict): ses = eid eid = ses['url'][-36:] else: eid = one.to_eid(eid) collection = _collection_filter_from_args(probe=probe) channels = _load_channels_locations_from_disk(eid, one=one, collection=collection, brain_regions=brain_atlas.regions) incomplete_probes = [k for k in channels if 'x' not in channels[k]] for iprobe in incomplete_probes: channels_, source = _load_channel_locations_traj(eid, probe=iprobe, one=one, aligned=aligned, brain_atlas=brain_atlas, return_source=True) if channels_ is not None: channels[iprobe] = channels_[iprobe] return channels
[docs] def load_spike_sorting_fast(eid, one=None, probe=None, dataset_types=None, spike_sorter=None, revision=None, brain_regions=None, nested=True, collection=None, return_collection=False): """ From an eid, loads spikes and clusters for all probes The following set of dataset types are loaded: 'clusters.channels', 'clusters.depths', 'clusters.metrics', 'spikes.clusters', 'spikes.times', 'probes.description' :param eid: experiment UUID or pathlib.Path of the local session :param one: an instance of OneAlyx :param probe: name of probe to load in, if not given all probes for session will be loaded :param dataset_types: additional spikes/clusters objects to add to the standard default list :param spike_sorter: name of the spike sorting you want to load (None for default) :param collection: name of the spike sorting collection to load - exclusive with spike sorter name ex: "alf/probe00" :param brain_regions: iblatlas.regions.BrainRegions object - will label acronyms if provided :param nested: if a single probe is required, do not output a dictionary with the probe name as key :param return_collection: (False) if True, will return the collection used to load :return: spikes, clusters, channels (dict of bunch, 1 bunch per probe) """ _logger.warning('Deprecation warning: brainbox.io.one.load_spike_sorting_fast will be removed in future versions.' 'Use brainbox.io.one.SpikeSortingLoader instead') if collection is None: collection = _collection_filter_from_args(probe, spike_sorter) _logger.debug(f"load spike sorting with collection filter {collection}") kwargs = dict(eid=eid, one=one, collection=collection, revision=revision, dataset_types=dataset_types, brain_regions=brain_regions) spikes, clusters, channels = _load_spike_sorting(**kwargs, return_channels=True) clusters = merge_clusters_channels(clusters, channels, keys_to_add_extra=None) if nested is False and len(spikes.keys()) == 1: k = list(spikes.keys())[0] channels = channels[k] clusters = clusters[k] spikes = spikes[k] if return_collection: return spikes, clusters, channels, collection else: return spikes, clusters, channels
[docs] def load_spike_sorting(eid, one=None, probe=None, dataset_types=None, spike_sorter=None, revision=None, brain_regions=None, return_collection=False): """ From an eid, loads spikes and clusters for all probes The following set of dataset types are loaded: 'clusters.channels', 'clusters.depths', 'clusters.metrics', 'spikes.clusters', 'spikes.times', 'probes.description' :param eid: experiment UUID or pathlib.Path of the local session :param one: an instance of OneAlyx :param probe: name of probe to load in, if not given all probes for session will be loaded :param dataset_types: additional spikes/clusters objects to add to the standard default list :param spike_sorter: name of the spike sorting you want to load (None for default) :param brain_regions: iblatlas.regions.BrainRegions object - will label acronyms if provided :param return_collection:(bool - False) if True, returns the collection for loading the data :return: spikes, clusters (dict of bunch, 1 bunch per probe) """ _logger.warning('Deprecation warning: brainbox.io.one.load_spike_sorting will be removed in future versions.' 'Use brainbox.io.one.SpikeSortingLoader instead') collection = _collection_filter_from_args(probe, spike_sorter) _logger.debug(f"load spike sorting with collection filter {collection}") spikes, clusters = _load_spike_sorting(eid=eid, one=one, collection=collection, revision=revision, return_channels=False, dataset_types=dataset_types, brain_regions=brain_regions) if return_collection: return spikes, clusters, collection else: return spikes, clusters
[docs] def load_spike_sorting_with_channel(eid, one=None, probe=None, aligned=False, dataset_types=None, spike_sorter=None, brain_atlas=None, nested=True, return_collection=False): """ For a given eid, get spikes, clusters and channels information, and merges clusters and channels information before returning all three variables. Parameters ---------- eid : [str, UUID, Path, dict] Experiment session identifier; may be a UUID, URL, experiment reference string details dict or Path one : one.api.OneAlyx An instance of ONE (shouldn't be in 'local' mode) probe : [str, list of str] The probe label(s), e.g. 'probe01' aligned : bool Whether to get the latest user aligned channel when not resolved or use histology track dataset_types : list of str Optional additional spikes/clusters objects to add to the standard default list spike_sorter : str Name of the spike sorting you want to load (None for default which is pykilosort if it's available otherwise the default MATLAB kilosort) brain_atlas : iblatlas.atlas.BrainAtlas Brain atlas object (default: Allen atlas) return_collection: bool Returns an extra argument with the collection chosen Returns ------- spikes : dict of one.alf.io.AlfBunch A dict with probe labels as keys, contains bunch(es) of spike data for the provided session and spike sorter, with keys ('clusters', 'times') clusters : dict of one.alf.io.AlfBunch A dict with probe labels as keys, contains bunch(es) of cluster data, with keys ('channels', 'depths', 'metrics') channels : dict of one.alf.io.AlfBunch A dict with probe labels as keys, contains channel locations with keys ('acronym', 'atlas_id', 'x', 'y', 'z'). Atlas IDs non-lateralized. """ # --- Get spikes and clusters data _logger.warning('Deprecation warning: brainbox.io.one.load_spike_sorting will be removed in future versions.' 'Use brainbox.io.one.SpikeSortingLoader instead') one = one or ONE() brain_atlas = brain_atlas or AllenAtlas() spikes, clusters, collection = load_spike_sorting( eid, one=one, probe=probe, dataset_types=dataset_types, spike_sorter=spike_sorter, return_collection=True) # -- Get brain regions and assign to clusters channels = load_channel_locations(eid, one=one, probe=probe, aligned=aligned, brain_atlas=brain_atlas) clusters = merge_clusters_channels(clusters, channels, keys_to_add_extra=None) if nested is False and len(spikes.keys()) == 1: k = list(spikes.keys())[0] channels = channels[k] clusters = clusters[k] spikes = spikes[k] if return_collection: return spikes, clusters, channels, collection else: return spikes, clusters, channels
[docs] def load_ephys_session(eid, one=None): """ From an eid, hits the Alyx database and downloads a standard default set of dataset types From a local session Path (pathlib.Path), loads a standard default set of dataset types to perform analysis: 'clusters.channels', 'clusters.depths', 'clusters.metrics', 'spikes.clusters', 'spikes.times', 'probes.description' Parameters ---------- eid : [str, UUID, Path, dict] Experiment session identifier; may be a UUID, URL, experiment reference string details dict or Path one : oneibl.one.OneAlyx, optional ONE object to use for loading. Will generate internal one if not used, by default None Returns ------- spikes : dict of one.alf.io.AlfBunch A dict with probe labels as keys, contains bunch(es) of spike data for the provided session and spike sorter, with keys ('clusters', 'times') clusters : dict of one.alf.io.AlfBunch A dict with probe labels as keys, contains bunch(es) of cluster data, with keys ('channels', 'depths', 'metrics') trials : one.alf.io.AlfBunch of numpy.ndarray The session trials data """ assert one spikes, clusters = load_spike_sorting(eid, one=one) trials = one.load_object(eid, 'trials') return spikes, clusters, trials
def _remove_old_clusters(session_path, probe): # gets clusters and spikes from a local session folder probe_path = session_path.joinpath('alf', probe) # look for clusters.metrics.csv file, if it exists delete as we now have .pqt file instead cluster_file = probe_path.joinpath('clusters.metrics.csv') if cluster_file.exists(): os.remove(cluster_file) _logger.info('Deleting old clusters.metrics.csv file')
[docs] def merge_clusters_channels(dic_clus, channels, keys_to_add_extra=None): """ Takes (default and any extra) values in given keys from channels and assign them to clusters. If channels does not contain any data, the new keys are added to clusters but left empty. Parameters ---------- dic_clus : dict of one.alf.io.AlfBunch 1 bunch per probe, containing cluster information channels : dict of one.alf.io.AlfBunch 1 bunch per probe, containing channels bunch with keys ('acronym', 'atlas_id', 'x', 'y', z', 'localCoordinates') keys_to_add_extra : list of str Any extra keys to load into channels bunches Returns ------- dict of one.alf.io.AlfBunch clusters (1 bunch per probe) with new keys values. """ probe_labels = list(channels.keys()) # Convert dict_keys into list keys_to_add_default = ['acronym', 'atlas_id', 'x', 'y', 'z', 'axial_um', 'lateral_um'] if keys_to_add_extra is None: keys_to_add = keys_to_add_default else: # Append extra optional keys keys_to_add = list(set(keys_to_add_extra + keys_to_add_default)) for label in probe_labels: clu_ch = dic_clus[label]['channels'] for key in keys_to_add: try: assert key in channels[label].keys() # Check key is in channels ch_key = channels[label][key] nch_key = len(ch_key) if ch_key is not None else 0 if max(clu_ch) < nch_key: # Check length as will use clu_ch as index dic_clus[label][key] = ch_key[clu_ch] else: _logger.warning( f'Probe {label}: merging channels and clusters for key "{key}" has {nch_key} on channels' f' but expected {max(clu_ch)}. Data in new cluster key "{key}" is returned empty.') dic_clus[label][key] = [] except AssertionError: _logger.warning(f'Either clusters or channels does not have key {key}, could not merge') continue return dic_clus
[docs] def load_passive_rfmap(eid, one=None): """ For a given eid load in the passive receptive field mapping protocol data Parameters ---------- eid : [str, UUID, Path, dict] Experiment session identifier; may be a UUID, URL, experiment reference string details dict or Path one : oneibl.one.OneAlyx, optional An instance of ONE (may be in 'local' - offline - mode) Returns ------- one.alf.io.AlfBunch Passive receptive field mapping data """ one = one or ONE() # Load in the receptive field mapping data rf_map = one.load_object(eid, obj='passiveRFM', collection='alf') frames = np.fromfile(one.load_dataset(eid, '_iblrig_RFMapStim.raw.bin', collection='raw_passive_data'), dtype="uint8") y_pix, x_pix = 15, 15 frames = np.transpose(np.reshape(frames, [y_pix, x_pix, -1], order="F"), [2, 1, 0]) rf_map['frames'] = frames return rf_map
[docs] def load_wheel_reaction_times(eid, one=None): """ Return the calculated reaction times for session. Reaction times are defined as the time between the go cue (onset tone) and the onset of the first substantial wheel movement. A movement is considered sufficiently large if its peak amplitude is at least 1/3rd of the distance to threshold (~0.1 radians). Negative times mean the onset of the movement occurred before the go cue. Nans may occur if there was no detected movement withing the period, or when the goCue_times or feedback_times are nan. Parameters ---------- eid : [str, UUID, Path, dict] Experiment session identifier; may be a UUID, URL, experiment reference string details dict or Path one : one.api.OneAlyx, optional one object to use for loading. Will generate internal one if not used, by default None Returns ---------- array-like reaction times """ if one is None: one = ONE() trials = one.load_object(eid, 'trials') # If already extracted, load and return if trials and 'firstMovement_times' in trials: return trials['firstMovement_times'] - trials['goCue_times'] # Otherwise load wheelMoves object and calculate moves = one.load_object(eid, 'wheelMoves') # Re-extract wheel moves if necessary if not moves or 'peakAmplitude' not in moves: wheel = one.load_object(eid, 'wheel') moves = extract_wheel_moves(wheel['timestamps'], wheel['position']) assert trials and moves, 'unable to load trials and wheelMoves data' firstMove_times, is_final_movement, ids = extract_first_movement_times(moves, trials) return firstMove_times - trials['goCue_times']
[docs] def load_iti(trials): """ The inter-trial interval (ITI) time for each trial, defined as the period of open-loop grey screen commencing at stimulus off and lasting until the quiescent period at the start of the following trial. Note that the ITI for the first trial is the time between the first trial and the next, therefore the last value is NaN. Parameters ---------- trials : one.alf.io.AlfBunch An ALF trials object containing the keys {'intervals', 'stimOff_times'}. Returns ------- np.array An array of inter-trial intervals, the last value being NaN. """ if not {'intervals', 'stimOff_times'} <= set(trials.keys()): raise ValueError('trials must contain keys {"intervals", "stimOff_times"}') return np.r_[(np.roll(trials['intervals'][:, 0], -1) - trials['stimOff_times'])[:-1], np.nan]
[docs] def load_channels_from_insertion(ins, depths=None, one=None, ba=None): PROV_2_VAL = { 'Resolved': 90, 'Ephys aligned histology track': 70, 'Histology track': 50, 'Micro-manipulator': 30, 'Planned': 10} one = one or ONE() ba = ba or atlas.AllenAtlas() traj = one.alyx.rest('trajectories', 'list', probe_insertion=ins['id']) val = [PROV_2_VAL[tr['provenance']] for tr in traj] idx = np.argmax(val) traj = traj[idx] if depths is None: depths = trace_header(version=1)[:, 1] if traj['provenance'] == 'Planned' or traj['provenance'] == 'Micro-manipulator': ins = atlas.Insertion.from_dict(traj) # Deepest coordinate first xyz = np.c_[ins.tip, ins.entry].T xyz_channels = histology.interpolate_along_track(xyz, (depths + TIP_SIZE_UM) / 1e6) else: xyz = np.array(ins['json']['xyz_picks']) / 1e6 if traj['provenance'] == 'Histology track': xyz = xyz[np.argsort(xyz[:, 2]), :] xyz_channels = histology.interpolate_along_track(xyz, (depths + TIP_SIZE_UM) / 1e6) else: align_key = ins['json']['extended_qc']['alignment_stored'] feature = traj['json'][align_key][0] track = traj['json'][align_key][1] ephysalign = EphysAlignment(xyz, depths, track_prev=track, feature_prev=feature, brain_atlas=ba, speedy=True) xyz_channels = ephysalign.get_channel_locations(feature, track) return xyz_channels
[docs] @dataclass class SpikeSortingLoader: """ Object that will load spike sorting data for a given probe insertion. This class can be instantiated in several manners - With Alyx database probe id: SpikeSortingLoader(pid=pid, one=one) - With Alyx database eic and probe name: SpikeSortingLoader(eid=eid, pname='probe00', one=one) - From a local session and probe name: SpikeSortingLoader(session_path=session_path, pname='probe00') NB: When no ONE instance is passed, any datasets that are loaded will not be recorded. """ one: One = None atlas: None = None pid: str = None eid: str = '' pname: str = '' session_path: Path = '' # the following properties are the outcome of the post init function collections: list = None datasets: list = None # list of all datasets belonging to the session # the following properties are the outcome of a reading function files: dict = None raw_data_files: list = None # list of raw ap and lf files corresponding to the recording collection: str = '' histology: str = '' # 'alf', 'resolved', 'aligned' or 'traced' spike_sorter: str = 'pykilosort' spike_sorting_path: Path = None _sync: dict = None def __post_init__(self): # pid gets precedence if self.pid is not None: try: self.eid, self.pname = self.one.pid2eid(self.pid) except NotImplementedError: if self.eid == '' or self.pname == '': raise IOError("Cannot infer session id and probe name from pid. " "You need to pass eid and pname explicitly when instantiating SpikeSortingLoader.") self.session_path = self.one.eid2path(self.eid) # then eid / pname combination elif self.session_path is None or self.session_path == '': self.session_path = self.one.eid2path(self.eid) # fully local providing a session path else: if self.one: self.eid = self.one.to_eid(self.session_path) else: self.one = One(cache_dir=self.session_path.parents[2], mode='local') df_sessions = cache._make_sessions_df(self.session_path) self.one._cache['sessions'] = df_sessions.set_index('id') self.one._cache['datasets'] = cache._make_datasets_df(self.session_path, hash_files=False) self.eid = str(self.session_path.relative_to(self.session_path.parents[2])) # populates default properties self.collections = self.one.list_collections( self.eid, filename='spikes*', collection=f"alf/{self.pname}*") self.datasets = self.one.list_datasets(self.eid) if self.atlas is None: self.atlas = AllenAtlas() self.files = {} self.raw_data_files = [] def _load_object(self, *args, **kwargs): """ This function is a wrapper around alfio.load_object that will remove the UUID in the filename if the object is on SDSC. """ remove_uuids = getattr(self.one, 'uuid_filenames', False) d = alfio.load_object(*args, **kwargs) if remove_uuids: # pops the UUID in the key names keys = list(d.keys()) for k in keys: d[k[:-37]] = d.pop(k) return d @staticmethod def _get_attributes(dataset_types): """returns attributes to load for spikes and clusters objects""" dataset_types = [] if dataset_types is None else dataset_types spike_attributes = [sp.split('.')[1] for sp in dataset_types if 'spikes.' in sp] spike_attributes = list(set(SPIKES_ATTRIBUTES + spike_attributes)) cluster_attributes = [cl.split('.')[1] for cl in dataset_types if 'clusters.' in cl] cluster_attributes = list(set(CLUSTERS_ATTRIBUTES + cluster_attributes)) waveform_attributes = [cl.split('.')[1] for cl in dataset_types if 'waveforms.' in cl] waveform_attributes = list(set(WAVEFORMS_ATTRIBUTES + waveform_attributes)) return {'spikes': spike_attributes, 'clusters': cluster_attributes, 'waveforms': waveform_attributes} def _get_spike_sorting_collection(self, spike_sorter=None): """ Filters a list or array of collections to get the relevant spike sorting dataset if there is a pykilosort, load it """ for sorter in list([spike_sorter, 'iblsorter', 'pykilosort']): if sorter is None: continue if sorter == "": collection = next(filter(lambda c: c == f'alf/{self.pname}', self.collections), None) else: collection = next(filter(lambda c: c == f'alf/{self.pname}/{sorter}', self.collections), None) if collection is not None: return collection # if none is found amongst the defaults, prefers the shortest collection = collection or next(iter(sorted(filter(lambda c: f'alf/{self.pname}' in c, self.collections), key=len)), None) _logger.debug(f"selecting: {collection} to load amongst candidates: {self.collections}") return collection
[docs] def load_spike_sorting_object(self, obj, *args, **kwargs): """ Loads an ALF object :param obj: object name, str between 'spikes', 'clusters' or 'channels' :param spike_sorter: (defaults to 'pykilosort') :param dataset_types: list of extra dataset types, for example ['spikes.samples'] :param collection: string specifiying the collection, for example 'alf/probe01/pykilosort' :param kwargs: additional arguments to be passed to one.api.One.load_object :param missing: 'raise' (default) or 'ignore' :return: """ self.download_spike_sorting_object(obj, *args, **kwargs) return self._load_object(self.files[obj])
[docs] def get_version(self, spike_sorter='pykilosort'): collection = self._get_spike_sorting_collection(spike_sorter=spike_sorter) dset = self.one.alyx.rest('datasets', 'list', session=self.eid, collection=collection, name='spikes.times.npy') return dset[0]['version'] if len(dset) else 'unknown'
[docs] def download_spike_sorting_object(self, obj, spike_sorter='pykilosort', dataset_types=None, collection=None, attribute=None, missing='raise', **kwargs): """ Downloads an ALF object :param obj: object name, str between 'spikes', 'clusters' or 'channels' :param spike_sorter: (defaults to 'pykilosort') :param dataset_types: list of extra dataset types, for example ['spikes.samples'] :param collection: string specifiying the collection, for example 'alf/probe01/pykilosort' :param kwargs: additional arguments to be passed to one.api.One.load_object :param attribute: list of attributes to load for the object :param missing: 'raise' (default) or 'ignore' :return: """ if len(self.collections) == 0: return {}, {}, {} self.collection = self._get_spike_sorting_collection(spike_sorter=spike_sorter) collection = collection or self.collection _logger.debug(f"loading spike sorting object {obj} from {collection}") attributes = self._get_attributes(dataset_types) try: self.files[obj] = self.one.load_object( self.eid, obj=obj, attribute=attributes.get(obj, None), collection=collection, download_only=True, **kwargs) except ALFObjectNotFound as e: if missing == 'raise': raise e
[docs] def download_spike_sorting(self, objects=None, **kwargs): """ Downloads spikes, clusters and channels :param spike_sorter: (defaults to 'pykilosort') :param dataset_types: list of extra dataset types :param objects: list of objects to download, defaults to ['spikes', 'clusters', 'channels'] :return: """ objects = ['spikes', 'clusters', 'channels'] if objects is None else objects for obj in objects: self.download_spike_sorting_object(obj=obj, **kwargs) self.spike_sorting_path = self.files['clusters'][0].parent
[docs] def download_raw_electrophysiology(self, band='ap'): """ Downloads raw electrophysiology data files on local disk. :param band: "ap" (default) or "lf" for LFP band :return: list of raw data files full paths (ch, meta and cbin files) """ raw_data_files = [] for suffix in [f'*.{band}.ch', f'*.{band}.meta', f'*.{band}.cbin']: try: # FIXME: this will fail if multiple LFP segments are found raw_data_files.append(self.one.load_dataset( self.eid, download_only=True, collection=f'raw_ephys_data/{self.pname}', dataset=suffix, check_hash=False, )) except ALFObjectNotFound: _logger.debug(f"{self.session_path} can't locate raw data collection raw_ephys_data/{self.pname}, file {suffix}") self.raw_data_files = list(set(self.raw_data_files + raw_data_files)) return raw_data_files
[docs] def raw_electrophysiology(self, stream=True, band='ap', **kwargs): """ Returns a reader for the raw electrophysiology data By default it is a streamer object, but if stream is False, it will return a spikeglx.Reader after having downloaded the raw data file if necessary :param stream: :param band: :param kwargs: :return: """ if stream: return Streamer(pid=self.pid, one=self.one, typ=band, **kwargs) else: raw_data_files = self.download_raw_electrophysiology(band=band) cbin_file = next(filter(lambda f: re.match(rf".*\.{band}\..*cbin", f.name), raw_data_files), None) if cbin_file is not None: return spikeglx.Reader(cbin_file)
[docs] def download_raw_waveforms(self, **kwargs): """ Downloads raw waveforms extracted from sorting to local disk. """ _logger.debug(f"loading waveforms from {self.collection}") return self.one.load_object( id=self.eid, obj="waveforms", attribute=["traces", "templates", "table", "channels"], collection=self._get_spike_sorting_collection("pykilosort"), download_only=True, **kwargs )
[docs] def raw_waveforms(self, **kwargs): wf_paths = self.download_raw_waveforms(**kwargs) return WaveformsLoader(wf_paths[0].parent)
[docs] def load_channels(self, **kwargs): """ Loads channels The channel locations can come from several sources, it will load the most advanced version of the histology available, regardless of the spike sorting version loaded. The steps are (from most advanced to fresh out of the imaging): - alf: the final version of channel locations, same as resolved with the difference that data is on file - resolved: channel locations alignments have been agreed upon - aligned: channel locations have been aligned, but review or other alignments are pending, potentially not accurate - traced: the histology track has been recovered from microscopy, however the depths may not match, inaccurate data :param spike_sorter: (defaults to 'pykilosort') :param dataset_types: list of extra dataset types :return: """ # we do not specify the spike sorter on purpose here: the electrode sites do not depend on the spike sorting self.download_spike_sorting_object(obj='electrodeSites', collection=f'alf/{self.pname}', missing='ignore') self.download_spike_sorting_object(obj='channels', missing='ignore', **kwargs) channels = self._load_object(self.files['channels'], wildcards=self.one.wildcards) if 'electrodeSites' in self.files: # if common dict keys, electrodeSites prevails channels = channels | self._load_object(self.files['electrodeSites'], wildcards=self.one.wildcards) if 'brainLocationIds_ccf_2017' not in channels: _logger.debug(f"loading channels from alyx for {self.files['channels']}") _channels, self.histology = _load_channel_locations_traj( self.eid, probe=self.pname, one=self.one, brain_atlas=self.atlas, return_source=True, aligned=True) if _channels: channels = _channels[self.pname] else: channels = _channels_alf2bunch(channels, brain_regions=self.atlas.regions) self.histology = 'alf' return Bunch(channels)
[docs] def load_spike_sorting(self, spike_sorter='iblsorter', revision=None, enforce_version=False, good_units=False, **kwargs): """ Loads spikes, clusters and channels There could be several spike sorting collections, by default the loader will get the pykilosort collection The channel locations can come from several sources, it will load the most advanced version of the histology available, regardless of the spike sorting version loaded. The steps are (from most advanced to fresh out of the imaging): - alf: the final version of channel locations, same as resolved with the difference that data is on file - resolved: channel locations alignments have been agreed upon - aligned: channel locations have been aligned, but review or other alignments are pending, potentially not accurate - traced: the histology track has been recovered from microscopy, however the depths may not match, inaccurate data :param spike_sorter: (defaults to 'pykilosort') :param revision: for example "2024-05-06", (defaults to None): :param enforce_version: if True, will raise an error if the spike sorting version and revision is not the expected one :param dataset_types: list of extra dataset types, for example: ['spikes.samples', 'spikes.templates'] :param good_units: False, if True will load only the good units, possibly by downloading a smaller spikes table :param kwargs: additional arguments to be passed to one.api.One.load_object :return: """ if len(self.collections) == 0: return {}, {}, {} self.files = {} self.spike_sorter = spike_sorter self.revision = revision objects = ['passingSpikes', 'clusters', 'channels'] if good_units else None self.download_spike_sorting(spike_sorter=spike_sorter, revision=revision, objects=objects, **kwargs) channels = self.load_channels(spike_sorter=spike_sorter, revision=revision, **kwargs) clusters = self._load_object(self.files['clusters'], wildcards=self.one.wildcards) if good_units: spikes = self._load_object(self.files['passingSpikes'], wildcards=self.one.wildcards) else: spikes = self._load_object(self.files['spikes'], wildcards=self.one.wildcards) if enforce_version: self._assert_version_consistency() return spikes, clusters, channels
def _assert_version_consistency(self): """ Makes sure the state of the spike sorting object matches the files downloaded :return: None """ for k in ['spikes', 'clusters', 'channels', 'passingSpikes']: for fn in self.files.get(k, []): if self.spike_sorter: assert fn.relative_to(self.session_path).parts[2] == self.spike_sorter, \ f"You required strict version {self.spike_sorter}, {fn} does not match" if self.revision: assert full_path_parts(fn)[5] == self.revision, \ f"You required strict revision {self.revision}, {fn} does not match"
[docs] @staticmethod def compute_metrics(spikes, clusters=None): nc = clusters['channels'].size if clusters else np.unique(spikes['clusters']).size metrics = pd.DataFrame(quick_unit_metrics( spikes['clusters'], spikes['times'], spikes['amps'], spikes['depths'], cluster_ids=np.arange(nc))) return metrics
[docs] @staticmethod def merge_clusters(spikes, clusters, channels, cache_dir=None, compute_metrics=False): """ Merge the metrics and the channel information into the clusters dictionary :param spikes: :param clusters: :param channels: :param cache_dir: if specified, will look for a cached parquet file to speed up. This is to be used for clusters or analysis applications (defaults to None). :param compute_metrics: if True, will explicitly recompute metrics (defaults to false) :return: cluster dictionary containing metrics and histology """ if spikes == {}: return nc = clusters['channels'].size # recompute metrics if they are not available metrics = None if 'metrics' in clusters: metrics = clusters.pop('metrics') if metrics.shape[0] != nc: metrics = None if metrics is None or compute_metrics is True: _logger.debug("recompute clusters metrics") metrics = SpikeSortingLoader.compute_metrics(spikes, clusters) if isinstance(cache_dir, Path): metrics.to_parquet(Path(cache_dir).joinpath('clusters.metrics.pqt')) for k in metrics.keys(): clusters[k] = metrics[k].to_numpy() for k in channels.keys(): clusters[k] = channels[k][clusters['channels']] if cache_dir is not None: _logger.debug(f'caching clusters metrics in {cache_dir}') pd.DataFrame(clusters).to_parquet(Path(cache_dir).joinpath('clusters.pqt')) return clusters
@property def url(self): """Gets flatiron URL for the session""" webclient = getattr(self.one, '_web_client', None) return webclient.rel_path2url(get_alf_path(self.session_path)) if webclient else None def _get_probe_info(self): if self._sync is None: timestamps = self.one.load_dataset( self.eid, dataset='_spikeglx_*.timestamps.npy', collection=f'raw_ephys_data/{self.pname}') _ = self.one.load_dataset( # this is not used here but we want to trigger the download for potential tasks self.eid, dataset='_spikeglx_*.sync.npy', collection=f'raw_ephys_data/{self.pname}') try: ap_meta = spikeglx.read_meta_data(self.one.load_dataset( self.eid, dataset='_spikeglx_*.ap.meta', collection=f'raw_ephys_data/{self.pname}')) fs = spikeglx._get_fs_from_meta(ap_meta) except ALFObjectNotFound: ap_meta = None fs = 30_000 self._sync = { 'timestamps': timestamps, 'forward': interp1d(timestamps[:, 0], timestamps[:, 1], fill_value='extrapolate'), 'reverse': interp1d(timestamps[:, 1], timestamps[:, 0], fill_value='extrapolate'), 'ap_meta': ap_meta, 'fs': fs, }
[docs] def timesprobe2times(self, values, direction='forward'): self._get_probe_info() if direction == 'forward': return self._sync['forward'](values * self._sync['fs']) elif direction == 'reverse': return self._sync['reverse'](values) / self._sync['fs']
[docs] def samples2times(self, values, direction='forward'): """ Converts ephys sample values to session main clock seconds :param values: numpy array of times in seconds or samples to resync :param direction: 'forward' (samples probe time to seconds main time) or 'reverse' (seconds main time to samples probe time) :return: """ self._get_probe_info() return self._sync[direction](values)
@property def pid2ref(self): return f"{self.one.eid2ref(self.eid, as_dict=False)}_{self.pname}" def _default_plot_title(self, spikes): title = f"{self.pid2ref}, {self.pid} \n" \ f"{spikes['clusters'].size:_} spikes, {np.unique(spikes['clusters']).size:_} clusters" return title
[docs] def raster(self, spikes, channels, save_dir=None, br=None, label='raster', time_series=None, drift=None, title=None, **kwargs): """ :param spikes: spikes dictionary or Bunch :param channels: channels dictionary or Bunch. :param save_dir: if specified save to this directory as "{pid}_{probe}_{label}.png". Otherwise, plot. :param br: brain regions object (optional) :param label: label for saved image (optional, default="raster") :param time_series: timeseries dictionary for behavioral event times (optional) :param **kwargs: kwargs passed to `driftmap()` (optional) :return: """ br = br or BrainRegions() time_series = time_series or {} fig, axs = plt.subplots(2, 2, gridspec_kw={ 'width_ratios': [.95, .05], 'height_ratios': [.1, .9]}, figsize=(16, 9), sharex='col') axs[0, 1].set_axis_off() # axs[0, 0].set_xticks([]) if kwargs is None: # set default raster plot parameters kwargs = {"t_bin": 0.007, "d_bin": 10, "vmax": 0.5} brainbox.plot.driftmap(spikes['times'], spikes['depths'], ax=axs[1, 0], **kwargs) if title is None: title = self._default_plot_title(spikes) axs[0, 0].title.set_text(title) for k, ts in time_series.items(): vertical_lines(ts, ymin=0, ymax=3800, ax=axs[1, 0]) if 'atlas_id' in channels: plot_brain_regions(channels['atlas_id'], channel_depths=channels['axial_um'], brain_regions=br, display=True, ax=axs[1, 1], title=self.histology) axs[1, 0].set_ylim(0, 3800) axs[1, 0].set_xlim(spikes['times'][0], spikes['times'][-1]) fig.tight_layout() if drift is None: self.download_spike_sorting_object('drift', self.spike_sorter, missing='ignore') if 'drift' in self.files: drift = self._load_object(self.files['drift'], wildcards=self.one.wildcards) if isinstance(drift, dict): axs[0, 0].plot(drift['times'], drift['um'], 'k', alpha=.5) axs[0, 0].set(ylim=[-15, 15]) if save_dir is not None: png_file = save_dir.joinpath(f"{self.pid}_{self.pid2ref}_{label}.png") if Path(save_dir).is_dir() else Path(save_dir) fig.savefig(png_file) plt.close(fig) gc.collect() else: return fig, axs
[docs] def plot_rawdata_snippet(self, sr, spikes, clusters, t0, channels=None, br: BrainRegions = None, save_dir=None, label='raster', gain=-93, title=None): # compute the raw data offset and destripe, we take 400ms around t0 first_sample, last_sample = (int((t0 - 0.2) * sr.fs), int((t0 + 0.2) * sr.fs)) raw = sr[first_sample:last_sample, :-sr.nsync].T channel_labels = channels['labels'] if (channels is not None) and ('labels' in channels) else True destriped = ibldsp.voltage.destripe(raw, sr.fs, channel_labels=channel_labels) # filter out the spikes according to good/bad clusters and to the time slice spike_sel = slice(*np.searchsorted(spikes['samples'], [first_sample, last_sample])) ss = spikes['samples'][spike_sel] sc = clusters['channels'][spikes['clusters'][spike_sel]] sok = clusters['label'][spikes['clusters'][spike_sel]] == 1 if title is None: title = self._default_plot_title(spikes) # display the raw data snippet with spikes overlaid fig, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9), sharex='col') Density(destriped, fs=sr.fs, taxis=1, gain=gain, ax=axs[0], t0=t0 - 0.2, unit='s') axs[0].scatter(ss[sok] / sr.fs, sc[sok], color="green", alpha=0.5) axs[0].scatter(ss[~sok] / sr.fs, sc[~sok], color="red", alpha=0.5) axs[0].set(title=title, xlim=[t0 - 0.035, t0 + 0.035]) # adds the channel locations if available if (channels is not None) and ('atlas_id' in channels): br = br or BrainRegions() plot_brain_regions(channels['atlas_id'], channel_depths=channels['axial_um'], brain_regions=br, display=True, ax=axs[1], title=self.histology) axs[1].get_yaxis().set_visible(False) fig.tight_layout() if save_dir is not None: png_file = save_dir.joinpath(f"{self.pid}_{self.pid2ref}_{label}.png") if Path(save_dir).is_dir() else Path(save_dir) fig.savefig(png_file) plt.close(fig) gc.collect() else: return fig, axs
[docs] @dataclass class SessionLoader: """ Object to load session data for a give session in the recommended way. Parameters ---------- one: one.api.ONE instance Can be in remote or local mode (required) session_path: string or pathlib.Path The absolute path to the session (one of session_path or eid is required) eid: string database UUID of the session (one of session_path or eid is required) If both are provided, session_path takes precedence over eid. Examples -------- 1) Load all available session data for one session: >>> from one.api import ONE >>> from brainbox.io.one import SessionLoader >>> one = ONE() >>> sess_loader = SessionLoader(one=one, session_path='/mnt/s0/Data/Subjects/cortexlab/KS022/2019-12-10/001/') # Object is initiated, but no data is loaded as you can see in the data_info attribute >>> sess_loader.data_info name is_loaded 0 trials False 1 wheel False 2 pose False 3 motion_energy False 4 pupil False # Loading all available session data, the data_info attribute now shows which data has been loaded >>> sess_loader.load_session_data() >>> sess_loader.data_info name is_loaded 0 trials True 1 wheel True 2 pose True 3 motion_energy True 4 pupil False # The data is loaded in pandas dataframes that you can access via the respective attributes, e.g. >>> type(sess_loader.trials) pandas.core.frame.DataFrame >>> sess_loader.trials.shape (626, 18) # Each data comes with its own timestamps in a column called 'times' >>> sess_loader.wheel['times'] 0 0.134286 1 0.135286 2 0.136286 3 0.137286 4 0.138286 ... # For camera data (pose, motionEnergy) the respective functions load the data into one dataframe per camera. # The dataframes of all cameras are collected in a dictionary >>> type(sess_loader.pose) dict >>> sess_loader.pose.keys() dict_keys(['leftCamera', 'rightCamera', 'bodyCamera']) >>> sess_loader.pose['bodyCamera'].columns Index(['times', 'tail_start_x', 'tail_start_y', 'tail_start_likelihood'], dtype='object') # In order to control the loading of specific data by e.g. specifying parameters, use the individual loading functions: >>> sess_loader.load_wheel(sampling_rate=100) """ one: One = None session_path: Path = '' eid: str = '' revision: str = '' data_info: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False) trials: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False) wheel: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False) pose: dict = field(default_factory=dict, repr=False) motion_energy: dict = field(default_factory=dict, repr=False) pupil: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False) def __post_init__(self): """ Function that runs automatically after initiation of the dataclass attributes. Checks for required inputs, sets session_path and eid, creates data_info table. """ if self.one is None: raise ValueError("An input to one is required. If not connection to a database is desired, it can be " "a fully local instance of One.") # If session path is given, takes precedence over eid if self.session_path is not None and self.session_path != '': self.eid = self.one.to_eid(self.session_path) self.session_path = Path(self.session_path) # Providing no session path, try to infer from eid else: if self.eid is not None and self.eid != '': self.session_path = self.one.eid2path(self.eid) else: raise ValueError("If no session path is given, eid is required.") data_names = [ 'trials', 'wheel', 'pose', 'motion_energy', 'pupil' ] self.data_info = pd.DataFrame(columns=['name', 'is_loaded'], data=zip(data_names, [False] * len(data_names)))
[docs] def load_session_data(self, trials=True, wheel=True, pose=True, motion_energy=True, pupil=True, reload=False): """ Function to load available session data into the SessionLoader object. Input parameters allow to control which data is loaded. Data is loaded into an attribute of the SessionLoader object with the same name as the input parameter (e.g. SessionLoader.trials, SessionLoader.pose). Information about which data is loaded is stored in SessionLoader.data_info Parameters ---------- trials: boolean Whether to load all trials data into SessionLoader.trials, default is True wheel: boolean Whether to load wheel data (position, velocity, acceleration) into SessionLoader.wheel, default is True pose: boolean Whether to load pose tracking results (DLC) for each available camera into SessionLoader.pose, default is True motion_energy: boolean Whether to load motion energy data (whisker pad for left/right camera, body for body camera) into SessionLoader.motion_energy, default is True pupil: boolean Whether to load pupil diameter (raw and smooth) for the left/right camera into SessionLoader.pupil, default is True reload: boolean Whether to reload data that has already been loaded into this SessionLoader object, default is False """ load_df = self.data_info.copy() load_df['to_load'] = [ trials, wheel, pose, motion_energy, pupil ] load_df['load_func'] = [ self.load_trials, self.load_wheel, self.load_pose, self.load_motion_energy, self.load_pupil ] for idx, row in load_df.iterrows(): if row['to_load'] is False: _logger.debug(f"Not loading {row['name']} data, set to False.") elif row['is_loaded'] is True and reload is False: _logger.debug(f"Not loading {row['name']} data, is already loaded and reload=False.") else: try: _logger.info(f"Loading {row['name']} data") row['load_func']() self.data_info.loc[idx, 'is_loaded'] = True except BaseException as e: _logger.warning(f"Could not load {row['name']} data.") _logger.debug(e)
def _find_behaviour_collection(self, obj): """ Function to find the trial or wheel collection Parameters ---------- obj: str Alf object to load, either 'trials' or 'wheel' """ dataset = '_ibl_trials.table.pqt' if obj == 'trials' else '_ibl_wheel.position.npy' dsets = self.one.list_datasets(self.eid, dataset) if len(dsets) == 0: return 'alf' else: collections = [full_path_parts(self.session_path.joinpath(d), as_dict=True)['collection'] for d in dsets] if len(set(collections)) == 1: return collections[0] else: _logger.error(f'Multiple collections found {collections}. Specify collection when loading, ' f'e.g sl.load_{obj}(collection="{collections[0]}")') raise ALFMultipleCollectionsFound
[docs] def load_trials(self, collection=None): """ Function to load trials data into SessionLoader.trials Parameters ---------- collection: str Alf collection of trials data """ if not collection: collection = self._find_behaviour_collection('trials') # itiDuration frequently has a mismatched dimension, and we don't need it, exclude using regex self.one.wildcards = False self.trials = self.one.load_object( self.eid, 'trials', collection=collection, attribute=r'(?!itiDuration).*', revision=self.revision or None).to_df() self.one.wildcards = True self.data_info.loc[self.data_info['name'] == 'trials', 'is_loaded'] = True
[docs] def load_wheel(self, fs=1000, corner_frequency=20, order=8, collection=None): """ Function to load wheel data (position, velocity, acceleration) into SessionLoader.wheel. The wheel position is first interpolated to a uniform sampling rate. Then velocity and acceleration are computed, during which a Butterworth low-pass filter is applied. Parameters ---------- fs: int, float Sampling frequency for the wheel position, default is 1000 Hz corner_frequency: int, float Corner frequency of Butterworth low-pass filter, default is 20 order: int, float Order of Butterworth low_pass filter, default is 8 collection: str Alf collection of wheel data """ if not collection: collection = self._find_behaviour_collection('wheel') wheel_raw = self.one.load_object(self.eid, 'wheel', collection=collection, revision=self.revision or None) if wheel_raw['position'].shape[0] != wheel_raw['timestamps'].shape[0]: raise ValueError("Length mismatch between 'wheel.position' and 'wheel.timestamps") # resample the wheel position and compute velocity, acceleration self.wheel = pd.DataFrame(columns=['times', 'position', 'velocity', 'acceleration']) self.wheel['position'], self.wheel['times'] = interpolate_position( wheel_raw['timestamps'], wheel_raw['position'], freq=fs) self.wheel['velocity'], self.wheel['acceleration'] = velocity_filtered( self.wheel['position'], fs=fs, corner_frequency=corner_frequency, order=order) self.wheel = self.wheel.apply(np.float32) self.data_info.loc[self.data_info['name'] == 'wheel', 'is_loaded'] = True
[docs] def load_pose(self, likelihood_thr=0.9, views=['left', 'right', 'body']): """ Function to load the pose estimation results (DLC) into SessionLoader.pose. SessionLoader.pose is a dictionary where keys are the names of the cameras for which pose data is loaded, and values are pandas Dataframes with the timestamps and pose data, one row for each body part tracked for that camera. Parameters ---------- likelihood_thr: float The position of each tracked body part come with a likelihood of that estimate for each time point. Estimates for time points with likelihood < likelihood_thr are set to NaN. To skip thresholding set likelihood_thr=1. Default is 0.9 views: list List of camera views for which to try and load data. Possible options are {'left', 'right', 'body'} """ # empty the dictionary so that if one loads only one view, after having loaded several, the others don't linger self.pose = {} for view in views: pose_raw = self.one.load_object(self.eid, f'{view}Camera', attribute=['dlc', 'times'], revision=self.revision or None) # Double check if video timestamps are correct length or can be fixed times_fixed, dlc = self._check_video_timestamps(view, pose_raw['times'], pose_raw['dlc']) self.pose[f'{view}Camera'] = likelihood_threshold(dlc, likelihood_thr) self.pose[f'{view}Camera'].insert(0, 'times', times_fixed) self.data_info.loc[self.data_info['name'] == 'pose', 'is_loaded'] = True
[docs] def load_motion_energy(self, views=['left', 'right', 'body']): """ Function to load the motion energy data into SessionLoader.motion_energy. SessionLoader.motion_energy is a dictionary where keys are the names of the cameras for which motion energy data is loaded, and values are pandas Dataframes with the timestamps and motion energy data. The motion energy for the left and right camera is calculated for a square roughly covering the whisker pad (whiskerMotionEnergy). The motion energy for the body camera is calculated for a square covering much of the body (bodyMotionEnergy). Parameters ---------- views: list List of camera views for which to try and load data. Possible options are {'left', 'right', 'body'} """ names = {'left': 'whiskerMotionEnergy', 'right': 'whiskerMotionEnergy', 'body': 'bodyMotionEnergy'} # empty the dictionary so that if one loads only one view, after having loaded several, the others don't linger self.motion_energy = {} for view in views: me_raw = self.one.load_object( self.eid, f'{view}Camera', attribute=['ROIMotionEnergy', 'times'], revision=self.revision or None) # Double check if video timestamps are correct length or can be fixed times_fixed, motion_energy = self._check_video_timestamps( view, me_raw['times'], me_raw['ROIMotionEnergy']) self.motion_energy[f'{view}Camera'] = pd.DataFrame(columns=[names[view]], data=motion_energy) self.motion_energy[f'{view}Camera'].insert(0, 'times', times_fixed) self.data_info.loc[self.data_info['name'] == 'motion_energy', 'is_loaded'] = True
[docs] def load_licks(self): """ Not yet implemented """ pass
[docs] def load_pupil(self, snr_thresh=5.): """ Function to load raw and smoothed pupil diameter data from the left camera into SessionLoader.pupil. Parameters ---------- snr_thresh: float An SNR is calculated from the raw and smoothed pupil diameter. If this snr < snr_thresh the data will be considered unusable and will be discarded. """ # Try to load from features feat_raw = self.one.load_object(self.eid, 'leftCamera', attribute=['times', 'features'], revision=self.revision or None) if 'features' in feat_raw.keys(): times_fixed, feats = self._check_video_timestamps('left', feat_raw['times'], feat_raw['features']) self.pupil = feats.copy() self.pupil.insert(0, 'times', times_fixed) # If unavailable compute on the fly else: _logger.info('Pupil diameter not available, trying to compute on the fly.') if (self.data_info[self.data_info['name'] == 'pose']['is_loaded'].values[0] and 'leftCamera' in self.pose.keys()): # If pose data is already loaded, we don't know if it was threshold at 0.9, so we need a little stunt copy_pose = self.pose['leftCamera'].copy() # Save the previously loaded pose data self.load_pose(views=['left'], likelihood_thr=0.9) # Load new with threshold 0.9 dlc_thr = self.pose['leftCamera'].copy() # Save the threshold pose data in new variable self.pose['leftCamera'] = copy_pose.copy() # Get previously loaded pose data back in place else: self.load_pose(views=['left'], likelihood_thr=0.9) dlc_thr = self.pose['leftCamera'].copy() self.pupil['pupilDiameter_raw'] = get_pupil_diameter(dlc_thr) try: self.pupil['pupilDiameter_smooth'] = get_smooth_pupil_diameter(self.pupil['pupilDiameter_raw'], 'left') except BaseException as e: _logger.error("Loaded raw pupil diameter but computing smooth pupil diameter failed. " "Saving all NaNs for pupilDiameter_smooth.") _logger.debug(e) self.pupil['pupilDiameter_smooth'] = np.nan if not np.all(np.isnan(self.pupil['pupilDiameter_smooth'])): good_idxs = np.where( ~np.isnan(self.pupil['pupilDiameter_smooth']) & ~np.isnan(self.pupil['pupilDiameter_raw']))[0] snr = (np.var(self.pupil['pupilDiameter_smooth'][good_idxs]) / (np.var(self.pupil['pupilDiameter_smooth'][good_idxs] - self.pupil['pupilDiameter_raw'][good_idxs]))) if snr < snr_thresh: self.pupil = pd.DataFrame() raise ValueError(f'Pupil diameter SNR ({snr:.2f}) below threshold SNR ({snr_thresh}), removing data.')
def _check_video_timestamps(self, view, video_timestamps, video_data): """ Helper function to check for the length of the video frames vs video timestamps and fix in case timestamps are longer than video frames. """ # If camera times are shorter than video data, or empty, no current fix if video_timestamps.shape[0] < video_data.shape[0]: if video_timestamps.shape[0] == 0: msg = f'Camera times empty for {view}Camera.' else: msg = f'Camera times are shorter than video data for {view}Camera.' _logger.warning(msg) raise ValueError(msg) # For pre-GPIO sessions, it is possible that the camera times are longer than the actual video. # This is because the first few frames are sometimes not recorded. We can remove the first few # timestamps in this case elif video_timestamps.shape[0] > video_data.shape[0]: video_timestamps_fixed = video_timestamps[-video_data.shape[0]:] return video_timestamps_fixed, video_data else: return video_timestamps, video_data
[docs] class EphysSessionLoader(SessionLoader): """ Spike sorting enhanced version of SessionLoader Loads spike sorting data for all probes in the session, in the self.ephys dict >>> EphysSessionLoader(eid=eid, one=one) To select for a specific probe >>> EphysSessionLoader(eid=eid, one=one, pid=pid) """ def __init__(self, *args, pname=None, pid=None, **kwargs): """ Needs an active connection in order to get the list of insertions in the session :param args: :param kwargs: """ super().__init__(*args, **kwargs) # if necessary, restrict the query qargs = {} if pname is None else {'name': pname} qargs = qargs or ({} if pid is None else {'id': pid}) insertions = self.one.alyx.rest('insertions', 'list', session=self.eid, **qargs) self.ephys = {} for ins in insertions: self.ephys[ins['name']] = {} self.ephys[ins['name']]['ssl'] = SpikeSortingLoader(pid=ins['id'], one=self.one)
[docs] def load_session_data(self, *args, **kwargs): super().load_session_data(*args, **kwargs) self.load_spike_sorting()
[docs] def load_spike_sorting(self, pnames=None): pnames = pnames or list(self.ephys.keys()) for pname in pnames: spikes, clusters, channels = self.ephys[pname]['ssl'].load_spike_sorting() self.ephys[pname]['spikes'] = spikes self.ephys[pname]['clusters'] = clusters self.ephys[pname]['channels'] = channels
@property def probes(self): return {k: self.ephys[k]['ssl'].pid for k in self.ephys}