Source code for iblutil.spacer

"""Tools to generate and identify spacers.

Spacers are sequences of up and down pulses with a specific, identifiable pattern.
They are generated with a chirp coding to reduce cross-correlaation sidelobes.
They are used to mark the beginning of a behaviour sequence within a session.

Example
-------
>>> spacer = Spacer()
>>> spacer.add_spacer_states(sma, t, next_state='first_state')
>>> for i in range(ntrials):
... sma.add_state(
...     state_name='first_state',
...     state_timer=tup,
...     state_change_conditions={'Tup': f'spacer_low_{i:02d}'},
...     output_actions=[('BNC1', 255)],  # To FPGA
... )

"""

import numpy as np


[docs] class Spacer: def __init__(self, dt_start=.02, dt_end=.4, n_pulses=8, tup=.05): """Computes spacer up times using a chirp up and down pattern. Parameters ---------- dt_start : float First spacer up time. dt_end : float Last spacer up time. n_pulses : int Number of spacer up times, one-sided (i.e. 8 means 16 - 1 spacers times) tup: float Duration of the spacer up time. """ self.dt_start = dt_start self.dt_end = dt_end self.n_pulses = n_pulses self.tup = tup assert np.all(np.diff(self.times) > self.tup), 'Spacers are overlapping' def __repr__(self): return f'Spacer(dt_start={self.dt_start}, dt_end={self.dt_end}, n_pulses={self.n_pulses}, tup={self.tup})' @property def times(self): """Computes spacer up times using a chirp up and down pattern. Each time corresponds to an up time of the BNC1 signal. Returns ------- numpy.array Numpy arrays of spacer times. """ # upsweep t = np.linspace(self.dt_start, self.dt_end, self.n_pulses) + self.tup # downsweep t = np.r_[t, np.flipud(t[1:])] t = np.cumsum(t) return t
[docs] def generate_template(self, fs=1000): """ Generates a spacer voltage template to cross-correlate with a voltage trace from a DAQ to detect a voltage trace. Parameters ---------- fs : int DAQ sampling frequency. Returns ------- numpy.array The template spacer signal. """ t = self.times ns = int((t[-1] + self.tup * 10) * fs) sig = np.zeros(ns, ) sig[(t * fs).astype(np.int32)] = 1 sig[((t + self.tup) * fs).astype(np.int32)] = -1 sig = np.cumsum(sig) return sig
[docs] def add_spacer_states(self, sma=None, next_state='exit'): """ Add spacer states to a state machine. Parameters ---------- sma : pybpodapi.state_machine.StateMachine A Bpod state machine instance. next_state : str The name of the state to follow the spacer state. """ assert next_state is not None t = self.times dt = np.diff(t, append=t[-1] + self.tup * 2) for i, time in enumerate(t): if sma is None: print(i, time, dt[i]) continue next_loop = f'spacer_high_{i + 1:02d}' if i < len(t) - 1 else next_state sma.add_state( state_name=f'spacer_high_{i:02d}', state_timer=self.tup, state_change_conditions={'Tup': f'spacer_low_{i:02d}'}, output_actions=[('BNC1', 255)], # To FPGA ) sma.add_state( state_name=f'spacer_low_{i:02d}', state_timer=dt[i] - self.tup, state_change_conditions={'Tup': next_loop}, output_actions=[], )
[docs] def find_spacers_from_fronts(self, fronts, fs=1000): """ Given the timestamps and polarities of a digital signal, returns the timestamps of each signal. This method first finds the locations where there are n consecutive pulses of the correct width then convolves this part of the signal with the template signal. This method may be relaxed in order to make it robust to noise in the signal. Parameters ---------- fronts : dict[str, numpy.array] Dictionary with keys ('times', 'polarities') containing the timestamps and polarities of the signal fronts, respectively. fs : int The sampling frequency of the DAQ signal. Returns ------- numpy.array The times of the protocol spacer signals. """ n_pulses = (self.n_pulses * 2) - 1 is_pulse = np.isclose(np.diff(fronts['times']), self.tup, rtol=1e-2) is_pulse = np.insert(is_pulse, 0, False) ind, = np.where(is_pulse) # Find consecutive pulses that are the correct length close together max_d = 1. # look for fronts less than 1 second apart consecutive = np.logical_and(np.diff(ind) == 2, np.diff(fronts['times'][ind]) < max_d) consecutive = np.pad(consecutive, 1, 'constant', constant_values=False) edges, = np.where(~consecutive) spacer_times = [] for i in np.arange(edges.size - 1): if edges[i + 1] - edges[i] == n_pulses: # This could be relaxed to allow for noise idx = np.arange(ind[edges[i]], ind[edges[i + 1] - 1] + 1) # +1 to include final down t = fronts['times'][idx] ts = np.arange(t[0], t[-1], 1 / fs) # Evenly resample at given frequency # Reconstruct trace where 1 = high, 0 = low signal = np.zeros_like(ts) ii = np.searchsorted(ts, t, side='left') signal[ii[ii < len(signal)]] = fronts['polarities'][idx[ii < len(signal)]] signal = np.cumsum(signal) + 1 # {-1, 0} -> {0, 1} try: spacer, = self.find_spacers(signal, fs=fs) spacer_times.append(spacer + t[0]) except IndexError: continue return np.array(spacer_times)
[docs] def find_spacers(self, signal, threshold=0.9, fs=1000): """ Find spacers in a voltage time series. Assumes that the signal is a digital signal between 0 and 1. Parameters ---------- signal : numpy.ndarray The signal in which to find the spacer. threshold : float The cross-correlation detection threshold. fs : int The sampling frequency of the DAQ signal. Returns ------- numpy.ndarray An array containing the times of each spacer signal relative to the first sample. """ template = self.generate_template(fs=fs) xcor = np.correlate(signal, template, mode='full') / np.sum(template) idetect = np.where(xcor > threshold)[0] iidetect = np.cumsum(np.diff(idetect, prepend=0) > 1) nspacers = iidetect[-1] tspacer = np.zeros(nspacers) for i in range(nspacers): ispacer = idetect[iidetect == i + 1] imax = np.argmax(xcor[ispacer]) tspacer[i] = (ispacer[imax] - template.size + 1) / fs return tspacer