"""
A module for aligning the wheel motion with the rotary encoder. Currently used by the camera QC
in order to check timestamp alignment.
"""
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.widgets import RectangleSelector
import numpy as np
from scipy import signal, ndimage, interpolate
import cv2
from itertools import cycle
import matplotlib.animation as animation
import logging
from pathlib import Path
from joblib import Parallel, delayed, cpu_count
from ibldsp.utils import WindowGenerator
from one.api import ONE
import ibllib.io.video as vidio
from iblutil.util import Bunch
from ibllib.io.extractors.ephys_fpga import get_sync_fronts, get_sync_and_chn_map
import ibllib.io.raw_data_loaders as raw
import ibllib.io.extractors.camera as cam
from ibllib.plots.snapshot import ReportSnapshot
import brainbox.video as video
import brainbox.behavior.wheel as wh
from brainbox.singlecell import bin_spikes
from brainbox.behavior.dlc import likelihood_threshold, get_speed
from brainbox.task.trials import find_trial_ids
import one.alf.io as alfio
from one.alf.exceptions import ALFObjectNotFound
from one.alf.spec import is_session_path, is_uuid_string
[docs]
def find_nearest(array, value):
array = np.asarray(array)
idx = (np.abs(array - value)).argmin()
return idx
[docs]
class MotionAlignment:
roi = {'left': ((800, 1020), (233, 1096)), 'right': ((426, 510), (104, 545)), 'body': ((402, 481), (31, 103))}
def __init__(self, eid=None, one=None, log=logging.getLogger(__name__), stream=False, **kwargs):
self.one = one or ONE()
self.eid = eid
self.session_path = kwargs.pop('session_path', None) or self.one.eid2path(eid)
self.ref = self.one.dict2ref(self.one.path2ref(self.session_path))
self.log = log
self.trials = self.wheel = self.camera_times = None
raw_cam_path = self.session_path.joinpath('raw_video_data')
camera_path = list(raw_cam_path.glob('_iblrig_*Camera.raw.*'))
if stream:
self.video_paths = vidio.url_from_eid(self.eid)
else:
self.video_paths = {vidio.label_from_path(x): x for x in camera_path}
self.data = Bunch()
self.alignment = Bunch()
[docs]
def align_all_trials(self, side='all'):
"""Align all wheel motion for all trials"""
if self.trials is None:
self.load_data()
if side == 'all':
side = self.video_paths.keys()
if not isinstance(side, str):
# Try to iterate over sides
[self.align_all_trials(s) for s in side]
if side not in self.video_paths:
raise ValueError(f'{side} camera video file not found')
# Align each trial sequentially
for i in np.arange(self.trials['intervals'].shape[0]):
self.align_motion(i, display=False)
[docs]
@staticmethod
def set_roi(video_path):
"""Manually set the ROIs for a given set of videos
TODO Improve docstring
TODO A method for setting ROIs by label
"""
frame = vidio.get_video_frame(str(video_path), 0)
def line_select_callback(eclick, erelease):
"""
Callback for line selection.
*eclick* and *erelease* are the press and release events.
"""
x1, y1 = eclick.xdata, eclick.ydata
x2, y2 = erelease.xdata, erelease.ydata
print("(%3.2f, %3.2f) --> (%3.2f, %3.2f)" % (x1, y1, x2, y2))
return np.array([[x1, x2], [y1, y2]])
plt.imshow(frame)
roi = RectangleSelector(plt.gca(), line_select_callback, drawtype='box', useblit=True, button=[1, 3],
# don't use middle button
minspanx=5, minspany=5, spancoords='pixels', interactive=True)
plt.show()
((x1, x2, *_), (y1, *_, y2)) = roi.corners
col = np.arange(round(x1), round(x2), dtype=int)
row = np.arange(round(y1), round(y2), dtype=int)
return col, row
[docs]
def load_data(self, download=False):
"""
Load wheel, trial and camera timestamp data
:return: wheel, trials
"""
if download:
self.data.wheel = self.one.load_object(self.eid, 'wheel')
self.data.trials = self.one.load_object(self.eid, 'trials')
cam, det = self.one.load_datasets(self.eid, ['*Camera.times*'])
self.data.camera_times = {vidio.label_from_path(d['rel_path']): ts for ts, d in zip(cam, det)}
else:
alf_path = self.session_path / 'alf'
wheel_path = next(alf_path.rglob('*wheel.timestamps*')).parent
self.data.wheel = alfio.load_object(wheel_path, 'wheel', short_keys=True)
trials_path = next(alf_path.rglob('*trials.table*')).parent
self.data.trials = alfio.load_object(trials_path, 'trials')
self.data.camera_times = {vidio.label_from_path(x): alfio.load_file_content(x) for x in
alf_path.rglob('*Camera.times*')}
assert all(x is not None for x in self.data.values())
def _set_eid_or_path(self, session_path_or_eid):
"""Parse a given eID or session path
If a session UUID is given, resolves and stores the local path and vice versa
:param session_path_or_eid: A session eid or path
:return:
"""
self.eid = None
if is_uuid_string(str(session_path_or_eid)):
self.eid = session_path_or_eid
# Try to set session_path if data is found locally
self.session_path = self.one.eid2path(self.eid)
elif is_session_path(session_path_or_eid):
self.session_path = Path(session_path_or_eid)
if self.one is not None:
self.eid = self.one.path2eid(self.session_path)
if not self.eid:
self.log.warning('Failed to determine eID from session path')
else:
self.log.error('Cannot run alignment: an experiment uuid or session path is required')
raise ValueError("'session' must be a valid session path or uuid")
[docs]
def align_motion(self, period=(-np.inf, np.inf), side='left', sd_thresh=10, display=False):
"""
Align video to the wheel using cross-correlation of the video motion signal and the rotary
encoder.
Parameters
----------
period : (float, float)
The time period over which to do the alignment.
side : {'left', 'right'}
With which camera to perform the alignment.
sd_thresh : float
For plotting where the motion energy goes above this standard deviation threshold.
display : bool
When true, displays the aligned wheel motion energy along with the rotary encoder
signal.
Returns
-------
int
Frame offset, i.e. by how many frames the video was shifted to match the rotary encoder
signal. Negative values mean the video was shifted backwards with respect to the wheel
timestamps.
float
The peak cross-correlation.
numpy.ndarray
The motion energy used in the cross-correlation, i.e. the frame difference for the
period given.
"""
# Get data samples within period
wheel = self.data['wheel']
self.alignment.label = side
self.alignment.to_mask = lambda ts: np.logical_and(ts >= period[0], ts <= period[1])
camera_times = self.data['camera_times'][side]
cam_mask = self.alignment.to_mask(camera_times)
frame_numbers, = np.where(cam_mask)
if frame_numbers.size == 0:
raise ValueError('No frames during given period')
# Motion Energy
camera_path = self.video_paths[side]
roi = (*[slice(*r) for r in self.roi[side]], 0)
try:
# TODO Add function arg to make grayscale
self.alignment.frames = vidio.get_video_frames_preload(camera_path, frame_numbers, mask=roi)
assert self.alignment.frames.size != 0
except AssertionError:
self.log.error('Failed to open video')
return None, None, None
self.alignment.df, stDev = video.motion_energy(self.alignment.frames, 2)
self.alignment.period = period # For plotting
# Calculate rotary encoder velocity trace
x = camera_times[cam_mask]
Fs = 1000
pos, t = wh.interpolate_position(wheel.timestamps, wheel.position, freq=Fs)
v, _ = wh.velocity_filtered(pos, Fs)
interp_mask = self.alignment.to_mask(t)
# Convert to normalized speed
xs = np.unique([find_nearest(t[interp_mask], ts) for ts in x])
vs = np.abs(v[interp_mask][xs])
vs = (vs - np.min(vs)) / (np.max(vs) - np.min(vs))
# FIXME This can be used as a goodness of fit measure
USE_CV2 = False
if USE_CV2:
# convert from numpy format to openCV format
dfCV = np.float32(self.alignment.df.reshape((-1, 1)))
reCV = np.float32(vs.reshape((-1, 1)))
# perform cross correlation
resultCv = cv2.matchTemplate(dfCV, reCV, cv2.TM_CCORR_NORMED)
# convert result back to numpy array
xcorr = np.asarray(resultCv)
else:
xcorr = signal.correlate(self.alignment.df, vs)
# Cross correlate wheel speed trace with the motion energy
CORRECTION = 2
self.alignment.c = max(xcorr)
self.alignment.xcorr = np.argmax(xcorr)
self.alignment.dt_i = self.alignment.xcorr - xs.size + CORRECTION
self.log.info(f'{side} camera, adjusted by {self.alignment.dt_i} frames')
if display:
# Plot the motion energy
fig, ax = plt.subplots(2, 1, sharex='all')
y = np.pad(self.alignment.df, 1, 'edge')
ax[0].plot(x, y, '-x', label='wheel motion energy')
thresh = stDev > sd_thresh
ax[0].vlines(x[np.array(np.pad(thresh, 1, 'constant', constant_values=False))], 0, 1, linewidth=0.5, linestyle=':',
label=f'>{sd_thresh} s.d. diff')
ax[1].plot(t[interp_mask], np.abs(v[interp_mask]))
# Plot other stuff
dt = np.diff(camera_times[[0, np.abs(self.alignment.dt_i)]])
fps = 1 / np.diff(camera_times).mean()
ax[0].plot(t[interp_mask][xs] - dt, vs, 'r-x', label='velocity (shifted)')
ax[0].set_title('normalized motion energy, %s camera, %.0f fps' % (side, fps))
ax[0].set_ylabel('rate of change (a.u.)')
ax[0].legend()
ax[1].set_ylabel('wheel speed (rad / s)')
ax[1].set_xlabel('Time (s)')
title = f'{self.ref}, from {period[0]:.1f}s - {period[1]:.1f}s'
fig.suptitle(title, fontsize=16)
fig.set_size_inches(19.2, 9.89)
return self.alignment.dt_i, self.alignment.c, self.alignment.df
[docs]
def plot_alignment(self, energy=True, save=False):
if not self.alignment:
self.log.error('No alignment data, run `align_motion` first')
return
# Change backend based on save flag
backend = matplotlib.get_backend().lower()
if (save and backend != 'agg') or (not save and backend == 'agg'):
new_backend = 'Agg' if save else 'Qt5Agg'
self.log.warning('Switching backend from %s to %s', backend, new_backend)
matplotlib.use(new_backend)
from matplotlib import pyplot as plt
# Main animated plots
fig, axes = plt.subplots(nrows=2)
title = f'{self.ref}' # ', from {period[0]:.1f}s - {period[1]:.1f}s'
fig.suptitle(title, fontsize=16)
wheel = self.data['wheel']
wheel_mask = self.alignment['to_mask'](wheel.timestamps)
ts = self.data['camera_times'][self.alignment['label']]
frame_numbers, = np.where(self.alignment['to_mask'](ts))
if energy:
self.alignment['frames'] = video.frame_diffs(self.alignment['frames'], 2)
frame_numbers = frame_numbers[1:-1]
data = {'frame_ids': frame_numbers}
def init_plot():
"""
Plot the wheel data for the current trial
:return: None
"""
data['im'] = axes[0].imshow(self.alignment['frames'][0])
axes[0].axis('off')
axes[0].set_title(f'adjusted by {self.alignment["dt_i"]} frames')
# Plot the wheel position
ax = axes[1]
ax.clear()
ax.plot(wheel.timestamps[wheel_mask], wheel.position[wheel_mask], '-x')
ts_0 = frame_numbers[0]
data['idx_0'] = ts_0 - self.alignment['dt_i']
ts_0 = ts[ts_0 + self.alignment['dt_i']]
data['ln'] = ax.axvline(x=ts_0, color='k')
ax.set_xlim([ts_0 - (3 / 2), ts_0 + (3 / 2)])
data['frame_num'] = 0
mkr = find_nearest(wheel.timestamps[wheel_mask], ts_0)
data['marker'], = ax.plot(wheel.timestamps[wheel_mask][mkr], wheel.position[wheel_mask][mkr], 'r-x')
ax.set_ylabel('Wheel position (rad))')
ax.set_xlabel('Time (s))')
return
def animate(i):
"""
Callback for figure animation. Sets image data for current frame and moves pointer
along axis
:param i: unused; the current time step of the calling method
:return: None
"""
if i < 0:
data['frame_num'] -= 1
if data['frame_num'] < 0:
data['frame_num'] = len(self.alignment['frames']) - 1
else:
data['frame_num'] += 1
if data['frame_num'] >= len(self.alignment['frames']):
data['frame_num'] = 0
i = data['frame_num'] # NB: This is index for current trial's frame list
frame = self.alignment['frames'][i]
t_x = ts[data['idx_0'] + i]
data['ln'].set_xdata([t_x, t_x])
axes[1].set_xlim([t_x - (3 / 2), t_x + (3 / 2)])
data['im'].set_data(frame)
mkr = find_nearest(wheel.timestamps[wheel_mask], t_x)
data['marker'].set_data([wheel.timestamps[wheel_mask][mkr]], [wheel.position[wheel_mask][mkr]])
return data['im'], data['ln'], data['marker']
anim = animation.FuncAnimation(fig, animate, init_func=init_plot,
frames=(range(len(self.alignment.df)) if save else cycle(range(60))), interval=20,
blit=False, repeat=not save, cache_frame_data=False)
anim.running = False
def process_key(event):
"""
Callback for key presses.
:param event: a figure key_press_event
:return: None
"""
if event.key.isspace():
if anim.running:
anim.event_source.stop()
else:
anim.event_source.start()
anim.running = ~anim.running
elif event.key == 'right':
if anim.running:
anim.event_source.stop()
anim.running = False
animate(1)
fig.canvas.draw()
elif event.key == 'left':
if anim.running:
anim.event_source.stop()
anim.running = False
animate(-1)
fig.canvas.draw()
fig.canvas.mpl_connect('key_press_event', process_key)
# init_plot()
# while True:
# animate(0)
if save:
filename = '%s_%c.mp4' % (self.ref, self.alignment['label'][0])
if isinstance(save, (str, Path)):
filename = Path(save).joinpath(filename)
self.log.info(f'Saving to {filename}')
# Set up formatting for the movie files
Writer = animation.writers['ffmpeg']
writer = Writer(fps=24, metadata=dict(artist='Miles Wells'), bitrate=1800)
anim.save(str(filename), writer=writer)
else:
plt.show()
[docs]
class MotionAlignmentFullSession:
def __init__(self, session_path, label, **kwargs):
"""
Class to extract camera times using video motion energy wheel alignment
:param session_path: path of the session
:param label: video label, only 'left' and 'right' videos are supported
:param kwargs: threshold - the threshold to apply when identifying frames with artefacts (default 20)
upload - whether to upload summary figure to alyx (default False)
twin - the window length used when computing the shifts between the wheel and video
nprocesses - the number of CPU processes to use
sync - the type of sync scheme used (options 'nidq' or 'bpod')
location - whether the code is being run on SDSC or not (options 'SDSC' or None)
"""
self.session_path = session_path
self.label = label
self.threshold = kwargs.get('threshold', 20)
self.upload = kwargs.get('upload', False)
self.twin = kwargs.get('twin', 150)
self.nprocess = kwargs.get('nprocess', int(cpu_count() - cpu_count() / 4))
self.load_data(sync=kwargs.get('sync', 'nidq'), location=kwargs.get('location', None))
self.roi, self.mask = self.get_roi_mask()
if self.upload:
self.one = ONE(mode='remote')
self.one.alyx.authenticate()
self.eid = self.one.path2eid(self.session_path)
[docs]
def load_data(self, sync='nidq', location=None):
"""
Loads relevant data from disk to perform motion alignment
:param sync: type of sync used, 'nidq' or 'bpod'
:param location: where the code is being run, if location='SDSC', the dataset uuids are removed
when loading the data
:return:
"""
def fix_keys(alf_object):
"""
Given an alf object removes the dataset uuid from the keys
:param alf_object:
:return:
"""
ob = Bunch()
for key in alf_object.keys():
vals = alf_object[key]
ob[key.split('.')[0]] = vals
return ob
alf_path = self.session_path.joinpath('alf')
wheel_path = next(alf_path.rglob('*wheel.timestamps*')).parent
wheel = (fix_keys(alfio.load_object(wheel_path, 'wheel')) if location == 'SDSC'
else alfio.load_object(wheel_path, 'wheel'))
self.wheel_timestamps = wheel.timestamps
# Compute interpolated wheel position and wheel times
wheel_pos, self.wheel_time = wh.interpolate_position(wheel.timestamps, wheel.position, freq=1000)
# Compute wheel velocity
self.wheel_vel, _ = wh.velocity_filtered(wheel_pos, 1000)
# Load in original camera times
self.camera_times = alfio.load_file_content(next(alf_path.rglob(f'_ibl_{self.label}Camera.times*.npy')))
self.camera_path = str(next(self.session_path.joinpath('raw_video_data').glob(f'_iblrig_{self.label}Camera.raw*.mp4')))
self.camera_meta = vidio.get_video_meta(self.camera_path)
# TODO should read in the description file to get the correct sync location
if sync == 'nidq':
# If the sync is 'nidq' we read in the camera ttls from the spikeglx sync object
sync, chmap = get_sync_and_chn_map(self.session_path, sync_collection='raw_ephys_data')
sr = get_sync_fronts(sync, chmap[f'{self.label}_camera'])
self.ttls = sr.times[::2]
else:
# Otherwise we assume the sync is 'bpod' and we read in the camera ttls from the raw bpod data
cam_extractor = cam.CameraTimestampsBpod(session_path=self.session_path)
cam_extractor.bpod_trials = raw.load_data(self.session_path, task_collection='raw_behavior_data')
self.ttls = cam_extractor._times_from_bpod()
# Check if the ttl and video sizes match up
self.tdiff = self.ttls.size - self.camera_meta['length']
if self.tdiff < 0:
# In this case there are fewer ttls than camera frames. This is not ideal, for now we pad the ttls with
# nans but if this is too many we reject the wheel alignment based on the qc
self.ttl_times = self.ttls
self.times = np.r_[self.ttl_times, np.full((np.abs(self.tdiff)), np.nan)]
self.short_flag = True
elif self.tdiff > 0:
# In this case there are more ttls than camera frames. This happens often, for now we remove the first
# tdiff ttls from the ttls
self.ttl_times = self.ttls[self.tdiff:]
self.times = self.ttls[self.tdiff:]
self.short_flag = False
# Compute the frame rate of the camera
self.frate = round(1 / np.nanmedian(np.diff(self.ttl_times)))
# We attempt to load in some behavior data (trials and dlc). This is only needed for the summary plots, having
# trial aligned paw velocity (from the dlc) is a nice sanity check to make sure the alignment went well
try:
self.trials = alfio.load_file_content(next(alf_path.rglob('_ibl_trials.table*.pqt')))
self.dlc = alfio.load_file_content(next(alf_path.rglob(f'_ibl_{self.label}Camera.dlc*.pqt')))
self.dlc = likelihood_threshold(self.dlc)
self.behavior = True
except (ALFObjectNotFound, StopIteration):
self.behavior = False
# Load in a single frame that we will use for the summary plot
self.frame_example = vidio.get_video_frames_preload(self.camera_path, np.arange(10, 11), mask=np.s_[:, :, 0])
[docs]
def get_roi_mask(self):
"""
Compute the region of interest mask for a given camera. This corresponds to a box in the video that we will
use to compute the wheel motion energy
:return:
"""
if self.label == 'right':
roi = ((450, 512), (120, 200))
else:
roi = ((900, 1024), (850, 1010))
roi_mask = (*[slice(*r) for r in roi], 0)
return roi, roi_mask
[docs]
def find_contaminated_frames(self, video_frames, thresold=20, normalise=True):
"""
Finds frames in the video that have artefacts such as the mouse's paw or a human hand. In order to determine
frames with contamination an Otsu thresholding is applied to each frame to detect the artefact from the
background image
:param video_frames: np array of video frames (nframes, nwidth, nheight)
:param thresold: threshold to differentiate artefact from background
:param normalise: whether to normalise the threshold values for each frame to the baseline
:return: mask of frames that are contaminated
"""
high = np.zeros((video_frames.shape[0]))
# Iterate through each frame and compute and store the otsu threshold value for each frame
for idx, frame in enumerate(video_frames):
ret, _ = cv2.threshold(cv2.GaussianBlur(frame, (5, 5), 0), 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
high[idx] = ret
# If normalise is True, we divide the threshold values for each frame by the minimum value
if normalise:
high -= np.min(high)
# Identify the frames that have a threshold value greater than the specified threshold cutoff
contaminated_frames = np.where(high > thresold)[0]
return contaminated_frames
[docs]
def compute_motion_energy(self, first, last, wg, iw):
"""
Computes the video motion energy for frame indexes between first and last. This function is written to be run
in a parallel fashion jusing joblib.parallel
:param first: first frame index of frame interval to consider
:param last: last frame index of frame interval to consider
:param wg: WindowGenerator
:param iw: iteration of the WindowGenerator
:return:
"""
if iw == wg.nwin - 1:
return
# Open the video and read in the relvant video frames between first idx and last idx
cap = cv2.VideoCapture(self.camera_path)
frames = vidio.get_video_frames_preload(cap, np.arange(first, last), mask=self.mask)
# Identify if any of the frames have artefacts in them
idx = self.find_contaminated_frames(frames, self.threshold)
# If some of the frames are contaminated we find all the continuous intervals of contamination
# and set the value for contaminated pixels for these frames to the average of the first frame before and after
# this contamination interval
if len(idx) != 0:
before_status = False
after_status = False
counter = 0
n_frames = 200
# If it is the first frame that is contaminated, we need to read in a bit more of the video to find a
# frame prior to contamination. We attempt this 20 times, after that we just take the value for the first
# frame
while np.any(idx == 0) and counter < 20 and iw != 0:
n_before_offset = (counter + 1) * n_frames
first -= n_frames
extra_frames = vidio.get_video_frames_preload(cap, frame_numbers=np.arange(first - n_frames, first),
mask=self.mask)
frames = np.concatenate([extra_frames, frames], axis=0)
idx = self.find_contaminated_frames(frames, self.threshold)
before_status = True
counter += 1
if counter > 0:
print(f'In before: {counter}')
counter = 0
# If it is the last frame that is contaminated, we need to read in a bit more of the video to find a
# frame after the contamination. We attempt this 20 times, after that we just take the value for the last
# frame
while np.any(idx == frames.shape[0] - 1) and counter < 20 and iw != wg.nwin - 1:
n_after_offset = (counter + 1) * n_frames
last += n_frames
extra_frames = vidio.get_video_frames_preload(cap, frame_numbers=np.arange(last, last + n_frames), mask=self.mask)
frames = np.concatenate([frames, extra_frames], axis=0)
idx = self.find_contaminated_frames(frames, self.threshold)
after_status = True
counter += 1
if counter > 0:
print(f'In after: {counter}')
# We find all the continuous intervals that contain contamination and fix the affected pixels
# by taking the average value of the frame prior and after contamination
intervals = np.split(idx, np.where(np.diff(idx) != 1)[0] + 1)
for ints in intervals:
if len(ints) > 0 and ints[0] == 0:
ints = ints[1:]
if len(ints) > 0 and ints[-1] == frames.shape[0] - 1:
ints = ints[:-1]
th_all = np.zeros_like(frames[0])
# We find all affected pixels
for idx in ints:
img = np.copy(frames[idx])
blur = cv2.GaussianBlur(img, (5, 5), 0)
ret, th = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
th = cv2.GaussianBlur(th, (5, 5), 10)
th_all += th
# Compute the average image of the frame prior and after the interval
vals = np.mean(np.dstack([frames[ints[0] - 1], frames[ints[-1] + 1]]), axis=-1)
# For each frame set the affected pixels to the value of the clean average image
for idx in ints:
img = frames[idx]
img[th_all > 0] = vals[th_all > 0]
# If we have read in extra video frames we need to cut these off and make sure we only
# consider the frames between the interval first and last given as args
if before_status:
frames = frames[n_before_offset:]
if after_status:
frames = frames[:(-1 * n_after_offset)]
# Once the frames have been cleaned we compute the motion energy between frames
frame_me, _ = video.motion_energy(frames, diff=2, normalize=False)
cap.release()
return frame_me[2:]
[docs]
def compute_shifts(self, times, me, first, last, iw, wg):
"""
Compute the cross-correlation between the video motion energy and the wheel velocity to find the mismatch
between the camera ttls and the video frames. This function is written to run in a parallel manner using
joblib.parallel
:param times: the times of the video frames across the whole session (ttls)
:param me: the video motion energy computed across the whole session
:param first: first time idx to consider
:param last: last time idx to consider
:param wg: WindowGenerator
:param iw: iteration of the WindowGenerator
:return:
"""
# If we are in the last window we exit
if iw == wg.nwin - 1:
return np.nan, np.nan
# Find the time interval we are interested in
t_first = times[first]
t_last = times[last]
# If both times during this interval are nan exit
if np.isnan(t_last) and np.isnan(t_first):
return np.nan, np.nan
# If only the last time is nan, we find the last non nan time value
elif np.isnan(t_last):
t_last = times[np.where(~np.isnan(times))[0][-1]]
# Find the mask of timepoints that fall in this interval
mask = np.logical_and(times >= t_first, times <= t_last)
# Restrict the video motion energy to this interval and normalise the values
align_me = me[np.where(mask)[0]]
align_me = (align_me - np.nanmin(align_me)) / (np.nanmax(align_me) - np.nanmin(align_me))
# Find closest timepoints in wheel that match the time interval
wh_mask = np.logical_and(self.wheel_time >= t_first, self.wheel_time <= t_last)
if np.sum(wh_mask) == 0:
return np.nan, np.nan
# Find the mask for the wheel times
xs = np.searchsorted(self.wheel_time[wh_mask], times[mask])
xs[xs == np.sum(wh_mask)] = np.sum(wh_mask) - 1
# Convert to normalized speed
vs = np.abs(self.wheel_vel[wh_mask][xs])
vs = (vs - np.min(vs)) / (np.max(vs) - np.min(vs))
# Account for nan values in the video motion energy
isnan = np.isnan(align_me)
if np.sum(isnan) > 0:
where_nan = np.where(isnan)[0]
assert where_nan[0] == 0
assert where_nan[-1] == np.sum(isnan) - 1
if np.all(isnan):
return np.nan, np.nan
# Compute the cross correlation between the video motion energy and the wheel speed
xcorr = signal.correlate(align_me[~isnan], vs[~isnan])
# The max value of the cross correlation indicates the shift that needs to be applied
# The +2 comes from the fact that the video motion energy was computed from the difference between frames
shift = np.nanargmax(xcorr) - align_me[~isnan].size + 2
return shift, t_first + (t_last - t_first) / 2
[docs]
def clean_shifts(self, x, n=1):
"""
Removes artefacts from the computed shifts across time. We assume that the shifts should never increase
over time and that the jump between consecutive shifts shouldn't be greater than 1
:param x: computed shifts
:param n: condition to apply
:return:
"""
y = x.copy()
dy = np.diff(y, prepend=y[0])
while True:
pos = np.where(dy == 1)[0] if n == 1 else np.where(dy > 2)[0]
# added frames: this doesn't make sense and this is noise
if pos.size == 0:
break
neg = np.where(dy == -1)[0] if n == 1 else np.where(dy < -2)[0]
if len(pos) > len(neg):
neg = np.append(neg, dy.size - 1)
iss = np.minimum(np.searchsorted(neg, pos), neg.size - 1)
imin = np.argmin(np.minimum(np.abs(pos - neg[iss - 1]), np.abs(pos - neg[iss])))
idx = np.max([0, iss[imin] - 1])
ineg = neg[idx:iss[imin] + 1]
ineg = ineg[np.argmin(np.abs(pos[imin] - ineg))]
dy[pos[imin]] = 0
dy[ineg] = 0
return np.cumsum(dy) + y[0]
[docs]
def qc_shifts(self, shifts, shifts_filt):
"""
Compute qc values for the wheel alignment. We consider 4 things
1. The number of camera ttl values that are missing (when we have less ttls than video frames)
2. The number of shifts that have nan values, this means the video motion energy computation
3. The number of large jumps (>10) between the computed shifts
4. The number of jumps (>1) between the shifts after they have been cleaned
:param shifts: np.array of shifts over session
:param shifts_filt: np.array of shifts after being cleaned over session
:return:
"""
ttl_per = (np.abs(self.tdiff) / self.camera_meta['length']) * 100 if self.tdiff < 0 else 0
nan_per = (np.sum(np.isnan(shifts_filt)) / shifts_filt.size) * 100
shifts_sum = np.where(np.abs(np.diff(shifts)) > 10)[0].size
shifts_filt_sum = np.where(np.abs(np.diff(shifts_filt)) > 1)[0].size
qc = dict()
qc['ttl_per'] = ttl_per
qc['nan_per'] = nan_per
qc['shifts_sum'] = shifts_sum
qc['shifts_filt_sum'] = shifts_filt_sum
qc_outcome = True
# If more than 10% of ttls are missing we don't get new times
if ttl_per > 10:
qc_outcome = False
# If too many of the shifts are nans it means the alignment is not accurate
if nan_per > 40:
qc_outcome = False
# If there are too many artefacts could be errors
if shifts_sum > 60:
qc_outcome = False
# If there are jumps > 1 in the filtered shifts then there is a problem
if shifts_filt_sum > 0:
qc_outcome = False
return qc, qc_outcome
[docs]
@staticmethod
def single_cluster_raster(spike_times, events, trial_idx, dividers, colors, labels, weights=None, fr=True,
norm=False, axs=None):
"""
Compute and plot trial aligned spike rasters and psth
:param spike_times: times of variable
:param events: trial times to align to
:param trial_idx: trial idx to sort by
:param dividers:
:param colors:
:param labels:
:param weights:
:param fr:
:param norm:
:param axs:
:return:
"""
pre_time = 0.4
post_time = 1
raster_bin = 0.01
psth_bin = 0.05
raster, t_raster = bin_spikes(
spike_times, events, pre_time=pre_time, post_time=post_time, bin_size=raster_bin, weights=weights)
psth, t_psth = bin_spikes(
spike_times, events, pre_time=pre_time, post_time=post_time, bin_size=psth_bin, weights=weights)
if fr:
psth = psth / psth_bin
if norm:
psth = psth - np.repeat(psth[:, 0][:, np.newaxis], psth.shape[1], axis=1)
raster = raster - np.repeat(raster[:, 0][:, np.newaxis], raster.shape[1], axis=1)
dividers = [0] + dividers + [len(trial_idx)]
if axs is None:
fig, axs = plt.subplots(2, 1, figsize=(4, 6), gridspec_kw={'height_ratios': [1, 3], 'hspace': 0}, sharex=True)
else:
fig = axs[0].get_figure()
label, lidx = np.unique(labels, return_index=True)
label_pos = []
for lab, lid in zip(label, lidx):
idx = np.where(np.array(labels) == lab)[0]
for iD in range(len(idx)):
if iD == 0:
t_ids = trial_idx[dividers[idx[iD]] + 1:dividers[idx[iD] + 1] + 1]
t_ints = dividers[idx[iD] + 1] - dividers[idx[iD]]
else:
t_ids = np.r_[t_ids, trial_idx[dividers[idx[iD]] + 1:dividers[idx[iD] + 1] + 1]]
t_ints = np.r_[t_ints, dividers[idx[iD] + 1] - dividers[idx[iD]]]
psth_div = np.nanmean(psth[t_ids], axis=0)
std_div = np.nanstd(psth[t_ids], axis=0) / np.sqrt(len(t_ids))
axs[0].fill_between(t_psth, psth_div - std_div, psth_div + std_div, alpha=0.4, color=colors[lid])
axs[0].plot(t_psth, psth_div, alpha=1, color=colors[lid])
lab_max = idx[np.argmax(t_ints)]
label_pos.append((dividers[lab_max + 1] - dividers[lab_max]) / 2 + dividers[lab_max])
axs[1].imshow(raster[trial_idx], cmap='binary', origin='lower',
extent=[np.min(t_raster), np.max(t_raster), 0, len(trial_idx)], aspect='auto')
width = raster_bin * 4
for iD in range(len(dividers) - 1):
axs[1].fill_between([post_time + raster_bin / 2, post_time + raster_bin / 2 + width],
[dividers[iD + 1], dividers[iD + 1]], [dividers[iD], dividers[iD]], color=colors[iD])
axs[1].set_xlim([-1 * pre_time, post_time + raster_bin / 2 + width])
secax = axs[1].secondary_yaxis('right')
secax.set_yticks(label_pos)
secax.set_yticklabels(label, rotation=90, rotation_mode='anchor', ha='center')
for ic, c in enumerate(np.array(colors)[lidx]):
secax.get_yticklabels()[ic].set_color(c)
axs[0].axvline(0, *axs[0].get_ylim(), c='k', ls='--', zorder=10) # TODO this doesn't always work
axs[1].axvline(0, *axs[1].get_ylim(), c='k', ls='--', zorder=10)
return fig, axs
[docs]
def plot_with_behavior(self):
"""
Makes a summary figure of the alignment when behaviour data is available
:return:
"""
self.dlc = likelihood_threshold(self.dlc)
trial_idx, dividers = find_trial_ids(self.trials, sort='side')
feature_ext = get_speed(self.dlc, self.camera_times, self.label, feature='paw_r')
feature_new = get_speed(self.dlc, self.new_times, self.label, feature='paw_r')
fig = plt.figure()
fig.set_size_inches(15, 9)
gs = gridspec.GridSpec(1, 5, figure=fig, width_ratios=[4, 1, 1, 1, 3], wspace=0.3, hspace=0.5)
gs0 = gridspec.GridSpecFromSubplotSpec(3, 1, subplot_spec=gs[0, 0])
ax01 = fig.add_subplot(gs0[0, 0])
ax02 = fig.add_subplot(gs0[1, 0])
ax03 = fig.add_subplot(gs0[2, 0])
gs1 = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=gs[0, 1], height_ratios=[1, 3])
ax11 = fig.add_subplot(gs1[0, 0])
ax12 = fig.add_subplot(gs1[1, 0])
gs2 = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=gs[0, 2], height_ratios=[1, 3])
ax21 = fig.add_subplot(gs2[0, 0])
ax22 = fig.add_subplot(gs2[1, 0])
gs3 = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=gs[0, 3], height_ratios=[1, 3])
ax31 = fig.add_subplot(gs3[0, 0])
ax32 = fig.add_subplot(gs3[1, 0])
gs4 = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=gs[0, 4])
ax41 = fig.add_subplot(gs4[0, 0])
ax42 = fig.add_subplot(gs4[1, 0])
ax01.plot(self.t_shifts, self.shifts, label='shifts')
ax01.plot(self.t_shifts, self.shifts_filt, label='shifts_filt')
ax01.set_ylim(np.min(self.shifts_filt) - 10, np.max(self.shifts_filt) + 10)
ax01.legend()
ax01.set_ylabel('Frames')
ax01.set_xlabel('Time in session')
xs = np.searchsorted(self.ttl_times, self.t_shifts)
ttl_diff = (self.times - self.camera_times)[xs] * self.camera_meta['fps']
ax02.plot(self.t_shifts, ttl_diff, label='extracted - ttl')
ax02.set_ylim(np.min(ttl_diff) - 10, np.max(ttl_diff) + 10)
ax02.legend()
ax02.set_ylabel('Frames')
ax02.set_xlabel('Time in session')
ax03.plot(self.camera_times, (self.camera_times - self.new_times) * self.camera_meta['fps'], 'k', label='extracted - new')
ax03.legend()
ax03.set_ylim(-5, 5)
ax03.set_ylabel('Frames')
ax03.set_xlabel('Time in session')
self.single_cluster_raster(self.wheel_timestamps, self.trials['firstMovement_times'].values, trial_idx, dividers,
['g', 'y'], ['left', 'right'], weights=self.wheel_vel, fr=False, axs=[ax11, ax12])
ax11.sharex(ax12)
ax11.set_ylabel('Wheel velocity')
ax11.set_title('Wheel')
ax12.set_xlabel('Time from first move')
self.single_cluster_raster(self.camera_times, self.trials['firstMovement_times'].values, trial_idx, dividers, ['g', 'y'],
['left', 'right'], weights=feature_ext, fr=False, axs=[ax21, ax22])
ax21.sharex(ax22)
ax21.set_ylabel('Paw r velocity')
ax21.set_title('Extracted times')
ax22.set_xlabel('Time from first move')
self.single_cluster_raster(self.new_times, self.trials['firstMovement_times'].values, trial_idx, dividers, ['g', 'y'],
['left', 'right'], weights=feature_new, fr=False, axs=[ax31, ax32])
ax31.sharex(ax32)
ax31.set_ylabel('Paw r velocity')
ax31.set_title('New times')
ax32.set_xlabel('Time from first move')
ax41.imshow(self.frame_example[0])
rect = matplotlib.patches.Rectangle((self.roi[1][1], self.roi[0][0]), self.roi[1][0] - self.roi[1][1],
self.roi[0][1] - self.roi[0][0], linewidth=4, edgecolor='g', facecolor='none')
ax41.add_patch(rect)
ax42.plot(self.all_me)
return fig
[docs]
def plot_without_behavior(self):
"""
Makes a summary figure of the alignment when behaviour data is not available
:return:
"""
fig = plt.figure()
fig.set_size_inches(7, 7)
gs = gridspec.GridSpec(1, 2, figure=fig)
gs0 = gridspec.GridSpecFromSubplotSpec(3, 1, subplot_spec=gs[0, 0])
ax01 = fig.add_subplot(gs0[0, 0])
ax02 = fig.add_subplot(gs0[1, 0])
ax03 = fig.add_subplot(gs0[2, 0])
gs1 = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=gs[0, 1])
ax04 = fig.add_subplot(gs1[0, 0])
ax05 = fig.add_subplot(gs1[1, 0])
ax01.plot(self.t_shifts, self.shifts, label='shifts')
ax01.plot(self.t_shifts, self.shifts_filt, label='shifts_filt')
ax01.set_ylim(np.min(self.shifts_filt) - 10, np.max(self.shifts_filt) + 10)
ax01.legend()
ax01.set_ylabel('Frames')
ax01.set_xlabel('Time in session')
xs = np.searchsorted(self.ttl_times, self.t_shifts)
ttl_diff = (self.times - self.camera_times)[xs] * self.camera_meta['fps']
ax02.plot(self.t_shifts, ttl_diff, label='extracted - ttl')
ax02.set_ylim(np.min(ttl_diff) - 10, np.max(ttl_diff) + 10)
ax02.legend()
ax02.set_ylabel('Frames')
ax02.set_xlabel('Time in session')
ax03.plot(self.camera_times, (self.camera_times - self.new_times) * self.camera_meta['fps'], 'k', label='extracted - new')
ax03.legend()
ax03.set_ylim(-5, 5)
ax03.set_ylabel('Frames')
ax03.set_xlabel('Time in session')
ax04.imshow(self.frame_example[0])
rect = matplotlib.patches.Rectangle((self.roi[1][1], self.roi[0][0]), self.roi[1][0] - self.roi[1][1],
self.roi[0][1] - self.roi[0][0], linewidth=4, edgecolor='g', facecolor='none')
ax04.add_patch(rect)
ax05.plot(self.all_me)
return fig
[docs]
def process(self):
"""
Main function used to apply the video motion wheel alignment to the camera times. This function does the
following
1. Computes the video motion energy across the whole session (computed in windows and parallelised)
2. Computes the shift that should be applied to the camera times across the whole session by computing
the cross correlation between the video motion energy and the wheel speed (computed in
overlapping windows and parallelised)
3. Removes artefacts from the computed shifts
4. Computes the qc for the wheel alignment
5. Extracts the new camera times using the shifts computed from the video wheel alignment
6. If upload is True, creates a summary plot of the alignment and uploads the figure to the relevant session
on alyx
:return:
"""
# Compute the motion energy of the wheel for the whole video
wg = WindowGenerator(self.camera_meta['length'], 5000, 4)
out = Parallel(n_jobs=self.nprocess)(
delayed(self.compute_motion_energy)(first, last, wg, iw) for iw, (first, last) in enumerate(wg.firstlast))
# Concatenate the motion energy into one big array
self.all_me = np.array([])
for vals in out[:-1]:
self.all_me = np.r_[self.all_me, vals]
toverlap = self.twin - 1
all_me = np.r_[np.full((int(self.camera_meta['fps'] * toverlap)), np.nan), self.all_me]
to_app = self.times[0] - ((np.arange(int(self.camera_meta['fps'] * toverlap), ) + 1) / self.frate)[::-1]
times = np.r_[to_app, self.times]
wg = WindowGenerator(all_me.size - 1, int(self.camera_meta['fps'] * self.twin), int(self.camera_meta['fps'] * toverlap))
out = Parallel(n_jobs=1)(delayed(self.compute_shifts)(times, all_me, first, last, iw, wg)
for iw, (first, last) in enumerate(wg.firstlast))
self.shifts = np.array([])
self.t_shifts = np.array([])
for vals in out[:-1]:
self.shifts = np.r_[self.shifts, vals[0]]
self.t_shifts = np.r_[self.t_shifts, vals[1]]
idx = np.bitwise_and(self.t_shifts >= self.ttl_times[0], self.t_shifts < self.ttl_times[-1])
self.shifts = self.shifts[idx]
self.t_shifts = self.t_shifts[idx]
shifts_filt = ndimage.percentile_filter(self.shifts, 80, 120)
shifts_filt = self.clean_shifts(shifts_filt, n=1)
self.shifts_filt = self.clean_shifts(shifts_filt, n=2)
self.qc, self.qc_outcome = self.qc_shifts(self.shifts, self.shifts_filt)
self.new_times = self.extract_times(self.shifts_filt, self.t_shifts)
if self.upload:
fig = self.plot_with_behavior() if self.behavior else self.plot_without_behavior()
save_fig_path = Path(self.session_path.joinpath('snapshot', 'video', f'video_wheel_alignment_{self.label}.png'))
save_fig_path.parent.mkdir(exist_ok=True, parents=True)
fig.savefig(save_fig_path)
snp = ReportSnapshot(self.session_path, self.eid, content_type='session', one=self.one)
snp.outputs = [save_fig_path]
snp.register_images(widths=['orig'])
plt.close(fig)
return self.new_times