"""
Set of functions to handle wheel data.
"""
import numpy as np
from numpy import pi
from iblutil.numerical import between_sorted
import scipy.interpolate as interpolate
import scipy.signal
from scipy.linalg import hankel
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection
# from ibllib.io.extractors.ephys_fpga import WHEEL_TICKS # FIXME Circular dependencies
__all__ = ['cm_to_deg',
'cm_to_rad',
'interpolate_position',
'get_movement_onset',
'movements',
'samples_to_cm',
'traces_by_trial',
'velocity_filtered']
# Define some constants
ENC_RES = 1024 * 4 # Rotary encoder resolution, assumes X4 encoding
WHEEL_DIAMETER = 3.1 * 2 # Wheel diameter in cm
[docs]
def interpolate_position(re_ts, re_pos, freq=1000, kind='linear', fill_gaps=None):
"""
Return linearly interpolated wheel position.
Parameters
----------
re_ts : array_like
Array of timestamps
re_pos: array_like
Array of unwrapped wheel positions
freq : float
frequency in Hz of the interpolation
kind : {'linear', 'cubic'}
Type of interpolation. Defaults to linear interpolation.
fill_gaps : float
Minimum gap length to fill. For gaps over this time (seconds),
forward fill values before interpolation
Returns
-------
yinterp : array
Interpolated position
t : array
Timestamps of interpolated positions
"""
t = np.arange(re_ts[0], re_ts[-1], 1 / freq) # Evenly resample at frequency
if t[-1] > re_ts[-1]:
t = t[:-1] # Occasionally due to precision errors the last sample may be outside of range.
yinterp = interpolate.interp1d(re_ts, re_pos, kind=kind)(t)
if fill_gaps:
# Find large gaps and forward fill @fixme This is inefficient
gaps, = np.where(np.diff(re_ts) >= fill_gaps)
for i in gaps:
yinterp[(t >= re_ts[i]) & (t < re_ts[i + 1])] = re_pos[i]
return yinterp, t
[docs]
def velocity_filtered(pos, fs, corner_frequency=20, order=8):
"""
Compute wheel velocity from uniformly sampled wheel data.
pos: array_like
Vector of uniformly sampled wheel positions.
fs : float
Frequency in Hz of the sampling frequency.
corner_frequency : float
Corner frequency of low-pass filter.
order : int
Order of Butterworth filter.
Returns
-------
vel : np.ndarray
Array of velocity values.
acc : np.ndarray
Array of acceleration values.
"""
sos = scipy.signal.butter(**{'N': order, 'Wn': corner_frequency / fs * 2, 'btype': 'lowpass'}, output='sos')
vel = np.insert(np.diff(scipy.signal.sosfiltfilt(sos, pos)), 0, 0) * fs
acc = np.insert(np.diff(vel), 0, 0) * fs
return vel, acc
[docs]
def get_movement_onset(intervals, event_times):
"""
Find the time at which movement started, given an event timestamp that occurred during the
movement.
Parameters
----------
intervals : numpy.array
The wheel movement intervals.
event_times : numpy.array
Sorted event timestamps anywhere during movement of interest, e.g. peak velocity, feedback
time.
Returns
-------
numpy.array
An array the length of event_time of intervals.
Examples
--------
Find the last movement onset before each trial response time
>>> trials = one.load_object(eid, 'trials')
>>> wheelMoves = one.load_object(eid, 'wheelMoves')
>>> onsets = last_movement_onset(wheelMoves.intervals, trials.response_times)
"""
if not np.all(np.diff(event_times) > 0):
raise ValueError('event_times must be in ascending order.')
onsets = np.full(event_times.size, np.nan)
for i in np.arange(intervals.shape[0]):
onset = between_sorted(event_times, intervals[i, :])
if np.any(onset):
onsets[onset] = intervals[i, 0]
return onsets
[docs]
def movements(t, pos, freq=1000, pos_thresh=8, t_thresh=.2, min_gap=.1, pos_thresh_onset=1.5,
min_dur=.05, make_plots=False):
"""
Detect wheel movements.
Parameters
----------
t : array_like
An array of evenly sampled wheel timestamps in absolute seconds
pos : array_like
An array of evenly sampled wheel positions
freq : int
The sampling rate of the wheel data
pos_thresh : float
The minimum required movement during the t_thresh window to be considered part of a
movement
t_thresh : float
The time window over which to check whether the pos_thresh has been crossed
min_gap : float
The minimum time between one movement's offset and another movement's onset in order to be
considered separate. Movements with a gap smaller than this are 'stictched together'
pos_thresh_onset : float
A lower threshold for finding precise onset times. The first position of each movement
transition that is this much bigger than the starting position is considered the onset
min_dur : float
The minimum duration of a valid movement. Detected movements shorter than this are ignored
make_plots : boolean
Plot trace of position and velocity, showing detected onsets and offsets
Returns
-------
onsets : np.ndarray
Timestamps of detected movement onsets
offsets : np.ndarray
Timestamps of detected movement offsets
peak_amps : np.ndarray
The absolute maximum amplitude of each detected movement, relative to onset position
peak_vel_times : np.ndarray
Timestamps of peak velocity for each detected movement
"""
# Wheel position must be evenly sampled.
dt = np.diff(t)
assert np.all(np.abs(dt - dt.mean()) < 1e-10), 'Values not evenly sampled'
# Convert the time threshold into number of samples given the sampling frequency
t_thresh_samps = int(np.round(t_thresh * freq))
max_disp = np.empty(t.size, dtype=float) # initialize array of total wheel displacement
# Calculate a Hankel matrix of size t_thresh_samps in batches. This is effectively a
# sliding window within which we look for changes in position greater than pos_thresh
BATCH_SIZE = 10000 # do this in batches in order to keep memory usage reasonable
c = 0 # index of 'window' position
while True:
i2proc = np.arange(BATCH_SIZE) + c
i2proc = i2proc[i2proc < t.size]
w2e = hankel(pos[i2proc], np.full(t_thresh_samps, np.nan))
# Below is the total change in position for each window
max_disp[i2proc] = np.nanmax(w2e, axis=1) - np.nanmin(w2e, axis=1)
c += BATCH_SIZE - t_thresh_samps
if i2proc[-1] == t.size - 1:
break
moving = max_disp > pos_thresh # for each window is the change in position greater than our threshold?
moving = np.insert(moving, 0, False) # First sample should always be not moving to ensure we have an onset
moving[-1] = False # Likewise, ensure we always end on an offset
onset_samps = np.where(~moving[:-1] & moving[1:])[0]
offset_samps = np.where(moving[:-1] & ~moving[1:])[0]
too_short = np.where((onset_samps[1:] - offset_samps[:-1]) / freq < min_gap)[0]
for p in too_short:
moving[offset_samps[p]:onset_samps[p + 1] + 1] = True
onset_samps = np.where(~moving[:-1] & moving[1:])[0]
onsets_disp_arr = np.empty((onset_samps.size, t_thresh_samps))
c = 0
cwt = 0
while onset_samps.size != 0:
i2proc = np.arange(BATCH_SIZE) + c
icomm = np.intersect1d(i2proc[:-t_thresh_samps - 1], onset_samps, assume_unique=True)
itpltz = np.intersect1d(i2proc[:-t_thresh_samps - 1], onset_samps,
return_indices=True, assume_unique=True)[1]
i2proc = i2proc[i2proc < t.size]
if icomm.size > 0:
w2e = hankel(pos[i2proc], np.full(t_thresh_samps, np.nan))
w2e = np.abs((w2e.T - w2e[:, 0]).T)
onsets_disp_arr[cwt + np.arange(icomm.size), :] = w2e[itpltz, :]
cwt += icomm.size
c += BATCH_SIZE - t_thresh_samps
if i2proc[-1] >= onset_samps[-1]:
break
has_onset = onsets_disp_arr > pos_thresh_onset
A = np.argmin(np.fliplr(has_onset).T, axis=0)
onset_lags = t_thresh_samps - A
onset_samps = onset_samps + onset_lags - 1
onsets = t[onset_samps]
offset_samps = np.where(moving[:-1] & ~moving[1:])[0]
offsets = t[offset_samps]
durations = offsets - onsets
too_short = durations < min_dur
onset_samps = onset_samps[~too_short]
onsets = onsets[~too_short]
offset_samps = offset_samps[~too_short]
offsets = offsets[~too_short]
moveGaps = onsets[1:] - offsets[:-1]
gap_too_small = moveGaps < min_gap
if onsets.size > 0:
onsets = onsets[np.insert(~gap_too_small, 0, True)] # always keep first onset
onset_samps = onset_samps[np.insert(~gap_too_small, 0, True)]
offsets = offsets[np.append(~gap_too_small, True)] # always keep last offset
offset_samps = offset_samps[np.append(~gap_too_small, True)]
# Calculate the peak amplitudes -
# the maximum absolute value of the difference from the onset position
peaks = (pos[m + np.abs(pos[m:n] - pos[m]).argmax()] - pos[m]
for m, n in zip(onset_samps, offset_samps))
peak_amps = np.fromiter(peaks, dtype=float, count=onsets.size)
N = 10 # Number of points in the Gaussian
STDEV = 1.8 # Equivalent to a width factor (alpha value) of 2.5
gauss = scipy.signal.windows.gaussian(N, STDEV) # A 10-point Gaussian window of a given s.d.
vel = scipy.signal.convolve(np.diff(np.insert(pos, 0, 0)), gauss, mode='same')
# For each movement period, find the timestamp where the absolute velocity was greatest
peaks = (t[m + np.abs(vel[m:n]).argmax()] for m, n in zip(onset_samps, offset_samps))
peak_vel_times = np.fromiter(peaks, dtype=float, count=onsets.size)
if make_plots:
fig, axes = plt.subplots(nrows=2, sharex='all')
indices = np.sort(np.hstack((onset_samps, offset_samps))) # Points to split trace
vel, acc = velocity_filtered(pos, freq)
# Plot the wheel position and velocity
for ax, y in zip(axes, (pos, vel)):
ax.plot(onsets, y[onset_samps], 'go')
ax.plot(offsets, y[offset_samps], 'bo')
t_split = np.split(np.vstack((t, y)).T, indices, axis=0)
ax.add_collection(LineCollection(t_split[1::2], colors='r')) # Moving
ax.add_collection(LineCollection(t_split[0::2], colors='k')) # Not moving
axes[1].autoscale() # rescale after adding line collections
axes[0].autoscale()
axes[0].set_ylabel('position')
axes[1].set_ylabel('velocity')
axes[1].set_xlabel('time')
axes[0].legend(['onsets', 'offsets', 'in movement'])
plt.show()
return onsets, offsets, peak_amps, peak_vel_times
[docs]
def cm_to_deg(positions, wheel_diameter=WHEEL_DIAMETER):
"""
Convert wheel position to degrees turned. This may be useful for e.g. calculating velocity
in revolutions per second
:param positions: array of wheel positions in cm
:param wheel_diameter: the diameter of the wheel in cm
:return: array of wheel positions in degrees turned
# Example: Convert linear cm to degrees
>>> cm_to_deg(3.142 * WHEEL_DIAMETER)
360.04667846020925
# Example: Get positions in deg from cm for 5cm diameter wheel
>>> import numpy as np
>>> cm_to_deg(np.array([0.0270526 , 0.04057891, 0.05410521, 0.06763151]), wheel_diameter=5)
array([0.61999992, 0.93000011, 1.24000007, 1.55000003])
"""
return positions / (wheel_diameter * pi) * 360
[docs]
def cm_to_rad(positions, wheel_diameter=WHEEL_DIAMETER):
"""
Convert wheel position to radians. This may be useful for e.g. calculating angular velocity.
:param positions: array of wheel positions in cm
:param wheel_diameter: the diameter of the wheel in cm
:return: array of wheel angle in radians
# Example: Convert linear cm to radians
>>> cm_to_rad(1)
0.3225806451612903
# Example: Get positions in rad from cm for 5cm diameter wheel
>>> import numpy as np
>>> cm_to_rad(np.array([0.0270526 , 0.04057891, 0.05410521, 0.06763151]), wheel_diameter=5)
array([0.01082104, 0.01623156, 0.02164208, 0.0270526 ])
"""
return positions * (2 / wheel_diameter)
[docs]
def samples_to_cm(positions, wheel_diameter=WHEEL_DIAMETER, resolution=ENC_RES):
"""
Convert wheel position samples to cm linear displacement. This may be useful for
inter-converting threshold units
:param positions: array of wheel positions in sample counts
:param wheel_diameter: the diameter of the wheel in cm
:param resolution: resolution of the rotary encoder
:return: array of wheel angle in radians
# Example: Get resolution in linear cm
>>> samples_to_cm(1)
0.004755340442445488
# Example: Get positions in linear cm for 4X, 360 ppr encoder
>>> import numpy as np
>>> samples_to_cm(np.array([2, 3, 4, 5, 6, 7, 6, 5, 4]), resolution=360*4)
array([0.0270526 , 0.04057891, 0.05410521, 0.06763151, 0.08115781,
0.09468411, 0.08115781, 0.06763151, 0.05410521])
"""
return positions / resolution * pi * wheel_diameter
def direction_changes(t, vel, intervals):
"""
Find the direction changes for the given movement intervals.
Parameters
----------
t : array_like
An array of evenly sampled wheel timestamps in absolute seconds
vel : array_like
An array of evenly sampled wheel positions
intervals : array_like
An n-by-2 array of wheel movement intervals
Returns
----------
times : iterable
A list of numpy arrays of direction change timestamps, one array per interval
indices : iterable
A list of numpy arrays containing indices of direction changes; the size of times
"""
indices = []
times = []
chg = np.insert(np.diff(np.sign(vel)) != 0, 0, 0)
for on, off in intervals.reshape(-1, 2):
mask = np.logical_and(t > on, t < off)
ind, = np.where(np.logical_and(mask, chg))
times.append(t[ind])
indices.append(ind)
return times, indices
[docs]
def traces_by_trial(t, *args, start=None, end=None, separate=True):
"""
Returns list of tuples of positions and velocity for samples between stimulus onset and
feedback.
:param t: numpy array of timestamps
:param args: optional numpy arrays of the same length as timestamps, such as positions,
velocities or accelerations
:param start: start timestamp or array thereof
:param end: end timestamp or array thereof
:param separate: when True, the output is returned as tuples list of the form [(t, args[0],
args[1]), ...], when False, the output is a list of n-by-m ndarrays where n = number of
positional args and m = len(t)
:return: list of sliced arrays where length == len(start)
"""
if start is None:
start = t[0]
if end is None:
end = t[-1]
traces = np.stack((t, *args))
assert len(start) == len(end), 'number of start timestamps must equal end timestamps'
def to_mask(a, b):
return np.logical_and(t > a, t < b)
cuts = [traces[:, to_mask(s, e)] for s, e in zip(start, end)]
return [(cuts[n][0, :], cuts[n][1, :]) for n in range(len(cuts))] if separate else cuts
if __name__ == '__main__':
import doctest
doctest.testmod()