Source code for ibllib.dsp.voltage

"""
Module to work with raw voltage traces. Spike sorting pre-processing functions.
"""
from pathlib import Path

import numpy as np
import scipy.signal
import scipy.stats
from joblib import Parallel, delayed, cpu_count

from ibllib.io import spikeglx
import ibllib.dsp.fourier as fdsp
from ibllib.dsp import fshift, rms
from ibllib.ephys import neuropixel


[docs]def agc(x, wl=.5, si=.002, epsilon=1e-8): """ Automatic gain control w_agc, gain = agc(w, wl=.5, si=.002, epsilon=1e-8) such as w_agc / gain = w :param x: seismic array (sample last dimension) :param wl: window length (secs) :param si: sampling interval (secs) :param epsilon: whitening (useful mainly for synthetic data) :return: AGC data array, gain applied to data """ ns_win = np.round(wl / si / 2) * 2 + 1 w = np.hanning(ns_win) w /= np.sum(w) gain = fdsp.convolve(np.abs(x), w, mode='same') gain += (np.sum(gain, axis=1) * epsilon / x.shape[-1])[:, np.newaxis] gain = 1 / gain return x * gain, gain
[docs]def fk(x, si=.002, dx=1, vbounds=None, btype='highpass', ntr_pad=0, ntr_tap=None, lagc=.5, collection=None, kfilt=None): """Frequency-wavenumber filter: filters apparent plane-waves velocity :param x: the input array to be filtered. dimension, the filtering is considering axis=0: spatial dimension, axis=1 temporal dimension. (ntraces, ns) :param si: sampling interval (secs) :param dx: spatial interval (usually meters) :param vbounds: velocity high pass [v1, v2], cosine taper from 0 to 1 between v1 and v2 :param btype: {‘lowpass’, ‘highpass’}, velocity filter : defaults to highpass :param ntr_pad: padding will add ntr_padd mirrored traces to each side :param ntr_tap: taper (if None, set to ntr_pad) :param lagc: length of agc in seconds. If set to None or 0, no agc :param kfilt: optional (None) if kfilter is applied, parameters as dict (bounds are in m-1 according to the dx parameter) kfilt = {'bounds': [0.05, 0.1], 'btype', 'highpass'} :param collection: vector length ntraces. Each unique value set of traces is a collection on which the FK filter will run separately (shot gaters, receiver gathers) :return: """ if collection is not None: xout = np.zeros_like(x) for c in np.unique(collection): sel = collection == c xout[sel, :] = fk(x[sel, :], si=si, dx=dx, vbounds=vbounds, ntr_pad=ntr_pad, ntr_tap=ntr_tap, lagc=lagc, collection=None) return xout assert vbounds nx, nt = x.shape # lateral padding left and right ntr_pad = int(ntr_pad) ntr_tap = ntr_pad if ntr_tap is None else ntr_tap nxp = nx + ntr_pad * 2 # compute frequency wavenumber scales and deduce the velocity filter fscale = fdsp.fscale(nt, si) kscale = fdsp.fscale(nxp, dx) kscale[0] = 1e-6 v = fscale[np.newaxis, :] / kscale[:, np.newaxis] if btype.lower() in ['highpass', 'hp']: fk_att = fdsp.fcn_cosine(vbounds)(np.abs(v)) elif btype.lower() in ['lowpass', 'lp']: fk_att = (1 - fdsp.fcn_cosine(vbounds)(np.abs(v))) # if a k-filter is also provided, apply it if kfilt is not None: katt = fdsp._freq_vector(np.abs(kscale), kfilt['bounds'], typ=kfilt['btype']) fk_att *= katt[:, np.newaxis] # import matplotlib.pyplot as plt # plt.imshow(np.fft.fftshift(np.abs(v), axes=0).T, aspect='auto', vmin=0, vmax=1e5, # extent=[np.min(kscale), np.max(kscale), 0, np.max(fscale) * 2]) # plt.imshow(np.fft.fftshift(np.abs(fk_att), axes=0).T, aspect='auto', vmin=0, vmax=1, # extent=[np.min(kscale), np.max(kscale), 0, np.max(fscale) * 2]) # apply the attenuation in fk-domain if not lagc: xf = np.copy(x) gain = 1 else: xf, gain = agc(x, wl=lagc, si=si) if ntr_pad > 0: # pad the array with a mirrored version of itself and apply a cosine taper xf = np.r_[np.flipud(xf[:ntr_pad]), xf, np.flipud(xf[-ntr_pad:])] if ntr_tap > 0: taper = fdsp.fcn_cosine([0, ntr_tap])(np.arange(nxp)) # taper up taper *= 1 - fdsp.fcn_cosine([nxp - ntr_tap, nxp])(np.arange(nxp)) # taper down xf = xf * taper[:, np.newaxis] xf = np.real(np.fft.ifft2(fk_att * np.fft.fft2(xf))) if ntr_pad > 0: xf = xf[ntr_pad:-ntr_pad, :] return xf / gain
[docs]def car(x, collection=None, lagc=300, butter_kwargs=None, **kwargs): """ Applies common average referencing with optional automatic gain control :param x: the input array to be filtered. dimension, the filtering is considering axis=0: spatial dimension, axis=1 temporal dimension. (ntraces, ns) :param collection: :param lagc: window size for time domain automatic gain control (no agc otherwise) :param butter_kwargs: filtering parameters: defaults: {'N': 3, 'Wn': 0.1, 'btype': 'highpass'} :return: """ if butter_kwargs is None: butter_kwargs = {'N': 3, 'Wn': 0.1, 'btype': 'highpass'} if collection is not None: xout = np.zeros_like(x) for c in np.unique(collection): sel = collection == c xout[sel, :] = kfilt(x=x[sel, :], ntr_pad=0, ntr_tap=None, collection=None, butter_kwargs=butter_kwargs) return xout # apply agc and keep the gain in handy if not lagc: xf = np.copy(x) gain = 1 else: xf, gain = agc(x, wl=lagc, si=1.0) # apply CAR and then un-apply the gain xf = xf - np.median(xf, axis=0) return xf / gain
[docs]def kfilt(x, collection=None, ntr_pad=0, ntr_tap=None, lagc=300, butter_kwargs=None): """ Applies a butterworth filter on the 0-axis with tapering / padding :param x: the input array to be filtered. dimension, the filtering is considering axis=0: spatial dimension, axis=1 temporal dimension. (ntraces, ns) :param collection: :param ntr_pad: traces added to each side (mirrored) :param ntr_tap: n traces for apodizatin on each side :param lagc: window size for time domain automatic gain control (no agc otherwise) :param butter_kwargs: filtering parameters: defaults: {'N': 3, 'Wn': 0.1, 'btype': 'highpass'} :return: """ if butter_kwargs is None: butter_kwargs = {'N': 3, 'Wn': 0.1, 'btype': 'highpass'} if collection is not None: xout = np.zeros_like(x) for c in np.unique(collection): sel = collection == c xout[sel, :] = kfilt(x=x[sel, :], ntr_pad=0, ntr_tap=None, collection=None, butter_kwargs=butter_kwargs) return xout nx, nt = x.shape # lateral padding left and right ntr_pad = int(ntr_pad) ntr_tap = ntr_pad if ntr_tap is None else ntr_tap nxp = nx + ntr_pad * 2 # apply agc and keep the gain in handy if not lagc: xf = np.copy(x) gain = 1 else: xf, gain = agc(x, wl=lagc, si=1.0) if ntr_pad > 0: # pad the array with a mirrored version of itself and apply a cosine taper xf = np.r_[np.flipud(xf[:ntr_pad]), xf, np.flipud(xf[-ntr_pad:])] if ntr_tap > 0: taper = fdsp.fcn_cosine([0, ntr_tap])(np.arange(nxp)) # taper up taper *= 1 - fdsp.fcn_cosine([nxp - ntr_tap, nxp])(np.arange(nxp)) # taper down xf = xf * taper[:, np.newaxis] sos = scipy.signal.butter(**butter_kwargs, output='sos') xf = scipy.signal.sosfiltfilt(sos, xf, axis=0) if ntr_pad > 0: xf = xf[ntr_pad:-ntr_pad, :] return xf / gain
[docs]def interpolate_bad_channels(data, channel_labels=None, h=None, p=1.3, kriging_distance_um=20): """ Interpolate the channel labeled as bad channels using linear interpolation. The weights applied to neighbouring channels come from an exponential decay function :param data: (nc, ns) np.ndarray :param channel_labels; (nc) np.ndarray: 0: channel is good, 1: dead, 2:noisy, 3: out of the brain :param h: dict with fields 'x' and 'y', np.ndarrays :param p: :param kriging_distance_um: :return: """ # from ibllib.plots.figures import ephys_bad_channels # ephys_bad_channels(x, 30000, channel_labels[0], channel_labels[1]) x = h['x'] y = h['y'] # we interpolate only noisy channels or dead channels (0: good), out of the brain channels are left bad_channels = np.where(np.logical_or(channel_labels == 1, channel_labels == 2))[0] for i in bad_channels: # compute the weights to apply to neighbouring traces offset = np.abs(x - x[i] + 1j * (y - y[i])) weights = np.exp(-(offset / kriging_distance_um) ** p) weights[bad_channels] = 0 weights[weights < 0.005] = 0 weights = weights / np.sum(weights) imult = np.where(weights > 0.005)[0] if imult.size == 0: data[i, :] = 0 continue data[i, :] = np.matmul(weights[imult], data[imult, :]) # from viewephys.gui import viewephys # f = viewephys(data.T, fs=1/30, h=h, title='interp2') return data
def _get_destripe_parameters(fs, butter_kwargs, k_kwargs, k_filter): """gets the default params for destripe. This is used for both the destripe fcn on a numpy array and the function that actuates on a cbin file""" if butter_kwargs is None: butter_kwargs = {'N': 3, 'Wn': 300 / fs * 2, 'btype': 'highpass'} if k_kwargs is None: k_kwargs = {'ntr_pad': 60, 'ntr_tap': 0, 'lagc': 3000, 'butter_kwargs': {'N': 3, 'Wn': 0.01, 'btype': 'highpass'}} if k_filter: spatial_fcn = lambda dat: kfilt(dat, **k_kwargs) # noqa else: spatial_fcn = lambda dat: car(dat, **k_kwargs) # noqa return butter_kwargs, k_kwargs, spatial_fcn
[docs]def destripe(x, fs, neuropixel_version=1, butter_kwargs=None, k_kwargs=None, channel_labels=None, k_filter=True): """Super Car (super slow also...) - far from being set in stone but a good workflow example :param x: demultiplexed array (nc, ns) :param fs: sampling frequency :param neuropixel_version (optional): 1 or 2. Useful for the ADC shift correction. If None, no correction is applied :param channel_labels: None: (default) keep all channels OR (recommended to pre-compute) index array for the first axis of x indicating the selected traces. On a full workflow, one should scan sparingly the full file to get a robust estimate of the selection. If None, and estimation is done using only the current batch is provided for convenience but should be avoided in production. OR (only for quick display or as an example) True: deduces the bad channels from the data provided :param butter_kwargs: (optional, None) butterworth params, see the code for the defaults dict :param k_kwargs: (optional, None) K-filter params, see the code for the defaults dict can also be set to 'car', in which case the median accross channels will be subtracted :param k_filter (True): applies k-filter by default, otherwise, apply CAR. :return: x, filtered array """ butter_kwargs, k_kwargs, spatial_fcn = _get_destripe_parameters(fs, butter_kwargs, k_kwargs, k_filter) h = neuropixel.trace_header(version=neuropixel_version) if channel_labels is True: channel_labels, _ = detect_bad_channels(x, fs) # butterworth sos = scipy.signal.butter(**butter_kwargs, output='sos') x = scipy.signal.sosfiltfilt(sos, x) # channel interpolation # apply ADC shift if neuropixel_version is not None: x = fshift(x, h['sample_shift'], axis=1) # apply spatial filter only on channels that are inside of the brain if channel_labels is not None: x = interpolate_bad_channels(x, channel_labels, h) inside_brain = np.where(channel_labels != 3)[0] x[inside_brain, :] = spatial_fcn(x[inside_brain, :]) # apply the k-filter else: x = spatial_fcn(x) return x
[docs]def destripe_lfp(x, fs, channel_labels=None, **kwargs): """ Wrapper around the destipe function with some default parameters to destripe the LFP band See help destripe function for documentation :param x: :param fs: :return: """ kwargs['butter_kwargs'] = {'N': 3, 'Wn': 2 / fs * 2, 'btype': 'highpass'} kwargs['k_filter'] = False if channel_labels is True: kwargs['channel_labels'], _ = detect_bad_channels(x, fs=fs, psd_hf_threshold=1.4) return destripe(x, fs, **kwargs)
[docs]def decompress_destripe_cbin(sr_file, output_file=None, h=None, wrot=None, append=False, nc_out=None, butter_kwargs=None, dtype=np.int16, ns2add=0, nbatch=None, nprocesses=None, compute_rms=True, reject_channels=True, k_kwargs=None, k_filter=True, reader_kwargs=None): """ From a spikeglx Reader object, decompresses and apply ADC. Saves output as a flat binary file in int16 Production version with optimized FFTs - requires pyfftw :param sr: seismic reader object (spikeglx.Reader) :param output_file: (optional, defaults to .bin extension of the compressed bin file) :param h: (optional) neuropixel trace header. Dictionary with key 'sample_shift' :param wrot: (optional) whitening matrix [nc x nc] or amplitude scalar to apply to the output :param append: (optional, False) for chronic recordings, append to end of file :param nc_out: (optional, True) saves non selected channels (synchronisation trace) in output :param butterworth filter parameters: {'N': 3, 'Wn': 300 / sr.fs * 2, 'btype': 'highpass'} :param dtype: (optional, np.int16) output sample format :param ns2add: (optional) for kilosort, adds padding samples at the end of the file so the total number of samples is a multiple of the batchsize :param nbatch: (optional) batch size :param nprocesses: (optional) number of parallel processes to run, defaults to number or processes detected with joblib interp 3:outside of brain and discard :param reject_channels: (True) detects noisy or bad channels and interpolate them. Channels outside of the brain are left untouched :param k_kwargs: (None) arguments for the kfilter function :param reader_kwargs: (None) optional arguments for the spikeglx Reader instance :param k_filter: (True) Performs a k-filter - if False will do median common average referencing :return: """ import pyfftw SAMPLES_TAPER = 1024 NBATCH = nbatch or 65536 # handles input parameters reader_kwargs = {} if reader_kwargs is None else reader_kwargs sr = spikeglx.Reader(sr_file, open=True, **reader_kwargs) if reject_channels: # get bad channels if option is on channel_labels = detect_bad_channels_cbin(sr) assert isinstance(sr_file, str) or isinstance(sr_file, Path) butter_kwargs, k_kwargs, spatial_fcn = _get_destripe_parameters(sr.fs, butter_kwargs, k_kwargs, k_filter) h = sr.geometry if h is None else h ncv = h['sample_shift'].size # number of channels output_file = sr.file_bin.with_suffix('.bin') if output_file is None else Path(output_file) assert output_file != sr.file_bin taper = np.r_[0, scipy.signal.windows.cosine((SAMPLES_TAPER - 1) * 2), 0] # create the FFT stencils nc_out = nc_out or sr.nc # compute LP filter coefficients sos = scipy.signal.butter(**butter_kwargs, output='sos') nbytes = dtype(1).nbytes nprocesses = nprocesses or int(cpu_count() - cpu_count() / 4) win = pyfftw.empty_aligned((ncv, NBATCH), dtype='float32') WIN = pyfftw.empty_aligned((ncv, int(NBATCH / 2 + 1)), dtype='complex64') fft_object = pyfftw.FFTW(win, WIN, axes=(1,), direction='FFTW_FORWARD', threads=4) dephas = np.zeros((ncv, NBATCH), dtype=np.float32) dephas[:, 1] = 1. DEPHAS = np.exp(1j * np.angle(fft_object(dephas)) * h['sample_shift'][:, np.newaxis]) # if we want to compute the rms ap across the session if compute_rms: ap_rms_file = output_file.parent.joinpath('ap_rms.bin') ap_time_file = output_file.parent.joinpath('ap_time.bin') rms_nbytes = np.float32(1).nbytes if append: rms_offset = Path(ap_rms_file).stat().st_size time_offset = Path(ap_time_file).stat().st_size with open(ap_time_file, 'rb') as tid: t = tid.read() time_data = np.frombuffer(t, dtype=np.float32) t0 = time_data[-1] else: rms_offset = 0 time_offset = 0 t0 = 0 open(ap_rms_file, 'wb').close() open(ap_time_file, 'wb').close() if append: # need to find the end of the file and the offset offset = Path(output_file).stat().st_size else: offset = 0 open(output_file, 'wb').close() # chunks to split the file into, dependent on number of parallel processes CHUNK_SIZE = int(sr.ns / nprocesses) def my_function(i_chunk, n_chunk): _sr = spikeglx.Reader(sr_file, **reader_kwargs) n_batch = int(np.ceil(i_chunk * CHUNK_SIZE / NBATCH)) first_s = (NBATCH - SAMPLES_TAPER * 2) * n_batch # Find the maximum sample for each chunk max_s = _sr.ns if i_chunk == n_chunk - 1 else (i_chunk + 1) * CHUNK_SIZE # need to redefine this here to avoid 4 byte boundary error win = pyfftw.empty_aligned((ncv, NBATCH), dtype='float32') WIN = pyfftw.empty_aligned((ncv, int(NBATCH / 2 + 1)), dtype='complex64') fft_object = pyfftw.FFTW(win, WIN, axes=(1,), direction='FFTW_FORWARD', threads=4) ifft_object = pyfftw.FFTW(WIN, win, axes=(1,), direction='FFTW_BACKWARD', threads=4) fid = open(output_file, 'r+b') if i_chunk == 0: fid.seek(offset) else: fid.seek(offset + ((first_s + SAMPLES_TAPER) * nc_out * nbytes)) if compute_rms: aid = open(ap_rms_file, 'r+b') tid = open(ap_time_file, 'r+b') if i_chunk == 0: aid.seek(rms_offset) tid.seek(time_offset) else: aid.seek(rms_offset + (n_batch * ncv * rms_nbytes)) tid.seek(time_offset + (n_batch * rms_nbytes)) while True: last_s = np.minimum(NBATCH + first_s, _sr.ns) # Apply tapers chunk = _sr[first_s:last_s, :ncv].T chunk[:, :SAMPLES_TAPER] *= taper[:SAMPLES_TAPER] chunk[:, -SAMPLES_TAPER:] *= taper[SAMPLES_TAPER:] # Apply filters chunk = scipy.signal.sosfiltfilt(sos, chunk) # Find the indices to save ind2save = [SAMPLES_TAPER, NBATCH - SAMPLES_TAPER] if last_s == _sr.ns: # for the last batch just use the normal fft as the stencil doesn't fit chunk = fshift(chunk, s=h['sample_shift']) ind2save[1] = NBATCH else: # apply precomputed fshift of the proper length chunk = ifft_object(fft_object(chunk) * DEPHAS) if first_s == 0: # for the first batch save the start with taper applied ind2save[0] = 0 # interpolate missing traces after the low-cut filter it's important to leave the # channels outside of the brain outside of the computation if reject_channels: chunk = interpolate_bad_channels(chunk, channel_labels, h=h) inside_brain = np.where(channel_labels != 3)[0] chunk[inside_brain, :] = spatial_fcn(chunk[inside_brain, :]) # apply the k-filter / CAR else: chunk = spatial_fcn(chunk) # apply the k-filter / CAR # add back sync trace and save chunk = np.r_[chunk, _sr[first_s:last_s, ncv:].T].T # Compute rms - we get it before applying the whitening if compute_rms: ap_rms = rms(chunk[:, :ncv], axis=0) ap_t = t0 + (first_s + (last_s - first_s - 1) / 2) / _sr.fs ap_rms.astype(np.float32).tofile(aid) ap_t.astype(np.float32).tofile(tid) # convert to normalised intnorm = 1 / _sr.sample2volts chunk = chunk[slice(*ind2save), :] * intnorm # apply the whitening matrix if necessary if wrot is not None: chunk[:, :ncv] = np.dot(chunk[:, :ncv], wrot) chunk[:, :nc_out].astype(dtype).tofile(fid) first_s += NBATCH - SAMPLES_TAPER * 2 if last_s >= max_s: if last_s == _sr.ns: if ns2add > 0: np.tile(chunk[-1, :nc_out].astype(dtype), (ns2add, 1)).tofile(fid) fid.close() if compute_rms: aid.close() tid.close() break _ = Parallel(n_jobs=nprocesses)(delayed(my_function)(i, nprocesses) for i in range(nprocesses)) sr.close() # Here convert the ap_rms bin files to the ibl format and save if compute_rms: with open(ap_rms_file, 'rb') as aid, open(ap_time_file, 'rb') as tid: rms_data = aid.read() time_data = tid.read() time_data = np.frombuffer(time_data, dtype=np.float32) rms_data = np.frombuffer(rms_data, dtype=np.float32) assert(rms_data.shape[0] == time_data.shape[0] * ncv) rms_data = rms_data.reshape(time_data.shape[0], ncv) np.save(output_file.parent.joinpath('_iblqc_ephysTimeRmsAP.rms.npy'), rms_data) np.save(output_file.parent.joinpath('_iblqc_ephysTimeRmsAP.timestamps.npy'), time_data)
[docs]def rcoeff(x, y): """ Computes pairwise Person correlation coefficients for matrices. That is for 2 matrices the same size, computes the row to row coefficients and outputs a vector corresponding to the number of rows of the first matrix If the second array is a vector then computes the correlation coefficient for all rows :param x: np array [nc, ns] :param y: np array [nc, ns] or [ns] :return: r [nc] """ def normalize(z): mean = np.mean(z, axis=-1) return z - mean if mean.size == 1 else z - mean[:, np.newaxis] xnorm = normalize(x) ynorm = normalize(y) rcor = np.sum(xnorm * ynorm, axis=-1) / np.sqrt(np.sum(np.square(xnorm), axis=-1) * np.sum(np.square(ynorm), axis=-1)) return rcor
[docs]def detect_bad_channels(raw, fs, similarity_threshold=(-0.5, 1), psd_hf_threshold=None): """ Bad channels detection for Neuropixel probes Labels channels 0: all clear 1: dead low coherence / amplitude 2: noisy 3: outside of the brain :param raw: [nc, ns] :param fs: sampling frequency :param similarity_threshold: :param psd_hf_threshold: :return: labels (numpy vector [nc]), xfeats: dictionary of features [nc] """ def rneighbours(raw, n=1): # noqa """ Computes Pearson correlation with the sum of neighbouring traces :param raw: nc, ns :param n: :return: """ nc = raw.shape[0] mixer = np.triu(np.ones((nc, nc)), 1) - np.triu(np.ones((nc, nc)), 1 + n) mixer += np.tril(np.ones((nc, nc)), -1) - np.tril(np.ones((nc, nc)), - n - 1) r = rcoeff(raw, np.matmul(raw.T, mixer).T) r[np.isnan(r)] = 0 return r def detrend(x, nmed): ntap = int(np.ceil(nmed / 2)) xf = np.r_[np.zeros(ntap) + x[0], x, np.zeros(ntap) + x[-1]] # assert np.all(xcorf[ntap:-ntap] == xcor) xf = scipy.signal.medfilt(xf, nmed)[ntap:-ntap] return x - xf def channels_similarity(raw, nmed=0): """ Computes the similarity based on zero-lag crosscorrelation of each channel with the median trace referencing :param raw: [nc, ns] :param nmed: :return: """ def fxcor(x, y): return scipy.fft.irfft(scipy.fft.rfft(x) * np.conj(scipy.fft.rfft(y)), n=raw.shape[-1]) def nxcor(x, ref): ref = ref - np.mean(ref) apeak = fxcor(ref, ref)[0] x = x - np.mean(x, axis=-1)[:, np.newaxis] return fxcor(x, ref)[:, 0] / apeak ref = np.median(raw, axis=0) xcor = nxcor(raw, ref) if nmed > 0: xcor = detrend(xcor, nmed) + 1 return xcor nc, _ = raw.shape raw = raw - np.mean(raw, axis=-1)[:, np.newaxis] # removes DC offset xcor = channels_similarity(raw) fscale, psd = scipy.signal.welch(raw * 1e6, fs=fs) # units; uV ** 2 / Hz if psd_hf_threshold is None: # the LFP band data is obviously much stronger so auto-adjust the default threshold psd_hf_threshold = 1.4 if fs < 5000 else 0.02 sos_hp = scipy.signal.butter(**{'N': 3, 'Wn': 300 / fs * 2, 'btype': 'highpass'}, output='sos') hf = scipy.signal.sosfiltfilt(sos_hp, raw) xcorf = channels_similarity(hf) xfeats = ({ 'ind': np.arange(nc), 'rms_raw': rms(raw), # very similar to the rms avfter butterworth filter 'xcor_hf': detrend(xcor, 11), 'xcor_lf': xcorf - detrend(xcorf, 11) - 1, 'psd_hf': np.mean(psd[:, fscale > (fs / 2 * 0.8)], axis=-1), # 80% nyquists }) # make recommendation ichannels = np.zeros(nc) idead = np.where(similarity_threshold[0] > xfeats['xcor_hf'])[0] inoisy = np.where(np.logical_or(xfeats['psd_hf'] > psd_hf_threshold, xfeats['xcor_hf'] > similarity_threshold[1]))[0] # the channels outside of the brains are the contiguous channels below the threshold on the trend coherency ioutside = np.where(xfeats['xcor_lf'] < -0.75)[0] if ioutside.size > 0 and ioutside[-1] == (nc - 1): a = np.cumsum(np.r_[0, np.diff(ioutside) - 1]) ioutside = ioutside[a == np.max(a)] ichannels[ioutside] = 3 # indices ichannels[idead] = 1 ichannels[inoisy] = 2 # from ibllib.plots.figures import ephys_bad_channels # ephys_bad_channels(x, 30000, ichannels, xfeats) return ichannels, xfeats
[docs]def detect_bad_channels_cbin(bin_file, n_batches=10, batch_duration=0.3, display=False): """ Runs a ap-binary file scan to automatically detect faulty channels :param bin_file: full file path to the binary or compressed binary file from spikeglx :param n_batches: number of batches throughout the file (defaults to 10) :param batch_duration: batch length in seconds, defaults to 0.3 :param display: if True will return a figure with features and an excerpt of the raw data :return: channel_labels: nc int array with 0:ok, 1:dead, 2:high noise, 3:outside of the brain """ sr = bin_file if isinstance(bin_file, spikeglx.Reader) else spikeglx.Reader(bin_file) nc = sr.nc - sr.nsync channel_labels = np.zeros((nc, n_batches)) # loop over the file and take the mode of detections for i, t0 in enumerate(np.linspace(0, sr.rl - batch_duration, n_batches)): sl = slice(int(t0 * sr.fs), int((t0 + batch_duration) * sr.fs)) channel_labels[:, i], _xfeats = detect_bad_channels(sr[sl, :nc].T, fs=sr.fs) if i == 0: # init the features dictionary if necessary xfeats = {k: np.zeros((nc, n_batches)) for k in _xfeats} for k in xfeats: xfeats[k][:, i] = _xfeats[k] # the features are averaged so there may be a discrepancy between the mode and applying # the thresholds to the average of the features - the goal of those features is for display only xfeats_med = {k: np.median(xfeats[k], axis=-1) for k in xfeats} channel_flags, _ = scipy.stats.mode(channel_labels, axis=1) if display: raw = sr[sl, :nc].T from ibllib.plots.figures import ephys_bad_channels ephys_bad_channels(raw, sr.fs, channel_flags, xfeats_med) return channel_flags