Source code for ibllib.io.extractors.mesoscope

"""Mesoscope (timeline) data extraction."""
import logging

import numpy as np
from scipy.signal import find_peaks
import one.alf.io as alfio
from one.alf.path import session_path_parts
from iblutil.util import ensure_list
import matplotlib.pyplot as plt
from packaging import version

from ibllib.plots.misc import squares, vertical_lines
from ibllib.io.raw_daq_loaders import (extract_sync_timeline, timeline_get_channel,
                                       correct_counter_discontinuities, load_timeline_sync_and_chmap)
import ibllib.io.extractors.base as extractors_base
from ibllib.io.extractors.ephys_fpga import FpgaTrials, WHEEL_TICKS, WHEEL_RADIUS_CM, _assign_events_to_trial
from ibllib.io.extractors.training_wheel import extract_wheel_moves
from ibllib.io.extractors.camera import attribute_times
from brainbox.behavior.wheel import velocity_filtered

_logger = logging.getLogger(__name__)


[docs] def patch_imaging_meta(meta: dict) -> dict: """ Patch imaging metadata for compatibility across versions. A copy of the dict is NOT returned. Parameters ---------- meta : dict A folder path that contains a rawImagingData.meta file. Returns ------- dict The loaded metadata file, updated to the most recent version. """ # 2023-05-17 (unversioned) adds nFrames, channelSaved keys, MM and Deg keys ver = version.parse(meta.get('version') or '0.0.0') if ver <= version.parse('0.0.0'): if 'channelSaved' not in meta: meta['channelSaved'] = next((x['channelIdx'] for x in meta['FOV'] if 'channelIdx' in x), []) fields = ('topLeft', 'topRight', 'bottomLeft', 'bottomRight') for fov in meta.get('FOV', []): for unit in ('Deg', 'MM'): if unit not in fov: # topLeftDeg, etc. -> Deg[topLeft] fov[unit] = {f: fov.pop(f + unit, None) for f in fields} elif ver == version.parse('0.1.0'): for fov in meta.get('FOV', []): if 'roiUuid' in fov: fov['roiUUID'] = fov.pop('roiUuid') # 2024-09-17 Modified the 2 unit vectors for the positive ML axis and the positive AP axis, # which then transform [X,Y] coordinates (in degrees) to [ML,AP] coordinates (in MM). if ver < version.Version('0.1.5') and 'imageOrientation' in meta: pos_ml, pos_ap = meta['imageOrientation']['positiveML'], meta['imageOrientation']['positiveAP'] center_ml, center_ap = meta['centerMM']['ML'], meta['centerMM']['AP'] res = meta['scanImageParams']['objectiveResolution'] # previously [[0, res/1000], [-res/1000, 0], [0, 0]] TF = np.linalg.pinv(np.c_[np.vstack([pos_ml, pos_ap, [0, 0]]), [1, 1, 1]]) @ \ (np.array([[res / 1000, 0], [0, res / 1000], [0, 0]]) + np.array([center_ml, center_ap])) TF = np.round(TF, 3) # handle floating-point error by rounding if not np.allclose(TF, meta['coordsTF']): meta['coordsTF'] = TF.tolist() centerDegXY = np.array([meta['centerDeg']['x'], meta['centerDeg']['y']]) for fov in meta.get('FOV', []): fov['MM'] = {k: (np.r_[np.array(v) - centerDegXY, 1] @ TF).tolist() for k, v in fov['Deg'].items()} assert 'nFrames' in meta, '"nFrames" key missing from meta data; rawImagingData.meta.json likely an old version' return meta
[docs] def plot_timeline(timeline, channels=None, raw=True): """ Plot the timeline data. Parameters ---------- timeline : one.alf.io.AlfBunch The timeline data object. channels : list of str An iterable of channel names to plot. raw : bool If true, plot the raw DAQ samples; if false, apply TTL thresholds and plot changes. Returns ------- matplotlib.pyplot.Figure The figure containing timeline subplots. list of matplotlib.pyplot.Axes The axes for each timeline channel plotted. """ meta = {x.copy().pop('name'): x for x in timeline['meta']['inputs']} channels = channels or meta.keys() fig, axes = plt.subplots(len(channels), 1, sharex=True) axes = ensure_list(axes) if not raw: chmap = {ch: meta[ch]['arrayColumn'] for ch in channels} sync = extract_sync_timeline(timeline, chmap=chmap) for i, (ax, ch) in enumerate(zip(axes, channels)): if raw: # axesScale controls vertical scaling of each trace (multiplicative) values = timeline['raw'][:, meta[ch]['arrayColumn'] - 1] * meta[ch]['axesScale'] ax.plot(timeline['timestamps'], values) elif np.any(idx := sync['channels'] == chmap[ch]): squares(sync['times'][idx], sync['polarities'][idx], ax=ax) ax.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False) ax.spines['bottom'].set_visible(False), ax.spines['left'].set_visible(True) ax.set_ylabel(ch, rotation=45, fontsize=8) # Add back x-axis ticks to the last plot axes[-1].tick_params(axis='x', which='both', bottom=True, labelbottom=True) axes[-1].spines['bottom'].set_visible(True) plt.get_current_fig_manager().window.showMaximized() # full screen fig.tight_layout(h_pad=0) return fig, axes
[docs] class TimelineTrials(FpgaTrials): """Similar extraction to the FPGA, however counter and position channels are treated differently.""" timeline = None """one.alf.io.AlfBunch: The timeline data object.""" sync_field = 'itiIn_times' """str: The trial event to synchronize (must be present in extracted trials).""" def __init__(self, *args, sync_collection='raw_sync_data', **kwargs): """An extractor for all ephys trial data, in Timeline time""" super().__init__(*args, **kwargs) self.timeline = alfio.load_object(self.session_path / sync_collection, 'DAQdata', namespace='timeline')
[docs] def load_sync(self, sync_collection='raw_sync_data', chmap=None, **_): """Load the DAQ sync and channel map data. Parameters ---------- sync_collection : str The session subdirectory where the sync data are located. chmap : dict A map of channel names and their corresponding indices. If None, the channel map is loaded using the :func:`ibllib.io.raw_daq_loaders.timeline_meta2chmap` method. Returns ------- one.alf.io.AlfBunch A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses and the corresponding channel numbers. dict A map of channel names and their corresponding indices. """ if not self.timeline: self.timeline = alfio.load_object(self.session_path / sync_collection, 'DAQdata', namespace='timeline') sync, chmap = load_timeline_sync_and_chmap( self.session_path / sync_collection, timeline=self.timeline, chmap=chmap) return sync, chmap
def _extract(self, sync=None, chmap=None, sync_collection='raw_sync_data', **kwargs) -> dict: trials = super()._extract(sync, chmap, sync_collection='raw_sync_data', **kwargs) if kwargs.get('display', False): plot_timeline(self.timeline, channels=chmap.keys(), raw=True) return trials
[docs] def get_bpod_event_times(self, sync, chmap, bpod_event_ttls=None, display=False, **kwargs): """ Extract Bpod times from sync. Unlike the superclass method. This one doesn't reassign the first trial pulse. Parameters ---------- sync : dict A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses and the corresponding channel numbers. Must contain a 'bpod' key. chmap : dict A map of channel names and their corresponding indices. bpod_event_ttls : dict of tuple A map of event names to (min, max) TTL length. Returns ------- dict A dictionary with keys {'times', 'polarities'} containing Bpod TTL fronts. dict A dictionary of events (from `bpod_event_ttls`) and their intervals as an Nx2 array. """ # Assign the Bpod BNC2 events based on TTL length. The defaults are below, however these # lengths are defined by the state machine of the task protocol and therefore vary. if bpod_event_ttls is None: # The trial start TTLs are often too short for the low sampling rate of the DAQ and are # therefore not used in extraction bpod_event_ttls = {'valve_open': (2.33e-4, 0.4), 'trial_end': (0.4, np.inf)} bpod, bpod_event_intervals = super().get_bpod_event_times( sync=sync, chmap=chmap, bpod_event_ttls=bpod_event_ttls, display=display, **kwargs) # TODO Here we can make use of the 'bpod_rising_edge' channel, if available return bpod, bpod_event_intervals
[docs] def build_trials(self, sync=None, chmap=None, **kwargs): """ Extract task related event times from the sync. The two major differences are that the sampling rate is lower for imaging so the short Bpod trial start TTLs are often absent. For this reason, the sync happens using the ITI_in TTL. Second, the valve used at the mesoscope has a way to record the raw voltage across the solenoid, giving a more accurate readout of the valve's activity. If the reward_valve channel is present on the DAQ, this is used to extract the valve open times. Parameters ---------- sync : dict 'polarities' of fronts detected on sync trace for all 16 chans and their 'times' chmap : dict Map of channel names and their corresponding index. Default to constant. Returns ------- dict A map of trial event timestamps. """ # Get the events from the sync. # Store the cleaned frame2ttl, audio, and bpod pulses as this will be used for QC self.frame2ttl = self.get_stimulus_update_times(sync, chmap, **kwargs) self.audio, audio_event_intervals = self.get_audio_event_times(sync, chmap, **kwargs) if not set(audio_event_intervals.keys()) >= {'ready_tone', 'error_tone'}: raise ValueError( 'Expected at least "ready_tone" and "error_tone" audio events.' '`audio_event_ttls` kwarg may be incorrect.') self.bpod, bpod_event_intervals = self.get_bpod_event_times(sync, chmap, **kwargs) if not set(bpod_event_intervals.keys()) >= {'valve_open', 'trial_end'}: raise ValueError( 'Expected at least "trial_end" and "valve_open" audio events. ' '`bpod_event_ttls` kwarg may be incorrect.') t_iti_in, t_trial_end = bpod_event_intervals['trial_end'].T fpga_events = alfio.AlfBunch({ 'itiIn_times': t_iti_in, 'intervals_1': t_trial_end, 'goCue_times': audio_event_intervals['ready_tone'][:, 0], 'errorTone_times': audio_event_intervals['error_tone'][:, 0] }) # Sync the Bpod clock to the DAQ self.bpod2fpga, drift_ppm, ibpod, ifpga = self.sync_bpod_clock(self.bpod_trials, fpga_events, self.sync_field) out = dict() out.update({k: self.bpod_trials[k][ibpod] for k in self.bpod_fields}) out.update({k: self.bpod2fpga(self.bpod_trials[k][ibpod]) for k in self.bpod_rsync_fields}) start_times = out['intervals'][:, 0] last_trial_end = out['intervals'][-1, 1] def assign_to_trial(events, take='last', starts=start_times, **kwargs): """Assign DAQ events to trials. Because we may not have trial start TTLs on the DAQ (because of the low sampling rate), there may be an extra last trial that's not in the Bpod intervals as the extractor ignores the last trial. This function trims the input array before assigning so that the last trial's events are correctly assigned. """ return _assign_events_to_trial(starts, events[events <= last_trial_end], take, **kwargs) out['itiIn_times'] = assign_to_trial(fpga_events['itiIn_times'][ifpga]) # Extract valve open times from the DAQ valve_driver_ttls = bpod_event_intervals['valve_open'] correct = self.bpod_trials['feedbackType'] == 1 # If there is a reward_valve channel, the valve has if any(ch['name'] == 'reward_valve' for ch in self.timeline['meta']['inputs']): # TODO Let's look at the expected open length based on calibration and reward volume # import scipy.interpolate # # FIXME support v7 settings? # fcn_vol2time = scipy.interpolate.pchip( # self.bpod_extractor.settings['device_valve']['WATER_CALIBRATION_WEIGHT_PERDROP'], # self.bpod_extractor.settings['device_valve']['WATER_CALIBRATION_OPEN_TIMES'] # ) # reward_time = fcn_vol2time(self.bpod_extractor.settings.get('REWARD_AMOUNT_UL')) / 1e3 # Use the driver TTLs to find the valve open times that correspond to the valve opening valve_intervals, valve_open_times = self.get_valve_open_times(driver_ttls=valve_driver_ttls) if valve_open_times.size != np.sum(correct): _logger.warning( 'Number of valve open times does not equal number of correct trials (%i != %i)', valve_open_times.size, np.sum(correct)) out['valveOpen_times'] = assign_to_trial(valve_open_times) else: # Use the valve controller TTLs recorded on the Bpod channel as the reward time out['valveOpen_times'] = assign_to_trial(valve_driver_ttls[:, 0]) # Stimulus times extracted based on trigger times # When assigning events all start times must not be NaN so here we substitute freeze # trigger times on nogo trials for stim on trigger times, then replace with NaN again go_trials = np.where(out['choice'] != 0)[0] lims = np.copy(out['stimOnTrigger_times']) lims[go_trials] = out['stimFreezeTrigger_times'][go_trials] out['stimFreeze_times'] = assign_to_trial( self.frame2ttl['times'], 'last', starts=lims, t_trial_end=out['stimOffTrigger_times']) out['stimFreeze_times'][out['choice'] == 0] = np.nan # Here we do the same but use stim off trigger times lims = np.copy(out['stimOffTrigger_times']) lims[go_trials] = out['stimFreezeTrigger_times'][go_trials] out['stimOn_times'] = assign_to_trial( self.frame2ttl['times'], 'first', starts=out['stimOnTrigger_times'], t_trial_end=lims) out['stimOff_times'] = assign_to_trial( self.frame2ttl['times'], 'first', starts=out['stimOffTrigger_times'], t_trial_end=out['intervals'][:, 1] ) # Audio times error_cue = fpga_events['errorTone_times'] if error_cue.size != np.sum(~correct): _logger.warning( 'N detected error tones does not match number of incorrect trials (%i != %i)', error_cue.size, np.sum(~correct)) go_cue = fpga_events['goCue_times'] out['goCue_times'] = assign_to_trial(go_cue, take='first') out['errorCue_times'] = assign_to_trial(error_cue) if go_cue.size > start_times.size: _logger.warning( 'More go cue tones detected than trials! (%i vs %i)', go_cue.size, start_times.size) elif go_cue.size < start_times.size: """ If the error cues are all assigned and some go cues are missed it may be that some responses were so fast that the go cue and error tone merged, or the go cue TTL was too long. """ _logger.warning('%i go cue tones missed', start_times.size - go_cue.size) err_trig = self.bpod2fpga(self.bpod_trials['errorCueTrigger_times']) go_trig = self.bpod2fpga(self.bpod_trials['goCueTrigger_times']) assert not np.any(np.isnan(go_trig)) assert err_trig.size == go_trig.size # should be length of n trials with NaNs # Find which trials are missing a go cue _go_cue = assign_to_trial(go_cue, take='first') error_cue = assign_to_trial(error_cue) missing = np.isnan(_go_cue) # Get all the DAQ timestamps where audio channel was HIGH raw = timeline_get_channel(self.timeline, 'audio') raw = (raw - raw.min()) / (raw.max() - raw.min()) # min-max normalize ups = self.timeline.timestamps[raw > .5] # timestamps where input HIGH # Get the timestamps of the first HIGH after the trigger times (allow up to 200ms after). # Indices of ups directly following a go trigger, or -1 if none found (or trigger NaN) idx = attribute_times(ups, go_trig, tol=0.2, take='after') # Trial indices that didn't have detected goCue and now has been assigned an `ups` index assigned = np.where(idx != -1 & missing)[0] # ignore unassigned _go_cue[assigned] = ups[idx[assigned]] # Remove mis-assigned error tone times (i.e. those that have now been assigned to goCue) error_cue_without_trig, = np.where(~np.isnan(error_cue) & np.isnan(err_trig)) i_to_remove = np.intersect1d(assigned, error_cue_without_trig, assume_unique=True) error_cue[i_to_remove] = np.nan # For those trials where go cue was merged with the error cue and therefore mis-assigned, # we must re-assign the error cue times as the first HIGH after the error trigger. idx = attribute_times(ups, err_trig, tol=0.2, take='after') assigned = np.where(idx != -1 & missing)[0] # ignore unassigned error_cue[assigned] = ups[idx[assigned]] out['goCue_times'] = _go_cue out['errorCue_times'] = error_cue # Because we're not assert np.intersect1d(out['goCue_times'], out['errorCue_times']).size == 0, \ 'audio tones not assigned correctly; tones likely missed' # Feedback times out['feedback_times'] = np.copy(out['valveOpen_times']) ind_err = np.isnan(out['valveOpen_times']) out['feedback_times'][ind_err] = out['errorCue_times'][ind_err] return out
[docs] def extract_wheel_sync(self, ticks=WHEEL_TICKS, radius=WHEEL_RADIUS_CM, coding='x4', tmin=None, tmax=None): """ Gets the wheel position from Timeline counter channel. Parameters ---------- ticks : int Number of ticks corresponding to a full revolution (1024 for IBL rotary encoder). radius : float Radius of the wheel. Defaults to 1 for an output in radians. coding : str {'x1', 'x2', 'x4'} Rotary encoder encoding (IBL default is x4). tmin : float The minimum time from which to extract the sync pulses. tmax : float The maximum time up to which we extract the sync pulses. Returns ------- np.array Wheel timestamps in seconds. np.array Wheel positions in radians. See Also -------- ibllib.io.extractors.ephys_fpga.extract_wheel_sync """ if coding not in ('x1', 'x2', 'x4'): raise ValueError('Unsupported coding; must be one of x1, x2 or x4') raw = correct_counter_discontinuities(timeline_get_channel(self.timeline, 'rotary_encoder')) # Timeline evenly samples counter so we extract only change points d = np.diff(raw) ind, = np.where(~np.isclose(d, 0)) pos = raw[ind + 1] pos -= pos[0] # Start from zero pos = pos / ticks * np.pi * 2 * radius / int(coding[1]) # Convert to radians # Get timestamps of changes and trim based on protocol spacers ts = self.timeline['timestamps'][ind + 1] tmin = ts.min() if tmin is None else tmin tmax = ts.max() if tmax is None else tmax mask = np.logical_and(ts >= tmin, ts <= tmax) return ts[mask], pos[mask]
[docs] def get_wheel_positions(self, ticks=WHEEL_TICKS, radius=WHEEL_RADIUS_CM, coding='x4', tmin=None, tmax=None, display=False, **kwargs): """ Gets the wheel position and detected movements from Timeline counter channel. Called by the super class extractor (FPGATrials._extract). Parameters ---------- ticks : int Number of ticks corresponding to a full revolution (1024 for IBL rotary encoder). radius : float Radius of the wheel. Defaults to 1 for an output in radians. coding : str {'x1', 'x2', 'x4'} Rotary encoder encoding (IBL default is x4). tmin : float The minimum time from which to extract the sync pulses. tmax : float The maximum time up to which we extract the sync pulses. display : bool If true, plot the wheel positions from bpod and the DAQ. Returns ------- dict wheel object with keys ('timestamps', 'position'). dict wheelMoves object with keys ('intervals' 'peakAmplitude'). """ wheel = self.extract_wheel_sync(ticks=ticks, radius=radius, coding=coding, tmin=tmin, tmax=tmax) wheel = dict(zip(('timestamps', 'position'), wheel)) moves = extract_wheel_moves(wheel['timestamps'], wheel['position']) if display: assert self.bpod_trials, 'no bpod trials to compare' fig, (ax0, ax1) = plt.subplots(nrows=2, sharex=True) bpod_ts = self.bpod_trials['wheel_timestamps'] bpod_pos = self.bpod_trials['wheel_position'] ax0.plot(self.bpod2fpga(bpod_ts), bpod_pos) ax0.set_ylabel('Bpod wheel position / rad') ax1.plot(wheel['timestamps'], wheel['position']) ax1.set_ylabel('DAQ wheel position / rad'), ax1.set_xlabel('Time / s') return wheel, moves
[docs] def get_valve_open_times(self, display=False, threshold=100, driver_ttls=None): """ Get the valve open times from the raw timeline voltage trace. Parameters ---------- display : bool Plot detected times on the raw voltage trace. threshold : float The threshold of voltage change to apply. The default was set by eye; units should be Volts per sample but doesn't appear to be. driver_ttls : numpy.array An optional array of driver TTLs to use for assigning with the valve times. Returns ------- numpy.array The detected valve open intervals. numpy.array If driver_ttls is not None, returns an array of open times that occurred directly after the driver TTLs. """ WARN_THRESH = 10e-3 # open time threshold below which to log warning tl = self.timeline info = next(x for x in tl['meta']['inputs'] if x['name'] == 'reward_valve') values = tl['raw'][:, info['arrayColumn'] - 1] # Timeline indices start from 1 # The voltage changes over ~1ms and can therefore occur over two DAQ samples at 2kHz # making simple thresholding an issue. For this reason we convolve the signal with a # window and detect the peaks and troughs. if (Fs := tl['meta']['daqSampleRate']) != 2000: # e.g. 2kHz _logger.warning('Reward valve detection not tested with a DAQ sample rate of %i', Fs) dt = 1e-3 # change in voltage takes ~1ms when changing valve open state N = dt / (1 / Fs) # this means voltage change occurs over N samples vel, _ = velocity_filtered(values, int(Fs / N)) # filtered voltage change over time ups, _ = find_peaks(vel, height=threshold) # valve closes (-5V -> 0V) downs, _ = find_peaks(-1 * vel, height=threshold) # valve opens (0V -> -5V) # Convert these times into intervals ixs = np.argsort(np.r_[downs, ups]) # sort indices times = tl['timestamps'][np.r_[downs, ups]][ixs] # ordered valve event times polarities = np.r_[np.zeros_like(downs) - 1, np.ones_like(ups)][ixs] # polarity sorted missing = np.where(np.diff(polarities) == 0)[0] # if some changes were missed insert NaN times = np.insert(times, missing + int(polarities[0] == -1), np.nan) if polarities[-1] == -1: # ensure ends with a valve close times = np.r_[times, np.nan] if polarities[0] == 1: # ensure starts with a valve open # It seems it can start out at -5V (open), then when the reward happens it closes and # immediately opens. In this case we insert discard the first open time. times = np.r_[np.nan, times] intervals = times.reshape(-1, 2) # Log warning of improbably short intervals short = np.sum(np.diff(intervals) < WARN_THRESH) if short > 0: _logger.warning('%i valve open intervals shorter than %i ms', short, WARN_THRESH) # The closing of the valve is noisy. Keep only the falls that occur immediately after a Bpod TTL if driver_ttls is not None: # Returns an array of open_times indices, one for each driver TTL ind = attribute_times(intervals[:, 0], driver_ttls[:, 0], tol=.1, take='after') open_times = intervals[ind[ind >= 0], 0] # TODO Log any > 40ms? Difficult to report missing valve times because of calibration if display: fig, (ax0, ax1) = plt.subplots(nrows=2, sharex=True) ax0.plot(tl['timestamps'], timeline_get_channel(tl, 'bpod'), color='grey', linestyle='-') if driver_ttls is not None: x = np.empty_like(driver_ttls.flatten()) x[0::2] = driver_ttls[:, 0] x[1::2] = driver_ttls[:, 1] y = np.ones_like(x) y[1::2] -= 2 squares(x, y, ax=ax0, yrange=[0, 5]) # vertical_lines(driver_ttls, ymax=5, ax=ax0, linestyle='--', color='b') ax0.plot(open_times, np.ones_like(open_times) * 4.5, 'g*') ax1.plot(tl['timestamps'], values, 'k-o') ax1.set_ylabel('Voltage / V'), ax1.set_xlabel('Time / s') ax2 = ax1.twinx() ax2.set_ylabel('dV', color='grey') ax2.plot(tl['timestamps'], vel, linestyle='-', color='grey') ax2.plot(intervals[:, 1], np.ones(len(intervals)) * threshold, 'r*', label='close') ax2.plot(intervals[:, 0], np.ones(len(intervals)) * threshold, 'g*', label='open') return intervals if driver_ttls is None else (intervals, open_times)
def _assign_events_audio(self, audio_times, audio_polarities, display=False): """ This is identical to ephys_fpga._assign_events_audio, except for the ready tone threshold. Parameters ---------- audio_times : numpy.array An array of audio TTL front times. audio_polarities : numpy.array An array of audio TTL front polarities (1 for rises, -1 for falls). display : bool If true, display audio pulses and the assigned onsets. Returns ------- numpy.array The times of the go cue onsets. numpy.array The times of the error tone onsets. """ # make sure that there are no 2 consecutive fall or consecutive rise events assert np.all(np.abs(np.diff(audio_polarities)) == 2) # take only even time differences: i.e. from rising to falling fronts dt = np.diff(audio_times) onsets = audio_polarities[:-1] == 1 # error tones are events lasting from 400ms to 1200ms i_error_tone_in = np.where(np.logical_and(0.4 < dt, dt < 1.2) & onsets)[0] t_error_tone_in = audio_times[i_error_tone_in] # detect ready tone by length below 300 ms i_ready_tone_in = np.where(np.logical_and(dt <= 0.3, onsets))[0] t_ready_tone_in = audio_times[i_ready_tone_in] if display: # pragma: no cover fig, ax = plt.subplots(nrows=2, sharex=True) ax[0].plot(self.timeline.timestamps, timeline_get_channel(self.timeline, 'audio'), 'k-o') ax[0].set_ylabel('Voltage / V') squares(audio_times, audio_polarities, yrange=[-1, 1], ax=ax[1]) vertical_lines(t_ready_tone_in, ymin=-.8, ymax=.8, ax=ax[1], label='go cue') vertical_lines(t_error_tone_in, ymin=-.8, ymax=.8, ax=ax[1], label='error tone') ax[1].set_xlabel('Time / s') ax[1].legend() return t_ready_tone_in, t_error_tone_in
[docs] class MesoscopeSyncTimeline(extractors_base.BaseExtractor): """Extraction of mesoscope imaging times.""" var_names = ('mpci_times', 'mpciStack_timeshift') save_names = ('mpci.times.npy', 'mpciStack.timeshift.npy') """one.alf.io.AlfBunch: The raw imaging meta data and frame times""" rawImagingData = None def __init__(self, session_path, n_FOVs): """ Extract the mesoscope frame times from DAQ data acquired through Timeline. Parameters ---------- session_path : str, pathlib.Path The session path to extract times from. n_FOVs : int The number of fields of view acquired. """ super().__init__(session_path) self.n_FOVs = n_FOVs fov = list(map(lambda n: f'FOV_{n:02}', range(self.n_FOVs))) self.var_names = [f'{x}_{y.lower()}' for x in self.var_names for y in fov] self.save_names = [f'{y}/{x}' for x in self.save_names for y in fov] def _extract(self, sync=None, chmap=None, device_collection='raw_imaging_data', events=None): """ Extract the frame timestamps for each individual field of view (FOV) and the time offsets for each line scan. The detected frame times from the 'neural_frames' channel of the DAQ are split into bouts corresponding to the number of raw_imaging_data folders. These timestamps should match the number of frame timestamps extracted from the image file headers (found in the rawImagingData.times file). The field of view (FOV) shifts are then applied to these timestamps for each field of view and provided together with the line shifts. Parameters ---------- sync : one.alf.io.AlfBunch A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses and the corresponding channel numbers. chmap : dict A map of channel names and their corresponding indices. Only the 'neural_frames' channel is required. device_collection : str, iterable of str The location of the raw imaging data. events : pandas.DataFrame A table of software events, with columns {'time_timeline' 'name_timeline', 'event_timeline'}. Returns ------- list of numpy.array A list of timestamps for each FOV and the time offsets for each line scan. """ frame_times = sync['times'][sync['channels'] == chmap['neural_frames']] # imaging_start_time = datetime.datetime(*map(round, self.rawImagingData.meta['acquisitionStartTime'])) if isinstance(device_collection, str): device_collection = [device_collection] if events is not None: events = events[events.name == 'mpepUDP'] edges = self.get_bout_edges(frame_times, device_collection, events) fov_times = [] line_shifts = [] for (tmin, tmax), collection in zip(edges, sorted(device_collection)): imaging_data = alfio.load_object(self.session_path / collection, 'rawImagingData') imaging_data['meta'] = patch_imaging_meta(imaging_data['meta']) # Calculate line shifts _, fov_time_shifts, line_time_shifts = self.get_timeshifts(imaging_data['meta']) assert len(fov_time_shifts) == self.n_FOVs, f'unexpected number of FOVs for {collection}' ts = frame_times[np.logical_and(frame_times >= tmin, frame_times <= tmax)] assert ts.size >= imaging_data['times_scanImage'].size, f'fewer DAQ timestamps for {collection} than expected' if ts.size > imaging_data['times_scanImage'].size: _logger.warning( 'More DAQ frame times detected for %s than were found in the raw image data.\n' 'N DAQ frame times:\t%i\nN raw image data times:\t%i.\n' 'This may occur if the bout detection fails (e.g. UDPs recorded late), ' 'when image data is corrupt, or when frames are not written to file.', collection, ts.size, imaging_data['times_scanImage'].size) _logger.info('Dropping last %i frame times for %s', ts.size - imaging_data['times_scanImage'].size, collection) ts = ts[:imaging_data['times_scanImage'].size] fov_times.append([ts + offset for offset in fov_time_shifts]) if not line_shifts: line_shifts = line_time_shifts else: # The line shifts should be the same across all imaging bouts [np.testing.assert_array_equal(x, y) for x, y in zip(line_time_shifts, line_shifts)] # Concatenate imaging timestamps across all bouts for each field of view fov_times = list(map(np.concatenate, zip(*fov_times))) n_fov_times, = set(map(len, fov_times)) if n_fov_times != frame_times.size: # This may happen if an experimenter deletes a raw_imaging_data folder _logger.debug('FOV timestamps length does not match neural frame count; imaging bout(s) likely missing') return fov_times + line_shifts
[docs] def get_bout_edges(self, frame_times, collections=None, events=None, min_gap=1., display=False): """ Return an array of edge times for each imaging bout corresponding to a raw_imaging_data collection. Parameters ---------- frame_times : numpy.array An array of all neural frame count times. collections : iterable of str A set of raw_imaging_data collections, used to extract selected imaging periods. events : pandas.DataFrame A table of UDP event times, corresponding to times when recordings start and end. min_gap : float If start or end events not present, split bouts by finding gaps larger than this value. display : bool If true, plot the detected bout edges and raw frame times. Returns ------- numpy.array An array of imaging bout intervals. """ if events is None or events.empty: # No UDP events to mark blocks so separate based on gaps in frame rate idx = np.where(np.diff(frame_times) > min_gap)[0] starts = np.r_[frame_times[0], frame_times[idx + 1]] ends = np.r_[frame_times[idx], frame_times[-1]] else: # Split using Exp/BlockStart and Exp/BlockEnd times _, subject, date, _ = session_path_parts(self.session_path) pattern = rf'(Exp|Block)%s\s{subject}\s{date.replace("-", "")}\s\d+' # Get start times UDP_start = events[events['info'].str.match(pattern % 'Start')] if len(UDP_start) > 1 and UDP_start.loc[0, 'info'].startswith('Exp'): # Use ExpStart instead of first bout start UDP_start = UDP_start.copy().drop(1) # Use ExpStart/End instead of first/last BlockStart/End starts = frame_times[[np.where(frame_times >= t)[0][0] for t in UDP_start.time]] # Get end times UDP_end = events[events['info'].str.match(pattern % 'End')] if len(UDP_end) > 1 and UDP_end['info'].values[-1].startswith('Exp'): # Use last BlockEnd instead of ExpEnd UDP_end = UDP_end.copy().drop(UDP_end.index[-1]) if not UDP_end.empty: ends = frame_times[[np.where(frame_times <= t)[0][-1] for t in UDP_end.time]] else: # Get index of last frame to occur within a second of the previous frame consec = np.r_[np.diff(frame_times) > min_gap, True] idx = [np.where(np.logical_and(frame_times > t, consec))[0][0] for t in starts] ends = frame_times[idx] # Remove any missing imaging bout collections edges = np.c_[starts, ends] if collections: if edges.shape[0] > len(collections): # Remove any bouts that correspond to a skipped collection # e.g. if {raw_imaging_data_00, raw_imaging_data_02}, remove middle bout include = sorted(int(c.rsplit('_', 1)[-1]) for c in collections) edges = edges[include, :] elif edges.shape[0] < len(collections): raise ValueError('More raw imaging folders than detected bouts') if display: _, ax = plt.subplots(1) ax.step(frame_times, np.arange(frame_times.size), label='frame times', color='k', ) vertical_lines(edges[:, 0], ax=ax, ymin=0, ymax=frame_times.size, label='bout start', color='b') vertical_lines(edges[:, 1], ax=ax, ymin=0, ymax=frame_times.size, label='bout end', color='orange') if edges.shape[0] != len(starts): vertical_lines(np.setdiff1d(starts, edges[:, 0]), ax=ax, ymin=0, ymax=frame_times.size, label='missing bout start', linestyle=':', color='b') vertical_lines(np.setdiff1d(ends, edges[:, 1]), ax=ax, ymin=0, ymax=frame_times.size, label='missing bout end', linestyle=':', color='orange') ax.set_xlabel('Time / s'), ax.set_ylabel('Frame #'), ax.legend(loc='lower right') return edges
[docs] @staticmethod def get_timeshifts(raw_imaging_meta): """ Calculate the time shifts for each field of view (FOV) and the relative offsets for each scan line. For a 2 scan field, 2 depth recording (so 4 FOVs): Frame 1, lines 1-512 correspond to FOV_00 Frame 1, lines 551-1062 correspond to FOV_01 Frame 2, lines 1-512 correspond to FOV_02 Frame 2, lines 551-1062 correspond to FOV_03 Frame 3, lines 1-512 correspond to FOV_00 ... Parameters ---------- raw_imaging_meta : dict Extracted ScanImage meta data (_ibl_rawImagingData.meta.json). Returns ------- list of numpy.array A list of arrays, one per FOV, containing indices of each image scan line. numpy.array An array of FOV time offsets (one value per FOV) relative to each frame acquisition time. list of numpy.array A list of arrays, one per FOV, containing the time offsets for each scan line, relative to each FOV offset. """ FOVs = raw_imaging_meta['FOV'] # Double-check meta extracted properly # assert meta.FOV.Zs is ascending but use slice_id field. This may not be necessary but is expected. slice_ids = np.array([fov['slice_id'] for fov in FOVs]) assert np.all(np.diff([x['Zs'] for x in FOVs]) >= 0), 'FOV depths not in ascending order' assert np.all(np.diff(slice_ids) >= 0), 'slice IDs not ordered' # Number of scan lines per FOV, i.e. number of Y pixels / image height n_lines = np.array([x['nXnYnZ'][1] for x in FOVs]) # We get indices from MATLAB extracted metadata so below two lines are no longer needed # n_valid_lines = np.sum(n_lines) # Number of lines imaged excluding flybacks # n_lines_per_gap = int((raw_meta['Height'] - n_valid_lines) / (len(FOVs) - 1)) # N lines during flyback line_period = raw_imaging_meta['scanImageParams']['hRoiManager']['linePeriod'] frame_time_shifts = slice_ids / raw_imaging_meta['scanImageParams']['hRoiManager']['scanFrameRate'] # Line indices are now extracted by the MATLAB function mesoscopeMetadataExtraction.m # They are indexed from 1 so we subtract 1 to convert to zero-indexed line_indices = [np.array(fov['lineIdx']) - 1 for fov in FOVs] # Convert to zero-indexed from MATLAB 1-indexed assert all(lns.size == n for lns, n in zip(line_indices, n_lines)), 'unexpected number of scan lines' # The start indices of each FOV in the raw images fov_start_idx = np.array([lns[0] for lns in line_indices]) roi_time_shifts = fov_start_idx * line_period # The time offset for each FOV fov_time_shifts = roi_time_shifts + frame_time_shifts line_time_shifts = [(lns - ln0) * line_period for lns, ln0 in zip(line_indices, fov_start_idx)] return line_indices, fov_time_shifts, line_time_shifts