Source code for brainbox.behavior.wheel

"""
Set of functions to handle wheel data.
"""
import numpy as np
from numpy import pi
from iblutil.numerical import between_sorted
import scipy.interpolate as interpolate
import scipy.signal
from scipy.linalg import hankel
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection
# from ibllib.io.extractors.ephys_fpga import WHEEL_TICKS  # FIXME Circular dependencies

__all__ = ['cm_to_deg',
           'cm_to_rad',
           'interpolate_position',
           'get_movement_onset',
           'movements',
           'samples_to_cm',
           'traces_by_trial',
           'velocity_filtered']

# Define some constants
ENC_RES = 1024 * 4  # Rotary encoder resolution, assumes X4 encoding
WHEEL_DIAMETER = 3.1 * 2  # Wheel diameter in cm


[docs] def interpolate_position(re_ts, re_pos, freq=1000, kind='linear', fill_gaps=None): """ Return linearly interpolated wheel position. Parameters ---------- re_ts : array_like Array of timestamps re_pos: array_like Array of unwrapped wheel positions freq : float frequency in Hz of the interpolation kind : {'linear', 'cubic'} Type of interpolation. Defaults to linear interpolation. fill_gaps : float Minimum gap length to fill. For gaps over this time (seconds), forward fill values before interpolation Returns ------- yinterp : array Interpolated position t : array Timestamps of interpolated positions """ t = np.arange(re_ts[0], re_ts[-1], 1 / freq) # Evenly resample at frequency if t[-1] > re_ts[-1]: t = t[:-1] # Occasionally due to precision errors the last sample may be outside of range. yinterp = interpolate.interp1d(re_ts, re_pos, kind=kind)(t) if fill_gaps: # Find large gaps and forward fill @fixme This is inefficient gaps, = np.where(np.diff(re_ts) >= fill_gaps) for i in gaps: yinterp[(t >= re_ts[i]) & (t < re_ts[i + 1])] = re_pos[i] return yinterp, t
[docs] def velocity_filtered(pos, fs, corner_frequency=20, order=8): """ Compute wheel velocity from uniformly sampled wheel data. pos: array_like Vector of uniformly sampled wheel positions. fs : float Frequency in Hz of the sampling frequency. corner_frequency : float Corner frequency of low-pass filter. order : int Order of Butterworth filter. Returns ------- vel : np.ndarray Array of velocity values. acc : np.ndarray Array of acceleration values. """ sos = scipy.signal.butter(**{'N': order, 'Wn': corner_frequency / fs * 2, 'btype': 'lowpass'}, output='sos') vel = np.insert(np.diff(scipy.signal.sosfiltfilt(sos, pos)), 0, 0) * fs acc = np.insert(np.diff(vel), 0, 0) * fs return vel, acc
[docs] def get_movement_onset(intervals, event_times): """ Find the time at which movement started, given an event timestamp that occurred during the movement. Parameters ---------- intervals : numpy.array The wheel movement intervals. event_times : numpy.array Sorted event timestamps anywhere during movement of interest, e.g. peak velocity, feedback time. Returns ------- numpy.array An array the length of event_time of intervals. Examples -------- Find the last movement onset before each trial response time >>> trials = one.load_object(eid, 'trials') >>> wheelMoves = one.load_object(eid, 'wheelMoves') >>> onsets = last_movement_onset(wheelMoves.intervals, trials.response_times) """ if not np.all(np.diff(event_times) > 0): raise ValueError('event_times must be in ascending order.') onsets = np.full(event_times.size, np.nan) for i in np.arange(intervals.shape[0]): onset = between_sorted(event_times, intervals[i, :]) if np.any(onset): onsets[onset] = intervals[i, 0] return onsets
[docs] def movements(t, pos, freq=1000, pos_thresh=8, t_thresh=.2, min_gap=.1, pos_thresh_onset=1.5, min_dur=.05, make_plots=False): """ Detect wheel movements. Parameters ---------- t : array_like An array of evenly sampled wheel timestamps in absolute seconds pos : array_like An array of evenly sampled wheel positions freq : int The sampling rate of the wheel data pos_thresh : float The minimum required movement during the t_thresh window to be considered part of a movement t_thresh : float The time window over which to check whether the pos_thresh has been crossed min_gap : float The minimum time between one movement's offset and another movement's onset in order to be considered separate. Movements with a gap smaller than this are 'stictched together' pos_thresh_onset : float A lower threshold for finding precise onset times. The first position of each movement transition that is this much bigger than the starting position is considered the onset min_dur : float The minimum duration of a valid movement. Detected movements shorter than this are ignored make_plots : boolean Plot trace of position and velocity, showing detected onsets and offsets Returns ------- onsets : np.ndarray Timestamps of detected movement onsets offsets : np.ndarray Timestamps of detected movement offsets peak_amps : np.ndarray The absolute maximum amplitude of each detected movement, relative to onset position peak_vel_times : np.ndarray Timestamps of peak velocity for each detected movement """ # Wheel position must be evenly sampled. dt = np.diff(t) assert np.all(np.abs(dt - dt.mean()) < 1e-10), 'Values not evenly sampled' # Convert the time threshold into number of samples given the sampling frequency t_thresh_samps = int(np.round(t_thresh * freq)) max_disp = np.empty(t.size, dtype=float) # initialize array of total wheel displacement # Calculate a Hankel matrix of size t_thresh_samps in batches. This is effectively a # sliding window within which we look for changes in position greater than pos_thresh BATCH_SIZE = 10000 # do this in batches in order to keep memory usage reasonable c = 0 # index of 'window' position while True: i2proc = np.arange(BATCH_SIZE) + c i2proc = i2proc[i2proc < t.size] w2e = hankel(pos[i2proc], np.full(t_thresh_samps, np.nan)) # Below is the total change in position for each window max_disp[i2proc] = np.nanmax(w2e, axis=1) - np.nanmin(w2e, axis=1) c += BATCH_SIZE - t_thresh_samps if i2proc[-1] == t.size - 1: break moving = max_disp > pos_thresh # for each window is the change in position greater than our threshold? moving = np.insert(moving, 0, False) # First sample should always be not moving to ensure we have an onset moving[-1] = False # Likewise, ensure we always end on an offset onset_samps = np.where(~moving[:-1] & moving[1:])[0] offset_samps = np.where(moving[:-1] & ~moving[1:])[0] too_short = np.where((onset_samps[1:] - offset_samps[:-1]) / freq < min_gap)[0] for p in too_short: moving[offset_samps[p]:onset_samps[p + 1] + 1] = True onset_samps = np.where(~moving[:-1] & moving[1:])[0] onsets_disp_arr = np.empty((onset_samps.size, t_thresh_samps)) c = 0 cwt = 0 while onset_samps.size != 0: i2proc = np.arange(BATCH_SIZE) + c icomm = np.intersect1d(i2proc[:-t_thresh_samps - 1], onset_samps, assume_unique=True) itpltz = np.intersect1d(i2proc[:-t_thresh_samps - 1], onset_samps, return_indices=True, assume_unique=True)[1] i2proc = i2proc[i2proc < t.size] if icomm.size > 0: w2e = hankel(pos[i2proc], np.full(t_thresh_samps, np.nan)) w2e = np.abs((w2e.T - w2e[:, 0]).T) onsets_disp_arr[cwt + np.arange(icomm.size), :] = w2e[itpltz, :] cwt += icomm.size c += BATCH_SIZE - t_thresh_samps if i2proc[-1] >= onset_samps[-1]: break has_onset = onsets_disp_arr > pos_thresh_onset A = np.argmin(np.fliplr(has_onset).T, axis=0) onset_lags = t_thresh_samps - A onset_samps = onset_samps + onset_lags - 1 onsets = t[onset_samps] offset_samps = np.where(moving[:-1] & ~moving[1:])[0] offsets = t[offset_samps] durations = offsets - onsets too_short = durations < min_dur onset_samps = onset_samps[~too_short] onsets = onsets[~too_short] offset_samps = offset_samps[~too_short] offsets = offsets[~too_short] moveGaps = onsets[1:] - offsets[:-1] gap_too_small = moveGaps < min_gap if onsets.size > 0: onsets = onsets[np.insert(~gap_too_small, 0, True)] # always keep first onset onset_samps = onset_samps[np.insert(~gap_too_small, 0, True)] offsets = offsets[np.append(~gap_too_small, True)] # always keep last offset offset_samps = offset_samps[np.append(~gap_too_small, True)] # Calculate the peak amplitudes - # the maximum absolute value of the difference from the onset position peaks = (pos[m + np.abs(pos[m:n] - pos[m]).argmax()] - pos[m] for m, n in zip(onset_samps, offset_samps)) peak_amps = np.fromiter(peaks, dtype=float, count=onsets.size) N = 10 # Number of points in the Gaussian STDEV = 1.8 # Equivalent to a width factor (alpha value) of 2.5 gauss = scipy.signal.windows.gaussian(N, STDEV) # A 10-point Gaussian window of a given s.d. vel = scipy.signal.convolve(np.diff(np.insert(pos, 0, 0)), gauss, mode='same') # For each movement period, find the timestamp where the absolute velocity was greatest peaks = (t[m + np.abs(vel[m:n]).argmax()] for m, n in zip(onset_samps, offset_samps)) peak_vel_times = np.fromiter(peaks, dtype=float, count=onsets.size) if make_plots: fig, axes = plt.subplots(nrows=2, sharex='all') indices = np.sort(np.hstack((onset_samps, offset_samps))) # Points to split trace vel, acc = velocity_filtered(pos, freq) # Plot the wheel position and velocity for ax, y in zip(axes, (pos, vel)): ax.plot(onsets, y[onset_samps], 'go') ax.plot(offsets, y[offset_samps], 'bo') t_split = np.split(np.vstack((t, y)).T, indices, axis=0) ax.add_collection(LineCollection(t_split[1::2], colors='r')) # Moving ax.add_collection(LineCollection(t_split[0::2], colors='k')) # Not moving axes[1].autoscale() # rescale after adding line collections axes[0].autoscale() axes[0].set_ylabel('position') axes[1].set_ylabel('velocity') axes[1].set_xlabel('time') axes[0].legend(['onsets', 'offsets', 'in movement']) plt.show() return onsets, offsets, peak_amps, peak_vel_times
[docs] def cm_to_deg(positions, wheel_diameter=WHEEL_DIAMETER): """ Convert wheel position to degrees turned. This may be useful for e.g. calculating velocity in revolutions per second :param positions: array of wheel positions in cm :param wheel_diameter: the diameter of the wheel in cm :return: array of wheel positions in degrees turned # Example: Convert linear cm to degrees >>> cm_to_deg(3.142 * WHEEL_DIAMETER) 360.04667846020925 # Example: Get positions in deg from cm for 5cm diameter wheel >>> import numpy as np >>> cm_to_deg(np.array([0.0270526 , 0.04057891, 0.05410521, 0.06763151]), wheel_diameter=5) array([0.61999992, 0.93000011, 1.24000007, 1.55000003]) """ return positions / (wheel_diameter * pi) * 360
[docs] def cm_to_rad(positions, wheel_diameter=WHEEL_DIAMETER): """ Convert wheel position to radians. This may be useful for e.g. calculating angular velocity. :param positions: array of wheel positions in cm :param wheel_diameter: the diameter of the wheel in cm :return: array of wheel angle in radians # Example: Convert linear cm to radians >>> cm_to_rad(1) 0.3225806451612903 # Example: Get positions in rad from cm for 5cm diameter wheel >>> import numpy as np >>> cm_to_rad(np.array([0.0270526 , 0.04057891, 0.05410521, 0.06763151]), wheel_diameter=5) array([0.01082104, 0.01623156, 0.02164208, 0.0270526 ]) """ return positions * (2 / wheel_diameter)
[docs] def samples_to_cm(positions, wheel_diameter=WHEEL_DIAMETER, resolution=ENC_RES): """ Convert wheel position samples to cm linear displacement. This may be useful for inter-converting threshold units :param positions: array of wheel positions in sample counts :param wheel_diameter: the diameter of the wheel in cm :param resolution: resolution of the rotary encoder :return: array of wheel angle in radians # Example: Get resolution in linear cm >>> samples_to_cm(1) 0.004755340442445488 # Example: Get positions in linear cm for 4X, 360 ppr encoder >>> import numpy as np >>> samples_to_cm(np.array([2, 3, 4, 5, 6, 7, 6, 5, 4]), resolution=360*4) array([0.0270526 , 0.04057891, 0.05410521, 0.06763151, 0.08115781, 0.09468411, 0.08115781, 0.06763151, 0.05410521]) """ return positions / resolution * pi * wheel_diameter
def direction_changes(t, vel, intervals): """ Find the direction changes for the given movement intervals. Parameters ---------- t : array_like An array of evenly sampled wheel timestamps in absolute seconds vel : array_like An array of evenly sampled wheel positions intervals : array_like An n-by-2 array of wheel movement intervals Returns ---------- times : iterable A list of numpy arrays of direction change timestamps, one array per interval indices : iterable A list of numpy arrays containing indices of direction changes; the size of times """ indices = [] times = [] chg = np.insert(np.diff(np.sign(vel)) != 0, 0, 0) for on, off in intervals.reshape(-1, 2): mask = np.logical_and(t > on, t < off) ind, = np.where(np.logical_and(mask, chg)) times.append(t[ind]) indices.append(ind) return times, indices
[docs] def traces_by_trial(t, *args, start=None, end=None, separate=True): """ Returns list of tuples of positions and velocity for samples between stimulus onset and feedback. :param t: numpy array of timestamps :param args: optional numpy arrays of the same length as timestamps, such as positions, velocities or accelerations :param start: start timestamp or array thereof :param end: end timestamp or array thereof :param separate: when True, the output is returned as tuples list of the form [(t, args[0], args[1]), ...], when False, the output is a list of n-by-m ndarrays where n = number of positional args and m = len(t) :return: list of sliced arrays where length == len(start) """ if start is None: start = t[0] if end is None: end = t[-1] traces = np.stack((t, *args)) assert len(start) == len(end), 'number of start timestamps must equal end timestamps' def to_mask(a, b): return np.logical_and(t > a, t < b) cuts = [traces[:, to_mask(s, e)] for s, e in zip(start, end)] return [(cuts[n][0, :], cuts[n][1, :]) for n in range(len(cuts))] if separate else cuts
if __name__ == '__main__': import doctest doctest.testmod()