"""Data extraction from raw FPGA output.
The behaviour extraction happens in the following stages:
1. The NI DAQ events are extracted into a map of event times and TTL polarities.
2. The Bpod trial events are extracted from the raw Bpod data, depending on the task protocol.
3. As protocols may be chained together within a given recording, the period of a given task
protocol is determined using the 'spacer' DAQ signal (see `get_protocol_period`).
4. Physical behaviour events such as stim on and reward time are separated out by TTL length or
sequence within the trial.
5. The Bpod clock is sync'd with the FPGA using one of the extracted trial events.
6. The Bpod software events are then converted to FPGA time.
Examples
--------
For simple extraction, use the FPGATrials class:
>>> extractor = FpgaTrials(session_path)
>>> trials, _ = extractor.extract(update=False, save=False)
Notes
-----
Sync extraction in this module only supports FPGA data acquired with an NI DAQ as part of a
Neuropixels recording system, however a sync and channel map extracted from a different DAQ format
can be passed to the FpgaTrials class.
See Also
--------
For dynamic pipeline sessions it is best to call the extractor via the BehaviorTask class.
TODO notes on subclassing various methods of FpgaTrials for custom hardware.
"""
import logging
from itertools import cycle
from pathlib import Path
import uuid
import re
from functools import partial
import matplotlib.pyplot as plt
from matplotlib.colors import TABLEAU_COLORS
import numpy as np
from packaging import version
import spikeglx
import ibldsp.utils
import one.alf.io as alfio
from one.alf.path import filename_parts
from iblutil.util import Bunch
from iblutil.spacer import Spacer
import ibllib.exceptions as err
from ibllib.io import raw_data_loaders as raw, session_params
from ibllib.io.extractors.bpod_trials import get_bpod_extractor
import ibllib.io.extractors.base as extractors_base
from ibllib.io.extractors.training_wheel import extract_wheel_moves
from ibllib import plots
from ibllib.io.extractors.default_channel_maps import DEFAULT_MAPS
_logger = logging.getLogger(__name__)
SYNC_BATCH_SIZE_SECS = 100
"""int: Number of samples to read at once in bin file for sync."""
WHEEL_RADIUS_CM = 1 # stay in radians
"""float: The radius of the wheel used in the task. A value of 1 ensures units remain in radians."""
WHEEL_TICKS = 1024
"""int: The number of encoder pulses per channel for one complete rotation."""
BPOD_FPGA_DRIFT_THRESHOLD_PPM = 150
"""int: Throws an error if Bpod to FPGA clock drift is higher than this value."""
CHMAPS = {'3A':
{'ap':
{'left_camera': 2,
'right_camera': 3,
'body_camera': 4,
'bpod': 7,
'frame2ttl': 12,
'rotary_encoder_0': 13,
'rotary_encoder_1': 14,
'audio': 15
}
},
'3B':
{'nidq':
{'left_camera': 0,
'right_camera': 1,
'body_camera': 2,
'imec_sync': 3,
'frame2ttl': 4,
'rotary_encoder_0': 5,
'rotary_encoder_1': 6,
'audio': 7,
'bpod': 16,
'laser': 17,
'laser_ttl': 18},
'ap':
{'imec_sync': 6}
},
}
"""dict: The default channel indices corresponding to various devices for different recording systems."""
[docs]
def data_for_keys(keys, data):
"""Check keys exist in 'data' dict and contain values other than None."""
return data is not None and all(k in data and data.get(k, None) is not None for k in keys)
[docs]
def get_ibl_sync_map(ef, version):
"""
Gets default channel map for the version/binary file type combination
:param ef: ibllib.io.spikeglx.glob_ephys_file dictionary with field 'ap' or 'nidq'
:return: channel map dictionary
"""
# Determine default channel map
if version == '3A':
default_chmap = CHMAPS['3A']['ap']
elif version == '3B':
if ef.get('nidq', None):
default_chmap = CHMAPS['3B']['nidq']
elif ef.get('ap', None):
default_chmap = CHMAPS['3B']['ap']
# Try to load channel map from file
chmap = spikeglx.get_sync_map(ef['path'])
# If chmap provided but not with all keys, fill up with default values
if not chmap:
return default_chmap
else:
if data_for_keys(default_chmap.keys(), chmap):
return chmap
else:
_logger.warning("Keys missing from provided channel map, "
"setting missing keys from default channel map")
return {**default_chmap, **chmap}
def _sync_to_alf(raw_ephys_apfile, output_path=None, save=False, parts=''):
"""
Extracts sync.times, sync.channels and sync.polarities from binary ephys dataset
:param raw_ephys_apfile: bin file containing ephys data or spike
:param output_path: output directory
:param save: bool write to disk only if True
:param parts: string or list of strings that will be appended to the filename before extension
:return:
"""
# handles input argument: support ibllib.io.spikeglx.Reader, str and pathlib.Path
if isinstance(raw_ephys_apfile, spikeglx.Reader):
sr = raw_ephys_apfile
else:
raw_ephys_apfile = Path(raw_ephys_apfile)
sr = spikeglx.Reader(raw_ephys_apfile)
if not (opened := sr.is_open):
sr.open()
# if no output, need a temp folder to swap for big files
if not output_path:
output_path = raw_ephys_apfile.parent
file_ftcp = Path(output_path).joinpath(f'fronts_times_channel_polarity{uuid.uuid4()}.bin')
# loop over chunks of the raw ephys file
wg = ibldsp.utils.WindowGenerator(sr.ns, int(SYNC_BATCH_SIZE_SECS * sr.fs), overlap=1)
fid_ftcp = open(file_ftcp, 'wb')
for sl in wg.slice:
ss = sr.read_sync(sl)
ind, fronts = ibldsp.utils.fronts(ss, axis=0)
# a = sr.read_sync_analog(sl)
sav = np.c_[(ind[0, :] + sl.start) / sr.fs, ind[1, :], fronts.astype(np.double)]
sav.tofile(fid_ftcp)
# close temp file, read from it and delete
fid_ftcp.close()
tim_chan_pol = np.fromfile(str(file_ftcp))
tim_chan_pol = tim_chan_pol.reshape((int(tim_chan_pol.size / 3), 3))
file_ftcp.unlink()
sync = {'times': tim_chan_pol[:, 0],
'channels': tim_chan_pol[:, 1],
'polarities': tim_chan_pol[:, 2]}
# If opened Reader was passed into function, leave open
if not opened:
sr.close()
if save:
out_files = alfio.save_object_npy(output_path, sync, 'sync',
namespace='spikeglx', parts=parts)
return Bunch(sync), out_files
else:
return Bunch(sync)
def _rotary_encoder_positions_from_fronts(ta, pa, tb, pb, ticks=WHEEL_TICKS, radius=WHEEL_RADIUS_CM, coding='x4'):
"""
Extracts the rotary encoder absolute position as function of time from fronts detected
on the 2 channels. Outputs in units of radius parameters, by default radians
Coding options detailed here: http://www.ni.com/tutorial/7109/pt/
Here output is clockwise from subject perspective
:param ta: time of fronts on channel A
:param pa: polarity of fronts on channel A
:param tb: time of fronts on channel B
:param pb: polarity of fronts on channel B
:param ticks: number of ticks corresponding to a full revolution (1024 for IBL rotary encoder)
:param radius: radius of the wheel. Defaults to 1 for an output in radians
:param coding: x1, x2 or x4 coding (IBL default is x4)
:return: indices vector (ta) and position vector
"""
if coding == 'x1':
ia = np.searchsorted(tb, ta[pa == 1])
ia = ia[ia < ta.size]
ia = ia[pa[ia] == 1]
ib = np.searchsorted(ta, tb[pb == 1])
ib = ib[ib < tb.size]
ib = ib[pb[ib] == 1]
t = np.r_[ta[ia], tb[ib]]
p = np.r_[ia * 0 + 1, ib * 0 - 1]
ordre = np.argsort(t)
t = t[ordre]
p = p[ordre]
p = np.cumsum(p) / ticks * np.pi * 2 * radius
return t, p
elif coding == 'x2':
p = pb[np.searchsorted(tb, ta) - 1] * pa
p = - np.cumsum(p) / ticks * np.pi * 2 * radius / 2
return ta, p
elif coding == 'x4':
p = np.r_[pb[np.searchsorted(tb, ta) - 1] * pa, -pa[np.searchsorted(ta, tb) - 1] * pb]
t = np.r_[ta, tb]
ordre = np.argsort(t)
t = t[ordre]
p = p[ordre]
p = - np.cumsum(p) / ticks * np.pi * 2 * radius / 4
return t, p
def _assign_events_to_trial(t_trial_start, t_event, take='last', t_trial_end=None):
"""
Assign events to a trial given trial start times and event times.
Trials without an event result in nan value in output time vector.
The output has a consistent size with t_trial_start and ready to output to alf.
Parameters
----------
t_trial_start : numpy.array
An array of start times, used to bin edges for assigning values from `t_event`.
t_event : numpy.array
An array of event times to assign to trials.
take : str {'first', 'last'}, int
'first' takes first event > t_trial_start; 'last' takes last event < the next
t_trial_start; an int defines the index to take for events within trial bounds. The index
may be negative.
t_trial_end : numpy.array
Optional array of end times, used to bin edges for assigning values from `t_event`.
Returns
-------
numpy.array
An array the length of `t_trial_start` containing values from `t_event`. Unassigned values
are replaced with np.nan.
See Also
--------
FpgaTrials._assign_events - Assign trial events based on TTL length.
"""
# make sure the events are sorted
try:
assert np.all(np.diff(t_trial_start) >= 0)
except AssertionError:
raise ValueError('Trial starts vector not sorted')
try:
assert np.all(np.diff(t_event) >= 0)
except AssertionError:
raise ValueError('Events vector is not sorted')
# remove events that happened before the first trial start
remove = t_event < t_trial_start[0]
if t_trial_end is not None:
if not np.all(np.diff(t_trial_end) >= 0):
raise ValueError('Trial end vector not sorted')
if not np.all(t_trial_end[:-1] < t_trial_start[1:]):
raise ValueError('Trial end times must not overlap with trial start times')
# remove events between end and next start, and after last end
remove |= t_event > t_trial_end[-1]
for e, s in zip(t_trial_end[:-1], t_trial_start[1:]):
remove |= np.logical_and(s > t_event, t_event >= e)
t_event = t_event[~remove]
ind = np.searchsorted(t_trial_start, t_event) - 1
t_event_nans = np.zeros_like(t_trial_start) * np.nan
# select first or last element matching each trial start
if take == 'last':
iall, iu = np.unique(np.flip(ind), return_index=True)
t_event_nans[iall] = t_event[- (iu - ind.size + 1)]
elif take == 'first':
iall, iu = np.unique(ind, return_index=True)
t_event_nans[iall] = t_event[iu]
else: # if the index is arbitrary, needs to be numeric (could be negative if from the end)
iall = np.unique(ind)
minsize = take + 1 if take >= 0 else - take
# for each trial, take the take nth element if there are enough values in trial
for iu in iall:
match = t_event[iu == ind]
if len(match) >= minsize:
t_event_nans[iu] = match[take]
return t_event_nans
[docs]
def get_sync_fronts(sync, channel_nb, tmin=None, tmax=None):
"""
Return the sync front polarities and times for a given channel.
Parameters
----------
sync : dict
'polarities' of fronts detected on sync trace for all 16 channels and their 'times'.
channel_nb : int
The integer corresponding to the desired sync channel.
tmin : float
The minimum time from which to extract the sync pulses.
tmax : float
The maximum time up to which we extract the sync pulses.
Returns
-------
Bunch
Channel times and polarities.
"""
selection = sync['channels'] == channel_nb
selection = np.logical_and(selection, sync['times'] <= tmax) if tmax else selection
selection = np.logical_and(selection, sync['times'] >= tmin) if tmin else selection
return Bunch({'times': sync['times'][selection],
'polarities': sync['polarities'][selection]})
def _clean_audio(audio, display=False):
"""
one guy wired the 150 Hz camera output onto the soundcard. The effect is to get 150 Hz periodic
square pulses, 2ms up and 4.666 ms down. When this happens we remove all of the intermediate
pulses to repair the audio trace
Here is some helper code
dd = np.diff(audio['times'])
1 / np.median(dd[::2]) # 2ms up
1 / np.median(dd[1::2]) # 4.666 ms down
1 / (np.median(dd[::2]) + np.median(dd[1::2])) # both sum to 150 Hz
This only runs on sessions when the bug is detected and leaves others untouched
"""
DISCARD_THRESHOLD = 0.01
average_150_hz = np.mean(1 / np.diff(audio['times'][audio['polarities'] == 1]) > 140)
naudio = audio['times'].size
if average_150_hz > 0.7 and naudio > 100:
_logger.warning('Soundcard signal on FPGA seems to have been mixed with 150Hz camera')
keep_ind = np.r_[np.diff(audio['times']) > DISCARD_THRESHOLD, False]
keep_ind = np.logical_and(keep_ind, audio['polarities'] == -1)
keep_ind = np.where(keep_ind)[0]
keep_ind = np.sort(np.r_[0, keep_ind, keep_ind + 1, naudio - 1])
if display: # pragma: no cover
from ibllib.plots import squares
squares(audio['times'], audio['polarities'], ax=None, yrange=[-1, 1])
squares(audio['times'][keep_ind], audio['polarities'][keep_ind], yrange=[-1, 1])
audio = {'times': audio['times'][keep_ind],
'polarities': audio['polarities'][keep_ind]}
return audio
def _clean_frame2ttl(frame2ttl, threshold=0.01, display=False):
"""
Clean the frame2ttl events.
Frame 2ttl calibration can be unstable and the fronts may be flickering at an unrealistic
pace. This removes the consecutive frame2ttl pulses happening too fast, below a threshold
of F2TTL_THRESH.
Parameters
----------
frame2ttl : dict
A dictionary of frame2TTL events, with keys {'times', 'polarities'}.
threshold : float
Consecutive pulses occurring with this many seconds ignored.
display : bool
If true, plots the input TTLs and the cleaned output.
Returns
-------
"""
dt = np.diff(frame2ttl['times'])
iko = np.where(np.logical_and(dt < threshold, frame2ttl['polarities'][:-1] == -1))[0]
iko = np.unique(np.r_[iko, iko + 1])
frame2ttl_ = {'times': np.delete(frame2ttl['times'], iko),
'polarities': np.delete(frame2ttl['polarities'], iko)}
if iko.size > (0.1 * frame2ttl['times'].size):
_logger.warning(f'{iko.size} ({iko.size / frame2ttl["times"].size:.2%}) '
f'frame to TTL polarity switches below {threshold} secs')
if display: # pragma: no cover
fig, (ax0, ax1) = plt.subplots(2, sharex=True)
plots.squares(frame2ttl['times'] * 1000, frame2ttl['polarities'], yrange=[0.1, 0.9], ax=ax0)
plots.squares(frame2ttl_['times'] * 1000, frame2ttl_['polarities'], yrange=[1.1, 1.9], ax=ax1)
import seaborn as sns
sns.displot(dt[dt < 0.05], binwidth=0.0005)
return frame2ttl_
def _get_all_probes_sync(session_path, bin_exists=True):
# round-up of all bin ephys files in the session, infer revision and get sync map
ephys_files = spikeglx.glob_ephys_files(session_path, bin_exists=bin_exists)
version = spikeglx.get_neuropixel_version_from_files(ephys_files)
# attach the sync information to each binary file found
for ef in ephys_files:
ef['sync'] = alfio.load_object(ef.path, 'sync', namespace='spikeglx', short_keys=True)
ef['sync_map'] = get_ibl_sync_map(ef, version)
return ephys_files
[docs]
def get_wheel_positions(sync, chmap, tmin=None, tmax=None):
"""
Gets the wheel position from synchronisation pulses
Parameters
----------
sync : dict
A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses and
the corresponding channel numbers.
chmap : dict[str, int]
A map of channel names and their corresponding indices.
tmin : float
The minimum time from which to extract the sync pulses.
tmax : float
The maximum time up to which we extract the sync pulses.
Returns
-------
Bunch
A dictionary with keys ('timestamps', 'position'), containing the wheel event timestamps and
position in radians
Bunch
A dictionary of detected movement times with keys ('intervals', 'peakAmplitude', 'peakVelocity_times').
"""
ts, pos = extract_wheel_sync(sync=sync, chmap=chmap, tmin=tmin, tmax=tmax)
moves = Bunch(extract_wheel_moves(ts, pos))
wheel = Bunch({'timestamps': ts, 'position': pos})
return wheel, moves
[docs]
def get_main_probe_sync(session_path, bin_exists=False):
"""
From 3A or 3B multiprobe session, returns the main probe (3A) or nidq sync pulses
with the attached channel map (default chmap if none)
Parameters
----------
session_path : str, pathlib.Path
The absolute session path, i.e. '/path/to/subject/yyyy-mm-dd/nnn'.
bin_exists : bool
Whether there is a .bin file present.
Returns
-------
one.alf.io.AlfBunch
A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses and
the corresponding channel numbers.
dict
A map of channel names and their corresponding indices.
"""
ephys_files = _get_all_probes_sync(session_path, bin_exists=bin_exists)
if not ephys_files:
raise FileNotFoundError(f"No ephys files found in {session_path}")
version = spikeglx.get_neuropixel_version_from_files(ephys_files)
if version == '3A':
# the sync master is the probe with the most sync pulses
sync_box_ind = np.argmax([ef.sync.times.size for ef in ephys_files])
elif version == '3B':
# the sync master is the nidq breakout box
sync_box_ind = np.argmax([1 if ef.get('nidq') else 0 for ef in ephys_files])
sync = ephys_files[sync_box_ind].sync
sync_chmap = ephys_files[sync_box_ind].sync_map
return sync, sync_chmap
[docs]
def get_protocol_period(session_path, protocol_number, bpod_sync):
"""
Parameters
----------
session_path : str, pathlib.Path
The absolute session path, i.e. '/path/to/subject/yyyy-mm-dd/nnn'.
protocol_number : int
The order that the protocol was run in.
bpod_sync : dict
The sync times and polarities for Bpod BNC1.
Returns
-------
float
The time of the detected spacer for the protocol number.
float, None
The time of the next detected spacer or None if this is the last protocol run.
"""
# The spacers are TTLs generated by Bpod at the start of each protocol
spacer_times = Spacer().find_spacers_from_fronts(bpod_sync)
# Ensure that the number of detected spacers matched the number of expected tasks
if acquisition_description := session_params.read_params(session_path):
n_tasks = len(acquisition_description.get('tasks', []))
assert len(spacer_times) >= protocol_number, (f'expected {n_tasks} spacers, found only {len(spacer_times)} - '
f'can not return protocol number {protocol_number}.')
assert n_tasks > protocol_number >= 0, f'protocol number must be between 0 and {n_tasks}'
else:
assert protocol_number < len(spacer_times)
start = spacer_times[int(protocol_number)]
end = None if len(spacer_times) - 1 == protocol_number else spacer_times[int(protocol_number + 1)]
return start, end
[docs]
class FpgaTrials(extractors_base.BaseExtractor):
save_names = ('_ibl_trials.goCueTrigger_times.npy', '_ibl_trials.stimOnTrigger_times.npy',
'_ibl_trials.stimOffTrigger_times.npy', None, None, None, None, None,
'_ibl_trials.stimOff_times.npy', None, None, None, '_ibl_trials.quiescencePeriod.npy',
'_ibl_trials.table.pqt', '_ibl_wheel.timestamps.npy',
'_ibl_wheel.position.npy', '_ibl_wheelMoves.intervals.npy',
'_ibl_wheelMoves.peakAmplitude.npy', None)
var_names = ('goCueTrigger_times', 'stimOnTrigger_times',
'stimOffTrigger_times', 'stimFreezeTrigger_times', 'errorCueTrigger_times',
'errorCue_times', 'itiIn_times', 'stimFreeze_times', 'stimOff_times',
'valveOpen_times', 'phase', 'position', 'quiescence', 'table',
'wheel_timestamps', 'wheel_position',
'wheelMoves_intervals', 'wheelMoves_peakAmplitude', 'wheelMoves_peakVelocity_times')
bpod_rsync_fields = ('intervals', 'response_times', 'goCueTrigger_times',
'stimOnTrigger_times', 'stimOffTrigger_times',
'stimFreezeTrigger_times', 'errorCueTrigger_times')
"""tuple of str: Fields from Bpod extractor that we want to re-sync to FPGA."""
bpod_fields = ('feedbackType', 'choice', 'rewardVolume', 'contrastLeft', 'contrastRight',
'probabilityLeft', 'phase', 'position', 'quiescence')
"""tuple of str: Fields from bpod extractor that we want to save."""
sync_field = 'intervals_0' # trial start events
"""str: The trial event to synchronize (must be present in extracted trials)."""
bpod = None
"""dict of numpy.array: The Bpod out TTLs recorded on the DAQ. Used in the QC viewer plot."""
def __init__(self, *args, bpod_trials=None, bpod_extractor=None, **kwargs):
"""An extractor for ephysChoiceWorld trials data, in FPGA time.
This class may be subclassed to handle moderate variations in hardware and task protocol,
however there is flexible
"""
super().__init__(*args, **kwargs)
self.bpod2fpga = None
self.bpod_trials = bpod_trials
self.frame2ttl = self.audio = self.bpod = self.settings = None
if bpod_extractor:
self.bpod_extractor = bpod_extractor
self._update_var_names()
def _update_var_names(self, bpod_fields=None, bpod_rsync_fields=None):
"""
Updates this object's attributes based on the Bpod trials extractor.
Fields updated: bpod_fields, bpod_rsync_fields, save_names, and var_names.
Parameters
----------
bpod_fields : tuple
A set of Bpod trials fields to keep.
bpod_rsync_fields : tuple
A set of Bpod trials fields to sync to the DAQ times.
"""
if self.bpod_extractor:
for var_name, save_name in zip(self.bpod_extractor.var_names, self.bpod_extractor.save_names):
if var_name not in self.var_names:
self.var_names += (var_name,)
self.save_names += (save_name,)
# self.var_names = self.bpod_extractor.var_names
# self.save_names = self.bpod_extractor.save_names
self.settings = self.bpod_extractor.settings # This is used by the TaskQC
self.bpod_rsync_fields = bpod_rsync_fields
if self.bpod_rsync_fields is None:
self.bpod_rsync_fields = tuple(self._time_fields(self.bpod_extractor.var_names))
if 'table' in self.bpod_extractor.var_names:
if not self.bpod_trials:
self.bpod_trials = self.bpod_extractor.extract(save=False)
table_keys = alfio.AlfBunch.from_df(self.bpod_trials['table']).keys()
self.bpod_rsync_fields += tuple(self._time_fields(table_keys))
elif bpod_rsync_fields:
self.bpod_rsync_fields = bpod_rsync_fields
excluded = (*self.bpod_rsync_fields, 'table')
if bpod_fields:
assert not set(self.bpod_fields).intersection(excluded), 'bpod_fields must not also be bpod_rsync_fields'
self.bpod_fields = bpod_fields
elif self.bpod_extractor:
self.bpod_fields = tuple(x for x in self.bpod_extractor.var_names if x not in excluded)
if 'table' in self.bpod_extractor.var_names:
if not self.bpod_trials:
self.bpod_trials = self.bpod_extractor.extract(save=False)
table_keys = alfio.AlfBunch.from_df(self.bpod_trials['table']).keys()
self.bpod_fields += tuple([x for x in table_keys if x not in excluded])
@staticmethod
def _time_fields(trials_attr) -> set:
"""
Iterates over Bpod trials attributes returning those that correspond to times for syncing.
Parameters
----------
trials_attr : iterable of str
The Bpod field names.
Returns
-------
set
The field names that contain timestamps.
"""
FIELDS = ('times', 'timestamps', 'intervals')
pattern = re.compile(fr'^[_\w]*({"|".join(FIELDS)})[_\w]*$')
return set(filter(pattern.match, trials_attr))
[docs]
def load_sync(self, sync_collection='raw_ephys_data', **kwargs):
"""Load the DAQ sync and channel map data.
This method may be subclassed for novel DAQ systems. The sync must contain the following
keys: 'times' - an array timestamps in seconds; 'polarities' - an array of {-1, 1}
corresponding to TTL LOW and TTL HIGH, respectively; 'channels' - an array of ints
corresponding to channel number.
Parameters
----------
sync_collection : str
The session subdirectory where the sync data are located.
kwargs
Optional arguments used by subclass methods.
Returns
-------
one.alf.io.AlfBunch
A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses
and the corresponding channel numbers.
dict
A map of channel names and their corresponding indices.
"""
return get_sync_and_chn_map(self.session_path, sync_collection)
def _extract(self, sync=None, chmap=None, sync_collection='raw_ephys_data',
task_collection='raw_behavior_data', **kwargs) -> dict:
"""Extracts ephys trials by combining Bpod and FPGA sync pulses.
It is essential that the `var_names`, `bpod_rsync_fields`, `bpod_fields`, and `sync_field`
attributes are all correct for the bpod protocol used.
Below are the steps involved:
0. Load sync and bpod trials, if required.
1. Determine protocol period and discard sync events outside the task.
2. Classify multiplexed TTL events based on length (see :meth:`FpgaTrials.build_trials`).
3. Sync the Bpod clock to the DAQ clock using one of the assigned trial events.
4. Assign classified TTL events to trial events based on order within the trial.
4. Convert Bpod software event times to DAQ clock.
5. Extract the wheel from the DAQ rotary encoder signal, if required.
Parameters
----------
sync : dict
A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses
and the corresponding channel numbers. If None, the sync is loaded using the
`load_sync` method.
chmap : dict
A map of channel names and their corresponding indices. If None, the channel map is
loaded using the :meth:`FpgaTrials.load_sync` method.
sync_collection : str
The session subdirectory where the sync data are located. This is only used if the
sync or channel maps are not provided.
task_collection : str
The session subdirectory where the raw Bpod data are located. This is used for loading
the task settings and extracting the bpod trials, if not already done.
protocol_number : int
The protocol number if multiple protocols were run during the session. If provided, a
spacer signal must be present in order to determine the correct period.
kwargs
Optional arguments for subclass methods to use.
Returns
-------
dict
A dictionary of numpy arrays with `FpgaTrials.var_names` as keys.
"""
if sync is None or chmap is None:
_sync, _chmap = self.load_sync(sync_collection)
sync = sync or _sync
chmap = chmap or _chmap
if not self.bpod_trials: # extract the behaviour data from bpod
self.extractor = get_bpod_extractor(self.session_path, task_collection=task_collection)
_logger.info('Bpod trials extractor: %s.%s', self.extractor.__module__, self.extractor.__class__.__name__)
self.bpod_trials, *_ = self.extractor.extract(task_collection=task_collection, save=False, **kwargs)
# Explode trials table df
if 'table' in self.var_names:
trials_table = alfio.AlfBunch.from_df(self.bpod_trials.pop('table'))
table_columns = trials_table.keys()
self.bpod_trials.update(trials_table)
else:
if 'table' in self.bpod_trials:
_logger.error(
'"table" found in Bpod trials but missing from `var_names` attribute and will'
'therefore not be extracted. This is likely in error.')
table_columns = None
bpod = get_sync_fronts(sync, chmap['bpod'])
# Get the spacer times for this protocol
if any(arg in kwargs for arg in ('tmin', 'tmax')):
tmin, tmax = kwargs.get('tmin'), kwargs.get('tmax')
elif (protocol_number := kwargs.get('protocol_number')) is not None: # look for spacer
# The spacers are TTLs generated by Bpod at the start of each protocol
tmin, tmax = get_protocol_period(self.session_path, protocol_number, bpod)
else:
# Older sessions don't have protocol spacers so we sync the Bpod intervals here to
# find the approximate end time of the protocol (this will exclude the passive signals
# in ephysChoiceWorld that tend to ruin the final trial extraction).
_, trial_ints = self.get_bpod_event_times(sync, chmap, **kwargs)
t_trial_start = trial_ints.get('trial_start', np.array([[np.nan, np.nan]]))[:, 0]
bpod_start = self.bpod_trials['intervals'][:, 0]
if len(t_trial_start) > len(bpod_start) / 2: # if least half the trial start TTLs detected
_logger.warning('Attempting to get protocol period from aligning trial start TTLs')
fcn, *_ = ibldsp.utils.sync_timestamps(bpod_start, t_trial_start)
buffer = 2.5 # the number of seconds to include before/after task
start, end = fcn(self.bpod_trials['intervals'].flat[[0, -1]])
tmin = min(sync['times'][0], start - buffer)
tmax = max(sync['times'][-1], end + buffer)
else: # This type of alignment fails for some sessions, e.g. mesoscope
tmin = tmax = None
# Remove unnecessary data from sync
selection = np.logical_and(
sync['times'] <= (tmax if tmax is not None else sync['times'][-1]),
sync['times'] >= (tmin if tmin is not None else sync['times'][0]),
)
sync = alfio.AlfBunch({k: v[selection] for k, v in sync.items()})
_logger.debug('Protocol period from %.2fs to %.2fs (~%.0f min duration)',
*sync['times'][[0, -1]], np.diff(sync['times'][[0, -1]]) / 60)
# Get the trial events from the DAQ sync TTLs, sync clocks and build final trials datasets
out = self.build_trials(sync=sync, chmap=chmap, **kwargs)
# extract the wheel data
if any(x.startswith('wheel') for x in self.var_names):
wheel, moves = self.get_wheel_positions(sync=sync, chmap=chmap, tmin=tmin, tmax=tmax)
from ibllib.io.extractors.training_wheel import extract_first_movement_times
if not self.settings:
self.settings = raw.load_settings(session_path=self.session_path, task_collection=task_collection)
min_qt = self.settings.get('QUIESCENT_PERIOD', None)
first_move_onsets, *_ = extract_first_movement_times(moves, out, min_qt=min_qt)
out.update({'firstMovement_times': first_move_onsets})
out.update({f'wheel_{k}': v for k, v in wheel.items()})
out.update({f'wheelMoves_{k}': v for k, v in moves.items()})
# Re-create trials table
if table_columns:
trials_table = alfio.AlfBunch({x: out.pop(x) for x in table_columns})
out['table'] = trials_table.to_df()
out = alfio.AlfBunch({k: out[k] for k in self.var_names if k in out}) # Reorder output
assert self.var_names == tuple(out.keys())
return out
def _is_trials_object_attribute(self, var_name, variable_length_vars=None):
"""
Check if variable name is expected to have the same length as trials.intervals.
Parameters
----------
var_name : str
The variable name to check.
variable_length_vars : list
Set of variable names that are not expected to have the same length as trials.intervals.
This list may be passed by superclasses.
Returns
-------
bool
True if variable is a trials dataset.
Examples
--------
>>> assert self._is_trials_object_attribute('stimOnTrigger_times') is True
>>> assert self._is_trials_object_attribute('wheel_position') is False
"""
save_name = self.save_names[self.var_names.index(var_name)] if var_name in self.var_names else None
if save_name:
return filename_parts(save_name)[1] == 'trials'
else:
return var_name not in (variable_length_vars or [])
[docs]
def build_trials(self, sync, chmap, display=False, **kwargs):
"""
Extract task related event times from the sync.
The trial start times are the shortest Bpod TTLs and occur at the start of the trial. The
first trial start TTL of the session is longer and must be handled differently. The trial
start TTL is used to assign the other trial events to each trial.
The trial end is the end of the so-called 'ITI' Bpod event TTL (classified as the longest
of the three Bpod event TTLs). Go cue audio TTLs are the shorter of the two expected audio
tones. The first of these after each trial start is taken to be the go cue time. Error
tones are longer audio TTLs and assigned as the last of such occurrence after each trial
start. The valve open Bpod TTLs are medium-length, the last of which is used for each trial.
The feedback times are times of either valve open or error tone as there should be only one
such event per trial.
The stimulus times are taken from the frame2ttl events (with improbably high frequency TTLs
removed): the first TTL after each trial start is assumed to be the stim onset time; the
second to last and last are taken as the stimulus freeze and offset times, respectively.
Parameters
----------
sync : dict
'polarities' of fronts detected on sync trace for all 16 chans and their 'times'
chmap : dict
Map of channel names and their corresponding index. Default to constant.
display : bool, matplotlib.pyplot.Axes
Show the full session sync pulses display.
Returns
-------
dict
A map of trial event timestamps.
"""
# Get the events from the sync.
# Store the cleaned frame2ttl, audio, and bpod pulses as this will be used for QC
self.frame2ttl = self.get_stimulus_update_times(sync, chmap, **kwargs)
self.audio, audio_event_intervals = self.get_audio_event_times(sync, chmap, **kwargs)
if not set(audio_event_intervals.keys()) >= {'ready_tone', 'error_tone'}:
raise ValueError(
'Expected at least "ready_tone" and "error_tone" audio events.'
'`audio_event_ttls` kwarg may be incorrect.')
self.bpod, bpod_event_intervals = self.get_bpod_event_times(sync, chmap, **kwargs)
if not set(bpod_event_intervals.keys()) >= {'trial_start', 'valve_open', 'trial_end'}:
raise ValueError(
'Expected at least "trial_start", "trial_end", and "valve_open" audio events. '
'`bpod_event_ttls` kwarg may be incorrect.')
t_iti_in, t_trial_end = bpod_event_intervals['trial_end'].T
fpga_events = alfio.AlfBunch({
'goCue_times': audio_event_intervals['ready_tone'][:, 0],
'errorCue_times': audio_event_intervals['error_tone'][:, 0],
'valveOpen_times': bpod_event_intervals['valve_open'][:, 0],
'valveClose_times': bpod_event_intervals['valve_open'][:, 1],
'itiIn_times': t_iti_in,
'intervals_0': bpod_event_intervals['trial_start'][:, 0],
'intervals_1': t_trial_end
})
# Sync the Bpod clock to the DAQ.
# NB: The Bpod extractor typically drops the final, incomplete, trial. Hence there is
# usually at least one extra FPGA event. This shouldn't affect the sync. The final trial is
# dropped after assigning the FPGA events, using the `ifpga` index. Doing this after
# assigning the FPGA trial events ensures the last trial has the correct timestamps.
self.bpod2fpga, drift_ppm, ibpod, ifpga = self.sync_bpod_clock(self.bpod_trials, fpga_events, self.sync_field)
if np.any(np.diff(ibpod) != 1) and self.sync_field == 'intervals_0':
# One issue is that sometimes pulses may not have been detected, in this case
# add the events that have not been detected and re-extract the behaviour sync.
# This is only really relevant for the Bpod interval events as the other TTLs are
# from devices where a missing TTL likely means the Bpod event was truly absent.
_logger.warning('Missing Bpod TTLs; reassigning events using aligned Bpod start times')
bpod_start = self.bpod_trials['intervals'][:, 0]
missing_bpod = self.bpod2fpga(bpod_start[np.setxor1d(ibpod, np.arange(len(bpod_start)))])
t_trial_start = np.sort(np.r_[fpga_events['intervals_0'][:, 0], missing_bpod])
else:
t_trial_start = fpga_events['intervals_0']
t_trial_start = t_trial_start[ifpga]
out = alfio.AlfBunch()
# Add the Bpod trial events, converting the timestamp fields to FPGA time.
# NB: The trial intervals are by default a Bpod rsync field.
out.update({k: self.bpod_trials[k][ibpod] for k in self.bpod_fields})
for k in self.bpod_rsync_fields:
# Some personal projects may extract non-trials object datasets that may not have 1 event per trial
idx = ibpod if self._is_trials_object_attribute(k) else np.arange(len(self.bpod_trials[k]), dtype=int)
out[k] = self.bpod2fpga(self.bpod_trials[k][idx])
f2ttl_t = self.frame2ttl['times']
# Assign the FPGA events to individual trials
fpga_trials = {
'goCue_times': _assign_events_to_trial(t_trial_start, fpga_events['goCue_times'], take='first'),
'errorCue_times': _assign_events_to_trial(t_trial_start, fpga_events['errorCue_times']),
'valveOpen_times': _assign_events_to_trial(t_trial_start, fpga_events['valveOpen_times']),
'itiIn_times': _assign_events_to_trial(t_trial_start, fpga_events['itiIn_times']),
'stimOn_times': np.full_like(t_trial_start, np.nan),
'stimOff_times': np.full_like(t_trial_start, np.nan),
'stimFreeze_times': np.full_like(t_trial_start, np.nan)
}
# f2ttl times are unreliable owing to calibration and Bonsai sync square update issues.
# Take the first event after the FPGA aligned stimulus trigger time.
fpga_trials['stimOn_times'] = _assign_events_to_trial(
out['stimOnTrigger_times'], f2ttl_t, take='first', t_trial_end=out['stimOffTrigger_times'])
fpga_trials['stimOff_times'] = _assign_events_to_trial(
out['stimOffTrigger_times'], f2ttl_t, take='first', t_trial_end=out['intervals'][:, 1])
# For stim freeze we take the last event before the stim off trigger time.
# To avoid assigning early events (e.g. for sessions where there are few flips due to
# mis-calibration), we discount events before stim freeze trigger times (or stim on trigger
# times for versions below 6.2.5). We take the last event rather than the first after stim
# freeze trigger because often there are multiple flips after the trigger, presumably
# before the stim actually stops.
stim_freeze = np.copy(out['stimFreezeTrigger_times'])
go_trials = np.where(out['choice'] != 0)[0]
# NB: versions below 6.2.5 have no trigger times so use stim on trigger times
lims = np.copy(out['stimOnTrigger_times'])
if not np.isnan(stim_freeze).all():
# Stim freeze times are NaN for nogo trials, but for all others use stim freeze trigger
# times. _assign_events_to_trial requires ascending timestamps so no NaNs allowed.
lims[go_trials] = stim_freeze[go_trials]
# take last event after freeze/stim on trigger, before stim off trigger
stim_freeze = _assign_events_to_trial(lims, f2ttl_t, take='last', t_trial_end=out['stimOffTrigger_times'])
fpga_trials['stimFreeze_times'][go_trials] = stim_freeze[go_trials]
# Feedback times are valve open on correct trials and error tone in on incorrect trials
fpga_trials['feedback_times'] = np.copy(fpga_trials['valveOpen_times'])
ind_err = np.isnan(fpga_trials['valveOpen_times'])
fpga_trials['feedback_times'][ind_err] = fpga_trials['errorCue_times'][ind_err]
out.update({k: fpga_trials[k] for k in fpga_trials.keys()})
if display: # pragma: no cover
width = 0.5
ymax = 5
if isinstance(display, bool):
plt.figure('Bpod FPGA Sync')
ax = plt.gca()
else:
ax = display
plots.squares(self.bpod['times'], self.bpod['polarities'] * 0.4 + 1, ax=ax, color='k')
plots.squares(self.frame2ttl['times'], self.frame2ttl['polarities'] * 0.4 + 2, ax=ax, color='k')
plots.squares(self.audio['times'], self.audio['polarities'] * 0.4 + 3, ax=ax, color='k')
color_map = TABLEAU_COLORS.keys()
for (event_name, event_times), c in zip(fpga_events.items(), cycle(color_map)):
plots.vertical_lines(event_times, ymin=0, ymax=ymax, ax=ax, color=c, label=event_name, linewidth=width)
# Plot the stimulus events along with the trigger times
stim_events = filter(lambda t: 'stim' in t[0], fpga_trials.items())
for (event_name, event_times), c in zip(stim_events, cycle(color_map)):
plots.vertical_lines(
event_times, ymin=0, ymax=ymax, ax=ax, color=c, label=event_name, linewidth=width, linestyle='--')
nm = event_name.replace('_times', 'Trigger_times')
plots.vertical_lines(
out[nm], ymin=0, ymax=ymax, ax=ax, color=c, label=nm, linewidth=width, linestyle=':')
ax.legend()
ax.set_yticks([0, 1, 2, 3])
ax.set_yticklabels(['', 'bpod', 'f2ttl', 'audio'])
ax.set_ylim([0, 5])
return out
[docs]
def get_wheel_positions(self, *args, **kwargs):
"""Extract wheel and wheelMoves objects.
This method is called by the main extract method and may be overloaded by subclasses.
"""
return get_wheel_positions(*args, **kwargs)
[docs]
def get_stimulus_update_times(self, sync, chmap, display=False, **_):
"""
Extract stimulus update times from sync.
Gets the stimulus times from the frame2ttl channel and cleans the signal.
Parameters
----------
sync : dict
A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses
and the corresponding channel numbers.
chmap : dict
A map of channel names and their corresponding indices. Must contain a 'frame2ttl' key.
display : bool
If true, plots the input TTLs and the cleaned output.
Returns
-------
dict
A dictionary with keys {'times', 'polarities'} containing stimulus TTL fronts.
"""
frame2ttl = get_sync_fronts(sync, chmap['frame2ttl'])
frame2ttl = _clean_frame2ttl(frame2ttl, display=display)
return frame2ttl
[docs]
def get_audio_event_times(self, sync, chmap, audio_event_ttls=None, display=False, **_):
"""
Extract audio times from sync.
Gets the TTL times from the 'audio' channel, cleans the signal, and classifies each TTL
event by length.
Parameters
----------
sync : dict
A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses
and the corresponding channel numbers.
chmap : dict
A map of channel names and their corresponding indices. Must contain an 'audio' key.
audio_event_ttls : dict
A map of event names to (min, max) TTL length.
display : bool
If true, plots the input TTLs and the cleaned output.
Returns
-------
dict
A dictionary with keys {'times', 'polarities'} containing audio TTL fronts.
dict
A dictionary of events (from `audio_event_ttls`) and their intervals as an Nx2 array.
"""
audio = get_sync_fronts(sync, chmap['audio'])
audio = _clean_audio(audio)
if audio['times'].size == 0:
_logger.error('No audio sync fronts found.')
if audio_event_ttls is None:
# For training/biased/ephys protocols, the ready tone should be below 110 ms. The error
# tone should be between 400ms and 1200ms
audio_event_ttls = {'ready_tone': (0, 0.1101), 'error_tone': (0.4, 1.2)}
audio_event_intervals = self._assign_events(audio['times'], audio['polarities'], audio_event_ttls, display=display)
return audio, audio_event_intervals
[docs]
def get_bpod_event_times(self, sync, chmap, bpod_event_ttls=None, display=False, **kwargs):
"""
Extract Bpod times from sync.
Gets the Bpod TTL times from the sync 'bpod' channel and classifies each TTL event by
length. NB: The first trial has an abnormal trial_start TTL that is usually mis-assigned.
This method accounts for this.
Parameters
----------
sync : dict
A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses
and the corresponding channel numbers. Must contain a 'bpod' key.
chmap : dict
A map of channel names and their corresponding indices.
bpod_event_ttls : dict of tuple
A map of event names to (min, max) TTL length.
Returns
-------
dict
A dictionary with keys {'times', 'polarities'} containing Bpod TTL fronts.
dict
A dictionary of events (from `bpod_event_ttls`) and their intervals as an Nx2 array.
"""
bpod = get_sync_fronts(sync, chmap['bpod'])
if bpod.times.size == 0:
raise err.SyncBpodFpgaException('No Bpod event found in FPGA. No behaviour extraction. '
'Check channel maps.')
# Assign the Bpod BNC2 events based on TTL length. The defaults are below, however these
# lengths are defined by the state machine of the task protocol and therefore vary.
if bpod_event_ttls is None:
# For training/biased/ephys protocols, the trial start TTL length is 0.1ms but this has
# proven to drift on some Bpods and this is the highest possible value that
# discriminates trial start from valve. Valve open events are between 50ms to 300 ms.
# ITI events are above 400 ms.
bpod_event_ttls = {
'trial_start': (0, 2.33e-4), 'valve_open': (2.33e-4, 0.4), 'trial_end': (0.4, np.inf)}
bpod_event_intervals = self._assign_events(
bpod['times'], bpod['polarities'], bpod_event_ttls, display=display)
if 'trial_start' not in bpod_event_intervals or bpod_event_intervals['trial_start'].size == 0:
return bpod, bpod_event_intervals
# The first trial pulse is longer and often assigned to another event.
# Here we move the earliest non-trial_start event to the trial_start array.
t0 = bpod_event_intervals['trial_start'][0, 0] # expect 1st event to be trial_start
pretrial = [(k, v[0, 0]) for k, v in bpod_event_intervals.items() if v.size and v[0, 0] < t0]
if pretrial:
(pretrial, _) = sorted(pretrial, key=lambda x: x[1])[0] # take the earliest event
dt = np.diff(bpod_event_intervals[pretrial][0, :]) * 1e3 # record TTL length to log
_logger.debug('Reassigning first %s to trial_start. TTL length = %.3g ms', pretrial, dt)
bpod_event_intervals['trial_start'] = np.r_[
bpod_event_intervals[pretrial][0:1, :], bpod_event_intervals['trial_start']
]
bpod_event_intervals[pretrial] = bpod_event_intervals[pretrial][1:, :]
return bpod, bpod_event_intervals
@staticmethod
def _assign_events(ts, polarities, event_lengths, precedence='shortest', display=False):
"""
Classify TTL events by length.
Outputs the synchronisation events such as trial intervals, valve opening, and audio.
Parameters
----------
ts : numpy.array
Numpy vector containing times of TTL fronts.
polarities : numpy.array
Numpy vector containing polarity of TTL fronts (1 rise, -1 fall).
event_lengths : dict of tuple
A map of TTL events and the range of permissible lengths, where l0 < ttl <= l1.
precedence : str {'shortest', 'longest', 'dict order'}
In the case of overlapping event TTL lengths, assign shortest/longest first or go by
the `event_lengths` dict order.
display : bool
If true, plots the TTLs with coloured lines delineating the assigned events.
Returns
-------
Dict[str, numpy.array]
A dictionary of events and their intervals as an Nx2 array.
See Also
--------
_assign_events_to_trial - classify TTLs by event order within a given trial period.
"""
event_intervals = dict.fromkeys(event_lengths)
assert 'unassigned' not in event_lengths.keys()
if len(ts) == 0:
return {k: np.array([[], []]).T for k in (*event_lengths.keys(), 'unassigned')}
# make sure that there are no 2 consecutive fall or consecutive rise events
assert np.all(np.abs(np.diff(polarities)) == 2)
if polarities[0] == -1:
ts = np.delete(ts, 0)
if polarities[-1] == 1: # if the final TTL is left HIGH, insert a NaN
ts = np.r_[ts, np.nan]
# take only even time differences: i.e. from rising to falling fronts
dt = np.diff(ts)[::2]
# Assign events from shortest TTL to largest
assigned = np.zeros(ts.shape, dtype=bool)
if precedence.lower() == 'shortest':
event_items = sorted(event_lengths.items(), key=lambda x: np.diff(x[1]))
elif precedence.lower() == 'longest':
event_items = sorted(event_lengths.items(), key=lambda x: np.diff(x[1]), reverse=True)
elif precedence.lower() == 'dict order':
event_items = event_lengths.items()
else:
raise ValueError(f'Precedence must be one of "shortest", "longest", "dict order", got "{precedence}".')
for event, (min_len, max_len) in event_items:
_logger.debug('%s: %.4G < ttl <= %.4G', event, min_len, max_len)
i_event = np.where(np.logical_and(dt > min_len, dt <= max_len))[0] * 2
i_event = i_event[np.where(~assigned[i_event])[0]] # remove those already assigned
event_intervals[event] = np.c_[ts[i_event], ts[i_event + 1]]
assigned[np.r_[i_event, i_event + 1]] = True
# Include the unassigned events for convenience and debugging
event_intervals['unassigned'] = ts[~assigned].reshape(-1, 2)
# Assert that event TTLs mutually exclusive
all_assigned = np.concatenate(list(event_intervals.values())).flatten()
assert all_assigned.size == np.unique(all_assigned).size, 'TTLs assigned to multiple events'
# some debug plots when needed
if display: # pragma: no cover
plt.figure()
plots.squares(ts, polarities, label='raw fronts')
for event, intervals in event_intervals.items():
plots.vertical_lines(intervals[:, 0], ymin=-0.2, ymax=1.1, linewidth=0.5, label=event)
plt.legend()
# Return map of event intervals in the same order as `event_lengths` dict
return {k: event_intervals[k] for k in (*event_lengths, 'unassigned')}
[docs]
@staticmethod
def sync_bpod_clock(bpod_trials, fpga_trials, sync_field):
"""
Sync the Bpod clock to FPGA one using the provided trial event.
It assumes that `sync_field` is in both `fpga_trials` and `bpod_trials`. Syncing on both
intervals is not supported so to sync on trial start times, `sync_field` should be
'intervals_0'.
Parameters
----------
bpod_trials : dict
A dictionary of extracted Bpod trial events.
fpga_trials : dict
A dictionary of TTL events extracted from FPGA sync (see `extract_behaviour_sync`
method).
sync_field : str
The trials key to use for syncing clocks. For intervals (i.e. Nx2 arrays) append the
column index, e.g. 'intervals_0'.
Returns
-------
function
Interpolation function such that f(timestamps_bpod) = timestamps_fpga.
float
The clock drift in parts per million.
numpy.array of int
The indices of the Bpod trial events in the FPGA trial events array.
numpy.array of int
The indices of the FPGA trial events in the Bpod trial events array.
Raises
------
ValueError
The key `sync_field` was not found in either the `bpod_trials` or `fpga_trials` dicts.
"""
_logger.info(f'Attempting to align Bpod clock to DAQ using trial event "{sync_field}"')
bpod_fpga_timestamps = [None, None]
for i, trials in enumerate((bpod_trials, fpga_trials)):
if sync_field not in trials:
# handle syncing on intervals
if not (m := re.match(r'(.*)_(\d)', sync_field)):
# If missing from bpod trials, either the sync field is incorrect,
# or the Bpod extractor is incorrect. If missing from the fpga events, check
# the sync field and the `extract_behaviour_sync` method.
raise ValueError(
f'Sync field "{sync_field}" not in extracted {"fpga" if i else "bpod"} events')
_sync_field, n = m.groups()
bpod_fpga_timestamps[i] = trials[_sync_field][:, int(n)]
else:
bpod_fpga_timestamps[i] = trials[sync_field]
# Sync the two timestamps
fcn, drift, ibpod, ifpga = ibldsp.utils.sync_timestamps(*bpod_fpga_timestamps, return_indices=True)
# If it's drifting too much throw warning or error
_logger.info('N trials: %i bpod, %i FPGA, %i merged, sync %.5f ppm',
*map(len, bpod_fpga_timestamps), len(ibpod), drift)
if drift > 200 and bpod_fpga_timestamps[0].size != bpod_fpga_timestamps[1].size:
raise err.SyncBpodFpgaException('sync cluster f*ck')
elif drift > BPOD_FPGA_DRIFT_THRESHOLD_PPM:
_logger.warning('BPOD/FPGA synchronization shows values greater than %.2f ppm',
BPOD_FPGA_DRIFT_THRESHOLD_PPM)
return fcn, drift, ibpod, ifpga
[docs]
class FpgaTrialsHabituation(FpgaTrials):
"""Extract habituationChoiceWorld trial events from an NI DAQ."""
save_names = ('_ibl_trials.stimCenter_times.npy', '_ibl_trials.feedbackType.npy', '_ibl_trials.rewardVolume.npy',
'_ibl_trials.stimOff_times.npy', '_ibl_trials.contrastLeft.npy', '_ibl_trials.contrastRight.npy',
'_ibl_trials.feedback_times.npy', '_ibl_trials.stimOn_times.npy', '_ibl_trials.stimOnTrigger_times.npy',
'_ibl_trials.intervals.npy', '_ibl_trials.goCue_times.npy', '_ibl_trials.goCueTrigger_times.npy',
None, None, None, None, None)
"""tuple of str: The filenames of each extracted dataset, or None if array should not be saved."""
var_names = ('stimCenter_times', 'feedbackType', 'rewardVolume', 'stimOff_times', 'contrastLeft',
'contrastRight', 'feedback_times', 'stimOn_times', 'stimOnTrigger_times', 'intervals',
'goCue_times', 'goCueTrigger_times', 'itiIn_times', 'stimOffTrigger_times',
'stimCenterTrigger_times', 'position', 'phase')
"""tuple of str: A list of names for the extracted variables. These become the returned output keys."""
bpod_rsync_fields = ('intervals', 'stimOn_times', 'feedback_times', 'stimCenterTrigger_times',
'goCue_times', 'itiIn_times', 'stimOffTrigger_times', 'stimOff_times',
'stimCenter_times', 'stimOnTrigger_times', 'goCueTrigger_times')
"""tuple of str: Fields from Bpod extractor that we want to re-sync to FPGA."""
bpod_fields = ('feedbackType', 'rewardVolume', 'contrastLeft', 'contrastRight', 'position', 'phase')
"""tuple of str: Fields from Bpod extractor that we want to save."""
sync_field = 'feedback_times' # valve open events
"""str: The trial event to synchronize (must be present in extracted trials)."""
def _extract(self, sync=None, chmap=None, sync_collection='raw_ephys_data',
task_collection='raw_behavior_data', **kwargs) -> dict:
"""
Extract habituationChoiceWorld trial events from an NI DAQ.
It is essential that the `var_names`, `bpod_rsync_fields`, `bpod_fields`, and `sync_field`
attributes are all correct for the bpod protocol used.
Unlike FpgaTrials, this class assumes different Bpod TTL events and syncs the Bpod clock
using the valve open times, instead of the trial start times.
Parameters
----------
sync : dict
A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses
and the corresponding channel numbers. If None, the sync is loaded using the
`load_sync` method.
dict
A map of channel names and their corresponding indices. If None, the channel map is
loaded using the `load_sync` method.
sync_collection : str
The session subdirectory where the sync data are located. This is only used if the
sync or channel maps are not provided.
task_collection : str
The session subdirectory where the raw Bpod data are located. This is used for loading
the task settings and extracting the bpod trials, if not already done.
protocol_number : int
The protocol number if multiple protocols were run during the session. If provided, a
spacer signal must be present in order to determine the correct period.
kwargs
Optional arguments for class methods, e.g. 'display', 'bpod_event_ttls'.
Returns
-------
dict
A dictionary of numpy arrays with `FpgaTrialsHabituation.var_names` as keys.
"""
# Version check: the ITI in TTL was added in a later version
if not self.settings:
self.settings = raw.load_settings(session_path=self.session_path, task_collection=task_collection)
iblrig_version = version.parse(self.settings.get('IBL_VERSION', '0.0.0'))
if version.parse('8.9.3') <= iblrig_version < version.parse('8.12.6'):
"""A second 1s TTL was added in this version during the 'iti' state, however this is
unrelated to the trial ITI and is unfortunately the same length as the trial start TTL."""
raise NotImplementedError('Ambiguous TTLs in 8.9.3 >= version < 8.12.6')
trials = super()._extract(sync=sync, chmap=chmap, sync_collection=sync_collection,
task_collection=task_collection, **kwargs)
return trials
[docs]
def get_bpod_event_times(self, sync, chmap, bpod_event_ttls=None, display=False, **kwargs):
"""
Extract Bpod times from sync.
Currently (at least v8.12 and below) there is no trial start or end TTL, only an ITI pulse.
Also the first trial pulse is incorrectly assigned due to its abnormal length.
Parameters
----------
sync : dict
A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses
and the corresponding channel numbers. Must contain a 'bpod' key.
chmap : dict
A map of channel names and their corresponding indices.
bpod_event_ttls : dict of tuple
A map of event names to (min, max) TTL length.
Returns
-------
dict
A dictionary with keys {'times', 'polarities'} containing Bpod TTL fronts.
dict
A dictionary of events (from `bpod_event_ttls`) and their intervals as an Nx2 array.
"""
bpod = get_sync_fronts(sync, chmap['bpod'])
if bpod.times.size == 0:
raise err.SyncBpodFpgaException('No Bpod event found in FPGA. No behaviour extraction. '
'Check channel maps.')
# Assign the Bpod BNC2 events based on TTL length. The defaults are below, however these
# lengths are defined by the state machine of the task protocol and therefore vary.
if bpod_event_ttls is None:
# Currently (at least v8.12 and below) there is no trial start or end TTL, only an ITI pulse
bpod_event_ttls = {'trial_iti': (1, 1.1), 'valve_open': (0, 0.4)}
bpod_event_intervals = self._assign_events(
bpod['times'], bpod['polarities'], bpod_event_ttls, display=display)
# The first trial pulse is shorter and assigned to valve_open. Here we remove the first
# valve event, prepend a 0 to the trial_start events, and drop the last trial if it was
# incomplete in Bpod.
bpod_event_intervals['trial_iti'] = np.r_[bpod_event_intervals['valve_open'][0:1, :],
bpod_event_intervals['trial_iti']]
bpod_event_intervals['valve_open'] = bpod_event_intervals['valve_open'][1:, :]
return bpod, bpod_event_intervals
[docs]
def build_trials(self, sync, chmap, display=False, **kwargs):
"""
Extract task related event times from the sync.
This is called by the superclass `_extract` method. The key difference here is that the
`trial_start` LOW->HIGH is the trial end, and HIGH->LOW is trial start.
Parameters
----------
sync : dict
'polarities' of fronts detected on sync trace for all 16 chans and their 'times'
chmap : dict
Map of channel names and their corresponding index. Default to constant.
display : bool, matplotlib.pyplot.Axes
Show the full session sync pulses display.
Returns
-------
dict
A map of trial event timestamps.
"""
# Get the events from the sync.
# Store the cleaned frame2ttl, audio, and bpod pulses as this will be used for QC
self.frame2ttl = self.get_stimulus_update_times(sync, chmap, **kwargs)
self.audio, audio_event_intervals = self.get_audio_event_times(sync, chmap, **kwargs)
self.bpod, bpod_event_intervals = self.get_bpod_event_times(sync, chmap, **kwargs)
if not set(bpod_event_intervals.keys()) >= {'valve_open', 'trial_iti'}:
raise ValueError(
'Expected at least "trial_iti" and "valve_open" Bpod events. `bpod_event_ttls` kwarg may be incorrect.')
fpga_events = alfio.AlfBunch({
'feedback_times': bpod_event_intervals['valve_open'][:, 0],
'valveClose_times': bpod_event_intervals['valve_open'][:, 1],
'intervals_0': bpod_event_intervals['trial_iti'][:, 1],
'intervals_1': bpod_event_intervals['trial_iti'][:, 0],
'goCue_times': audio_event_intervals['ready_tone'][:, 0]
})
# Sync the Bpod clock to the DAQ.
self.bpod2fpga, drift_ppm, ibpod, ifpga = self.sync_bpod_clock(self.bpod_trials, fpga_events, self.sync_field)
out = alfio.AlfBunch()
# Add the Bpod trial events, converting the timestamp fields to FPGA time.
# NB: The trial intervals are by default a Bpod rsync field.
out.update({k: self.bpod_trials[k][ibpod] for k in self.bpod_fields})
out.update({k: self.bpod2fpga(self.bpod_trials[k][ibpod]) for k in self.bpod_rsync_fields})
# Assigning each event to a trial ensures exactly one event per trial (missing events are NaN)
assign_to_trial = partial(_assign_events_to_trial, fpga_events['intervals_0'])
trials = alfio.AlfBunch({
'goCue_times': assign_to_trial(fpga_events['goCue_times'], take='first'),
'feedback_times': assign_to_trial(fpga_events['feedback_times']),
'stimCenter_times': assign_to_trial(self.frame2ttl['times'], take=-2),
'stimOn_times': assign_to_trial(self.frame2ttl['times'], take='first'),
'stimOff_times': assign_to_trial(self.frame2ttl['times']),
})
out.update({k: trials[k][ifpga] for k in trials.keys()})
# If stim on occurs before trial end, use stim on time. Likewise for trial end and stim off
to_correct = ~np.isnan(out['stimOn_times']) & (out['stimOn_times'] < out['intervals'][:, 0])
if np.any(to_correct):
_logger.warning('%i/%i stim on events occurring outside trial intervals', sum(to_correct), len(to_correct))
out['intervals'][to_correct, 0] = out['stimOn_times'][to_correct]
to_correct = ~np.isnan(out['stimOff_times']) & (out['stimOff_times'] > out['intervals'][:, 1])
if np.any(to_correct):
_logger.debug(
'%i/%i stim off events occurring outside trial intervals; using stim off times as trial end',
sum(to_correct), len(to_correct))
out['intervals'][to_correct, 1] = out['stimOff_times'][to_correct]
if display: # pragma: no cover
width = 0.5
ymax = 5
if isinstance(display, bool):
plt.figure('Bpod FPGA Sync')
ax = plt.gca()
else:
ax = display
plots.squares(self.bpod['times'], self.bpod['polarities'] * 0.4 + 1, ax=ax, color='k')
plots.squares(self.frame2ttl['times'], self.frame2ttl['polarities'] * 0.4 + 2, ax=ax, color='k')
plots.squares(self.audio['times'], self.audio['polarities'] * 0.4 + 3, ax=ax, color='k')
color_map = TABLEAU_COLORS.keys()
for (event_name, event_times), c in zip(trials.to_df().items(), cycle(color_map)):
plots.vertical_lines(event_times, ymin=0, ymax=ymax, ax=ax, color=c, label=event_name, linewidth=width)
ax.legend()
ax.set_yticks([0, 1, 2, 3])
ax.set_yticklabels(['', 'bpod', 'f2ttl', 'audio'])
ax.set_ylim([0, 4])
return out
[docs]
def get_sync_and_chn_map(session_path, sync_collection):
"""
Return sync and channel map for session based on collection where main sync is stored.
Parameters
----------
session_path : str, pathlib.Path
The absolute session path, i.e. '/path/to/subject/yyyy-mm-dd/nnn'.
sync_collection : str
The session subdirectory where the sync data are located.
Returns
-------
one.alf.io.AlfBunch
A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses and
the corresponding channel numbers.
dict
A map of channel names and their corresponding indices.
"""
if sync_collection == 'raw_ephys_data':
# Check to see if we have nidq files, if we do just go with this otherwise go into other function that deals with
# 3A probes
nidq_meta = next(session_path.joinpath(sync_collection).glob('*nidq.meta'), None)
if not nidq_meta:
sync, chmap = get_main_probe_sync(session_path)
else:
sync = load_sync(session_path, sync_collection)
ef = Bunch()
ef['path'] = session_path.joinpath(sync_collection)
ef['nidq'] = nidq_meta
chmap = get_ibl_sync_map(ef, '3B')
else:
sync = load_sync(session_path, sync_collection)
chmap = load_channel_map(session_path, sync_collection)
return sync, chmap
[docs]
def load_channel_map(session_path, sync_collection):
"""
Load syncing channel map for session path and collection
Parameters
----------
session_path : str, pathlib.Path
The absolute session path, i.e. '/path/to/subject/yyyy-mm-dd/nnn'.
sync_collection : str
The session subdirectory where the sync data are located.
Returns
-------
dict
A map of channel names and their corresponding indices.
"""
device = sync_collection.split('_')[1]
default_chmap = DEFAULT_MAPS[device]['nidq']
# Try to load channel map from file
chmap = spikeglx.get_sync_map(session_path.joinpath(sync_collection))
# If chmap provided but not with all keys, fill up with default values
if not chmap:
return default_chmap
else:
if data_for_keys(default_chmap.keys(), chmap):
return chmap
else:
_logger.warning('Keys missing from provided channel map, '
'setting missing keys from default channel map')
return {**default_chmap, **chmap}
[docs]
def load_sync(session_path, sync_collection):
"""
Load sync files from session path and collection.
Parameters
----------
session_path : str, pathlib.Path
The absolute session path, i.e. '/path/to/subject/yyyy-mm-dd/nnn'.
sync_collection : str
The session subdirectory where the sync data are located.
Returns
-------
one.alf.io.AlfBunch
A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses and
the corresponding channel numbers.
"""
sync = alfio.load_object(session_path.joinpath(sync_collection), 'sync', namespace='spikeglx', short_keys=True)
return sync