"""Mesoscope (timeline) data extraction."""
import logging
import numpy as np
from scipy.signal import find_peaks
import one.alf.io as alfio
from one.alf.path import session_path_parts
from iblutil.util import ensure_list
import matplotlib.pyplot as plt
from packaging import version
from ibllib.plots.misc import squares, vertical_lines
from ibllib.io.raw_daq_loaders import (extract_sync_timeline, timeline_get_channel,
correct_counter_discontinuities, load_timeline_sync_and_chmap)
import ibllib.io.extractors.base as extractors_base
from ibllib.io.extractors.ephys_fpga import FpgaTrials, WHEEL_TICKS, WHEEL_RADIUS_CM, _assign_events_to_trial
from ibllib.io.extractors.training_wheel import extract_wheel_moves
from ibllib.io.extractors.camera import attribute_times
from brainbox.behavior.wheel import velocity_filtered
_logger = logging.getLogger(__name__)
[docs]
def plot_timeline(timeline, channels=None, raw=True):
"""
Plot the timeline data.
Parameters
----------
timeline : one.alf.io.AlfBunch
The timeline data object.
channels : list of str
An iterable of channel names to plot.
raw : bool
If true, plot the raw DAQ samples; if false, apply TTL thresholds and plot changes.
Returns
-------
matplotlib.pyplot.Figure
The figure containing timeline subplots.
list of matplotlib.pyplot.Axes
The axes for each timeline channel plotted.
"""
meta = {x.copy().pop('name'): x for x in timeline['meta']['inputs']}
channels = channels or meta.keys()
fig, axes = plt.subplots(len(channels), 1, sharex=True)
axes = ensure_list(axes)
if not raw:
chmap = {ch: meta[ch]['arrayColumn'] for ch in channels}
sync = extract_sync_timeline(timeline, chmap=chmap)
for i, (ax, ch) in enumerate(zip(axes, channels)):
if raw:
# axesScale controls vertical scaling of each trace (multiplicative)
values = timeline['raw'][:, meta[ch]['arrayColumn'] - 1] * meta[ch]['axesScale']
ax.plot(timeline['timestamps'], values)
elif np.any(idx := sync['channels'] == chmap[ch]):
squares(sync['times'][idx], sync['polarities'][idx], ax=ax)
ax.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
ax.spines['bottom'].set_visible(False), ax.spines['left'].set_visible(True)
ax.set_ylabel(ch, rotation=45, fontsize=8)
# Add back x-axis ticks to the last plot
axes[-1].tick_params(axis='x', which='both', bottom=True, labelbottom=True)
axes[-1].spines['bottom'].set_visible(True)
plt.get_current_fig_manager().window.showMaximized() # full screen
fig.tight_layout(h_pad=0)
return fig, axes
[docs]
class TimelineTrials(FpgaTrials):
"""Similar extraction to the FPGA, however counter and position channels are treated differently."""
timeline = None
"""one.alf.io.AlfBunch: The timeline data object."""
sync_field = 'itiIn_times'
"""str: The trial event to synchronize (must be present in extracted trials)."""
def __init__(self, *args, sync_collection='raw_sync_data', **kwargs):
"""An extractor for all ephys trial data, in Timeline time"""
super().__init__(*args, **kwargs)
self.timeline = alfio.load_object(self.session_path / sync_collection, 'DAQdata', namespace='timeline')
[docs]
def load_sync(self, sync_collection='raw_sync_data', chmap=None, **_):
"""Load the DAQ sync and channel map data.
Parameters
----------
sync_collection : str
The session subdirectory where the sync data are located.
chmap : dict
A map of channel names and their corresponding indices. If None, the channel map is
loaded using the :func:`ibllib.io.raw_daq_loaders.timeline_meta2chmap` method.
Returns
-------
one.alf.io.AlfBunch
A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses
and the corresponding channel numbers.
dict
A map of channel names and their corresponding indices.
"""
if not self.timeline:
self.timeline = alfio.load_object(self.session_path / sync_collection, 'DAQdata', namespace='timeline')
sync, chmap = load_timeline_sync_and_chmap(
self.session_path / sync_collection, timeline=self.timeline, chmap=chmap)
return sync, chmap
def _extract(self, sync=None, chmap=None, sync_collection='raw_sync_data', **kwargs) -> dict:
trials = super()._extract(sync, chmap, sync_collection='raw_sync_data', **kwargs)
if kwargs.get('display', False):
plot_timeline(self.timeline, channels=chmap.keys(), raw=True)
return trials
[docs]
def get_bpod_event_times(self, sync, chmap, bpod_event_ttls=None, display=False, **kwargs):
"""
Extract Bpod times from sync.
Unlike the superclass method. This one doesn't reassign the first trial pulse.
Parameters
----------
sync : dict
A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses
and the corresponding channel numbers. Must contain a 'bpod' key.
chmap : dict
A map of channel names and their corresponding indices.
bpod_event_ttls : dict of tuple
A map of event names to (min, max) TTL length.
Returns
-------
dict
A dictionary with keys {'times', 'polarities'} containing Bpod TTL fronts.
dict
A dictionary of events (from `bpod_event_ttls`) and their intervals as an Nx2 array.
"""
# Assign the Bpod BNC2 events based on TTL length. The defaults are below, however these
# lengths are defined by the state machine of the task protocol and therefore vary.
if bpod_event_ttls is None:
# The trial start TTLs are often too short for the low sampling rate of the DAQ and are
# therefore not used in extraction
bpod_event_ttls = {'valve_open': (2.33e-4, 0.4), 'trial_end': (0.4, np.inf)}
bpod, bpod_event_intervals = super().get_bpod_event_times(
sync=sync, chmap=chmap, bpod_event_ttls=bpod_event_ttls, display=display, **kwargs)
# TODO Here we can make use of the 'bpod_rising_edge' channel, if available
return bpod, bpod_event_intervals
[docs]
def build_trials(self, sync=None, chmap=None, **kwargs):
"""
Extract task related event times from the sync.
The two major differences are that the sampling rate is lower for imaging so the short Bpod
trial start TTLs are often absent. For this reason, the sync happens using the ITI_in TTL.
Second, the valve used at the mesoscope has a way to record the raw voltage across the
solenoid, giving a more accurate readout of the valve's activity. If the reward_valve
channel is present on the DAQ, this is used to extract the valve open times.
Parameters
----------
sync : dict
'polarities' of fronts detected on sync trace for all 16 chans and their 'times'
chmap : dict
Map of channel names and their corresponding index. Default to constant.
Returns
-------
dict
A map of trial event timestamps.
"""
# Get the events from the sync.
# Store the cleaned frame2ttl, audio, and bpod pulses as this will be used for QC
self.frame2ttl = self.get_stimulus_update_times(sync, chmap, **kwargs)
self.audio, audio_event_intervals = self.get_audio_event_times(sync, chmap, **kwargs)
if not set(audio_event_intervals.keys()) >= {'ready_tone', 'error_tone'}:
raise ValueError(
'Expected at least "ready_tone" and "error_tone" audio events.'
'`audio_event_ttls` kwarg may be incorrect.')
self.bpod, bpod_event_intervals = self.get_bpod_event_times(sync, chmap, **kwargs)
if not set(bpod_event_intervals.keys()) >= {'valve_open', 'trial_end'}:
raise ValueError(
'Expected at least "trial_end" and "valve_open" audio events. '
'`bpod_event_ttls` kwarg may be incorrect.')
t_iti_in, t_trial_end = bpod_event_intervals['trial_end'].T
fpga_events = alfio.AlfBunch({
'itiIn_times': t_iti_in,
'intervals_1': t_trial_end,
'goCue_times': audio_event_intervals['ready_tone'][:, 0],
'errorTone_times': audio_event_intervals['error_tone'][:, 0]
})
# Sync the Bpod clock to the DAQ
self.bpod2fpga, drift_ppm, ibpod, ifpga = self.sync_bpod_clock(self.bpod_trials, fpga_events, self.sync_field)
out = dict()
out.update({k: self.bpod_trials[k][ibpod] for k in self.bpod_fields})
out.update({k: self.bpod2fpga(self.bpod_trials[k][ibpod]) for k in self.bpod_rsync_fields})
start_times = out['intervals'][:, 0]
last_trial_end = out['intervals'][-1, 1]
def assign_to_trial(events, take='last', starts=start_times, **kwargs):
"""Assign DAQ events to trials.
Because we may not have trial start TTLs on the DAQ (because of the low sampling rate),
there may be an extra last trial that's not in the Bpod intervals as the extractor
ignores the last trial. This function trims the input array before assigning so that
the last trial's events are correctly assigned.
"""
return _assign_events_to_trial(starts, events[events <= last_trial_end], take, **kwargs)
out['itiIn_times'] = assign_to_trial(fpga_events['itiIn_times'][ifpga])
# Extract valve open times from the DAQ
valve_driver_ttls = bpod_event_intervals['valve_open']
correct = self.bpod_trials['feedbackType'] == 1
# If there is a reward_valve channel, the valve has
if any(ch['name'] == 'reward_valve' for ch in self.timeline['meta']['inputs']):
# TODO Let's look at the expected open length based on calibration and reward volume
# import scipy.interpolate
# # FIXME support v7 settings?
# fcn_vol2time = scipy.interpolate.pchip(
# self.bpod_extractor.settings['device_valve']['WATER_CALIBRATION_WEIGHT_PERDROP'],
# self.bpod_extractor.settings['device_valve']['WATER_CALIBRATION_OPEN_TIMES']
# )
# reward_time = fcn_vol2time(self.bpod_extractor.settings.get('REWARD_AMOUNT_UL')) / 1e3
# Use the driver TTLs to find the valve open times that correspond to the valve opening
valve_intervals, valve_open_times = self.get_valve_open_times(driver_ttls=valve_driver_ttls)
if valve_open_times.size != np.sum(correct):
_logger.warning(
'Number of valve open times does not equal number of correct trials (%i != %i)',
valve_open_times.size, np.sum(correct))
out['valveOpen_times'] = assign_to_trial(valve_open_times)
else:
# Use the valve controller TTLs recorded on the Bpod channel as the reward time
out['valveOpen_times'] = assign_to_trial(valve_driver_ttls[:, 0])
# Stimulus times extracted based on trigger times
# When assigning events all start times must not be NaN so here we substitute freeze
# trigger times on nogo trials for stim on trigger times, then replace with NaN again
go_trials = np.where(out['choice'] != 0)[0]
lims = np.copy(out['stimOnTrigger_times'])
lims[go_trials] = out['stimFreezeTrigger_times'][go_trials]
out['stimFreeze_times'] = assign_to_trial(
self.frame2ttl['times'], 'last',
starts=lims, t_trial_end=out['stimOffTrigger_times'])
out['stimFreeze_times'][out['choice'] == 0] = np.nan
# Here we do the same but use stim off trigger times
lims = np.copy(out['stimOffTrigger_times'])
lims[go_trials] = out['stimFreezeTrigger_times'][go_trials]
out['stimOn_times'] = assign_to_trial(
self.frame2ttl['times'], 'first',
starts=out['stimOnTrigger_times'], t_trial_end=lims)
out['stimOff_times'] = assign_to_trial(
self.frame2ttl['times'], 'first',
starts=out['stimOffTrigger_times'], t_trial_end=out['intervals'][:, 1]
)
# Audio times
error_cue = fpga_events['errorTone_times']
if error_cue.size != np.sum(~correct):
_logger.warning(
'N detected error tones does not match number of incorrect trials (%i != %i)',
error_cue.size, np.sum(~correct))
go_cue = fpga_events['goCue_times']
out['goCue_times'] = assign_to_trial(go_cue, take='first')
out['errorCue_times'] = assign_to_trial(error_cue)
if go_cue.size > start_times.size:
_logger.warning(
'More go cue tones detected than trials! (%i vs %i)', go_cue.size, start_times.size)
elif go_cue.size < start_times.size:
"""
If the error cues are all assigned and some go cues are missed it may be that some
responses were so fast that the go cue and error tone merged, or the go cue TTL was too
long.
"""
_logger.warning('%i go cue tones missed', start_times.size - go_cue.size)
err_trig = self.bpod2fpga(self.bpod_trials['errorCueTrigger_times'])
go_trig = self.bpod2fpga(self.bpod_trials['goCueTrigger_times'])
assert not np.any(np.isnan(go_trig))
assert err_trig.size == go_trig.size # should be length of n trials with NaNs
# Find which trials are missing a go cue
_go_cue = assign_to_trial(go_cue, take='first')
error_cue = assign_to_trial(error_cue)
missing = np.isnan(_go_cue)
# Get all the DAQ timestamps where audio channel was HIGH
raw = timeline_get_channel(self.timeline, 'audio')
raw = (raw - raw.min()) / (raw.max() - raw.min()) # min-max normalize
ups = self.timeline.timestamps[raw > .5] # timestamps where input HIGH
# Get the timestamps of the first HIGH after the trigger times (allow up to 200ms after).
# Indices of ups directly following a go trigger, or -1 if none found (or trigger NaN)
idx = attribute_times(ups, go_trig, tol=0.2, take='after')
# Trial indices that didn't have detected goCue and now has been assigned an `ups` index
assigned = np.where(idx != -1 & missing)[0] # ignore unassigned
_go_cue[assigned] = ups[idx[assigned]]
# Remove mis-assigned error tone times (i.e. those that have now been assigned to goCue)
error_cue_without_trig, = np.where(~np.isnan(error_cue) & np.isnan(err_trig))
i_to_remove = np.intersect1d(assigned, error_cue_without_trig, assume_unique=True)
error_cue[i_to_remove] = np.nan
# For those trials where go cue was merged with the error cue and therefore mis-assigned,
# we must re-assign the error cue times as the first HIGH after the error trigger.
idx = attribute_times(ups, err_trig, tol=0.2, take='after')
assigned = np.where(idx != -1 & missing)[0] # ignore unassigned
error_cue[assigned] = ups[idx[assigned]]
out['goCue_times'] = _go_cue
out['errorCue_times'] = error_cue
# Because we're not
assert np.intersect1d(out['goCue_times'], out['errorCue_times']).size == 0, \
'audio tones not assigned correctly; tones likely missed'
# Feedback times
out['feedback_times'] = np.copy(out['valveOpen_times'])
ind_err = np.isnan(out['valveOpen_times'])
out['feedback_times'][ind_err] = out['errorCue_times'][ind_err]
return out
[docs]
def get_wheel_positions(self, ticks=WHEEL_TICKS, radius=WHEEL_RADIUS_CM, coding='x4',
tmin=None, tmax=None, display=False, **kwargs):
"""
Gets the wheel position and detected movements from Timeline counter channel.
Called by the super class extractor (FPGATrials._extract).
Parameters
----------
ticks : int
Number of ticks corresponding to a full revolution (1024 for IBL rotary encoder).
radius : float
Radius of the wheel. Defaults to 1 for an output in radians.
coding : str {'x1', 'x2', 'x4'}
Rotary encoder encoding (IBL default is x4).
tmin : float
The minimum time from which to extract the sync pulses.
tmax : float
The maximum time up to which we extract the sync pulses.
display : bool
If true, plot the wheel positions from bpod and the DAQ.
Returns
-------
dict
wheel object with keys ('timestamps', 'position').
dict
wheelMoves object with keys ('intervals' 'peakAmplitude').
"""
wheel = self.extract_wheel_sync(ticks=ticks, radius=radius, coding=coding, tmin=tmin, tmax=tmax)
wheel = dict(zip(('timestamps', 'position'), wheel))
moves = extract_wheel_moves(wheel['timestamps'], wheel['position'])
if display:
assert self.bpod_trials, 'no bpod trials to compare'
fig, (ax0, ax1) = plt.subplots(nrows=2, sharex=True)
bpod_ts = self.bpod_trials['wheel_timestamps']
bpod_pos = self.bpod_trials['wheel_position']
ax0.plot(self.bpod2fpga(bpod_ts), bpod_pos)
ax0.set_ylabel('Bpod wheel position / rad')
ax1.plot(wheel['timestamps'], wheel['position'])
ax1.set_ylabel('DAQ wheel position / rad'), ax1.set_xlabel('Time / s')
return wheel, moves
[docs]
def get_valve_open_times(self, display=False, threshold=100, driver_ttls=None):
"""
Get the valve open times from the raw timeline voltage trace.
Parameters
----------
display : bool
Plot detected times on the raw voltage trace.
threshold : float
The threshold of voltage change to apply. The default was set by eye; units should be
Volts per sample but doesn't appear to be.
driver_ttls : numpy.array
An optional array of driver TTLs to use for assigning with the valve times.
Returns
-------
numpy.array
The detected valve open intervals.
numpy.array
If driver_ttls is not None, returns an array of open times that occurred directly after
the driver TTLs.
"""
WARN_THRESH = 10e-3 # open time threshold below which to log warning
tl = self.timeline
info = next(x for x in tl['meta']['inputs'] if x['name'] == 'reward_valve')
values = tl['raw'][:, info['arrayColumn'] - 1] # Timeline indices start from 1
# The voltage changes over ~1ms and can therefore occur over two DAQ samples at 2kHz
# making simple thresholding an issue. For this reason we convolve the signal with a
# window and detect the peaks and troughs.
if (Fs := tl['meta']['daqSampleRate']) != 2000: # e.g. 2kHz
_logger.warning('Reward valve detection not tested with a DAQ sample rate of %i', Fs)
dt = 1e-3 # change in voltage takes ~1ms when changing valve open state
N = dt / (1 / Fs) # this means voltage change occurs over N samples
vel, _ = velocity_filtered(values, int(Fs / N)) # filtered voltage change over time
ups, _ = find_peaks(vel, height=threshold) # valve closes (-5V -> 0V)
downs, _ = find_peaks(-1 * vel, height=threshold) # valve opens (0V -> -5V)
# Convert these times into intervals
ixs = np.argsort(np.r_[downs, ups]) # sort indices
times = tl['timestamps'][np.r_[downs, ups]][ixs] # ordered valve event times
polarities = np.r_[np.zeros_like(downs) - 1, np.ones_like(ups)][ixs] # polarity sorted
missing = np.where(np.diff(polarities) == 0)[0] # if some changes were missed insert NaN
times = np.insert(times, missing + int(polarities[0] == -1), np.nan)
if polarities[-1] == -1: # ensure ends with a valve close
times = np.r_[times, np.nan]
if polarities[0] == 1: # ensure starts with a valve open
# It seems it can start out at -5V (open), then when the reward happens it closes and
# immediately opens. In this case we insert discard the first open time.
times = np.r_[np.nan, times]
intervals = times.reshape(-1, 2)
# Log warning of improbably short intervals
short = np.sum(np.diff(intervals) < WARN_THRESH)
if short > 0:
_logger.warning('%i valve open intervals shorter than %i ms', short, WARN_THRESH)
# The closing of the valve is noisy. Keep only the falls that occur immediately after a Bpod TTL
if driver_ttls is not None:
# Returns an array of open_times indices, one for each driver TTL
ind = attribute_times(intervals[:, 0], driver_ttls[:, 0], tol=.1, take='after')
open_times = intervals[ind[ind >= 0], 0]
# TODO Log any > 40ms? Difficult to report missing valve times because of calibration
if display:
fig, (ax0, ax1) = plt.subplots(nrows=2, sharex=True)
ax0.plot(tl['timestamps'], timeline_get_channel(tl, 'bpod'), color='grey', linestyle='-')
if driver_ttls is not None:
x = np.empty_like(driver_ttls.flatten())
x[0::2] = driver_ttls[:, 0]
x[1::2] = driver_ttls[:, 1]
y = np.ones_like(x)
y[1::2] -= 2
squares(x, y, ax=ax0, yrange=[0, 5])
# vertical_lines(driver_ttls, ymax=5, ax=ax0, linestyle='--', color='b')
ax0.plot(open_times, np.ones_like(open_times) * 4.5, 'g*')
ax1.plot(tl['timestamps'], values, 'k-o')
ax1.set_ylabel('Voltage / V'), ax1.set_xlabel('Time / s')
ax2 = ax1.twinx()
ax2.set_ylabel('dV', color='grey')
ax2.plot(tl['timestamps'], vel, linestyle='-', color='grey')
ax2.plot(intervals[:, 1], np.ones(len(intervals)) * threshold, 'r*', label='close')
ax2.plot(intervals[:, 0], np.ones(len(intervals)) * threshold, 'g*', label='open')
return intervals if driver_ttls is None else (intervals, open_times)
def _assign_events_audio(self, audio_times, audio_polarities, display=False):
"""
This is identical to ephys_fpga._assign_events_audio, except for the ready tone threshold.
Parameters
----------
audio_times : numpy.array
An array of audio TTL front times.
audio_polarities : numpy.array
An array of audio TTL front polarities (1 for rises, -1 for falls).
display : bool
If true, display audio pulses and the assigned onsets.
Returns
-------
numpy.array
The times of the go cue onsets.
numpy.array
The times of the error tone onsets.
"""
# make sure that there are no 2 consecutive fall or consecutive rise events
assert np.all(np.abs(np.diff(audio_polarities)) == 2)
# take only even time differences: i.e. from rising to falling fronts
dt = np.diff(audio_times)
onsets = audio_polarities[:-1] == 1
# error tones are events lasting from 400ms to 1200ms
i_error_tone_in = np.where(np.logical_and(0.4 < dt, dt < 1.2) & onsets)[0]
t_error_tone_in = audio_times[i_error_tone_in]
# detect ready tone by length below 300 ms
i_ready_tone_in = np.where(np.logical_and(dt <= 0.3, onsets))[0]
t_ready_tone_in = audio_times[i_ready_tone_in]
if display: # pragma: no cover
fig, ax = plt.subplots(nrows=2, sharex=True)
ax[0].plot(self.timeline.timestamps, timeline_get_channel(self.timeline, 'audio'), 'k-o')
ax[0].set_ylabel('Voltage / V')
squares(audio_times, audio_polarities, yrange=[-1, 1], ax=ax[1])
vertical_lines(t_ready_tone_in, ymin=-.8, ymax=.8, ax=ax[1], label='go cue')
vertical_lines(t_error_tone_in, ymin=-.8, ymax=.8, ax=ax[1], label='error tone')
ax[1].set_xlabel('Time / s')
ax[1].legend()
return t_ready_tone_in, t_error_tone_in
[docs]
class MesoscopeSyncTimeline(extractors_base.BaseExtractor):
"""Extraction of mesoscope imaging times."""
var_names = ('mpci_times', 'mpciStack_timeshift')
save_names = ('mpci.times.npy', 'mpciStack.timeshift.npy')
"""one.alf.io.AlfBunch: The raw imaging meta data and frame times"""
rawImagingData = None
def __init__(self, session_path, n_FOVs):
"""
Extract the mesoscope frame times from DAQ data acquired through Timeline.
Parameters
----------
session_path : str, pathlib.Path
The session path to extract times from.
n_FOVs : int
The number of fields of view acquired.
"""
super().__init__(session_path)
self.n_FOVs = n_FOVs
fov = list(map(lambda n: f'FOV_{n:02}', range(self.n_FOVs)))
self.var_names = [f'{x}_{y.lower()}' for x in self.var_names for y in fov]
self.save_names = [f'{y}/{x}' for x in self.save_names for y in fov]
def _extract(self, sync=None, chmap=None, device_collection='raw_imaging_data', events=None):
"""
Extract the frame timestamps for each individual field of view (FOV) and the time offsets
for each line scan.
The detected frame times from the 'neural_frames' channel of the DAQ are split into bouts
corresponding to the number of raw_imaging_data folders. These timestamps should match the
number of frame timestamps extracted from the image file headers (found in the
rawImagingData.times file). The field of view (FOV) shifts are then applied to these
timestamps for each field of view and provided together with the line shifts.
Parameters
----------
sync : one.alf.io.AlfBunch
A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses
and the corresponding channel numbers.
chmap : dict
A map of channel names and their corresponding indices. Only the 'neural_frames'
channel is required.
device_collection : str, iterable of str
The location of the raw imaging data.
events : pandas.DataFrame
A table of software events, with columns {'time_timeline' 'name_timeline',
'event_timeline'}.
Returns
-------
list of numpy.array
A list of timestamps for each FOV and the time offsets for each line scan.
"""
frame_times = sync['times'][sync['channels'] == chmap['neural_frames']]
# imaging_start_time = datetime.datetime(*map(round, self.rawImagingData.meta['acquisitionStartTime']))
if isinstance(device_collection, str):
device_collection = [device_collection]
if events is not None:
events = events[events.name == 'mpepUDP']
edges = self.get_bout_edges(frame_times, device_collection, events)
fov_times = []
line_shifts = []
for (tmin, tmax), collection in zip(edges, sorted(device_collection)):
imaging_data = alfio.load_object(self.session_path / collection, 'rawImagingData')
imaging_data['meta'] = patch_imaging_meta(imaging_data['meta'])
# Calculate line shifts
_, fov_time_shifts, line_time_shifts = self.get_timeshifts(imaging_data['meta'])
assert len(fov_time_shifts) == self.n_FOVs, f'unexpected number of FOVs for {collection}'
ts = frame_times[np.logical_and(frame_times >= tmin, frame_times <= tmax)]
assert ts.size >= imaging_data['times_scanImage'].size, f'fewer DAQ timestamps for {collection} than expected'
if ts.size > imaging_data['times_scanImage'].size:
_logger.warning(
'More DAQ frame times detected for %s than were found in the raw image data.\n'
'N DAQ frame times:\t%i\nN raw image data times:\t%i.\n'
'This may occur if the bout detection fails (e.g. UDPs recorded late), '
'when image data is corrupt, or when frames are not written to file.',
collection, ts.size, imaging_data['times_scanImage'].size)
_logger.info('Dropping last %i frame times for %s', ts.size - imaging_data['times_scanImage'].size, collection)
ts = ts[:imaging_data['times_scanImage'].size]
fov_times.append([ts + offset for offset in fov_time_shifts])
if not line_shifts:
line_shifts = line_time_shifts
else: # The line shifts should be the same across all imaging bouts
[np.testing.assert_array_equal(x, y) for x, y in zip(line_time_shifts, line_shifts)]
# Concatenate imaging timestamps across all bouts for each field of view
fov_times = list(map(np.concatenate, zip(*fov_times)))
n_fov_times, = set(map(len, fov_times))
if n_fov_times != frame_times.size:
# This may happen if an experimenter deletes a raw_imaging_data folder
_logger.debug('FOV timestamps length does not match neural frame count; imaging bout(s) likely missing')
return fov_times + line_shifts
[docs]
def get_bout_edges(self, frame_times, collections=None, events=None, min_gap=1., display=False):
"""
Return an array of edge times for each imaging bout corresponding to a raw_imaging_data
collection.
Parameters
----------
frame_times : numpy.array
An array of all neural frame count times.
collections : iterable of str
A set of raw_imaging_data collections, used to extract selected imaging periods.
events : pandas.DataFrame
A table of UDP event times, corresponding to times when recordings start and end.
min_gap : float
If start or end events not present, split bouts by finding gaps larger than this value.
display : bool
If true, plot the detected bout edges and raw frame times.
Returns
-------
numpy.array
An array of imaging bout intervals.
"""
if events is None or events.empty:
# No UDP events to mark blocks so separate based on gaps in frame rate
idx = np.where(np.diff(frame_times) > min_gap)[0]
starts = np.r_[frame_times[0], frame_times[idx + 1]]
ends = np.r_[frame_times[idx], frame_times[-1]]
else:
# Split using Exp/BlockStart and Exp/BlockEnd times
_, subject, date, _ = session_path_parts(self.session_path)
pattern = rf'(Exp|Block)%s\s{subject}\s{date.replace("-", "")}\s\d+'
# Get start times
UDP_start = events[events['info'].str.match(pattern % 'Start')]
if len(UDP_start) > 1 and UDP_start.loc[0, 'info'].startswith('Exp'):
# Use ExpStart instead of first bout start
UDP_start = UDP_start.copy().drop(1)
# Use ExpStart/End instead of first/last BlockStart/End
starts = frame_times[[np.where(frame_times >= t)[0][0] for t in UDP_start.time]]
# Get end times
UDP_end = events[events['info'].str.match(pattern % 'End')]
if len(UDP_end) > 1 and UDP_end['info'].values[-1].startswith('Exp'):
# Use last BlockEnd instead of ExpEnd
UDP_end = UDP_end.copy().drop(UDP_end.index[-1])
if not UDP_end.empty:
ends = frame_times[[np.where(frame_times <= t)[0][-1] for t in UDP_end.time]]
else:
# Get index of last frame to occur within a second of the previous frame
consec = np.r_[np.diff(frame_times) > min_gap, True]
idx = [np.where(np.logical_and(frame_times > t, consec))[0][0] for t in starts]
ends = frame_times[idx]
# Remove any missing imaging bout collections
edges = np.c_[starts, ends]
if collections:
if edges.shape[0] > len(collections):
# Remove any bouts that correspond to a skipped collection
# e.g. if {raw_imaging_data_00, raw_imaging_data_02}, remove middle bout
include = sorted(int(c.rsplit('_', 1)[-1]) for c in collections)
edges = edges[include, :]
elif edges.shape[0] < len(collections):
raise ValueError('More raw imaging folders than detected bouts')
if display:
_, ax = plt.subplots(1)
ax.step(frame_times, np.arange(frame_times.size), label='frame times', color='k', )
vertical_lines(edges[:, 0], ax=ax, ymin=0, ymax=frame_times.size, label='bout start', color='b')
vertical_lines(edges[:, 1], ax=ax, ymin=0, ymax=frame_times.size, label='bout end', color='orange')
if edges.shape[0] != len(starts):
vertical_lines(np.setdiff1d(starts, edges[:, 0]), ax=ax, ymin=0, ymax=frame_times.size,
label='missing bout start', linestyle=':', color='b')
vertical_lines(np.setdiff1d(ends, edges[:, 1]), ax=ax, ymin=0, ymax=frame_times.size,
label='missing bout end', linestyle=':', color='orange')
ax.set_xlabel('Time / s'), ax.set_ylabel('Frame #'), ax.legend(loc='lower right')
return edges
[docs]
@staticmethod
def get_timeshifts(raw_imaging_meta):
"""
Calculate the time shifts for each field of view (FOV) and the relative offsets for each
scan line.
For a 2 scan field, 2 depth recording (so 4 FOVs):
Frame 1, lines 1-512 correspond to FOV_00
Frame 1, lines 551-1062 correspond to FOV_01
Frame 2, lines 1-512 correspond to FOV_02
Frame 2, lines 551-1062 correspond to FOV_03
Frame 3, lines 1-512 correspond to FOV_00
...
Parameters
----------
raw_imaging_meta : dict
Extracted ScanImage meta data (_ibl_rawImagingData.meta.json).
Returns
-------
list of numpy.array
A list of arrays, one per FOV, containing indices of each image scan line.
numpy.array
An array of FOV time offsets (one value per FOV) relative to each frame acquisition
time.
list of numpy.array
A list of arrays, one per FOV, containing the time offsets for each scan line, relative
to each FOV offset.
"""
FOVs = raw_imaging_meta['FOV']
# Double-check meta extracted properly
# assert meta.FOV.Zs is ascending but use slice_id field. This may not be necessary but is expected.
slice_ids = np.array([fov['slice_id'] for fov in FOVs])
assert np.all(np.diff([x['Zs'] for x in FOVs]) >= 0), 'FOV depths not in ascending order'
assert np.all(np.diff(slice_ids) >= 0), 'slice IDs not ordered'
# Number of scan lines per FOV, i.e. number of Y pixels / image height
n_lines = np.array([x['nXnYnZ'][1] for x in FOVs])
# We get indices from MATLAB extracted metadata so below two lines are no longer needed
# n_valid_lines = np.sum(n_lines) # Number of lines imaged excluding flybacks
# n_lines_per_gap = int((raw_meta['Height'] - n_valid_lines) / (len(FOVs) - 1)) # N lines during flyback
line_period = raw_imaging_meta['scanImageParams']['hRoiManager']['linePeriod']
frame_time_shifts = slice_ids / raw_imaging_meta['scanImageParams']['hRoiManager']['scanFrameRate']
# Line indices are now extracted by the MATLAB function mesoscopeMetadataExtraction.m
# They are indexed from 1 so we subtract 1 to convert to zero-indexed
line_indices = [np.array(fov['lineIdx']) - 1 for fov in FOVs] # Convert to zero-indexed from MATLAB 1-indexed
assert all(lns.size == n for lns, n in zip(line_indices, n_lines)), 'unexpected number of scan lines'
# The start indices of each FOV in the raw images
fov_start_idx = np.array([lns[0] for lns in line_indices])
roi_time_shifts = fov_start_idx * line_period # The time offset for each FOV
fov_time_shifts = roi_time_shifts + frame_time_shifts
line_time_shifts = [(lns - ln0) * line_period for lns, ln0 in zip(line_indices, fov_start_idx)]
return line_indices, fov_time_shifts, line_time_shifts