"""
Quality control of raw Neuropixel electrophysiology data.
"""
from pathlib import Path
import logging
import numpy as np
import pandas as pd
from scipy import signal, stats
import one.alf.io as alfio
from iblutil.util import Bunch
import spikeglx
import neuropixel
from ibldsp import fourier, utils, voltage
from tqdm import tqdm
from brainbox.io.spikeglx import Streamer
from brainbox.metrics.single_units import spike_sorting_metrics
from ibllib.ephys import sync_probes, spikes
from ibllib.qc import base
from ibllib.io.extractors import ephys_fpga
from phylib.io import model
_logger = logging.getLogger(__name__)
RMS_WIN_LENGTH_SECS = 3
WELCH_WIN_LENGTH_SAMPLES = 1024
NCH_WAVEFORMS = 32 # number of channels to be saved in templates.waveforms and channels.waveforms
BATCHES_SPACING = 300
TMIN = 40
SAMPLE_LENGTH = 1
SPIKE_THRESHOLD_UV = -50 # negative, the threshold used for spike detection on pre-processed raw data
[docs]
class EphysQC(base.QC):
"""
A class for computing Ephys QC metrics.
:param probe_id: An existing and registered probe insertion ID.
:param one: An ONE instance pointing to the database the probe_id is registered with. Optional, will instantiate
default database if not given.
"""
def __init__(self, probe_id, session_path=None, **kwargs):
self.use_alyx = kwargs.pop('use_alyx', True)
self.stream = kwargs.pop('stream', True)
if self.use_alyx:
super().__init__(probe_id, endpoint='insertions', **kwargs)
self._outcome = 'NOT_SET'
self.pid = probe_id
self.session_path = session_path
keys = ('ap', 'ap_meta', 'lf', 'lf_meta')
self.data = Bunch.fromkeys(keys)
self.metrics = {}
def _ensure_required_data(self):
"""
Ensures the datasets required for QC are available locally or remotely.
"""
assert self.one is not None, 'ONE instance is required to ensure required data'
eid, pname = self.one.pid2eid(self.pid)
if self.session_path is None:
self.session_path = self.one.eid2path(eid)
self.probe_path = Path(self.session_path).joinpath('raw_ephys_data', pname)
# Check if there is at least one meta file available
meta_files = list(self.probe_path.rglob('*.meta'))
assert len(meta_files) != 0, f'No meta files in {self.probe_path}'
# Check if there is no more than one meta file per type
ap_meta = [meta for meta in meta_files if 'ap.meta' in meta.name]
assert not len(ap_meta) > 1, f'More than one ap.meta file in {self.probe_path}. Remove redundant files to run QC'
lf_meta = [meta for meta in meta_files if 'lf.meta' in meta.name]
assert not len(lf_meta) > 1, f'More than one lf.meta file in {self.probe_path}. Remove redundant files to run QC'
[docs]
def load_data(self, ensure=True) -> None:
"""
Load any locally available data.
"""
# First sanity check
if self.use_alyx:
self._ensure_required_data()
_logger.info('Gathering data for QC')
# Load metadata and, if locally present, bin file
for dstype in ['ap', 'lf']:
# We already checked that there is not more than one meta file per type
meta_file = next(self.probe_path.rglob(f'*{dstype}.meta'), None)
if meta_file is None:
_logger.warning(f'No {dstype}.meta file in {self.probe_path}, skipping QC for {dstype} data.')
else:
self.data[f'{dstype}_meta'] = spikeglx.read_meta_data(meta_file)
bin_file = next(meta_file.parent.glob(f'*{dstype}.*bin'), None)
if not bin_file:
# we only stream the AP file, we won't stream the full LF file...
if dstype == 'ap':
self.data[f'{dstype}'] = Streamer(pid=self.pid, one=self.one, remove_cached=True)
else:
self.data[f'{dstype}'] = None
else:
self.data[f'{dstype}'] = spikeglx.Reader(bin_file, open=True)
@staticmethod
def _compute_metrics_array(raw, fs, h):
"""
From a numpy array, computes rms on raw data, destripes, computes rms on destriped data
and performs a simple spike detection
:param raw: voltage numpy.array(ntraces, nsamples)
:param fs: sampling frequency (Hz)
:param h: dictionary containing sensor coordinates, see neuropixel.trace_header
:return: 3 numpy vectors nchannels length
"""
destripe = voltage.destripe(raw, fs=fs, h=h)
rms_raw = utils.rms(raw)
rms_pre_proc = utils.rms(destripe)
detections = spikes.detection(data=destripe.T, fs=fs, h=h, detect_threshold=SPIKE_THRESHOLD_UV * 1e-6)
spike_rate = np.bincount(detections.trace, minlength=raw.shape[0]).astype(np.float32)
channel_labels, _ = voltage.detect_bad_channels(raw, fs=fs)
_, psd = signal.welch(destripe, fs=fs, window='hann', nperseg=WELCH_WIN_LENGTH_SAMPLES,
detrend='constant', return_onesided=True, scaling='density', axis=-1)
return rms_raw, rms_pre_proc, spike_rate, channel_labels, psd
[docs]
def run(self, update: bool = False, overwrite: bool = True, stream: bool = None, **kwargs) -> (str, dict):
"""
Run QC on samples of the .ap file, and on the entire file for .lf data if it is present.
:param update: bool, whether to update the qc json fields for this probe. Default is False.
:param overwrite: bool, whether to overwrite locally existing outputs of this function. Default is False.
:param stream: bool, whether to stream the samples of the .ap data if not locally available. Defaults to value
set in class init (True if none set).
:return: A list of QC output files. In case of a complete run that is one file for .ap and three files for .lf.
"""
# If stream is explicitly given in run, overwrite value from init
if stream is not None:
self.stream = stream
# Load data
self.load_data()
self.out_path = kwargs.get('out_path', self.probe_path)
qc_files = []
# If ap meta file present, calculate median RMS per channel before and after destriping
# NB: ideally this should go a a separate function once we have a spikeglx.Streamer that behaves like the Reader
if self.data.ap_meta:
files = {'rms': self.out_path.joinpath("_iblqc_ephysChannels.apRMS.npy"),
'spike_rate': self.out_path.joinpath("_iblqc_ephysChannels.rawSpikeRates.npy"),
'channel_labels': self.out_path.joinpath("_iblqc_ephysChannels.labels.npy"),
'ap_freqs': self.out_path.joinpath("_iblqc_ephysSpectralDensityAP.freqs.npy"),
'ap_power': self.out_path.joinpath("_iblqc_ephysSpectralDensityAP.power.npy"),
}
if all([files[k].exists() for k in files]) and not overwrite:
_logger.warning(f'RMS map already exists for .ap data in {self.probe_path}, skipping. '
f'Use overwrite option.')
results = {k: np.load(files[k]) for k in files}
else:
sr = self.data['ap']
nc = sr.nc - sr.nsync
# verify that the channel layout is correct according to IBL layout
th = sr.geometry
if sr.meta.get('NP2.4_shank', None) is not None:
h = neuropixel.trace_header(sr.major_version, nshank=4)
h = neuropixel.split_trace_header(h, shank=int(sr.meta.get('NP2.4_shank')))
else:
h = neuropixel.trace_header(sr.major_version, nshank=np.unique(th['shank']).size)
if not (np.all(h['x'] == th['x']) and np.all(h['y'] == th['y'])):
_logger.critical("Channel geometry seems incorrect")
# raise ValueError("Wrong Neuropixel channel mapping used - ABORT")
t0s = np.arange(TMIN, sr.rl - SAMPLE_LENGTH, BATCHES_SPACING)
all_rms = np.zeros((2, nc, t0s.shape[0]))
all_srs, channel_ok = (np.zeros((nc, t0s.shape[0])) for _ in range(2))
psds = np.zeros((nc, fourier.fscale(WELCH_WIN_LENGTH_SAMPLES, 1, one_sided=True).size))
_logger.info(f'Computing RMS samples for .ap data {self.probe_path}')
for i, t0 in enumerate(t0s):
sl = slice(int(t0 * sr.fs), int((t0 + SAMPLE_LENGTH) * sr.fs))
raw = sr[sl, :-sr.nsync].T
all_rms[0, :, i], all_rms[1, :, i], all_srs[:, i], channel_ok[:, i], psd =\
self._compute_metrics_array(raw, sr.fs, h)
psds += psd
# Calculate the median RMS across all samples per channel
results = {'rms': np.median(all_rms, axis=-1),
'spike_rate': np.median(all_srs, axis=-1),
'channel_labels': stats.mode(channel_ok, axis=1)[0],
'ap_freqs': fourier.fscale(WELCH_WIN_LENGTH_SAMPLES, 1 / sr.fs, one_sided=True),
'ap_power': psds.T / len(t0s), # shape: (nfreqs, nchannels)
}
for k in files:
np.save(files[k], results[k])
qc_files.extend([files[k] for k in files])
for p in [10, 90]:
self.metrics[f'apRms_p{p}_raw'] = np.format_float_scientific(
np.percentile(results['rms'][0, :], p), precision=2)
self.metrics[f'apRms_p{p}_proc'] = np.format_float_scientific(
np.percentile(results['rms'][1, :], p), precision=2)
if update:
self.update_extended_qc(self.metrics)
# If lf meta and bin file present, run the old qc on LF data
if self.data.lf_meta and self.data.lf:
qc_files.extend(extract_rmsmap(self.data.lf, out_folder=self.out_path, overwrite=overwrite))
return qc_files
[docs]
def rmsmap(sglx, spectra=True, nmod=1):
"""
Computes RMS map in time domain and spectra for each channel of Neuropixel probe
:param sglx: Open spikeglx reader
:param spectra: Whether to compute the spectra
:param nmod: take every nmod windows, in cases where we don't want to compute over the whole signal
:return: a dictionary with amplitudes in channeltime space, channelfrequency space, time
and frequency scales
"""
rms_win_length_samples = 2 ** np.ceil(np.log2(sglx.fs * RMS_WIN_LENGTH_SECS))
# the window generator will generates window indices
wingen = utils.WindowGenerator(ns=sglx.ns, nswin=rms_win_length_samples, overlap=0)
nwin = np.ceil(wingen.nwin / nmod).astype(int)
# pre-allocate output dictionary of numpy arrays
win = {'TRMS': np.zeros((nwin, sglx.nc)),
'nsamples': np.zeros((nwin,)),
'fscale': fourier.fscale(WELCH_WIN_LENGTH_SAMPLES, 1 / sglx.fs, one_sided=True),
'tscale': wingen.tscale(fs=sglx.fs)[::nmod]}
win['spectral_density'] = np.zeros((len(win['fscale']), sglx.nc))
# loop through the whole session
with tqdm(total=wingen.nwin) as pbar:
for iwindow, (first, last) in enumerate(wingen.firstlast):
if np.mod(iwindow, nmod) != 0:
continue
D = sglx.read_samples(first_sample=first, last_sample=last)[0].transpose()
# remove low frequency noise below 1 Hz
D = fourier.hp(D, 1 / sglx.fs, [0, 1])
iw = np.floor(wingen.iw / nmod).astype(int)
win['TRMS'][iw, :] = utils.rms(D)
win['nsamples'][iw] = D.shape[1]
if spectra:
# the last window may be smaller than what is needed for welch
if last - first < WELCH_WIN_LENGTH_SAMPLES:
continue
# compute a smoothed spectrum using welch method
_, w = signal.welch(
D, fs=sglx.fs, window='hann', nperseg=WELCH_WIN_LENGTH_SAMPLES,
detrend='constant', return_onesided=True, scaling='density', axis=-1
)
win['spectral_density'] += w.T
# print at least every 20 windows
if (iw % min(20, max(int(np.floor(wingen.nwin / 75)), 1))) == 0:
pbar.update(iw)
sglx.close()
return win
[docs]
def raw_qc_session(session_path, overwrite=False):
"""
Wrapper that exectutes QC from a session folder and outputs the results whithin the same folder
as the original raw data.
:param session_path: path of the session (Subject/yyyy-mm-dd/number
:param overwrite: bool (False) Force means overwriting an existing QC file
:return: None
"""
efiles = spikeglx.glob_ephys_files(session_path)
qc_files = []
for efile in efiles:
if efile.get('ap') and efile.ap.exists():
qc_files.extend(extract_rmsmap(efile.ap, out_folder=None, overwrite=overwrite))
if efile.get('lf') and efile.lf.exists():
qc_files.extend(extract_rmsmap(efile.lf, out_folder=None, overwrite=overwrite))
return qc_files
[docs]
def validate_ttl_test(ses_path, display=False):
"""
For a mock session on the Ephys Choice world task, check the sync channels for all
device properly connected and perform a synchronization if dual probes to check that
all channels are recorded properly
:param ses_path: session path
:param display: show the probe synchronization plot if several probes
:return: True if tests pass, errors otherwise
"""
def _single_test(assertion, str_ok, str_ko):
if assertion:
_logger.info(str_ok)
return True
else:
_logger.error(str_ko)
return False
EXPECTED_RATES_HZ = {'left_camera': 60, 'right_camera': 150, 'body_camera': 30}
SYNC_RATE_HZ = 1
MIN_TRIALS_NB = 6
ok = True
ses_path = Path(ses_path)
if not ses_path.exists():
return False
# get the synchronization fronts (from the raw binary if necessary)
ephys_fpga.extract_sync(session_path=ses_path, overwrite=False)
rawsync, sync_map = ephys_fpga.get_main_probe_sync(ses_path)
last_time = rawsync['times'][-1]
# get upgoing fronts for each
sync = Bunch({})
for k in sync_map:
fronts = ephys_fpga.get_sync_fronts(rawsync, sync_map[k])
sync[k] = fronts['times'][fronts['polarities'] == 1]
wheel = ephys_fpga.extract_wheel_sync(rawsync, chmap=sync_map)
frame_rates = {'right_camera': np.round(1 / np.median(np.diff(sync.right_camera))),
'left_camera': np.round(1 / np.median(np.diff(sync.left_camera))),
'body_camera': np.round(1 / np.median(np.diff(sync.body_camera)))}
# check the camera frame rates
for lab in frame_rates:
expect = EXPECTED_RATES_HZ[lab]
ok &= _single_test(assertion=abs((1 - frame_rates[lab] / expect)) < 0.1,
str_ok=f'PASS: {lab} frame rate: {frame_rates[lab]} = {expect} Hz',
str_ko=f'FAILED: {lab} frame rate: {frame_rates[lab]} != {expect} Hz')
# check that the wheel has a minimum rate of activity on both channels
re_test = abs(1 - sync.rotary_encoder_1.size / sync.rotary_encoder_0.size) < 0.1
re_test &= len(wheel[1]) / last_time > 5
ok &= _single_test(assertion=re_test,
str_ok="PASS: Rotary encoder", str_ko="FAILED: Rotary encoder")
# check that the frame 2 ttls has a minimum rate of activity
ok &= _single_test(assertion=len(sync.frame2ttl) / last_time > 0.2,
str_ok="PASS: Frame2TTL", str_ko="FAILED: Frame2TTL")
# the audio has to have at least one event per trial
ok &= _single_test(assertion=len(sync.bpod) > len(sync.audio) > MIN_TRIALS_NB,
str_ok="PASS: audio", str_ko="FAILED: audio")
# the bpod has to have at least twice the amount of min trial pulses
ok &= _single_test(assertion=len(sync.bpod) > MIN_TRIALS_NB * 2,
str_ok="PASS: Bpod", str_ko="FAILED: Bpod")
try:
# note: tried to depend as little as possible on the extraction code but for the valve...
extractor = ephys_fpga.FpgaTrials(ses_path)
_, bpod_intervals = extractor.get_bpod_event_times(rawsync, sync_map)
t_valve_open = bpod_intervals['valve_open'][:, 0]
res = t_valve_open.size > 1
except AssertionError:
res = False
# check that the reward valve is actionned at least once
ok &= _single_test(assertion=res,
str_ok="PASS: Valve open", str_ko="FAILED: Valve open not detected")
_logger.info('ALL CHECKS PASSED !')
# the imec sync is for 3B Probes only
if sync.get('imec_sync') is not None:
ok &= _single_test(assertion=np.all(1 - SYNC_RATE_HZ * np.diff(sync.imec_sync) < 0.1),
str_ok="PASS: imec sync", str_ko="FAILED: imec sync")
# second step is to test that we can make the sync. Assertions are whithin the synch code
if sync.get('imec_sync') is not None:
sync_result, _ = sync_probes.version3B(ses_path, display=display)
else:
sync_result, _ = sync_probes.version3A(ses_path, display=display)
ok &= _single_test(assertion=sync_result, str_ok="PASS: synchronisation",
str_ko="FAILED: probe synchronizations threshold exceeded")
if not ok:
raise ValueError('FAILED TTL test')
return ok
[docs]
def spike_sorting_metrics_ks2(ks2_path=None, m=None, save=True, save_path=None):
"""
Given a path containing kilosort 2 output, compute quality metrics and optionally save them
to a clusters_metric.csv file
:param ks2_path:
:param save
:param save_path: If not given will save into the path given as ks2_path
:return:
"""
save_path = save_path or ks2_path
# ensure that either a ks2_path or a phylib `TemplateModel` object with unit info is given
assert not (ks2_path is None and m is None), 'Must either specify a path to a ks2 output ' \
'directory, or a phylib `TemplateModel` object'
# create phylib `TemplateModel` if not given
m = phy_model_from_ks2_path(ks2_path) if None else m
c, drift = spike_sorting_metrics(m.spike_times, m.spike_clusters, m.amplitudes, m.depths,
cluster_ids=np.arange(m.clusters_channels.size))
# include the ks2 cluster contamination if `cluster_ContamPct` file exists
file_contamination = ks2_path.joinpath('cluster_ContamPct.tsv')
if file_contamination.exists():
contam = pd.read_csv(file_contamination, sep='\t')
contam.rename(columns={'ContamPct': 'ks2_contamination_pct'}, inplace=True)
c = c.set_index('cluster_id', drop=False).join(contam.set_index('cluster_id'))
# include the ks2 cluster labels if `cluster_KSLabel` file exists
file_labels = ks2_path.joinpath('cluster_KSLabel.tsv')
if file_labels.exists():
ks2_labels = pd.read_csv(file_labels, sep='\t')
ks2_labels.rename(columns={'KSLabel': 'ks2_label'}, inplace=True)
c = c.set_index('cluster_id', drop=False).join(ks2_labels.set_index('cluster_id'))
if save:
Path(save_path).mkdir(exist_ok=True, parents=True)
# the file name contains the label of the probe (directory name in this case)
c.to_csv(Path(save_path).joinpath('cluster_metrics.csv'))
return c
[docs]
def phy_model_from_ks2_path(ks2_path, bin_path, bin_file=None):
if not bin_file:
bin_file = next(bin_path.rglob('*.ap.*bin'), None)
meta_file = next(bin_path.rglob('*.ap.meta'), None)
if meta_file and meta_file.exists():
meta = spikeglx.read_meta_data(meta_file)
fs = spikeglx._get_fs_from_meta(meta)
nch = (spikeglx._get_nchannels_from_meta(meta) -
len(spikeglx._get_sync_trace_indices_from_meta(meta)))
else:
fs = 30000
nch = 384
m = model.TemplateModel(dir_path=ks2_path,
dat_path=bin_file, # this assumes the raw data is in the same folder
sample_rate=fs,
n_channels_dat=nch,
n_closest_channels=NCH_WAVEFORMS)
m.depths = m.get_depths()
return m
# Make a bunch gathering all trial QC
[docs]
def qc_fpga_task(fpga_trials, alf_trials):
"""
:fpga_task is the dictionary output of
ibllib.io.extractors.ephys_fpga.extract_behaviour_sync
: bpod_trials is the dictionary output of ibllib.io.extractors.ephys_trials.extract_all
: alf_trials is the ALF _ibl_trials object after extraction (alfio.load_object)
:return: qc_session, qc_trials, True means QC passes while False indicates a failure
"""
GOCUE_STIMON_DELAY = 0.01 # -> 0.1
FEEDBACK_STIMFREEZE_DELAY = 0.01 # -> 0.1
VALVE_STIM_OFF_DELAY = 1
VALVE_STIM_OFF_JITTER = 0.1
ITI_IN_STIM_OFF_JITTER = 0.1
ERROR_STIM_OFF_DELAY = 2
ERROR_STIM_OFF_JITTER = 0.1
RESPONSE_FEEDBACK_DELAY = 0.0005
def strictly_after(t0, t1, threshold):
""" returns isafter, iswithinthreshold"""
return (t1 - t0) > 0, np.abs((t1 - t0)) <= threshold
ntrials = fpga_trials['stimOn_times'].size
qc_trials = Bunch({})
"""
First Check consistency of the dataset: whithin each trial, all events happen after trial
start should not be NaNs and increasing. This is not a QC but an assertion.
"""
status = True
for k in ['response_times', 'stimOn_times', 'response_times',
'goCueTrigger_times', 'goCue_times', 'feedback_times']:
if k.endswith('_bpod'):
tstart = alf_trials['intervals_bpod'][:, 0]
else:
tstart = alf_trials['intervals'][:, 0]
selection = ~np.isnan(alf_trials[k])
status &= np.all(alf_trials[k][selection] - tstart[selection] > 0)
status &= np.all(np.diff(alf_trials[k][selection]) > 0)
assert status
"""
This part of the function uses only fpga_trials information
"""
# check number of feedbacks: should always be one
qc_trials['n_feedback'] = (np.uint32(~np.isnan(fpga_trials['valveOpen_times'])) +
np.uint32(~np.isnan(fpga_trials['errorCue_times'])))
# check for non-Nans
qc_trials['stimOn_times_nan'] = ~np.isnan(fpga_trials['stimOn_times'])
qc_trials['goCue_times_nan'] = ~np.isnan(fpga_trials['goCue_times'])
# stimOn before goCue
qc_trials['stimOn_times_before_goCue_times'], qc_trials['stimOn_times_goCue_times_delay'] =\
strictly_after(fpga_trials['stimOn_times'], fpga_trials['goCue_times'], GOCUE_STIMON_DELAY)
# stimFreeze before feedback
qc_trials['stim_freeze_before_feedback'], qc_trials['stim_freeze_feedback_delay'] = \
strictly_after(fpga_trials['stimFreeze_times'], fpga_trials['feedback_times'],
FEEDBACK_STIMFREEZE_DELAY)
# stimOff 1 sec after valve, with 0.1 as acceptable jitter
qc_trials['stimOff_delay_valve'] = np.less(
np.abs(
fpga_trials['stimOff_times'] - fpga_trials['valveOpen_times'] - VALVE_STIM_OFF_DELAY
),
VALVE_STIM_OFF_JITTER, out=np.ones(ntrials, dtype=bool),
where=~np.isnan(fpga_trials['valveOpen_times']))
# iti_in whithin 0.01 sec of stimOff
qc_trials['iti_in_delay_stim_off'] = \
np.abs(fpga_trials['stimOff_times'] - fpga_trials['itiIn_times']) < ITI_IN_STIM_OFF_JITTER
# stimOff 2 secs after errorCue_times with jitter
# noise off happens 2 secs after stimm, with 0.1 as acceptable jitter
qc_trials['stimOff_delay_noise'] = np.less(
np.abs(
fpga_trials['stimOff_times'] - fpga_trials['errorCue_times'] - ERROR_STIM_OFF_DELAY
),
ERROR_STIM_OFF_JITTER, out=np.ones(ntrials, dtype=bool),
where=~np.isnan(fpga_trials['errorCue_times']))
"""
This part uses only alf_trials information
"""
# TEST Response times (from session start) should be increasing continuously
# Note: RT are not durations but time stamps from session start
# 1. check for non-Nans
qc_trials['response_times_nan'] = ~np.isnan(alf_trials['response_times'])
# 2. check for positive increase
qc_trials['response_times_increase'] = \
np.diff(np.append([0], alf_trials['response_times'])) > 0
# TEST Response times (from goCue) should be positive
qc_trials['response_times_goCue_times_diff'] = \
alf_trials['response_times'] - alf_trials['goCue_times'] > 0
# TEST 1. Response_times should be before feedback
qc_trials['response_before_feedback'] = \
alf_trials['feedback_times'] - alf_trials['response_times'] > 0
# 2. Delay between wheel reaches threshold (response time) and
# feedback is 100us, acceptable jitter 500 us
qc_trials['response_feedback_delay'] = \
alf_trials['feedback_times'] - alf_trials['response_times'] < RESPONSE_FEEDBACK_DELAY
# Test output at session level
qc_session = {k: np.all(qc_trials[k]) for k in qc_trials}
return qc_session, qc_trials