"""Data extraction from fibrephotometry DAQ files.
Below is the expected folder structure for a fibrephotometry session:
subject/
├─ 2021-06-30/
│ ├─ 001/
│ │ ├─ raw_photometry_data/
│ │ │ │ ├─ _neurophotometrics_fpData.raw.pqt
│ │ │ │ ├─ _neurophotometrics_fpData.channels.csv
│ │ │ │ ├─ _mcc_DAQdata.raw.tdms
fpData.raw.pqt is a copy of the 'FPdata' file, the output of the Neuophotometrics Bonsai workflow.
fpData.channels.csv is table of frame flags for deciphering LED and GPIO states. The default table,
copied from the Neurophotometrics manual can be found in iblscripts/deploy/fppc/
_mcc_DAQdata.raw.tdms is the DAQ tdms file, containing the pulses from bpod and from the neurophotometrics system
"""
import logging
import pandas as pd
import numpy as np
import scipy.interpolate
import one.alf.io as alfio
from ibllib.io.extractors.base import BaseExtractor
from ibllib.io.raw_daq_loaders import load_channels_tdms, load_raw_daq_tdms
from ibllib.io.extractors.training_trials import GoCueTriggerTimes
from ibldsp.utils import rises, sync_timestamps
_logger = logging.getLogger(__name__)
DAQ_CHMAP = {"photometry": 'AI0', 'bpod': 'AI1'}
V_THRESHOLD = 3
"""
Neurophotometrics FP3002 specific information.
The light source map refers to the available LEDs on the system.
The flags refers to the byte encoding of led states in the system.
"""
LIGHT_SOURCE_MAP = {
'color': ['None', 'Violet', 'Blue', 'Green'],
'wavelength': [0, 415, 470, 560],
'name': ['None', 'Isosbestic', 'GCaMP', 'RCaMP'],
}
NEUROPHOTOMETRICS_LED_STATES = {
'Condition': {
0: 'No additional signal',
1: 'Output 1 signal HIGH',
2: 'Output 0 signal HIGH',
3: 'Stimulation ON',
4: 'GPIO Line 2 HIGH',
5: 'GPIO Line 3 HIGH',
6: 'Input 1 HIGH',
7: 'Input 0 HIGH',
8: 'Output 0 signal HIGH + Stimulation',
9: 'Output 0 signal HIGH + Input 0 signal HIGH',
10: 'Input 0 signal HIGH + Stimulation',
11: 'Output 0 HIGH + Input 0 HIGH + Stimulation',
},
'No LED ON': {0: 0, 1: 8, 2: 16, 3: 32, 4: 64, 5: 128, 6: 256, 7: 512, 8: 48, 9: 528, 10: 544, 11: 560},
'L415': {0: 1, 1: 9, 2: 17, 3: 33, 4: 65, 5: 129, 6: 257, 7: 513, 8: 49, 9: 529, 10: 545, 11: 561},
'L470': {0: 2, 1: 10, 2: 18, 3: 34, 4: 66, 5: 130, 6: 258, 7: 514, 8: 50, 9: 530, 10: 546, 11: 562},
'L560': {0: 4, 1: 12, 2: 20, 3: 36, 4: 68, 5: 132, 6: 260, 7: 516, 8: 52, 9: 532, 10: 548, 11: 564}
}
[docs]
def sync_photometry_to_daq(vdaq, fs, df_photometry, chmap=DAQ_CHMAP, v_threshold=V_THRESHOLD):
"""
:param vdaq: dictionary of daq traces.
:param fs: sampling frequency
:param df_photometry:
:param chmap:
:param v_threshold:
:return:
"""
# here we take the flag that is the most common
daq_frames, tag_daq_frames = read_daq_timestamps(vdaq=vdaq, v_threshold=v_threshold)
nf = np.minimum(tag_daq_frames.size, df_photometry['Input0'].size)
# we compute the framecounter for the DAQ, and match the bpod up state frame by frame for different shifts
# the shift that minimizes the mismatch is usually good
df = np.median(np.diff(df_photometry['Timestamp']))
fc = np.cumsum(np.round(np.diff(daq_frames) / fs / df).astype(np.int32)) - 1 # this is a daq frame counter
fc = fc[fc < (nf - 1)]
max_shift = 300
error = np.zeros(max_shift * 2 + 1)
shifts = np.arange(-max_shift, max_shift + 1)
for i, shift in enumerate(shifts):
rolled_fp = np.roll(df_photometry['Input0'].values[fc], shift)
error[i] = np.sum(np.abs(rolled_fp - tag_daq_frames[:fc.size]))
# a negative shift means that the DAQ is ahead of the photometry and that the DAQ misses frame at the beginning
frame_shift = shifts[np.argmax(-error)]
if np.sign(frame_shift) == -1:
ifp = fc[np.abs(frame_shift):]
elif np.sign(frame_shift) == 0:
ifp = fc
elif np.sign(frame_shift) == 1:
ifp = fc[:-np.abs(frame_shift)]
t_photometry = df_photometry['Timestamp'].values[ifp]
t_daq = daq_frames[:ifp.size] / fs
# import matplotlib.pyplot as plt
# plt.plot(shifts, -error)
fcn_fp2daq = scipy.interpolate.interp1d(t_photometry, t_daq, fill_value='extrapolate')
drift_ppm = (np.polyfit(t_daq, t_photometry, 1)[0] - 1) * 1e6
if drift_ppm > 120:
_logger.warning(f"drift photometry to DAQ PPM: {drift_ppm}")
else:
_logger.info(f"drift photometry to DAQ PPM: {drift_ppm}")
# here is a bunch of safeguards
assert np.unique(np.diff(df_photometry['FrameCounter'])).size == 1 # checks that there are no missed frames on photo
assert np.abs(frame_shift) <= 5 # it's always the end frames that are missing
assert np.abs(drift_ppm) < 60
ts_daq = fcn_fp2daq(df_photometry['Timestamp'].values) # those are the timestamps in daq time
return ts_daq, fcn_fp2daq, drift_ppm
[docs]
def read_daq_voltage(daq_file, chmap=DAQ_CHMAP):
channel_names = [c.name for c in load_raw_daq_tdms(daq_file)['Analog'].channels()]
assert all([v in channel_names for v in chmap.values()]), "Missing channel"
vdaq, fs = load_channels_tdms(daq_file, chmap=chmap)
vdaq = {k: v - np.median(v) for k, v in vdaq.items()}
return vdaq, fs
[docs]
def read_daq_timestamps(vdaq, v_threshold=V_THRESHOLD):
"""
From a tdms daq file, extracts the photometry frames and their tagging.
:param vsaq: dictionary of the voltage traces from the DAQ. Each item has a key describing
the channel as per the channel map, and contains a single voltage trace.
:param v_threshold:
:return:
"""
daq_frames = rises(vdaq['photometry'], step=v_threshold, analog=True)
if daq_frames.size == 0:
daq_frames = rises(-vdaq['photometry'], step=v_threshold, analog=True)
_logger.warning(f'No photometry pulses detected, attempting to reverse voltage and detect again,'
f'found {daq_frames.size} in reverse voltage. CHECK YOUR FP WIRING TO THE DAQ !!')
tagged_frames = vdaq['bpod'][daq_frames] > v_threshold
return daq_frames, tagged_frames
[docs]
def check_timestamps(daq_file, photometry_file, tolerance=20, chmap=DAQ_CHMAP, v_threshold=V_THRESHOLD):
"""
Reads data file and checks that the number of timestamps check out with a tolerance of n_frames
:param daq_file:
:param photometry_file:
:param tolerance: number of acceptable missing frames between the daq and the photometry file
:param chmap:
:param v_threshold:
:return: None
"""
df_photometry = pd.read_csv(photometry_file)
v, fs = read_daq_voltage(daq_file=daq_file, chmap=chmap)
daq_frames, _ = read_daq_timestamps(vdaq=v, v_threshold=v_threshold)
assert (daq_frames.shape[0] - df_photometry.shape[0]) < tolerance
_logger.info(f"{daq_frames.shape[0] - df_photometry.shape[0]} frames difference, "
f"{'/'.join(daq_file.parts[-2:])}: {daq_frames.shape[0]} frames, "
f"{'/'.join(photometry_file.parts[-2:])}: {df_photometry.shape[0]}")
[docs]
class FibrePhotometry(BaseExtractor):
"""
FibrePhotometry(self.session_path, collection=self.collection)
"""
save_names = ('photometry.signal.pqt')
var_names = ('df_out')
def __init__(self, *args, collection='raw_photometry_data', **kwargs):
"""An extractor for all Neurophotometrics fibrephotometry data"""
self.collection = collection
super().__init__(*args, **kwargs)
@staticmethod
def _channel_meta(light_source_map=None):
"""
Return table of light source wavelengths and corresponding colour labels.
Parameters
----------
light_source_map : dict
An optional map of light source wavelengths (nm) used and their corresponding colour name.
Returns
-------
pandas.DataFrame
A sorted table of wavelength and colour name.
"""
light_source_map = light_source_map or LIGHT_SOURCE_MAP
meta = pd.DataFrame.from_dict(light_source_map)
meta.index.rename('channel_id', inplace=True)
return meta
def _extract(self, light_source_map=None, collection=None, regions=None, **kwargs):
"""
Parameters
----------
regions: list of str
The list of regions to extract. If None extracts all columns containing "Region". Defaults to None.
light_source_map : dict
An optional map of light source wavelengths (nm) used and their corresponding colour name.
collection: str / pathlib.Path
An optional relative path from the session root folder to find the raw photometry data.
Defaults to `raw_photometry_data`
Returns
-------
numpy.ndarray
A 1D array of signal values.
numpy.ndarray
A 1D array of ints corresponding to the active light source during a given frame.
pandas.DataFrame
A table of intensity for each region, with associated times, wavelengths, names and colors
"""
collection = collection or self.collection
fp_data = alfio.load_object(self.session_path / collection, 'fpData')
ts = self.extract_timestamps(fp_data['raw'], **kwargs)
# Load channels and
channel_meta_map = self._channel_meta(kwargs.get('light_source_map'))
led_states = fp_data.get('channels', pd.DataFrame(NEUROPHOTOMETRICS_LED_STATES))
led_states = led_states.set_index('Condition')
# Extract signal columns into 2D array
regions = regions or [k for k in fp_data['raw'].keys() if 'Region' in k]
out_df = fp_data['raw'].filter(items=regions, axis=1).sort_index(axis=1)
out_df['times'] = ts
out_df['wavelength'] = np.NaN
out_df['name'] = ''
out_df['color'] = ''
# Extract channel index
states = fp_data['raw'].get('LedState', fp_data['raw'].get('Flags', None))
for state in states.unique():
ir, ic = np.where(led_states == state)
if ic.size == 0:
continue
for cn in ['name', 'color', 'wavelength']:
out_df.loc[states == state, cn] = channel_meta_map.iloc[ic[0]][cn]
return out_df