Source code for ibllib.io.extractors.ephys_fpga

"""Data extraction from raw FPGA output.

The behaviour extraction happens in the following stages:

    1. The NI DAQ events are extracted into a map of event times and TTL polarities.
    2. The Bpod trial events are extracted from the raw Bpod data, depending on the task protocol.
    3. As protocols may be chained together within a given recording, the period of a given task
       protocol is determined using the 'spacer' DAQ signal (see `get_protocol_period`).
    4. Physical behaviour events such as stim on and reward time are separated out by TTL length or
       sequence within the trial.
    5. The Bpod clock is sync'd with the FPGA using one of the extracted trial events.
    6. The Bpod software events are then converted to FPGA time.

Examples
--------
For simple extraction, use the FPGATrials class:

>>> extractor = FpgaTrials(session_path)
>>> trials, _ = extractor.extract(update=False, save=False)

Notes
-----
Sync extraction in this module only supports FPGA data acquired with an NI DAQ as part of a
Neuropixels recording system, however a sync and channel map extracted from a different DAQ format
can be passed to the FpgaTrials class.

See Also
--------
For dynamic pipeline sessions it is best to call the extractor via the BehaviorTask class.

TODO notes on subclassing various methods of FpgaTrials for custom hardware.
"""
import logging
from itertools import cycle
from pathlib import Path
import uuid
import re
from functools import partial

import matplotlib.pyplot as plt
from matplotlib.colors import TABLEAU_COLORS
import numpy as np
from packaging import version

import spikeglx
import ibldsp.utils
import one.alf.io as alfio
from one.alf.path import filename_parts
from iblutil.util import Bunch
from iblutil.spacer import Spacer

import ibllib.exceptions as err
from ibllib.io import raw_data_loaders as raw, session_params
from ibllib.io.extractors.bpod_trials import get_bpod_extractor
import ibllib.io.extractors.base as extractors_base
from ibllib.io.extractors.training_wheel import extract_wheel_moves
from ibllib import plots
from ibllib.io.extractors.default_channel_maps import DEFAULT_MAPS

_logger = logging.getLogger(__name__)

SYNC_BATCH_SIZE_SECS = 100
"""int: Number of samples to read at once in bin file for sync."""

WHEEL_RADIUS_CM = 1  # stay in radians
"""float: The radius of the wheel used in the task. A value of 1 ensures units remain in radians."""

WHEEL_TICKS = 1024
"""int: The number of encoder pulses per channel for one complete rotation."""

BPOD_FPGA_DRIFT_THRESHOLD_PPM = 150
"""int: Throws an error if Bpod to FPGA clock drift is higher than this value."""

CHMAPS = {'3A':
          {'ap':
           {'left_camera': 2,
            'right_camera': 3,
            'body_camera': 4,
            'bpod': 7,
            'frame2ttl': 12,
            'rotary_encoder_0': 13,
            'rotary_encoder_1': 14,
            'audio': 15
            }
           },
          '3B':
          {'nidq':
           {'left_camera': 0,
            'right_camera': 1,
            'body_camera': 2,
            'imec_sync': 3,
            'frame2ttl': 4,
            'rotary_encoder_0': 5,
            'rotary_encoder_1': 6,
            'audio': 7,
            'bpod': 16,
            'laser': 17,
            'laser_ttl': 18},
           'ap':
           {'imec_sync': 6}
           },
          }
"""dict: The default channel indices corresponding to various devices for different recording systems."""


[docs] def data_for_keys(keys, data): """Check keys exist in 'data' dict and contain values other than None.""" return data is not None and all(k in data and data.get(k, None) is not None for k in keys)
[docs] def get_ibl_sync_map(ef, version): """ Gets default channel map for the version/binary file type combination :param ef: ibllib.io.spikeglx.glob_ephys_file dictionary with field 'ap' or 'nidq' :return: channel map dictionary """ # Determine default channel map if version == '3A': default_chmap = CHMAPS['3A']['ap'] elif version == '3B': if ef.get('nidq', None): default_chmap = CHMAPS['3B']['nidq'] elif ef.get('ap', None): default_chmap = CHMAPS['3B']['ap'] # Try to load channel map from file chmap = spikeglx.get_sync_map(ef['path']) # If chmap provided but not with all keys, fill up with default values if not chmap: return default_chmap else: if data_for_keys(default_chmap.keys(), chmap): return chmap else: _logger.warning("Keys missing from provided channel map, " "setting missing keys from default channel map") return {**default_chmap, **chmap}
def _sync_to_alf(raw_ephys_apfile, output_path=None, save=False, parts=''): """ Extracts sync.times, sync.channels and sync.polarities from binary ephys dataset :param raw_ephys_apfile: bin file containing ephys data or spike :param output_path: output directory :param save: bool write to disk only if True :param parts: string or list of strings that will be appended to the filename before extension :return: """ # handles input argument: support ibllib.io.spikeglx.Reader, str and pathlib.Path if isinstance(raw_ephys_apfile, spikeglx.Reader): sr = raw_ephys_apfile else: raw_ephys_apfile = Path(raw_ephys_apfile) sr = spikeglx.Reader(raw_ephys_apfile) if not (opened := sr.is_open): sr.open() # if no output, need a temp folder to swap for big files if not output_path: output_path = raw_ephys_apfile.parent file_ftcp = Path(output_path).joinpath(f'fronts_times_channel_polarity{uuid.uuid4()}.bin') # loop over chunks of the raw ephys file wg = ibldsp.utils.WindowGenerator(sr.ns, int(SYNC_BATCH_SIZE_SECS * sr.fs), overlap=1) fid_ftcp = open(file_ftcp, 'wb') for sl in wg.slice: ss = sr.read_sync(sl) ind, fronts = ibldsp.utils.fronts(ss, axis=0) # a = sr.read_sync_analog(sl) sav = np.c_[(ind[0, :] + sl.start) / sr.fs, ind[1, :], fronts.astype(np.double)] sav.tofile(fid_ftcp) # close temp file, read from it and delete fid_ftcp.close() tim_chan_pol = np.fromfile(str(file_ftcp)) tim_chan_pol = tim_chan_pol.reshape((int(tim_chan_pol.size / 3), 3)) file_ftcp.unlink() sync = {'times': tim_chan_pol[:, 0], 'channels': tim_chan_pol[:, 1], 'polarities': tim_chan_pol[:, 2]} # If opened Reader was passed into function, leave open if not opened: sr.close() if save: out_files = alfio.save_object_npy(output_path, sync, 'sync', namespace='spikeglx', parts=parts) return Bunch(sync), out_files else: return Bunch(sync) def _rotary_encoder_positions_from_fronts(ta, pa, tb, pb, ticks=WHEEL_TICKS, radius=WHEEL_RADIUS_CM, coding='x4'): """ Extracts the rotary encoder absolute position as function of time from fronts detected on the 2 channels. Outputs in units of radius parameters, by default radians Coding options detailed here: http://www.ni.com/tutorial/7109/pt/ Here output is clockwise from subject perspective :param ta: time of fronts on channel A :param pa: polarity of fronts on channel A :param tb: time of fronts on channel B :param pb: polarity of fronts on channel B :param ticks: number of ticks corresponding to a full revolution (1024 for IBL rotary encoder) :param radius: radius of the wheel. Defaults to 1 for an output in radians :param coding: x1, x2 or x4 coding (IBL default is x4) :return: indices vector (ta) and position vector """ if coding == 'x1': ia = np.searchsorted(tb, ta[pa == 1]) ia = ia[ia < ta.size] ia = ia[pa[ia] == 1] ib = np.searchsorted(ta, tb[pb == 1]) ib = ib[ib < tb.size] ib = ib[pb[ib] == 1] t = np.r_[ta[ia], tb[ib]] p = np.r_[ia * 0 + 1, ib * 0 - 1] ordre = np.argsort(t) t = t[ordre] p = p[ordre] p = np.cumsum(p) / ticks * np.pi * 2 * radius return t, p elif coding == 'x2': p = pb[np.searchsorted(tb, ta) - 1] * pa p = - np.cumsum(p) / ticks * np.pi * 2 * radius / 2 return ta, p elif coding == 'x4': p = np.r_[pb[np.searchsorted(tb, ta) - 1] * pa, -pa[np.searchsorted(ta, tb) - 1] * pb] t = np.r_[ta, tb] ordre = np.argsort(t) t = t[ordre] p = p[ordre] p = - np.cumsum(p) / ticks * np.pi * 2 * radius / 4 return t, p def _assign_events_to_trial(t_trial_start, t_event, take='last', t_trial_end=None): """ Assign events to a trial given trial start times and event times. Trials without an event result in nan value in output time vector. The output has a consistent size with t_trial_start and ready to output to alf. Parameters ---------- t_trial_start : numpy.array An array of start times, used to bin edges for assigning values from `t_event`. t_event : numpy.array An array of event times to assign to trials. take : str {'first', 'last'}, int 'first' takes first event > t_trial_start; 'last' takes last event < the next t_trial_start; an int defines the index to take for events within trial bounds. The index may be negative. t_trial_end : numpy.array Optional array of end times, used to bin edges for assigning values from `t_event`. Returns ------- numpy.array An array the length of `t_trial_start` containing values from `t_event`. Unassigned values are replaced with np.nan. See Also -------- FpgaTrials._assign_events - Assign trial events based on TTL length. """ # make sure the events are sorted try: assert np.all(np.diff(t_trial_start) >= 0) except AssertionError: raise ValueError('Trial starts vector not sorted') try: assert np.all(np.diff(t_event) >= 0) except AssertionError: raise ValueError('Events vector is not sorted') # remove events that happened before the first trial start remove = t_event < t_trial_start[0] if t_trial_end is not None: if not np.all(np.diff(t_trial_end) >= 0): raise ValueError('Trial end vector not sorted') if not np.all(t_trial_end[:-1] < t_trial_start[1:]): raise ValueError('Trial end times must not overlap with trial start times') # remove events between end and next start, and after last end remove |= t_event > t_trial_end[-1] for e, s in zip(t_trial_end[:-1], t_trial_start[1:]): remove |= np.logical_and(s > t_event, t_event >= e) t_event = t_event[~remove] ind = np.searchsorted(t_trial_start, t_event) - 1 t_event_nans = np.zeros_like(t_trial_start) * np.nan # select first or last element matching each trial start if take == 'last': iall, iu = np.unique(np.flip(ind), return_index=True) t_event_nans[iall] = t_event[- (iu - ind.size + 1)] elif take == 'first': iall, iu = np.unique(ind, return_index=True) t_event_nans[iall] = t_event[iu] else: # if the index is arbitrary, needs to be numeric (could be negative if from the end) iall = np.unique(ind) minsize = take + 1 if take >= 0 else - take # for each trial, take the take nth element if there are enough values in trial for iu in iall: match = t_event[iu == ind] if len(match) >= minsize: t_event_nans[iu] = match[take] return t_event_nans
[docs] def get_sync_fronts(sync, channel_nb, tmin=None, tmax=None): """ Return the sync front polarities and times for a given channel. Parameters ---------- sync : dict 'polarities' of fronts detected on sync trace for all 16 channels and their 'times'. channel_nb : int The integer corresponding to the desired sync channel. tmin : float The minimum time from which to extract the sync pulses. tmax : float The maximum time up to which we extract the sync pulses. Returns ------- Bunch Channel times and polarities. """ selection = sync['channels'] == channel_nb selection = np.logical_and(selection, sync['times'] <= tmax) if tmax else selection selection = np.logical_and(selection, sync['times'] >= tmin) if tmin else selection return Bunch({'times': sync['times'][selection], 'polarities': sync['polarities'][selection]})
def _clean_audio(audio, display=False): """ one guy wired the 150 Hz camera output onto the soundcard. The effect is to get 150 Hz periodic square pulses, 2ms up and 4.666 ms down. When this happens we remove all of the intermediate pulses to repair the audio trace Here is some helper code dd = np.diff(audio['times']) 1 / np.median(dd[::2]) # 2ms up 1 / np.median(dd[1::2]) # 4.666 ms down 1 / (np.median(dd[::2]) + np.median(dd[1::2])) # both sum to 150 Hz This only runs on sessions when the bug is detected and leaves others untouched """ DISCARD_THRESHOLD = 0.01 average_150_hz = np.mean(1 / np.diff(audio['times'][audio['polarities'] == 1]) > 140) naudio = audio['times'].size if average_150_hz > 0.7 and naudio > 100: _logger.warning('Soundcard signal on FPGA seems to have been mixed with 150Hz camera') keep_ind = np.r_[np.diff(audio['times']) > DISCARD_THRESHOLD, False] keep_ind = np.logical_and(keep_ind, audio['polarities'] == -1) keep_ind = np.where(keep_ind)[0] keep_ind = np.sort(np.r_[0, keep_ind, keep_ind + 1, naudio - 1]) if display: # pragma: no cover from ibllib.plots import squares squares(audio['times'], audio['polarities'], ax=None, yrange=[-1, 1]) squares(audio['times'][keep_ind], audio['polarities'][keep_ind], yrange=[-1, 1]) audio = {'times': audio['times'][keep_ind], 'polarities': audio['polarities'][keep_ind]} return audio def _clean_frame2ttl(frame2ttl, threshold=0.01, display=False): """ Clean the frame2ttl events. Frame 2ttl calibration can be unstable and the fronts may be flickering at an unrealistic pace. This removes the consecutive frame2ttl pulses happening too fast, below a threshold of F2TTL_THRESH. Parameters ---------- frame2ttl : dict A dictionary of frame2TTL events, with keys {'times', 'polarities'}. threshold : float Consecutive pulses occurring with this many seconds ignored. display : bool If true, plots the input TTLs and the cleaned output. Returns ------- """ dt = np.diff(frame2ttl['times']) iko = np.where(np.logical_and(dt < threshold, frame2ttl['polarities'][:-1] == -1))[0] iko = np.unique(np.r_[iko, iko + 1]) frame2ttl_ = {'times': np.delete(frame2ttl['times'], iko), 'polarities': np.delete(frame2ttl['polarities'], iko)} if iko.size > (0.1 * frame2ttl['times'].size): _logger.warning(f'{iko.size} ({iko.size / frame2ttl["times"].size:.2%}) ' f'frame to TTL polarity switches below {threshold} secs') if display: # pragma: no cover fig, (ax0, ax1) = plt.subplots(2, sharex=True) plots.squares(frame2ttl['times'] * 1000, frame2ttl['polarities'], yrange=[0.1, 0.9], ax=ax0) plots.squares(frame2ttl_['times'] * 1000, frame2ttl_['polarities'], yrange=[1.1, 1.9], ax=ax1) import seaborn as sns sns.displot(dt[dt < 0.05], binwidth=0.0005) return frame2ttl_
[docs] def extract_wheel_sync(sync, chmap=None, tmin=None, tmax=None): """ Extract wheel positions and times from sync fronts dictionary for all 16 channels. Output position is in radians, mathematical convention. Parameters ---------- sync : dict 'polarities' of fronts detected on sync trace for all 16 chans and their 'times' chmap : dict Map of channel names and their corresponding index. Default to constant. tmin : float The minimum time from which to extract the sync pulses. tmax : float The maximum time up to which we extract the sync pulses. Returns ------- numpy.array Wheel timestamps in seconds. numpy.array Wheel positions in radians. """ # Assume two separate edge count channels assert chmap.keys() >= {'rotary_encoder_0', 'rotary_encoder_1'} channela = get_sync_fronts(sync, chmap['rotary_encoder_0'], tmin=tmin, tmax=tmax) channelb = get_sync_fronts(sync, chmap['rotary_encoder_1'], tmin=tmin, tmax=tmax) re_ts, re_pos = _rotary_encoder_positions_from_fronts( channela['times'], channela['polarities'], channelb['times'], channelb['polarities'], ticks=WHEEL_TICKS, radius=WHEEL_RADIUS_CM, coding='x4') return re_ts, re_pos
[docs] def extract_sync(session_path, overwrite=False, ephys_files=None, namespace='spikeglx'): """ Reads ephys binary file (s) and extract sync within the binary file folder Assumes ephys data is within a `raw_ephys_data` folder :param session_path: '/path/to/subject/yyyy-mm-dd/001' :param overwrite: Bool on re-extraction, forces overwrite instead of loading existing files :return: list of sync dictionaries """ session_path = Path(session_path) if not ephys_files: ephys_files = spikeglx.glob_ephys_files(session_path) syncs = [] outputs = [] for efi in ephys_files: bin_file = efi.get('ap', efi.get('nidq', None)) if not bin_file: continue alfname = dict(object='sync', namespace=namespace) if efi.label: alfname['extra'] = efi.label file_exists = alfio.exists(bin_file.parent, **alfname) if not overwrite and file_exists: _logger.warning(f'Skipping raw sync: SGLX sync found for {efi.label}!') sync = alfio.load_object(bin_file.parent, **alfname) out_files, _ = alfio._ls(bin_file.parent, **alfname) else: sr = spikeglx.Reader(bin_file) sync, out_files = _sync_to_alf(sr, bin_file.parent, save=True, parts=efi.label) sr.close() outputs.extend(out_files) syncs.extend([sync]) return syncs, outputs
def _get_all_probes_sync(session_path, bin_exists=True): # round-up of all bin ephys files in the session, infer revision and get sync map ephys_files = spikeglx.glob_ephys_files(session_path, bin_exists=bin_exists) version = spikeglx.get_neuropixel_version_from_files(ephys_files) # attach the sync information to each binary file found for ef in ephys_files: ef['sync'] = alfio.load_object(ef.path, 'sync', namespace='spikeglx', short_keys=True) ef['sync_map'] = get_ibl_sync_map(ef, version) return ephys_files
[docs] def get_wheel_positions(sync, chmap, tmin=None, tmax=None): """ Gets the wheel position from synchronisation pulses Parameters ---------- sync : dict A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses and the corresponding channel numbers. chmap : dict[str, int] A map of channel names and their corresponding indices. tmin : float The minimum time from which to extract the sync pulses. tmax : float The maximum time up to which we extract the sync pulses. Returns ------- Bunch A dictionary with keys ('timestamps', 'position'), containing the wheel event timestamps and position in radians Bunch A dictionary of detected movement times with keys ('intervals', 'peakAmplitude', 'peakVelocity_times'). """ ts, pos = extract_wheel_sync(sync=sync, chmap=chmap, tmin=tmin, tmax=tmax) moves = Bunch(extract_wheel_moves(ts, pos)) wheel = Bunch({'timestamps': ts, 'position': pos}) return wheel, moves
[docs] def get_main_probe_sync(session_path, bin_exists=False): """ From 3A or 3B multiprobe session, returns the main probe (3A) or nidq sync pulses with the attached channel map (default chmap if none) Parameters ---------- session_path : str, pathlib.Path The absolute session path, i.e. '/path/to/subject/yyyy-mm-dd/nnn'. bin_exists : bool Whether there is a .bin file present. Returns ------- one.alf.io.AlfBunch A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses and the corresponding channel numbers. dict A map of channel names and their corresponding indices. """ ephys_files = _get_all_probes_sync(session_path, bin_exists=bin_exists) if not ephys_files: raise FileNotFoundError(f"No ephys files found in {session_path}") version = spikeglx.get_neuropixel_version_from_files(ephys_files) if version == '3A': # the sync master is the probe with the most sync pulses sync_box_ind = np.argmax([ef.sync.times.size for ef in ephys_files]) elif version == '3B': # the sync master is the nidq breakout box sync_box_ind = np.argmax([1 if ef.get('nidq') else 0 for ef in ephys_files]) sync = ephys_files[sync_box_ind].sync sync_chmap = ephys_files[sync_box_ind].sync_map return sync, sync_chmap
[docs] def get_protocol_period(session_path, protocol_number, bpod_sync): """ Parameters ---------- session_path : str, pathlib.Path The absolute session path, i.e. '/path/to/subject/yyyy-mm-dd/nnn'. protocol_number : int The order that the protocol was run in. bpod_sync : dict The sync times and polarities for Bpod BNC1. Returns ------- float The time of the detected spacer for the protocol number. float, None The time of the next detected spacer or None if this is the last protocol run. """ # The spacers are TTLs generated by Bpod at the start of each protocol spacer_times = Spacer().find_spacers_from_fronts(bpod_sync) # Ensure that the number of detected spacers matched the number of expected tasks if acquisition_description := session_params.read_params(session_path): n_tasks = len(acquisition_description.get('tasks', [])) assert len(spacer_times) >= protocol_number, (f'expected {n_tasks} spacers, found only {len(spacer_times)} - ' f'can not return protocol number {protocol_number}.') assert n_tasks > protocol_number >= 0, f'protocol number must be between 0 and {n_tasks}' else: assert protocol_number < len(spacer_times) start = spacer_times[int(protocol_number)] end = None if len(spacer_times) - 1 == protocol_number else spacer_times[int(protocol_number + 1)] return start, end
[docs] class FpgaTrials(extractors_base.BaseExtractor): save_names = ('_ibl_trials.goCueTrigger_times.npy', '_ibl_trials.stimOnTrigger_times.npy', '_ibl_trials.stimOffTrigger_times.npy', None, None, None, None, None, '_ibl_trials.stimOff_times.npy', None, None, None, '_ibl_trials.quiescencePeriod.npy', '_ibl_trials.table.pqt', '_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy', '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None) var_names = ('goCueTrigger_times', 'stimOnTrigger_times', 'stimOffTrigger_times', 'stimFreezeTrigger_times', 'errorCueTrigger_times', 'errorCue_times', 'itiIn_times', 'stimFreeze_times', 'stimOff_times', 'valveOpen_times', 'phase', 'position', 'quiescence', 'table', 'wheel_timestamps', 'wheel_position', 'wheelMoves_intervals', 'wheelMoves_peakAmplitude', 'wheelMoves_peakVelocity_times') bpod_rsync_fields = ('intervals', 'response_times', 'goCueTrigger_times', 'stimOnTrigger_times', 'stimOffTrigger_times', 'stimFreezeTrigger_times', 'errorCueTrigger_times') """tuple of str: Fields from Bpod extractor that we want to re-sync to FPGA.""" bpod_fields = ('feedbackType', 'choice', 'rewardVolume', 'contrastLeft', 'contrastRight', 'probabilityLeft', 'phase', 'position', 'quiescence') """tuple of str: Fields from bpod extractor that we want to save.""" sync_field = 'intervals_0' # trial start events """str: The trial event to synchronize (must be present in extracted trials).""" bpod = None """dict of numpy.array: The Bpod out TTLs recorded on the DAQ. Used in the QC viewer plot.""" def __init__(self, *args, bpod_trials=None, bpod_extractor=None, **kwargs): """An extractor for ephysChoiceWorld trials data, in FPGA time. This class may be subclassed to handle moderate variations in hardware and task protocol, however there is flexible """ super().__init__(*args, **kwargs) self.bpod2fpga = None self.bpod_trials = bpod_trials self.frame2ttl = self.audio = self.bpod = self.settings = None if bpod_extractor: self.bpod_extractor = bpod_extractor self._update_var_names() def _update_var_names(self, bpod_fields=None, bpod_rsync_fields=None): """ Updates this object's attributes based on the Bpod trials extractor. Fields updated: bpod_fields, bpod_rsync_fields, save_names, and var_names. Parameters ---------- bpod_fields : tuple A set of Bpod trials fields to keep. bpod_rsync_fields : tuple A set of Bpod trials fields to sync to the DAQ times. """ if self.bpod_extractor: for var_name, save_name in zip(self.bpod_extractor.var_names, self.bpod_extractor.save_names): if var_name not in self.var_names: self.var_names += (var_name,) self.save_names += (save_name,) # self.var_names = self.bpod_extractor.var_names # self.save_names = self.bpod_extractor.save_names self.settings = self.bpod_extractor.settings # This is used by the TaskQC self.bpod_rsync_fields = bpod_rsync_fields if self.bpod_rsync_fields is None: self.bpod_rsync_fields = tuple(self._time_fields(self.bpod_extractor.var_names)) if 'table' in self.bpod_extractor.var_names: if not self.bpod_trials: self.bpod_trials = self.bpod_extractor.extract(save=False) table_keys = alfio.AlfBunch.from_df(self.bpod_trials['table']).keys() self.bpod_rsync_fields += tuple(self._time_fields(table_keys)) elif bpod_rsync_fields: self.bpod_rsync_fields = bpod_rsync_fields excluded = (*self.bpod_rsync_fields, 'table') if bpod_fields: assert not set(self.bpod_fields).intersection(excluded), 'bpod_fields must not also be bpod_rsync_fields' self.bpod_fields = bpod_fields elif self.bpod_extractor: self.bpod_fields = tuple(x for x in self.bpod_extractor.var_names if x not in excluded) if 'table' in self.bpod_extractor.var_names: if not self.bpod_trials: self.bpod_trials = self.bpod_extractor.extract(save=False) table_keys = alfio.AlfBunch.from_df(self.bpod_trials['table']).keys() self.bpod_fields += tuple([x for x in table_keys if x not in excluded]) @staticmethod def _time_fields(trials_attr) -> set: """ Iterates over Bpod trials attributes returning those that correspond to times for syncing. Parameters ---------- trials_attr : iterable of str The Bpod field names. Returns ------- set The field names that contain timestamps. """ FIELDS = ('times', 'timestamps', 'intervals') pattern = re.compile(fr'^[_\w]*({"|".join(FIELDS)})[_\w]*$') return set(filter(pattern.match, trials_attr))
[docs] def load_sync(self, sync_collection='raw_ephys_data', **kwargs): """Load the DAQ sync and channel map data. This method may be subclassed for novel DAQ systems. The sync must contain the following keys: 'times' - an array timestamps in seconds; 'polarities' - an array of {-1, 1} corresponding to TTL LOW and TTL HIGH, respectively; 'channels' - an array of ints corresponding to channel number. Parameters ---------- sync_collection : str The session subdirectory where the sync data are located. kwargs Optional arguments used by subclass methods. Returns ------- one.alf.io.AlfBunch A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses and the corresponding channel numbers. dict A map of channel names and their corresponding indices. """ return get_sync_and_chn_map(self.session_path, sync_collection)
def _extract(self, sync=None, chmap=None, sync_collection='raw_ephys_data', task_collection='raw_behavior_data', **kwargs) -> dict: """Extracts ephys trials by combining Bpod and FPGA sync pulses. It is essential that the `var_names`, `bpod_rsync_fields`, `bpod_fields`, and `sync_field` attributes are all correct for the bpod protocol used. Below are the steps involved: 0. Load sync and bpod trials, if required. 1. Determine protocol period and discard sync events outside the task. 2. Classify multiplexed TTL events based on length (see :meth:`FpgaTrials.build_trials`). 3. Sync the Bpod clock to the DAQ clock using one of the assigned trial events. 4. Assign classified TTL events to trial events based on order within the trial. 4. Convert Bpod software event times to DAQ clock. 5. Extract the wheel from the DAQ rotary encoder signal, if required. Parameters ---------- sync : dict A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses and the corresponding channel numbers. If None, the sync is loaded using the `load_sync` method. chmap : dict A map of channel names and their corresponding indices. If None, the channel map is loaded using the :meth:`FpgaTrials.load_sync` method. sync_collection : str The session subdirectory where the sync data are located. This is only used if the sync or channel maps are not provided. task_collection : str The session subdirectory where the raw Bpod data are located. This is used for loading the task settings and extracting the bpod trials, if not already done. protocol_number : int The protocol number if multiple protocols were run during the session. If provided, a spacer signal must be present in order to determine the correct period. kwargs Optional arguments for subclass methods to use. Returns ------- dict A dictionary of numpy arrays with `FpgaTrials.var_names` as keys. """ if sync is None or chmap is None: _sync, _chmap = self.load_sync(sync_collection) sync = sync or _sync chmap = chmap or _chmap if not self.bpod_trials: # extract the behaviour data from bpod self.extractor = get_bpod_extractor(self.session_path, task_collection=task_collection) _logger.info('Bpod trials extractor: %s.%s', self.extractor.__module__, self.extractor.__class__.__name__) self.bpod_trials, *_ = self.extractor.extract(task_collection=task_collection, save=False, **kwargs) # Explode trials table df if 'table' in self.var_names: trials_table = alfio.AlfBunch.from_df(self.bpod_trials.pop('table')) table_columns = trials_table.keys() self.bpod_trials.update(trials_table) else: if 'table' in self.bpod_trials: _logger.error( '"table" found in Bpod trials but missing from `var_names` attribute and will' 'therefore not be extracted. This is likely in error.') table_columns = None bpod = get_sync_fronts(sync, chmap['bpod']) # Get the spacer times for this protocol if any(arg in kwargs for arg in ('tmin', 'tmax')): tmin, tmax = kwargs.get('tmin'), kwargs.get('tmax') elif (protocol_number := kwargs.get('protocol_number')) is not None: # look for spacer # The spacers are TTLs generated by Bpod at the start of each protocol tmin, tmax = get_protocol_period(self.session_path, protocol_number, bpod) else: # Older sessions don't have protocol spacers so we sync the Bpod intervals here to # find the approximate end time of the protocol (this will exclude the passive signals # in ephysChoiceWorld that tend to ruin the final trial extraction). _, trial_ints = self.get_bpod_event_times(sync, chmap, **kwargs) t_trial_start = trial_ints.get('trial_start', np.array([[np.nan, np.nan]]))[:, 0] bpod_start = self.bpod_trials['intervals'][:, 0] if len(t_trial_start) > len(bpod_start) / 2: # if least half the trial start TTLs detected _logger.warning('Attempting to get protocol period from aligning trial start TTLs') fcn, *_ = ibldsp.utils.sync_timestamps(bpod_start, t_trial_start) buffer = 2.5 # the number of seconds to include before/after task start, end = fcn(self.bpod_trials['intervals'].flat[[0, -1]]) tmin = min(sync['times'][0], start - buffer) tmax = max(sync['times'][-1], end + buffer) else: # This type of alignment fails for some sessions, e.g. mesoscope tmin = tmax = None # Remove unnecessary data from sync selection = np.logical_and( sync['times'] <= (tmax if tmax is not None else sync['times'][-1]), sync['times'] >= (tmin if tmin is not None else sync['times'][0]), ) sync = alfio.AlfBunch({k: v[selection] for k, v in sync.items()}) _logger.debug('Protocol period from %.2fs to %.2fs (~%.0f min duration)', *sync['times'][[0, -1]], np.diff(sync['times'][[0, -1]]) / 60) # Get the trial events from the DAQ sync TTLs, sync clocks and build final trials datasets out = self.build_trials(sync=sync, chmap=chmap, **kwargs) # extract the wheel data if any(x.startswith('wheel') for x in self.var_names): wheel, moves = self.get_wheel_positions(sync=sync, chmap=chmap, tmin=tmin, tmax=tmax) from ibllib.io.extractors.training_wheel import extract_first_movement_times if not self.settings: self.settings = raw.load_settings(session_path=self.session_path, task_collection=task_collection) min_qt = self.settings.get('QUIESCENT_PERIOD', None) first_move_onsets, *_ = extract_first_movement_times(moves, out, min_qt=min_qt) out.update({'firstMovement_times': first_move_onsets}) out.update({f'wheel_{k}': v for k, v in wheel.items()}) out.update({f'wheelMoves_{k}': v for k, v in moves.items()}) # Re-create trials table if table_columns: trials_table = alfio.AlfBunch({x: out.pop(x) for x in table_columns}) out['table'] = trials_table.to_df() out = alfio.AlfBunch({k: out[k] for k in self.var_names if k in out}) # Reorder output assert self.var_names == tuple(out.keys()) return out def _is_trials_object_attribute(self, var_name, variable_length_vars=None): """ Check if variable name is expected to have the same length as trials.intervals. Parameters ---------- var_name : str The variable name to check. variable_length_vars : list Set of variable names that are not expected to have the same length as trials.intervals. This list may be passed by superclasses. Returns ------- bool True if variable is a trials dataset. Examples -------- >>> assert self._is_trials_object_attribute('stimOnTrigger_times') is True >>> assert self._is_trials_object_attribute('wheel_position') is False """ save_name = self.save_names[self.var_names.index(var_name)] if var_name in self.var_names else None if save_name: return filename_parts(save_name)[1] == 'trials' else: return var_name not in (variable_length_vars or [])
[docs] def build_trials(self, sync, chmap, display=False, **kwargs): """ Extract task related event times from the sync. The trial start times are the shortest Bpod TTLs and occur at the start of the trial. The first trial start TTL of the session is longer and must be handled differently. The trial start TTL is used to assign the other trial events to each trial. The trial end is the end of the so-called 'ITI' Bpod event TTL (classified as the longest of the three Bpod event TTLs). Go cue audio TTLs are the shorter of the two expected audio tones. The first of these after each trial start is taken to be the go cue time. Error tones are longer audio TTLs and assigned as the last of such occurrence after each trial start. The valve open Bpod TTLs are medium-length, the last of which is used for each trial. The feedback times are times of either valve open or error tone as there should be only one such event per trial. The stimulus times are taken from the frame2ttl events (with improbably high frequency TTLs removed): the first TTL after each trial start is assumed to be the stim onset time; the second to last and last are taken as the stimulus freeze and offset times, respectively. Parameters ---------- sync : dict 'polarities' of fronts detected on sync trace for all 16 chans and their 'times' chmap : dict Map of channel names and their corresponding index. Default to constant. display : bool, matplotlib.pyplot.Axes Show the full session sync pulses display. Returns ------- dict A map of trial event timestamps. """ # Get the events from the sync. # Store the cleaned frame2ttl, audio, and bpod pulses as this will be used for QC self.frame2ttl = self.get_stimulus_update_times(sync, chmap, **kwargs) self.audio, audio_event_intervals = self.get_audio_event_times(sync, chmap, **kwargs) if not set(audio_event_intervals.keys()) >= {'ready_tone', 'error_tone'}: raise ValueError( 'Expected at least "ready_tone" and "error_tone" audio events.' '`audio_event_ttls` kwarg may be incorrect.') self.bpod, bpod_event_intervals = self.get_bpod_event_times(sync, chmap, **kwargs) if not set(bpod_event_intervals.keys()) >= {'trial_start', 'valve_open', 'trial_end'}: raise ValueError( 'Expected at least "trial_start", "trial_end", and "valve_open" audio events. ' '`bpod_event_ttls` kwarg may be incorrect.') t_iti_in, t_trial_end = bpod_event_intervals['trial_end'].T fpga_events = alfio.AlfBunch({ 'goCue_times': audio_event_intervals['ready_tone'][:, 0], 'errorCue_times': audio_event_intervals['error_tone'][:, 0], 'valveOpen_times': bpod_event_intervals['valve_open'][:, 0], 'valveClose_times': bpod_event_intervals['valve_open'][:, 1], 'itiIn_times': t_iti_in, 'intervals_0': bpod_event_intervals['trial_start'][:, 0], 'intervals_1': t_trial_end }) # Sync the Bpod clock to the DAQ. # NB: The Bpod extractor typically drops the final, incomplete, trial. Hence there is # usually at least one extra FPGA event. This shouldn't affect the sync. The final trial is # dropped after assigning the FPGA events, using the `ifpga` index. Doing this after # assigning the FPGA trial events ensures the last trial has the correct timestamps. self.bpod2fpga, drift_ppm, ibpod, ifpga = self.sync_bpod_clock(self.bpod_trials, fpga_events, self.sync_field) if np.any(np.diff(ibpod) != 1) and self.sync_field == 'intervals_0': # One issue is that sometimes pulses may not have been detected, in this case # add the events that have not been detected and re-extract the behaviour sync. # This is only really relevant for the Bpod interval events as the other TTLs are # from devices where a missing TTL likely means the Bpod event was truly absent. _logger.warning('Missing Bpod TTLs; reassigning events using aligned Bpod start times') bpod_start = self.bpod_trials['intervals'][:, 0] missing_bpod = self.bpod2fpga(bpod_start[np.setxor1d(ibpod, np.arange(len(bpod_start)))]) t_trial_start = np.sort(np.r_[fpga_events['intervals_0'][:, 0], missing_bpod]) else: t_trial_start = fpga_events['intervals_0'] t_trial_start = t_trial_start[ifpga] out = alfio.AlfBunch() # Add the Bpod trial events, converting the timestamp fields to FPGA time. # NB: The trial intervals are by default a Bpod rsync field. out.update({k: self.bpod_trials[k][ibpod] for k in self.bpod_fields}) for k in self.bpod_rsync_fields: # Some personal projects may extract non-trials object datasets that may not have 1 event per trial idx = ibpod if self._is_trials_object_attribute(k) else np.arange(len(self.bpod_trials[k]), dtype=int) out[k] = self.bpod2fpga(self.bpod_trials[k][idx]) f2ttl_t = self.frame2ttl['times'] # Assign the FPGA events to individual trials fpga_trials = { 'goCue_times': _assign_events_to_trial(t_trial_start, fpga_events['goCue_times'], take='first'), 'errorCue_times': _assign_events_to_trial(t_trial_start, fpga_events['errorCue_times']), 'valveOpen_times': _assign_events_to_trial(t_trial_start, fpga_events['valveOpen_times']), 'itiIn_times': _assign_events_to_trial(t_trial_start, fpga_events['itiIn_times']), 'stimOn_times': np.full_like(t_trial_start, np.nan), 'stimOff_times': np.full_like(t_trial_start, np.nan), 'stimFreeze_times': np.full_like(t_trial_start, np.nan) } # f2ttl times are unreliable owing to calibration and Bonsai sync square update issues. # Take the first event after the FPGA aligned stimulus trigger time. fpga_trials['stimOn_times'] = _assign_events_to_trial( out['stimOnTrigger_times'], f2ttl_t, take='first', t_trial_end=out['stimOffTrigger_times']) fpga_trials['stimOff_times'] = _assign_events_to_trial( out['stimOffTrigger_times'], f2ttl_t, take='first', t_trial_end=out['intervals'][:, 1]) # For stim freeze we take the last event before the stim off trigger time. # To avoid assigning early events (e.g. for sessions where there are few flips due to # mis-calibration), we discount events before stim freeze trigger times (or stim on trigger # times for versions below 6.2.5). We take the last event rather than the first after stim # freeze trigger because often there are multiple flips after the trigger, presumably # before the stim actually stops. stim_freeze = np.copy(out['stimFreezeTrigger_times']) go_trials = np.where(out['choice'] != 0)[0] # NB: versions below 6.2.5 have no trigger times so use stim on trigger times lims = np.copy(out['stimOnTrigger_times']) if not np.isnan(stim_freeze).all(): # Stim freeze times are NaN for nogo trials, but for all others use stim freeze trigger # times. _assign_events_to_trial requires ascending timestamps so no NaNs allowed. lims[go_trials] = stim_freeze[go_trials] # take last event after freeze/stim on trigger, before stim off trigger stim_freeze = _assign_events_to_trial(lims, f2ttl_t, take='last', t_trial_end=out['stimOffTrigger_times']) fpga_trials['stimFreeze_times'][go_trials] = stim_freeze[go_trials] # Feedback times are valve open on correct trials and error tone in on incorrect trials fpga_trials['feedback_times'] = np.copy(fpga_trials['valveOpen_times']) ind_err = np.isnan(fpga_trials['valveOpen_times']) fpga_trials['feedback_times'][ind_err] = fpga_trials['errorCue_times'][ind_err] out.update({k: fpga_trials[k] for k in fpga_trials.keys()}) if display: # pragma: no cover width = 0.5 ymax = 5 if isinstance(display, bool): plt.figure('Bpod FPGA Sync') ax = plt.gca() else: ax = display plots.squares(self.bpod['times'], self.bpod['polarities'] * 0.4 + 1, ax=ax, color='k') plots.squares(self.frame2ttl['times'], self.frame2ttl['polarities'] * 0.4 + 2, ax=ax, color='k') plots.squares(self.audio['times'], self.audio['polarities'] * 0.4 + 3, ax=ax, color='k') color_map = TABLEAU_COLORS.keys() for (event_name, event_times), c in zip(fpga_events.items(), cycle(color_map)): plots.vertical_lines(event_times, ymin=0, ymax=ymax, ax=ax, color=c, label=event_name, linewidth=width) # Plot the stimulus events along with the trigger times stim_events = filter(lambda t: 'stim' in t[0], fpga_trials.items()) for (event_name, event_times), c in zip(stim_events, cycle(color_map)): plots.vertical_lines( event_times, ymin=0, ymax=ymax, ax=ax, color=c, label=event_name, linewidth=width, linestyle='--') nm = event_name.replace('_times', 'Trigger_times') plots.vertical_lines( out[nm], ymin=0, ymax=ymax, ax=ax, color=c, label=nm, linewidth=width, linestyle=':') ax.legend() ax.set_yticks([0, 1, 2, 3]) ax.set_yticklabels(['', 'bpod', 'f2ttl', 'audio']) ax.set_ylim([0, 5]) return out
[docs] def get_wheel_positions(self, *args, **kwargs): """Extract wheel and wheelMoves objects. This method is called by the main extract method and may be overloaded by subclasses. """ return get_wheel_positions(*args, **kwargs)
[docs] def get_stimulus_update_times(self, sync, chmap, display=False, **_): """ Extract stimulus update times from sync. Gets the stimulus times from the frame2ttl channel and cleans the signal. Parameters ---------- sync : dict A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses and the corresponding channel numbers. chmap : dict A map of channel names and their corresponding indices. Must contain a 'frame2ttl' key. display : bool If true, plots the input TTLs and the cleaned output. Returns ------- dict A dictionary with keys {'times', 'polarities'} containing stimulus TTL fronts. """ frame2ttl = get_sync_fronts(sync, chmap['frame2ttl']) frame2ttl = _clean_frame2ttl(frame2ttl, display=display) return frame2ttl
[docs] def get_audio_event_times(self, sync, chmap, audio_event_ttls=None, display=False, **_): """ Extract audio times from sync. Gets the TTL times from the 'audio' channel, cleans the signal, and classifies each TTL event by length. Parameters ---------- sync : dict A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses and the corresponding channel numbers. chmap : dict A map of channel names and their corresponding indices. Must contain an 'audio' key. audio_event_ttls : dict A map of event names to (min, max) TTL length. display : bool If true, plots the input TTLs and the cleaned output. Returns ------- dict A dictionary with keys {'times', 'polarities'} containing audio TTL fronts. dict A dictionary of events (from `audio_event_ttls`) and their intervals as an Nx2 array. """ audio = get_sync_fronts(sync, chmap['audio']) audio = _clean_audio(audio) if audio['times'].size == 0: _logger.error('No audio sync fronts found.') if audio_event_ttls is None: # For training/biased/ephys protocols, the ready tone should be below 110 ms. The error # tone should be between 400ms and 1200ms audio_event_ttls = {'ready_tone': (0, 0.1101), 'error_tone': (0.4, 1.2)} audio_event_intervals = self._assign_events(audio['times'], audio['polarities'], audio_event_ttls, display=display) return audio, audio_event_intervals
[docs] def get_bpod_event_times(self, sync, chmap, bpod_event_ttls=None, display=False, **kwargs): """ Extract Bpod times from sync. Gets the Bpod TTL times from the sync 'bpod' channel and classifies each TTL event by length. NB: The first trial has an abnormal trial_start TTL that is usually mis-assigned. This method accounts for this. Parameters ---------- sync : dict A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses and the corresponding channel numbers. Must contain a 'bpod' key. chmap : dict A map of channel names and their corresponding indices. bpod_event_ttls : dict of tuple A map of event names to (min, max) TTL length. Returns ------- dict A dictionary with keys {'times', 'polarities'} containing Bpod TTL fronts. dict A dictionary of events (from `bpod_event_ttls`) and their intervals as an Nx2 array. """ bpod = get_sync_fronts(sync, chmap['bpod']) if bpod.times.size == 0: raise err.SyncBpodFpgaException('No Bpod event found in FPGA. No behaviour extraction. ' 'Check channel maps.') # Assign the Bpod BNC2 events based on TTL length. The defaults are below, however these # lengths are defined by the state machine of the task protocol and therefore vary. if bpod_event_ttls is None: # For training/biased/ephys protocols, the trial start TTL length is 0.1ms but this has # proven to drift on some Bpods and this is the highest possible value that # discriminates trial start from valve. Valve open events are between 50ms to 300 ms. # ITI events are above 400 ms. bpod_event_ttls = { 'trial_start': (0, 2.33e-4), 'valve_open': (2.33e-4, 0.4), 'trial_end': (0.4, np.inf)} bpod_event_intervals = self._assign_events( bpod['times'], bpod['polarities'], bpod_event_ttls, display=display) if 'trial_start' not in bpod_event_intervals or bpod_event_intervals['trial_start'].size == 0: return bpod, bpod_event_intervals # The first trial pulse is longer and often assigned to another event. # Here we move the earliest non-trial_start event to the trial_start array. t0 = bpod_event_intervals['trial_start'][0, 0] # expect 1st event to be trial_start pretrial = [(k, v[0, 0]) for k, v in bpod_event_intervals.items() if v.size and v[0, 0] < t0] if pretrial: (pretrial, _) = sorted(pretrial, key=lambda x: x[1])[0] # take the earliest event dt = np.diff(bpod_event_intervals[pretrial][0, :]) * 1e3 # record TTL length to log _logger.debug('Reassigning first %s to trial_start. TTL length = %.3g ms', pretrial, dt) bpod_event_intervals['trial_start'] = np.r_[ bpod_event_intervals[pretrial][0:1, :], bpod_event_intervals['trial_start'] ] bpod_event_intervals[pretrial] = bpod_event_intervals[pretrial][1:, :] return bpod, bpod_event_intervals
@staticmethod def _assign_events(ts, polarities, event_lengths, precedence='shortest', display=False): """ Classify TTL events by length. Outputs the synchronisation events such as trial intervals, valve opening, and audio. Parameters ---------- ts : numpy.array Numpy vector containing times of TTL fronts. polarities : numpy.array Numpy vector containing polarity of TTL fronts (1 rise, -1 fall). event_lengths : dict of tuple A map of TTL events and the range of permissible lengths, where l0 < ttl <= l1. precedence : str {'shortest', 'longest', 'dict order'} In the case of overlapping event TTL lengths, assign shortest/longest first or go by the `event_lengths` dict order. display : bool If true, plots the TTLs with coloured lines delineating the assigned events. Returns ------- Dict[str, numpy.array] A dictionary of events and their intervals as an Nx2 array. See Also -------- _assign_events_to_trial - classify TTLs by event order within a given trial period. """ event_intervals = dict.fromkeys(event_lengths) assert 'unassigned' not in event_lengths.keys() if len(ts) == 0: return {k: np.array([[], []]).T for k in (*event_lengths.keys(), 'unassigned')} # make sure that there are no 2 consecutive fall or consecutive rise events assert np.all(np.abs(np.diff(polarities)) == 2) if polarities[0] == -1: ts = np.delete(ts, 0) if polarities[-1] == 1: # if the final TTL is left HIGH, insert a NaN ts = np.r_[ts, np.nan] # take only even time differences: i.e. from rising to falling fronts dt = np.diff(ts)[::2] # Assign events from shortest TTL to largest assigned = np.zeros(ts.shape, dtype=bool) if precedence.lower() == 'shortest': event_items = sorted(event_lengths.items(), key=lambda x: np.diff(x[1])) elif precedence.lower() == 'longest': event_items = sorted(event_lengths.items(), key=lambda x: np.diff(x[1]), reverse=True) elif precedence.lower() == 'dict order': event_items = event_lengths.items() else: raise ValueError(f'Precedence must be one of "shortest", "longest", "dict order", got "{precedence}".') for event, (min_len, max_len) in event_items: _logger.debug('%s: %.4G < ttl <= %.4G', event, min_len, max_len) i_event = np.where(np.logical_and(dt > min_len, dt <= max_len))[0] * 2 i_event = i_event[np.where(~assigned[i_event])[0]] # remove those already assigned event_intervals[event] = np.c_[ts[i_event], ts[i_event + 1]] assigned[np.r_[i_event, i_event + 1]] = True # Include the unassigned events for convenience and debugging event_intervals['unassigned'] = ts[~assigned].reshape(-1, 2) # Assert that event TTLs mutually exclusive all_assigned = np.concatenate(list(event_intervals.values())).flatten() assert all_assigned.size == np.unique(all_assigned).size, 'TTLs assigned to multiple events' # some debug plots when needed if display: # pragma: no cover plt.figure() plots.squares(ts, polarities, label='raw fronts') for event, intervals in event_intervals.items(): plots.vertical_lines(intervals[:, 0], ymin=-0.2, ymax=1.1, linewidth=0.5, label=event) plt.legend() # Return map of event intervals in the same order as `event_lengths` dict return {k: event_intervals[k] for k in (*event_lengths, 'unassigned')}
[docs] @staticmethod def sync_bpod_clock(bpod_trials, fpga_trials, sync_field): """ Sync the Bpod clock to FPGA one using the provided trial event. It assumes that `sync_field` is in both `fpga_trials` and `bpod_trials`. Syncing on both intervals is not supported so to sync on trial start times, `sync_field` should be 'intervals_0'. Parameters ---------- bpod_trials : dict A dictionary of extracted Bpod trial events. fpga_trials : dict A dictionary of TTL events extracted from FPGA sync (see `extract_behaviour_sync` method). sync_field : str The trials key to use for syncing clocks. For intervals (i.e. Nx2 arrays) append the column index, e.g. 'intervals_0'. Returns ------- function Interpolation function such that f(timestamps_bpod) = timestamps_fpga. float The clock drift in parts per million. numpy.array of int The indices of the Bpod trial events in the FPGA trial events array. numpy.array of int The indices of the FPGA trial events in the Bpod trial events array. Raises ------ ValueError The key `sync_field` was not found in either the `bpod_trials` or `fpga_trials` dicts. """ _logger.info(f'Attempting to align Bpod clock to DAQ using trial event "{sync_field}"') bpod_fpga_timestamps = [None, None] for i, trials in enumerate((bpod_trials, fpga_trials)): if sync_field not in trials: # handle syncing on intervals if not (m := re.match(r'(.*)_(\d)', sync_field)): # If missing from bpod trials, either the sync field is incorrect, # or the Bpod extractor is incorrect. If missing from the fpga events, check # the sync field and the `extract_behaviour_sync` method. raise ValueError( f'Sync field "{sync_field}" not in extracted {"fpga" if i else "bpod"} events') _sync_field, n = m.groups() bpod_fpga_timestamps[i] = trials[_sync_field][:, int(n)] else: bpod_fpga_timestamps[i] = trials[sync_field] # Sync the two timestamps fcn, drift, ibpod, ifpga = ibldsp.utils.sync_timestamps(*bpod_fpga_timestamps, return_indices=True) # If it's drifting too much throw warning or error _logger.info('N trials: %i bpod, %i FPGA, %i merged, sync %.5f ppm', *map(len, bpod_fpga_timestamps), len(ibpod), drift) if drift > 200 and bpod_fpga_timestamps[0].size != bpod_fpga_timestamps[1].size: raise err.SyncBpodFpgaException('sync cluster f*ck') elif drift > BPOD_FPGA_DRIFT_THRESHOLD_PPM: _logger.warning('BPOD/FPGA synchronization shows values greater than %.2f ppm', BPOD_FPGA_DRIFT_THRESHOLD_PPM) return fcn, drift, ibpod, ifpga
[docs] class FpgaTrialsHabituation(FpgaTrials): """Extract habituationChoiceWorld trial events from an NI DAQ.""" save_names = ('_ibl_trials.stimCenter_times.npy', '_ibl_trials.feedbackType.npy', '_ibl_trials.rewardVolume.npy', '_ibl_trials.stimOff_times.npy', '_ibl_trials.contrastLeft.npy', '_ibl_trials.contrastRight.npy', '_ibl_trials.feedback_times.npy', '_ibl_trials.stimOn_times.npy', '_ibl_trials.stimOnTrigger_times.npy', '_ibl_trials.intervals.npy', '_ibl_trials.goCue_times.npy', '_ibl_trials.goCueTrigger_times.npy', None, None, None, None, None) """tuple of str: The filenames of each extracted dataset, or None if array should not be saved.""" var_names = ('stimCenter_times', 'feedbackType', 'rewardVolume', 'stimOff_times', 'contrastLeft', 'contrastRight', 'feedback_times', 'stimOn_times', 'stimOnTrigger_times', 'intervals', 'goCue_times', 'goCueTrigger_times', 'itiIn_times', 'stimOffTrigger_times', 'stimCenterTrigger_times', 'position', 'phase') """tuple of str: A list of names for the extracted variables. These become the returned output keys.""" bpod_rsync_fields = ('intervals', 'stimOn_times', 'feedback_times', 'stimCenterTrigger_times', 'goCue_times', 'itiIn_times', 'stimOffTrigger_times', 'stimOff_times', 'stimCenter_times', 'stimOnTrigger_times', 'goCueTrigger_times') """tuple of str: Fields from Bpod extractor that we want to re-sync to FPGA.""" bpod_fields = ('feedbackType', 'rewardVolume', 'contrastLeft', 'contrastRight', 'position', 'phase') """tuple of str: Fields from Bpod extractor that we want to save.""" sync_field = 'feedback_times' # valve open events """str: The trial event to synchronize (must be present in extracted trials).""" def _extract(self, sync=None, chmap=None, sync_collection='raw_ephys_data', task_collection='raw_behavior_data', **kwargs) -> dict: """ Extract habituationChoiceWorld trial events from an NI DAQ. It is essential that the `var_names`, `bpod_rsync_fields`, `bpod_fields`, and `sync_field` attributes are all correct for the bpod protocol used. Unlike FpgaTrials, this class assumes different Bpod TTL events and syncs the Bpod clock using the valve open times, instead of the trial start times. Parameters ---------- sync : dict A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses and the corresponding channel numbers. If None, the sync is loaded using the `load_sync` method. dict A map of channel names and their corresponding indices. If None, the channel map is loaded using the `load_sync` method. sync_collection : str The session subdirectory where the sync data are located. This is only used if the sync or channel maps are not provided. task_collection : str The session subdirectory where the raw Bpod data are located. This is used for loading the task settings and extracting the bpod trials, if not already done. protocol_number : int The protocol number if multiple protocols were run during the session. If provided, a spacer signal must be present in order to determine the correct period. kwargs Optional arguments for class methods, e.g. 'display', 'bpod_event_ttls'. Returns ------- dict A dictionary of numpy arrays with `FpgaTrialsHabituation.var_names` as keys. """ # Version check: the ITI in TTL was added in a later version if not self.settings: self.settings = raw.load_settings(session_path=self.session_path, task_collection=task_collection) iblrig_version = version.parse(self.settings.get('IBL_VERSION', '0.0.0')) if version.parse('8.9.3') <= iblrig_version < version.parse('8.12.6'): """A second 1s TTL was added in this version during the 'iti' state, however this is unrelated to the trial ITI and is unfortunately the same length as the trial start TTL.""" raise NotImplementedError('Ambiguous TTLs in 8.9.3 >= version < 8.12.6') trials = super()._extract(sync=sync, chmap=chmap, sync_collection=sync_collection, task_collection=task_collection, **kwargs) return trials
[docs] def get_bpod_event_times(self, sync, chmap, bpod_event_ttls=None, display=False, **kwargs): """ Extract Bpod times from sync. Currently (at least v8.12 and below) there is no trial start or end TTL, only an ITI pulse. Also the first trial pulse is incorrectly assigned due to its abnormal length. Parameters ---------- sync : dict A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses and the corresponding channel numbers. Must contain a 'bpod' key. chmap : dict A map of channel names and their corresponding indices. bpod_event_ttls : dict of tuple A map of event names to (min, max) TTL length. Returns ------- dict A dictionary with keys {'times', 'polarities'} containing Bpod TTL fronts. dict A dictionary of events (from `bpod_event_ttls`) and their intervals as an Nx2 array. """ bpod = get_sync_fronts(sync, chmap['bpod']) if bpod.times.size == 0: raise err.SyncBpodFpgaException('No Bpod event found in FPGA. No behaviour extraction. ' 'Check channel maps.') # Assign the Bpod BNC2 events based on TTL length. The defaults are below, however these # lengths are defined by the state machine of the task protocol and therefore vary. if bpod_event_ttls is None: # Currently (at least v8.12 and below) there is no trial start or end TTL, only an ITI pulse bpod_event_ttls = {'trial_iti': (1, 1.1), 'valve_open': (0, 0.4)} bpod_event_intervals = self._assign_events( bpod['times'], bpod['polarities'], bpod_event_ttls, display=display) # The first trial pulse is shorter and assigned to valve_open. Here we remove the first # valve event, prepend a 0 to the trial_start events, and drop the last trial if it was # incomplete in Bpod. bpod_event_intervals['trial_iti'] = np.r_[bpod_event_intervals['valve_open'][0:1, :], bpod_event_intervals['trial_iti']] bpod_event_intervals['valve_open'] = bpod_event_intervals['valve_open'][1:, :] return bpod, bpod_event_intervals
[docs] def build_trials(self, sync, chmap, display=False, **kwargs): """ Extract task related event times from the sync. This is called by the superclass `_extract` method. The key difference here is that the `trial_start` LOW->HIGH is the trial end, and HIGH->LOW is trial start. Parameters ---------- sync : dict 'polarities' of fronts detected on sync trace for all 16 chans and their 'times' chmap : dict Map of channel names and their corresponding index. Default to constant. display : bool, matplotlib.pyplot.Axes Show the full session sync pulses display. Returns ------- dict A map of trial event timestamps. """ # Get the events from the sync. # Store the cleaned frame2ttl, audio, and bpod pulses as this will be used for QC self.frame2ttl = self.get_stimulus_update_times(sync, chmap, **kwargs) self.audio, audio_event_intervals = self.get_audio_event_times(sync, chmap, **kwargs) self.bpod, bpod_event_intervals = self.get_bpod_event_times(sync, chmap, **kwargs) if not set(bpod_event_intervals.keys()) >= {'valve_open', 'trial_iti'}: raise ValueError( 'Expected at least "trial_iti" and "valve_open" Bpod events. `bpod_event_ttls` kwarg may be incorrect.') fpga_events = alfio.AlfBunch({ 'feedback_times': bpod_event_intervals['valve_open'][:, 0], 'valveClose_times': bpod_event_intervals['valve_open'][:, 1], 'intervals_0': bpod_event_intervals['trial_iti'][:, 1], 'intervals_1': bpod_event_intervals['trial_iti'][:, 0], 'goCue_times': audio_event_intervals['ready_tone'][:, 0] }) # Sync the Bpod clock to the DAQ. self.bpod2fpga, drift_ppm, ibpod, ifpga = self.sync_bpod_clock(self.bpod_trials, fpga_events, self.sync_field) out = alfio.AlfBunch() # Add the Bpod trial events, converting the timestamp fields to FPGA time. # NB: The trial intervals are by default a Bpod rsync field. out.update({k: self.bpod_trials[k][ibpod] for k in self.bpod_fields}) out.update({k: self.bpod2fpga(self.bpod_trials[k][ibpod]) for k in self.bpod_rsync_fields}) # Assigning each event to a trial ensures exactly one event per trial (missing events are NaN) assign_to_trial = partial(_assign_events_to_trial, fpga_events['intervals_0']) trials = alfio.AlfBunch({ 'goCue_times': assign_to_trial(fpga_events['goCue_times'], take='first'), 'feedback_times': assign_to_trial(fpga_events['feedback_times']), 'stimCenter_times': assign_to_trial(self.frame2ttl['times'], take=-2), 'stimOn_times': assign_to_trial(self.frame2ttl['times'], take='first'), 'stimOff_times': assign_to_trial(self.frame2ttl['times']), }) out.update({k: trials[k][ifpga] for k in trials.keys()}) # If stim on occurs before trial end, use stim on time. Likewise for trial end and stim off to_correct = ~np.isnan(out['stimOn_times']) & (out['stimOn_times'] < out['intervals'][:, 0]) if np.any(to_correct): _logger.warning('%i/%i stim on events occurring outside trial intervals', sum(to_correct), len(to_correct)) out['intervals'][to_correct, 0] = out['stimOn_times'][to_correct] to_correct = ~np.isnan(out['stimOff_times']) & (out['stimOff_times'] > out['intervals'][:, 1]) if np.any(to_correct): _logger.debug( '%i/%i stim off events occurring outside trial intervals; using stim off times as trial end', sum(to_correct), len(to_correct)) out['intervals'][to_correct, 1] = out['stimOff_times'][to_correct] if display: # pragma: no cover width = 0.5 ymax = 5 if isinstance(display, bool): plt.figure('Bpod FPGA Sync') ax = plt.gca() else: ax = display plots.squares(self.bpod['times'], self.bpod['polarities'] * 0.4 + 1, ax=ax, color='k') plots.squares(self.frame2ttl['times'], self.frame2ttl['polarities'] * 0.4 + 2, ax=ax, color='k') plots.squares(self.audio['times'], self.audio['polarities'] * 0.4 + 3, ax=ax, color='k') color_map = TABLEAU_COLORS.keys() for (event_name, event_times), c in zip(trials.to_df().items(), cycle(color_map)): plots.vertical_lines(event_times, ymin=0, ymax=ymax, ax=ax, color=c, label=event_name, linewidth=width) ax.legend() ax.set_yticks([0, 1, 2, 3]) ax.set_yticklabels(['', 'bpod', 'f2ttl', 'audio']) ax.set_ylim([0, 4]) return out
[docs] def get_sync_and_chn_map(session_path, sync_collection): """ Return sync and channel map for session based on collection where main sync is stored. Parameters ---------- session_path : str, pathlib.Path The absolute session path, i.e. '/path/to/subject/yyyy-mm-dd/nnn'. sync_collection : str The session subdirectory where the sync data are located. Returns ------- one.alf.io.AlfBunch A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses and the corresponding channel numbers. dict A map of channel names and their corresponding indices. """ if sync_collection == 'raw_ephys_data': # Check to see if we have nidq files, if we do just go with this otherwise go into other function that deals with # 3A probes nidq_meta = next(session_path.joinpath(sync_collection).glob('*nidq.meta'), None) if not nidq_meta: sync, chmap = get_main_probe_sync(session_path) else: sync = load_sync(session_path, sync_collection) ef = Bunch() ef['path'] = session_path.joinpath(sync_collection) ef['nidq'] = nidq_meta chmap = get_ibl_sync_map(ef, '3B') else: sync = load_sync(session_path, sync_collection) chmap = load_channel_map(session_path, sync_collection) return sync, chmap
[docs] def load_channel_map(session_path, sync_collection): """ Load syncing channel map for session path and collection Parameters ---------- session_path : str, pathlib.Path The absolute session path, i.e. '/path/to/subject/yyyy-mm-dd/nnn'. sync_collection : str The session subdirectory where the sync data are located. Returns ------- dict A map of channel names and their corresponding indices. """ device = sync_collection.split('_')[1] default_chmap = DEFAULT_MAPS[device]['nidq'] # Try to load channel map from file chmap = spikeglx.get_sync_map(session_path.joinpath(sync_collection)) # If chmap provided but not with all keys, fill up with default values if not chmap: return default_chmap else: if data_for_keys(default_chmap.keys(), chmap): return chmap else: _logger.warning('Keys missing from provided channel map, ' 'setting missing keys from default channel map') return {**default_chmap, **chmap}
[docs] def load_sync(session_path, sync_collection): """ Load sync files from session path and collection. Parameters ---------- session_path : str, pathlib.Path The absolute session path, i.e. '/path/to/subject/yyyy-mm-dd/nnn'. sync_collection : str The session subdirectory where the sync data are located. Returns ------- one.alf.io.AlfBunch A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses and the corresponding channel numbers. """ sync = alfio.load_object(session_path.joinpath(sync_collection), 'sync', namespace='spikeglx', short_keys=True) return sync