Source code for ibllib.io.extractors.ephys_passive

#!/usr/bin/env python
# -*- coding:utf-8 -*-
# @Author: Niccolò Bonacchi
# @Date: Monday, September 7th 2020, 11:51:17 am
import json
import logging
from pathlib import Path
from typing import Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import ibllib.io.raw_data_loaders as rawio
from ibllib.io.extractors import ephys_fpga
from ibllib.io.extractors.base import BaseExtractor
from ibllib.io.extractors.passive_plotting import (
    plot_audio_times,
    plot_gabor_times,
    plot_passive_periods,
    plot_rfmapping,
    plot_stims_times,
    plot_sync_channels,
    plot_valve_times,
)

log = logging.getLogger("ibllib")

# hardcoded var
FRAME_FS = 60  # Sampling freq of the ipad screen, in Hertz
FS_FPGA = 30000  # Sampling freq of the neural recording system screen, in Hertz
NVALVE = 40  # number of expected valve clicks
NGABOR = 20 + 20 * 4 * 2  # number of expected Gabor patches
NTONES = 40
NNOISES = 40
DEBUG_PLOTS = False

dataset_types = [
    "_spikeglx_sync.times",
    "_spikeglx_sync.channels",
    "_spikeglx_sync.polarities",
    "_iblrig_RFMapStim.raw",
    "_iblrig_stimPositionScreen.raw",
    "_iblrig_syncSquareUpdate.raw",
    "ephysData.raw.meta",
    "_iblrig_taskSettings.raw",
    "_iblrig_taskData.raw",
]

min_dataset_types = [
    "_spikeglx_sync.times",
    "_spikeglx_sync.channels",
    "_spikeglx_sync.polarities",
    "_iblrig_RFMapStim.raw",
    "ephysData.raw.meta",
    "_iblrig_taskSettings.raw",
    "_iblrig_taskData.raw",
]


# load session fixtures
def _load_passive_session_fixtures(session_path: str, task_collection: str = 'raw_passive_data') -> dict:
    """load_passive_session_fixtures Loads corresponding ephys session fixtures

    :param session_path: the path to a session
    :type session_path: str
    :return: position contrast phase delays and stim id's
    :rtype: dict
    """

    # THIS CAN BE PREGENERATED SESSION NO
    settings = rawio.load_settings(session_path, task_collection=task_collection)
    ses_nb = settings['PREGENERATED_SESSION_NUM']
    session_order = settings.get('SESSION_ORDER', None)
    if session_order:  # TODO test this out and make sure it okay
        assert settings["SESSION_ORDER"][settings["SESSION_IDX"]] == ses_nb

    path_fixtures = Path(ephys_fpga.__file__).parent.joinpath("ephys_sessions")

    fixture = {
        "pcs": np.load(path_fixtures.joinpath(f"session_{ses_nb}_passive_pcs.npy")),
        "delays": np.load(path_fixtures.joinpath(f"session_{ses_nb}_passive_stimDelays.npy")),
        "ids": np.load(path_fixtures.joinpath(f"session_{ses_nb}_passive_stimIDs.npy")),
    }

    return fixture


def _load_task_protocol(session_path: str, task_collection: str = 'raw_passive_data') -> str:
    """Find the IBL rig version used for the session

    :param session_path: the path to a session
    :type session_path: str
    :return: ibl rig task protocol version
    :rtype: str

    FIXME This function has a misleading name
    """
    settings = rawio.load_settings(session_path, task_collection=task_collection)
    ses_ver = settings["IBLRIG_VERSION"]

    return ses_ver


def _load_passive_stim_meta() -> dict:
    """load_passive_stim_meta Loads the passive protocol metadata

    :return: metadata about passive protocol stimulus presentation
    :rtype: dict
    """
    path_fixtures = Path(ephys_fpga.__file__).parent.joinpath("ephys_sessions")
    with open(path_fixtures.joinpath("passive_stim_meta.json"), "r") as f:
        meta = json.load(f)

    return meta


# 1/3 Define start and end times of the 3 passive periods
def _get_spacer_times(spacer_template, jitter, ttl_signal, t_quiet, thresh=3.0):
    """
    Find timestamps of spacer signal.
    :param spacer_template: list of indices where ttl signal changes
    :type spacer_template: array-like
    :param jitter: jitter (in seconds) for matching ttl_signal with spacer_template
    :type jitter: float
    :param ttl_signal:
    :type ttl_signal: array-like
    :param t_quiet: seconds between spacer and next stim
    :type t_quiet: float
    :param thresh: threshold value for the fttl convolved signal (with spacer template) to pass over to detect a spacer
    :type thresh: float
    :return: times of spacer onset/offset
    :rtype: n_spacer x 2 np.ndarray; first col onset times, second col offset
    """
    diff_spacer_template = np.diff(spacer_template)
    # add jitter;
    # remove extreme values
    spacer_model = jitter + diff_spacer_template[2:-2]
    # diff ttl signal to compare to spacer_model
    dttl = np.diff(ttl_signal)
    # remove diffs larger than max diff in model to clean up signal
    dttl[dttl > np.max(spacer_model)] = 0
    # convolve cleaned diff ttl signal w/ spacer model
    conv_dttl = np.correlate(dttl, spacer_model, mode="full")
    # find spacer location
    idxs_spacer_middle = np.where(
        (conv_dttl[1:-2] < thresh) & (conv_dttl[2:-1] > thresh) & (conv_dttl[3:] < thresh)
    )[0]
    # adjust indices for
    # - `np.where` call above
    # - length of spacer_model
    spacer_around = int((np.floor(len(spacer_model) / 2)))
    idxs_spacer_middle += 2 - spacer_around

    # for each spacer make sure the times are monotonically non-decreasing before
    # and monotonically non-increasing afterwards
    is_valid = np.zeros((idxs_spacer_middle.size), dtype=bool)
    for i, t in enumerate(idxs_spacer_middle):
        before = all(np.diff(dttl[t - spacer_around:t]) >= 0)
        after = all(np.diff(dttl[t + 1:t + 1 + spacer_around]) <= 0)
        is_valid[i] = np.bitwise_and(before, after)

    idxs_spacer_middle = idxs_spacer_middle[is_valid]

    # pull out spacer times (middle)
    ts_spacer_middle = ttl_signal[idxs_spacer_middle]
    # put beginning/end of spacer times into an array
    spacer_length = np.max(spacer_template)
    spacer_times = np.zeros(shape=(ts_spacer_middle.shape[0], 2))
    for i, t in enumerate(ts_spacer_middle):
        spacer_times[i, 0] = t - (spacer_length / 2) - t_quiet
        spacer_times[i, 1] = t + (spacer_length / 2) + t_quiet
    return spacer_times, conv_dttl


def _get_passive_spacers(session_path, sync_collection='raw_ephys_data',
                         sync=None, sync_map=None, tmin=None, tmax=None):
    """
    load and get spacer information, do corr to find spacer timestamps
    returns t_passive_starts, t_starts, t_ends
    """
    if sync is None or sync_map is None:
        sync, sync_map = ephys_fpga.get_sync_and_chn_map(session_path, sync_collection=sync_collection)
    meta = _load_passive_stim_meta()
    # t_end_ephys = passive.ephysCW_end(session_path=session_path)
    fttl = ephys_fpga.get_sync_fronts(sync, sync_map["frame2ttl"], tmin=tmin, tmax=tmax)
    fttl = ephys_fpga._clean_frame2ttl(fttl, display=False)
    spacer_template = (
        np.array(meta["VISUAL_STIM_0"]["ttl_frame_nums"], dtype=np.float32) / FRAME_FS
    )
    jitter = 3 / FRAME_FS  # allow for 3 screen refresh as jitter
    t_quiet = meta["VISUAL_STIM_0"]["delay_around"]
    spacer_times, conv_dttl = _get_spacer_times(
        spacer_template=spacer_template, jitter=jitter, ttl_signal=fttl["times"], t_quiet=t_quiet
    )

    # Check correct number of spacers found
    n_exp_spacer = np.sum(np.array(meta['STIM_ORDER']) == 0)  # Hardcoded 0 for spacer
    if n_exp_spacer != np.size(spacer_times) / 2:
        error_nspacer = True
        # sometimes the first spacer is truncated
        # assess whether the first spacer is undetected, and then launch another spacer detection on truncated fttl
        # with a lower threshold value
        # Note: take *3 for some margin
        if spacer_times[0][0] > (spacer_template[-1] + jitter) * 3 and (np.size(spacer_times) / 2) == n_exp_spacer - 1:
            # Truncate signals
            fttl_t = fttl["times"][np.where(fttl["times"] < spacer_times[0][0])]
            conv_dttl_t = conv_dttl[np.where(fttl["times"] < spacer_times[0][0])]
            ddttl = np.diff(np.diff(fttl_t))
            # Find spacer location
            # NB: cannot re-use the same algo for spacer detection as conv peaks towards spacer end
            # 1. Find time point at which conv raises above a given threshold value
            thresh = 2.0
            idx_nearend_spacer = int(np.where((conv_dttl_t[1:-2] < thresh) & (conv_dttl_t[2:-1] > thresh))[0])
            ddttl = ddttl[0:idx_nearend_spacer]
            # 2. Find time point before this, for which fttl diff increase/decrease (this is the middle of spacer)
            indx_middle = np.where((ddttl[0:-1] > 0) & (ddttl[1:] < 0))[0]
            if len(indx_middle) == 1:
                # 3. Add 1/2 spacer to middle idx to get the spacer end indx
                spacer_around = int((np.floor(len(spacer_template) / 2)))
                idx_end = int(indx_middle + spacer_around) + 1
                spacer_times = np.insert(spacer_times, 0, np.array([fttl["times"][0], fttl["times"][idx_end]]), axis=0)
                error_nspacer = False

        if error_nspacer:
            raise ValueError(
                f'The number of expected spacer ({n_exp_spacer}) '
                f'is different than the one found on the raw '
                f'trace ({int(np.size(spacer_times) / 2)})'
            )

    if tmax is None:
        tmax = sync['times'][-1]

    spacer_times = np.r_[spacer_times.flatten(), tmax]
    return spacer_times[0], spacer_times[1::2], spacer_times[2::2]


# 2/3 RFMapping stimuli
def _interpolate_rf_mapping_stimulus(idxs_up, idxs_dn, times, Xq, t_bin):
    """
    Interpolate stimulus presentation times to screen refresh rate to match `frames`
    :param ttl_01:
    :type ttl_01: array-like
    :param times: array of stimulus switch times
    :type times: array-like
    :param Xq: number of times (found in frames)
    :type frames: array-like
    :param t_bin: period of screen refresh rate
    :type t_bin: float
    :return: tuple of (stim_times, stim_frames)
    """

    beg_extrap_val = -10001
    end_extrap_val = -10000

    X = np.sort(np.concatenate([idxs_up, idxs_dn]))
    # make left and right extrapolations distinctive to easily find later
    Tq = np.interp(Xq, X, times, left=beg_extrap_val, right=end_extrap_val)
    # uniform spacing outside boundaries of ttl signal
    # first values
    n_beg = len(np.where(Tq == beg_extrap_val)[0])
    if 0 < n_beg < Tq.shape[0]:
        Tq[:n_beg] = times[0] - np.arange(n_beg, 0, -1) * t_bin
    # end values
    n_end = len(np.where(Tq == end_extrap_val)[0])
    if 0 < n_end < Tq.shape[0]:
        Tq[-n_end:] = times[-1] + np.arange(1, n_end + 1) * t_bin
    return Tq


def _get_id_raisefall_from_analogttl(ttl_01):
    """
    Get index of raise/fall from analog continuous TTL signal  (0-1 values)
    :param ttl_01: analog continuous TTL signal  (0-1 values)
    :return: index up (0>1), index down (1>0), number of ttl transition
    """
    # Check values are 0, 1, -1
    if not np.all(np.isin(np.unique(ttl_01), [-1, 0, 1])):
        raise ValueError("Values in input must be 0, 1, -1")
    else:
        # Find number of passage from [0 1] and [0 -1]
        d_ttl_01 = np.diff(ttl_01)
        id_up = np.where(np.logical_and(ttl_01 == 0, np.append(d_ttl_01, 0) == 1))[0]
        id_dw = np.where(np.logical_and(ttl_01 == 0, np.append(d_ttl_01, 0) == -1))[0]
        n_ttl_expected = 2 * (len(id_up) + len(id_dw))  # *2 for rise/fall of ttl pulse
        return id_up, id_dw, n_ttl_expected


def _reshape_RF(RF_file, meta_stim):
    """
    Reshape Receptive Field (RF) matrix. Take data associated to corner
    where frame2ttl placed to create TTL trace.
    :param RF_file: vector to be reshaped, containing RF info
    :param meta_stim: variable containing metadata information on RF
    :return: frames (reshaped RF), analog trace (0-1 values)
    """
    frame_array = np.fromfile(RF_file, dtype="uint8")
    y_pix, x_pix, _ = meta_stim["stim_file_shape"]
    frames = np.transpose(np.reshape(frame_array, [y_pix, x_pix, -1], order="F"), [2, 1, 0])
    ttl_trace = frames[:, 0, 0]
    # Convert values to 0,1,-1 for simplicity
    ttl_analogtrace_01 = np.zeros(np.size(ttl_trace))
    ttl_analogtrace_01[np.where(ttl_trace == 0)] = -1
    ttl_analogtrace_01[np.where(ttl_trace == 255)] = 1
    return frames, ttl_analogtrace_01


# 3/3 Replay of task stimuli
def _extract_passiveGabor_df(fttl: dict, session_path: str, task_collection: str = 'raw_passive_data') -> pd.DataFrame:
    # At this stage we want to define what pulses are and not quality control them.
    # Pulses are strictly alternating with intervals
    # find min max lengths for both (we don't know which are pulses and which are intervals yet)
    # trim edges of pulses
    diff0 = (np.min(np.diff(fttl["times"])[2:-2:2]), np.max(np.diff(fttl["times"])[2:-1:2]))
    diff1 = (np.min(np.diff(fttl["times"])[3:-2:2]), np.max(np.diff(fttl["times"])[3:-1:2]))
    # Highest max is of the intervals
    if max(diff0 + diff1) in diff0:
        thresh = diff0[0]
    elif max(diff0 + diff1) in diff1:
        thresh = diff1[0]
    # Anything lower than the min length of intervals is a pulse
    idx_start_stims = np.where((np.diff(fttl["times"]) < thresh) & (np.diff(fttl["times"]) > 0.1))[0]
    # Check if any pulse has been missed
    # i.e. expected length (without first pulse) and that it's alternating
    if len(idx_start_stims) < NGABOR - 1 and np.any(np.diff(idx_start_stims) > 2):
        log.warning("Looks like one or more pulses were not detected, trying to extrapolate...")
        missing_where = np.where(np.diff(idx_start_stims) > 2)[0]
        insert_where = missing_where + 1
        missing_value = idx_start_stims[missing_where] + 2
        idx_start_stims = np.insert(idx_start_stims, insert_where, missing_value)

    idx_end_stims = idx_start_stims + 1

    start_times = fttl["times"][idx_start_stims]
    end_times = fttl["times"][idx_end_stims]
    # Check if we missed the first stim
    if len(start_times) < NGABOR:
        first_stim_off_idx = idx_start_stims[0] - 1
        # first_stim_on_idx = first_stim_off_idx - 1
        end_times = np.insert(end_times, 0, fttl["times"][first_stim_off_idx])
        start_times = np.insert(start_times, 0, end_times[0] - 0.3)

    # intervals dstype requires reshaping of start and end times
    passiveGabor_intervals = np.array([(x, y) for x, y in zip(start_times, end_times)])

    # Check length of presentation of stim is within 150ms of expected
    if not np.allclose([y - x for x, y in passiveGabor_intervals], 0.3, atol=0.15):
        log.warning("Some Gabor presentation lengths seem wrong.")

    assert (
        len(passiveGabor_intervals) == NGABOR
    ), f"Wrong number of Gabor stimuli detected: {len(passiveGabor_intervals)} / {NGABOR}"
    fixture = _load_passive_session_fixtures(session_path, task_collection)
    passiveGabor_properties = fixture["pcs"]
    passiveGabor_table = np.append(passiveGabor_intervals, passiveGabor_properties, axis=1)
    columns = ["start", "stop", "position", "contrast", "phase"]
    passiveGabor_df = pd.DataFrame(passiveGabor_table, columns=columns)
    return passiveGabor_df


def _extract_passiveValve_intervals(bpod: dict) -> np.array:
    # passiveValve.intervals
    # Get valve intervals from bpod channel
    # bpod channel should only contain valve output for passiveCW protocol
    # All high fronts == valve open times and low fronts == valve close times
    valveOn_times = bpod["times"][bpod["polarities"] > 0]
    valveOff_times = bpod["times"][bpod["polarities"] < 0]

    assert len(valveOn_times) == NVALVE, "Wrong number of valve ONSET times"
    assert len(valveOff_times) == NVALVE, "Wrong number of valve OFFSET times"
    assert len(bpod["times"]) == NVALVE * 2, "Wrong number of valve FRONTS detected"  # (40 * 2)

    # check all values are within bpod tolerance of 100µs
    assert np.allclose(
        valveOff_times - valveOn_times, valveOff_times[1] - valveOn_times[1], atol=0.0001
    ), "Some valve outputs are longer or shorter than others"

    return np.array([(x, y) for x, y in zip(valveOn_times, valveOff_times)])


def _extract_passiveAudio_intervals(audio: dict, rig_version: str) -> Tuple[np.array, np.array]:

    # make an exception for task version = 6.2.5 where things are strange but data is recoverable
    if rig_version == '6.2.5':
        # Get all sound onsets and offsets
        soundOn_times = audio["times"][audio["polarities"] > 0]
        soundOff_times = audio["times"][audio["polarities"] < 0]

        # Have a couple that are wayyy too long!
        time_threshold = 10
        diff = soundOff_times - soundOn_times
        stupid = np.where(diff > time_threshold)[0]
        NREMOVE = len(stupid)
        not_stupid = np.where(diff < time_threshold)[0]

        assert len(soundOn_times) == NTONES + NNOISES - NREMOVE, "Wrong number of sound ONSETS"
        assert len(soundOff_times) == NTONES + NNOISES - NREMOVE, "Wrong number of sound OFFSETS"

        soundOn_times = soundOn_times[not_stupid]
        soundOff_times = soundOff_times[not_stupid]

        diff = soundOff_times - soundOn_times
        # Tone is ~100ms so check if diff < 0.3
        toneOn_times = soundOn_times[diff <= 0.3]
        toneOff_times = soundOff_times[diff <= 0.3]
        # Noise is ~500ms so check if diff > 0.3
        noiseOn_times = soundOn_times[diff > 0.3]
        noiseOff_times = soundOff_times[diff > 0.3]

        # append with nans
        toneOn_times = np.r_[toneOn_times, np.full((NTONES - len(toneOn_times)), np.NAN)]
        toneOff_times = np.r_[toneOff_times, np.full((NTONES - len(toneOff_times)), np.NAN)]
        noiseOn_times = np.r_[noiseOn_times, np.full((NNOISES - len(noiseOn_times)), np.NAN)]
        noiseOff_times = np.r_[noiseOff_times, np.full((NNOISES - len(noiseOff_times)), np.NAN)]

    else:
        # Get all sound onsets and offsets
        soundOn_times = audio["times"][audio["polarities"] > 0]
        soundOff_times = audio["times"][audio["polarities"] < 0]

        # Check they are the correct number
        assert len(soundOn_times) == NTONES + NNOISES, f"Wrong number of sound ONSETS, " \
                                                       f"{len(soundOn_times)}/{NTONES + NNOISES}"
        assert len(soundOff_times) == NTONES + NNOISES, f"Wrong number of sound OFFSETS, " \
                                                        f"{len(soundOn_times)}/{NTONES + NNOISES}"

        diff = soundOff_times - soundOn_times
        # Tone is ~100ms so check if diff < 0.3
        toneOn_times = soundOn_times[diff <= 0.3]
        toneOff_times = soundOff_times[diff <= 0.3]
        # Noise is ~500ms so check if diff > 0.3
        noiseOn_times = soundOn_times[diff > 0.3]
        noiseOff_times = soundOff_times[diff > 0.3]

        assert len(toneOn_times) == NTONES
        assert len(toneOff_times) == NTONES
        assert len(noiseOn_times) == NNOISES
        assert len(noiseOff_times) == NNOISES

        # Fixed delays from soundcard ~500µs
        np.allclose(toneOff_times - toneOn_times, 0.1, atol=0.0006)
        np.allclose(noiseOff_times - noiseOn_times, 0.5, atol=0.0006)

    passiveTone_intervals = np.append(
        toneOn_times.reshape((len(toneOn_times), 1)),
        toneOff_times.reshape((len(toneOff_times), 1)),
        axis=1,
    )
    passiveNoise_intervals = np.append(
        noiseOn_times.reshape((len(noiseOn_times), 1)),
        noiseOff_times.reshape((len(noiseOff_times), 1)),
        axis=1,
    )
    return passiveTone_intervals, passiveNoise_intervals


# ------------------------------------------------------------------
[docs] def extract_passive_periods(session_path: str, sync_collection: str = 'raw_ephys_data', sync: dict = None, sync_map: dict = None, tmin=None, tmax=None) -> pd.DataFrame: if sync is None or sync_map is None: sync, sync_map = ephys_fpga.get_sync_and_chn_map(session_path, sync_collection) t_start_passive, t_starts, t_ends = _get_passive_spacers( session_path, sync_collection, sync=sync, sync_map=sync_map, tmin=tmin, tmax=tmax ) t_starts_col = np.insert(t_starts, 0, t_start_passive) t_ends_col = np.insert(t_ends, 0, t_ends[-1]) # tpassive_protocol = [t_start_passive, t_ends[-1]] # tspontaneous = [t_starts[0], t_ends[0]] # trfm = [t_starts[1], t_ends[1]] # treplay = [t_starts[2], t_ends[2]] passivePeriods_df = pd.DataFrame( [t_starts_col, t_ends_col], index=["start", "stop"], columns=["passiveProtocol", "spontaneousActivity", "RFM", "taskReplay"], ) return passivePeriods_df # _ibl_passivePeriods.intervalsTable.csv
[docs] def extract_rfmapping( session_path: str, sync_collection: str = 'raw_ephys_data', task_collection: str = 'raw_passive_data', sync: dict = None, sync_map: dict = None, trfm: np.array = None ) -> Tuple[np.array, np.array]: meta = _load_passive_stim_meta() mkey = ( "VISUAL_STIM_" + {v: k for k, v in meta["VISUAL_STIMULI"].items()}["receptive_field_mapping"] ) if sync is None or sync_map is None: sync, sync_map = ephys_fpga.get_sync_and_chn_map(session_path, sync_collection) if trfm is None: passivePeriods_df = extract_passive_periods(session_path, sync_collection, sync=sync, sync_map=sync_map) trfm = passivePeriods_df.RFM.values fttl = ephys_fpga.get_sync_fronts(sync, sync_map["frame2ttl"], tmin=trfm[0], tmax=trfm[1]) fttl = ephys_fpga._clean_frame2ttl(fttl) RF_file = Path().joinpath(session_path, task_collection, "_iblrig_RFMapStim.raw.bin") passiveRFM_frames, RF_ttl_trace = _reshape_RF(RF_file=RF_file, meta_stim=meta[mkey]) rf_id_up, rf_id_dw, RF_n_ttl_expected = _get_id_raisefall_from_analogttl(RF_ttl_trace) meta[mkey]["ttl_num"] = RF_n_ttl_expected rf_times_on_idx = np.where(np.diff(fttl["times"]) < 1)[0] rf_times_off_idx = rf_times_on_idx + 1 RF_times = fttl["times"][np.sort(np.concatenate([rf_times_on_idx, rf_times_off_idx]))] RF_times_1 = RF_times[0::2] # Interpolate times for RF before outputting dataset passiveRFM_times = _interpolate_rf_mapping_stimulus( idxs_up=rf_id_up, idxs_dn=rf_id_dw, times=RF_times_1, Xq=np.arange(passiveRFM_frames.shape[0]), t_bin=1 / FRAME_FS, ) return passiveRFM_times # _ibl_passiveRFM.times.npy
[docs] def extract_task_replay( session_path: str, sync_collection: str = 'raw_ephys_data', task_collection: str = 'raw_passive_data', sync: dict = None, sync_map: dict = None, treplay: np.array = None ) -> Tuple[pd.DataFrame, pd.DataFrame]: if sync is None or sync_map is None: sync, sync_map = ephys_fpga.get_sync_and_chn_map(session_path, sync_collection) if treplay is None: passivePeriods_df = extract_passive_periods(session_path, sync_collection, sync=sync, sync_map=sync_map) treplay = passivePeriods_df.taskReplay.values # TODO need to check this is okay fttl = ephys_fpga.get_sync_fronts(sync, sync_map["frame2ttl"], tmin=treplay[0], tmax=treplay[1]) fttl = ephys_fpga._clean_frame2ttl(fttl) passiveGabor_df = _extract_passiveGabor_df(fttl, session_path, task_collection=task_collection) bpod = ephys_fpga.get_sync_fronts(sync, sync_map["bpod"], tmin=treplay[0], tmax=treplay[1]) passiveValve_intervals = _extract_passiveValve_intervals(bpod) task_version = _load_task_protocol(session_path, task_collection) audio = ephys_fpga.get_sync_fronts(sync, sync_map["audio"], tmin=treplay[0], tmax=treplay[1]) passiveTone_intervals, passiveNoise_intervals = _extract_passiveAudio_intervals(audio, task_version) passiveStims_df = np.concatenate( [passiveValve_intervals, passiveTone_intervals, passiveNoise_intervals], axis=1 ) columns = ["valveOn", "valveOff", "toneOn", "toneOff", "noiseOn", "noiseOff"] passiveStims_df = pd.DataFrame(passiveStims_df, columns=columns) return ( passiveGabor_df, passiveStims_df, ) # _ibl_passiveGabor.table.csv, _ibl_passiveStims.times_table.csv
[docs] def extract_replay_debug( session_path: str, sync_collection: str = 'raw_ephys_data', task_collection: str = 'raw_passive_data', sync: dict = None, sync_map: dict = None, treplay: np.array = None, ax: plt.axes = None, ) -> Tuple[pd.DataFrame, pd.DataFrame]: # Load sessions sync channels, map if sync is None or sync_map is None: sync, sync_map = ephys_fpga.get_sync_and_chn_map(session_path, sync_collection) if treplay is None: passivePeriods_df = extract_passive_periods(session_path, sync_collection=sync_collection, sync=sync, sync_map=sync_map) treplay = passivePeriods_df.taskReplay.values if ax is None: f, ax = plt.subplots(1, 1) f = ax.figure f.suptitle("/".join(str(session_path).split("/")[-5:])) plot_sync_channels(sync=sync, sync_map=sync_map, ax=ax) passivePeriods_df = extract_passive_periods(session_path, sync_collection=sync_collection, sync=sync, sync_map=sync_map) treplay = passivePeriods_df.taskReplay.values plot_passive_periods(passivePeriods_df, ax=ax) fttl = ephys_fpga.get_sync_fronts(sync, sync_map["frame2ttl"], tmin=treplay[0]) passiveGabor_df = _extract_passiveGabor_df(fttl, session_path, task_collection=task_collection) plot_gabor_times(passiveGabor_df, ax=ax) bpod = ephys_fpga.get_sync_fronts(sync, sync_map["bpod"], tmin=treplay[0]) passiveValve_intervals = _extract_passiveValve_intervals(bpod) plot_valve_times(passiveValve_intervals, ax=ax) task_version = _load_task_protocol(session_path, task_collection) audio = ephys_fpga.get_sync_fronts(sync, sync_map["audio"], tmin=treplay[0]) passiveTone_intervals, passiveNoise_intervals = _extract_passiveAudio_intervals(audio, task_version) plot_audio_times(passiveTone_intervals, passiveNoise_intervals, ax=ax) passiveStims_df = np.concatenate( [passiveValve_intervals, passiveTone_intervals, passiveNoise_intervals], axis=1 ) columns = ["valveOn", "valveOff", "toneOn", "toneOff", "noiseOn", "noiseOff"] passiveStims_df = pd.DataFrame(passiveStims_df, columns=columns) return ( passiveGabor_df, passiveStims_df, ) # _ibl_passiveGabor.table.csv, _ibl_passiveStims.table.csv
# Main passiveCW extractor, calls all others
[docs] class PassiveChoiceWorld(BaseExtractor): save_names = ( "_ibl_passivePeriods.intervalsTable.csv", "_ibl_passiveRFM.times.npy", "_ibl_passiveGabor.table.csv", "_ibl_passiveStims.table.csv", ) var_names = ( "passivePeriods_df", "passiveRFM_times", "passiveGabor_df", "passiveStims_df", ) def _extract(self, sync_collection: str = 'raw_ephys_data', task_collection: str = 'raw_passive_data', sync: dict = None, sync_map: dict = None, plot: bool = False, **kwargs) -> tuple: if sync is None or sync_map is None: sync, sync_map = ephys_fpga.get_sync_and_chn_map(self.session_path, sync_collection) # Get the start and end times of this protocol if (protocol_number := kwargs.get('protocol_number')) is not None: # look for spacer # The spacers are TTLs generated by Bpod at the start of each protocol bpod = ephys_fpga.get_sync_fronts(sync, sync_map['bpod']) tmin, tmax = ephys_fpga.get_protocol_period(self.session_path, protocol_number, bpod) else: tmin = tmax = None # Passive periods passivePeriods_df = extract_passive_periods(self.session_path, sync_collection=sync_collection, sync=sync, sync_map=sync_map, tmin=tmin, tmax=tmax) trfm = passivePeriods_df.RFM.values treplay = passivePeriods_df.taskReplay.values try: # RFMapping passiveRFM_times = extract_rfmapping(self.session_path, sync_collection=sync_collection, task_collection=task_collection, sync=sync, sync_map=sync_map, trfm=trfm) except Exception as e: log.error(f"Failed to extract RFMapping datasets: {e}") passiveRFM_times = None try: (passiveGabor_df, passiveStims_df,) = extract_task_replay( self.session_path, sync_collection=sync_collection, task_collection=task_collection, sync=sync, sync_map=sync_map, treplay=treplay) except Exception as e: log.error(f"Failed to extract task replay stimuli: {e}") passiveGabor_df, passiveStims_df = (None, None) if plot: f, ax = plt.subplots(1, 1) f.suptitle("/".join(str(self.session_path).split("/")[-5:])) plot_sync_channels(sync=sync, sync_map=sync_map, ax=ax) plot_passive_periods(passivePeriods_df, ax=ax) plot_rfmapping(passiveRFM_times, ax=ax) plot_gabor_times(passiveGabor_df, ax=ax) plot_stims_times(passiveStims_df, ax=ax) plt.show() data = ( passivePeriods_df, # _ibl_passivePeriods.intervalsTable.csv passiveRFM_times, # _ibl_passiveRFM.times.npy passiveGabor_df, # _ibl_passiveGabor.table.csv, passiveStims_df # _ibl_passiveStims.table.csv ) # Set save names to None if data not extracted - these will not be saved or registered self.save_names = tuple(None if y is None else x for x, y in zip(self.save_names, data)) return data