#!/usr/bin/env python
# -*- coding:utf-8 -*-
# @Author: Niccolò Bonacchi
# @Date: Monday, September 7th 2020, 11:51:17 am
import json
import logging
from pathlib import Path
from typing import Tuple
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import ibllib.io.raw_data_loaders as rawio
from ibllib.io.extractors import ephys_fpga
from ibllib.io.extractors.base import BaseExtractor
from ibllib.io.extractors.passive_plotting import (
plot_audio_times,
plot_gabor_times,
plot_passive_periods,
plot_rfmapping,
plot_stims_times,
plot_sync_channels,
plot_valve_times,
)
log = logging.getLogger("ibllib")
# hardcoded var
FRAME_FS = 60 # Sampling freq of the ipad screen, in Hertz
FS_FPGA = 30000 # Sampling freq of the neural recording system screen, in Hertz
NVALVE = 40 # number of expected valve clicks
NGABOR = 20 + 20 * 4 * 2 # number of expected Gabor patches
NTONES = 40
NNOISES = 40
DEBUG_PLOTS = False
dataset_types = [
"_spikeglx_sync.times",
"_spikeglx_sync.channels",
"_spikeglx_sync.polarities",
"_iblrig_RFMapStim.raw",
"_iblrig_stimPositionScreen.raw",
"_iblrig_syncSquareUpdate.raw",
"ephysData.raw.meta",
"_iblrig_taskSettings.raw",
"_iblrig_taskData.raw",
]
min_dataset_types = [
"_spikeglx_sync.times",
"_spikeglx_sync.channels",
"_spikeglx_sync.polarities",
"_iblrig_RFMapStim.raw",
"ephysData.raw.meta",
"_iblrig_taskSettings.raw",
"_iblrig_taskData.raw",
]
# load session fixtures
def _load_passive_session_fixtures(session_path: str, task_collection: str = 'raw_passive_data') -> dict:
"""load_passive_session_fixtures Loads corresponding ephys session fixtures
:param session_path: the path to a session
:type session_path: str
:return: position contrast phase delays and stim id's
:rtype: dict
"""
# THIS CAN BE PREGENERATED SESSION NO
settings = rawio.load_settings(session_path, task_collection=task_collection)
ses_nb = settings['PREGENERATED_SESSION_NUM']
session_order = settings.get('SESSION_ORDER', None)
if session_order: # TODO test this out and make sure it okay
assert settings["SESSION_ORDER"][settings["SESSION_IDX"]] == ses_nb
path_fixtures = Path(ephys_fpga.__file__).parent.joinpath("ephys_sessions")
fixture = {
"pcs": np.load(path_fixtures.joinpath(f"session_{ses_nb}_passive_pcs.npy")),
"delays": np.load(path_fixtures.joinpath(f"session_{ses_nb}_passive_stimDelays.npy")),
"ids": np.load(path_fixtures.joinpath(f"session_{ses_nb}_passive_stimIDs.npy")),
}
return fixture
def _load_task_version(session_path: str, task_collection: str = 'raw_passive_data') -> str:
"""Find the IBL rig version used for the session
:param session_path: the path to a session
:type session_path: str
:return: ibl rig task protocol version
:rtype: str
"""
settings = rawio.load_settings(session_path, task_collection=task_collection)
ses_ver = settings["IBLRIG_VERSION"]
return ses_ver
[docs]
def skip_task_replay(session_path: str, task_collection: str = 'raw_passive_data') -> bool:
"""Find whether the task replay portion of the passive stimulus has been shown
:param session_path: the path to a session
:type session_path: str
:param task_collection: collection containing task data
:type task_collection: str
:return: whether or not the task replay has been run
:rtype: bool
"""
settings = rawio.load_settings(session_path, task_collection=task_collection)
# Attempt to see if SKIP_EVENT_REPLAY is available, if not assume we do have task replay
skip_replay = settings.get('SKIP_EVENT_REPLAY', False)
return skip_replay
def _load_passive_stim_meta() -> dict:
"""load_passive_stim_meta Loads the passive protocol metadata
:return: metadata about passive protocol stimulus presentation
:rtype: dict
"""
path_fixtures = Path(ephys_fpga.__file__).parent.joinpath("ephys_sessions")
with open(path_fixtures.joinpath("passive_stim_meta.json"), "r") as f:
meta = json.load(f)
return meta
# 1/3 Define start and end times of the 3 passive periods
def _get_spacer_times(spacer_template, jitter, ttl_signal, t_quiet, thresh=3.0):
"""
Find timestamps of spacer signal.
:param spacer_template: list of indices where ttl signal changes
:type spacer_template: array-like
:param jitter: jitter (in seconds) for matching ttl_signal with spacer_template
:type jitter: float
:param ttl_signal:
:type ttl_signal: array-like
:param t_quiet: seconds between spacer and next stim
:type t_quiet: float
:param thresh: threshold value for the fttl convolved signal (with spacer template) to pass over to detect a spacer
:type thresh: float
:return: times of spacer onset/offset
:rtype: n_spacer x 2 np.ndarray; first col onset times, second col offset
"""
diff_spacer_template = np.diff(spacer_template)
# add jitter;
# remove extreme values
spacer_model = jitter + diff_spacer_template[2:-2]
# diff ttl signal to compare to spacer_model
dttl = np.diff(ttl_signal)
# remove diffs larger than max diff in model to clean up signal
dttl[dttl > np.max(spacer_model)] = 0
# convolve cleaned diff ttl signal w/ spacer model
conv_dttl = np.correlate(dttl, spacer_model, mode="full")
# find spacer location
idxs_spacer_middle = np.where(
(conv_dttl[1:-2] < thresh) & (conv_dttl[2:-1] > thresh) & (conv_dttl[3:] < thresh)
)[0]
# adjust indices for
# - `np.where` call above
# - length of spacer_model
spacer_around = int((np.floor(len(spacer_model) / 2)))
idxs_spacer_middle += 2 - spacer_around
# for each spacer make sure the times are monotonically non-decreasing before
# and monotonically non-increasing afterwards
is_valid = np.zeros((idxs_spacer_middle.size), dtype=bool)
for i, t in enumerate(idxs_spacer_middle):
before = all(np.diff(dttl[t - spacer_around:t]) >= 0)
after = all(np.diff(dttl[t + 1:t + 1 + spacer_around]) <= 0)
is_valid[i] = np.bitwise_and(before, after)
idxs_spacer_middle = idxs_spacer_middle[is_valid]
# pull out spacer times (middle)
ts_spacer_middle = ttl_signal[idxs_spacer_middle]
# put beginning/end of spacer times into an array
spacer_length = np.max(spacer_template)
spacer_times = np.zeros(shape=(ts_spacer_middle.shape[0], 2))
for i, t in enumerate(ts_spacer_middle):
spacer_times[i, 0] = t - (spacer_length / 2) - t_quiet
spacer_times[i, 1] = t + (spacer_length / 2) + t_quiet
return spacer_times, conv_dttl
def _get_passive_spacers(session_path, sync_collection='raw_ephys_data',
sync=None, sync_map=None, tmin=None, tmax=None):
"""
load and get spacer information, do corr to find spacer timestamps
returns t_passive_starts, t_starts, t_ends
"""
if sync is None or sync_map is None:
sync, sync_map = ephys_fpga.get_sync_and_chn_map(session_path, sync_collection=sync_collection)
meta = _load_passive_stim_meta()
# t_end_ephys = passive.ephysCW_end(session_path=session_path)
fttl = ephys_fpga.get_sync_fronts(sync, sync_map["frame2ttl"], tmin=tmin, tmax=tmax)
fttl = ephys_fpga._clean_frame2ttl(fttl, display=False)
spacer_template = (
np.array(meta["VISUAL_STIM_0"]["ttl_frame_nums"], dtype=np.float32) / FRAME_FS
)
jitter = 3 / FRAME_FS # allow for 3 screen refresh as jitter
t_quiet = meta["VISUAL_STIM_0"]["delay_around"]
spacer_times, conv_dttl = _get_spacer_times(
spacer_template=spacer_template, jitter=jitter, ttl_signal=fttl["times"], t_quiet=t_quiet
)
# Check correct number of spacers found
n_exp_spacer = np.sum(np.array(meta['STIM_ORDER']) == 0) # Hardcoded 0 for spacer
if n_exp_spacer != np.size(spacer_times) / 2:
error_nspacer = True
# sometimes the first spacer is truncated
# assess whether the first spacer is undetected, and then launch another spacer detection on truncated fttl
# with a lower threshold value
# Note: take *3 for some margin
if spacer_times[0][0] > (spacer_template[-1] + jitter) * 3 and (np.size(spacer_times) / 2) == n_exp_spacer - 1:
# Truncate signals
fttl_t = fttl["times"][np.where(fttl["times"] < spacer_times[0][0])]
conv_dttl_t = conv_dttl[np.where(fttl["times"] < spacer_times[0][0])]
ddttl = np.diff(np.diff(fttl_t))
# Find spacer location
# NB: cannot re-use the same algo for spacer detection as conv peaks towards spacer end
# 1. Find time point at which conv raises above a given threshold value
thresh = 2.0
idx_nearend_spacer = int(np.where((conv_dttl_t[1:-2] < thresh) & (conv_dttl_t[2:-1] > thresh))[0])
ddttl = ddttl[0:idx_nearend_spacer]
# 2. Find time point before this, for which fttl diff increase/decrease (this is the middle of spacer)
indx_middle = np.where((ddttl[0:-1] > 0) & (ddttl[1:] < 0))[0]
if len(indx_middle) == 1:
# 3. Add 1/2 spacer to middle idx to get the spacer end indx
spacer_around = int((np.floor(len(spacer_template) / 2)))
idx_end = int(indx_middle + spacer_around) + 1
spacer_times = np.insert(spacer_times, 0, np.array([fttl["times"][0], fttl["times"][idx_end]]), axis=0)
error_nspacer = False
if error_nspacer:
raise ValueError(
f'The number of expected spacer ({n_exp_spacer}) '
f'is different than the one found on the raw '
f'trace ({int(np.size(spacer_times) / 2)})'
)
if tmax is None:
tmax = sync['times'][-1]
spacer_times = np.r_[spacer_times.flatten(), tmax]
return spacer_times[0], spacer_times[1::2], spacer_times[2::2]
# 2/3 RFMapping stimuli
def _interpolate_rf_mapping_stimulus(idxs_up, idxs_dn, times, Xq, t_bin):
"""
Interpolate stimulus presentation times to screen refresh rate to match `frames`
:param ttl_01:
:type ttl_01: array-like
:param times: array of stimulus switch times
:type times: array-like
:param Xq: number of times (found in frames)
:type frames: array-like
:param t_bin: period of screen refresh rate
:type t_bin: float
:return: tuple of (stim_times, stim_frames)
"""
beg_extrap_val = -10001
end_extrap_val = -10000
X = np.sort(np.concatenate([idxs_up, idxs_dn]))
# make left and right extrapolations distinctive to easily find later
Tq = np.interp(Xq, X, times, left=beg_extrap_val, right=end_extrap_val)
# uniform spacing outside boundaries of ttl signal
# first values
n_beg = len(np.where(Tq == beg_extrap_val)[0])
if 0 < n_beg < Tq.shape[0]:
Tq[:n_beg] = times[0] - np.arange(n_beg, 0, -1) * t_bin
# end values
n_end = len(np.where(Tq == end_extrap_val)[0])
if 0 < n_end < Tq.shape[0]:
Tq[-n_end:] = times[-1] + np.arange(1, n_end + 1) * t_bin
return Tq
def _get_id_raisefall_from_analogttl(ttl_01):
"""
Get index of raise/fall from analog continuous TTL signal (0-1 values)
:param ttl_01: analog continuous TTL signal (0-1 values)
:return: index up (0>1), index down (1>0), number of ttl transition
"""
# Check values are 0, 1, -1
if not np.all(np.isin(np.unique(ttl_01), [-1, 0, 1])):
raise ValueError("Values in input must be 0, 1, -1")
else:
# Find number of passage from [0 1] and [0 -1]
d_ttl_01 = np.diff(ttl_01)
id_up = np.where(np.logical_and(ttl_01 == 0, np.append(d_ttl_01, 0) == 1))[0]
id_dw = np.where(np.logical_and(ttl_01 == 0, np.append(d_ttl_01, 0) == -1))[0]
n_ttl_expected = 2 * (len(id_up) + len(id_dw)) # *2 for rise/fall of ttl pulse
return id_up, id_dw, n_ttl_expected
def _reshape_RF(RF_file, meta_stim):
"""
Reshape Receptive Field (RF) matrix. Take data associated to corner
where frame2ttl placed to create TTL trace.
:param RF_file: vector to be reshaped, containing RF info
:param meta_stim: variable containing metadata information on RF
:return: frames (reshaped RF), analog trace (0-1 values)
"""
frame_array = np.fromfile(RF_file, dtype="uint8")
y_pix, x_pix, _ = meta_stim["stim_file_shape"]
frames = np.transpose(np.reshape(frame_array, [y_pix, x_pix, -1], order="F"), [2, 1, 0])
ttl_trace = frames[:, 0, 0]
# Convert values to 0,1,-1 for simplicity
ttl_analogtrace_01 = np.zeros(np.size(ttl_trace))
ttl_analogtrace_01[np.where(ttl_trace == 0)] = -1
ttl_analogtrace_01[np.where(ttl_trace == 255)] = 1
return frames, ttl_analogtrace_01
# 3/3 Replay of task stimuli
def _extract_passiveGabor_df(fttl: dict, session_path: str, task_collection: str = 'raw_passive_data') -> pd.DataFrame:
# At this stage we want to define what pulses are and not quality control them.
# Pulses are strictly alternating with intervals
# find min max lengths for both (we don't know which are pulses and which are intervals yet)
# trim edges of pulses
diff0 = (np.min(np.diff(fttl["times"])[2:-2:2]), np.max(np.diff(fttl["times"])[2:-1:2]))
diff1 = (np.min(np.diff(fttl["times"])[3:-2:2]), np.max(np.diff(fttl["times"])[3:-1:2]))
# Highest max is of the intervals
if max(diff0 + diff1) in diff0:
thresh = diff0[0]
elif max(diff0 + diff1) in diff1:
thresh = diff1[0]
# Anything lower than the min length of intervals is a pulse
idx_start_stims = np.where((np.diff(fttl["times"]) < thresh) & (np.diff(fttl["times"]) > 0.1))[0]
# Check if any pulse has been missed
# i.e. expected length (without first pulse) and that it's alternating
if len(idx_start_stims) < NGABOR - 1 and np.any(np.diff(idx_start_stims) > 2):
log.warning("Looks like one or more pulses were not detected, trying to extrapolate...")
missing_where = np.where(np.diff(idx_start_stims) > 2)[0]
insert_where = missing_where + 1
missing_value = idx_start_stims[missing_where] + 2
idx_start_stims = np.insert(idx_start_stims, insert_where, missing_value)
idx_end_stims = idx_start_stims + 1
start_times = fttl["times"][idx_start_stims]
end_times = fttl["times"][idx_end_stims]
# Check if we missed the first stim
if len(start_times) < NGABOR:
first_stim_off_idx = idx_start_stims[0] - 1
# first_stim_on_idx = first_stim_off_idx - 1
end_times = np.insert(end_times, 0, fttl["times"][first_stim_off_idx])
start_times = np.insert(start_times, 0, end_times[0] - 0.3)
# intervals dstype requires reshaping of start and end times
passiveGabor_intervals = np.array([(x, y) for x, y in zip(start_times, end_times)])
# Check length of presentation of stim is within 150ms of expected
if not np.allclose([y - x for x, y in passiveGabor_intervals], 0.3, atol=0.15):
log.warning("Some Gabor presentation lengths seem wrong.")
assert (
len(passiveGabor_intervals) == NGABOR
), f"Wrong number of Gabor stimuli detected: {len(passiveGabor_intervals)} / {NGABOR}"
fixture = _load_passive_session_fixtures(session_path, task_collection)
passiveGabor_properties = fixture["pcs"]
passiveGabor_table = np.append(passiveGabor_intervals, passiveGabor_properties, axis=1)
columns = ["start", "stop", "position", "contrast", "phase"]
passiveGabor_df = pd.DataFrame(passiveGabor_table, columns=columns)
return passiveGabor_df
def _extract_passiveValve_intervals(bpod: dict) -> np.array:
# passiveValve.intervals
# Get valve intervals from bpod channel
# bpod channel should only contain valve output for passiveCW protocol
# All high fronts == valve open times and low fronts == valve close times
valveOn_times = bpod["times"][bpod["polarities"] > 0]
valveOff_times = bpod["times"][bpod["polarities"] < 0]
assert len(valveOn_times) == NVALVE, "Wrong number of valve ONSET times"
assert len(valveOff_times) == NVALVE, "Wrong number of valve OFFSET times"
assert len(bpod["times"]) == NVALVE * 2, "Wrong number of valve FRONTS detected" # (40 * 2)
# check all values are within bpod tolerance of 100µs
assert np.allclose(
valveOff_times - valveOn_times, valveOff_times[1] - valveOn_times[1], atol=0.0001
), "Some valve outputs are longer or shorter than others"
return np.array([(x, y) for x, y in zip(valveOn_times, valveOff_times)])
def _extract_passiveAudio_intervals(audio: dict, rig_version: str) -> Tuple[np.array, np.array]:
# make an exception for task version = 6.2.5 where things are strange but data is recoverable
if rig_version == '6.2.5':
# Get all sound onsets and offsets
soundOn_times = audio["times"][audio["polarities"] > 0]
soundOff_times = audio["times"][audio["polarities"] < 0]
# Have a couple that are wayyy too long!
time_threshold = 10
diff = soundOff_times - soundOn_times
stupid = np.where(diff > time_threshold)[0]
NREMOVE = len(stupid)
not_stupid = np.where(diff < time_threshold)[0]
assert len(soundOn_times) == NTONES + NNOISES - NREMOVE, "Wrong number of sound ONSETS"
assert len(soundOff_times) == NTONES + NNOISES - NREMOVE, "Wrong number of sound OFFSETS"
soundOn_times = soundOn_times[not_stupid]
soundOff_times = soundOff_times[not_stupid]
diff = soundOff_times - soundOn_times
# Tone is ~100ms so check if diff < 0.3
toneOn_times = soundOn_times[diff <= 0.3]
toneOff_times = soundOff_times[diff <= 0.3]
# Noise is ~500ms so check if diff > 0.3
noiseOn_times = soundOn_times[diff > 0.3]
noiseOff_times = soundOff_times[diff > 0.3]
# append with nans
toneOn_times = np.r_[toneOn_times, np.full((NTONES - len(toneOn_times)), np.nan)]
toneOff_times = np.r_[toneOff_times, np.full((NTONES - len(toneOff_times)), np.nan)]
noiseOn_times = np.r_[noiseOn_times, np.full((NNOISES - len(noiseOn_times)), np.nan)]
noiseOff_times = np.r_[noiseOff_times, np.full((NNOISES - len(noiseOff_times)), np.nan)]
else:
# Get all sound onsets and offsets
soundOn_times = audio["times"][audio["polarities"] > 0]
soundOff_times = audio["times"][audio["polarities"] < 0]
# Check they are the correct number
assert len(soundOn_times) == NTONES + NNOISES, f"Wrong number of sound ONSETS, " \
f"{len(soundOn_times)}/{NTONES + NNOISES}"
assert len(soundOff_times) == NTONES + NNOISES, f"Wrong number of sound OFFSETS, " \
f"{len(soundOn_times)}/{NTONES + NNOISES}"
diff = soundOff_times - soundOn_times
# Tone is ~100ms so check if diff < 0.3
toneOn_times = soundOn_times[diff <= 0.3]
toneOff_times = soundOff_times[diff <= 0.3]
# Noise is ~500ms so check if diff > 0.3
noiseOn_times = soundOn_times[diff > 0.3]
noiseOff_times = soundOff_times[diff > 0.3]
assert len(toneOn_times) == NTONES
assert len(toneOff_times) == NTONES
assert len(noiseOn_times) == NNOISES
assert len(noiseOff_times) == NNOISES
# Fixed delays from soundcard ~500µs
np.allclose(toneOff_times - toneOn_times, 0.1, atol=0.0006)
np.allclose(noiseOff_times - noiseOn_times, 0.5, atol=0.0006)
passiveTone_intervals = np.append(
toneOn_times.reshape((len(toneOn_times), 1)),
toneOff_times.reshape((len(toneOff_times), 1)),
axis=1,
)
passiveNoise_intervals = np.append(
noiseOn_times.reshape((len(noiseOn_times), 1)),
noiseOff_times.reshape((len(noiseOff_times), 1)),
axis=1,
)
return passiveTone_intervals, passiveNoise_intervals
# ------------------------------------------------------------------
# Main passiveCW extractor, calls all others
[docs]
class PassiveChoiceWorld(BaseExtractor):
save_names = (
"_ibl_passivePeriods.intervalsTable.csv",
"_ibl_passiveRFM.times.npy",
"_ibl_passiveGabor.table.csv",
"_ibl_passiveStims.table.csv",
)
var_names = (
"passivePeriods_df",
"passiveRFM_times",
"passiveGabor_df",
"passiveStims_df",
)
def _extract(self, sync_collection: str = 'raw_ephys_data', task_collection: str = 'raw_passive_data', sync: dict = None,
sync_map: dict = None, plot: bool = False, **kwargs) -> tuple:
if sync is None or sync_map is None:
sync, sync_map = ephys_fpga.get_sync_and_chn_map(self.session_path, sync_collection)
# Get the start and end times of this protocol
if (protocol_number := kwargs.get('protocol_number')) is not None: # look for spacer
# The spacers are TTLs generated by Bpod at the start of each protocol
bpod = ephys_fpga.get_sync_fronts(sync, sync_map['bpod'])
tmin, tmax = ephys_fpga.get_protocol_period(self.session_path, protocol_number, bpod)
else:
tmin = tmax = None
# Passive periods
passivePeriods_df = extract_passive_periods(self.session_path, sync_collection=sync_collection, sync=sync,
sync_map=sync_map, tmin=tmin, tmax=tmax)
trfm = passivePeriods_df.RFM.values
treplay = passivePeriods_df.taskReplay.values
try:
# RFMapping
passiveRFM_times = extract_rfmapping(self.session_path, sync_collection=sync_collection,
task_collection=task_collection, sync=sync, sync_map=sync_map, trfm=trfm)
except Exception as e:
log.error(f"Failed to extract RFMapping datasets: {e}")
passiveRFM_times = None
skip_replay = skip_task_replay(self.session_path, task_collection)
if not skip_replay:
try:
(passiveGabor_df, passiveStims_df,) = extract_task_replay(
self.session_path, sync_collection=sync_collection, task_collection=task_collection, sync=sync,
sync_map=sync_map, treplay=treplay)
except Exception as e:
log.error(f"Failed to extract task replay stimuli: {e}")
passiveGabor_df, passiveStims_df = (None, None)
else:
# If we don't have task replay then we set the treplay intervals to NaN in our passivePeriods_df dataset
passiveGabor_df, passiveStims_df = (None, None)
passivePeriods_df.taskReplay = np.nan
if plot:
f, ax = plt.subplots(1, 1)
f.suptitle("/".join(str(self.session_path).split("/")[-5:]))
plot_sync_channels(sync=sync, sync_map=sync_map, ax=ax)
plot_passive_periods(passivePeriods_df, ax=ax)
plot_rfmapping(passiveRFM_times, ax=ax)
plot_gabor_times(passiveGabor_df, ax=ax)
plot_stims_times(passiveStims_df, ax=ax)
plt.show()
data = (
passivePeriods_df, # _ibl_passivePeriods.intervalsTable.csv
passiveRFM_times, # _ibl_passiveRFM.times.npy
passiveGabor_df, # _ibl_passiveGabor.table.csv,
passiveStims_df # _ibl_passiveStims.table.csv
)
# Set save names to None if data not extracted - these will not be saved or registered
self.save_names = tuple(None if y is None else x for x, y in zip(self.save_names, data))
return data