Source code for ibllib.io.extractors.video_motion

"""
A module for aligning the wheel motion with the rotary encoder.  Currently used by the camera QC
in order to check timestamp alignment.
"""
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.widgets import RectangleSelector
import numpy as np
from scipy import signal, ndimage, interpolate
import cv2
from itertools import cycle
import matplotlib.animation as animation
import logging
from pathlib import Path
from joblib import Parallel, delayed, cpu_count

from ibldsp.utils import WindowGenerator
from one.api import ONE
import ibllib.io.video as vidio
from iblutil.util import Bunch
from ibllib.io.extractors.ephys_fpga import get_sync_fronts, get_sync_and_chn_map
import ibllib.io.raw_data_loaders as raw
import ibllib.io.extractors.camera as cam
from ibllib.plots.snapshot import ReportSnapshot
import brainbox.video as video
import brainbox.behavior.wheel as wh
from brainbox.singlecell import bin_spikes
from brainbox.behavior.dlc import likelihood_threshold, get_speed
from brainbox.task.trials import find_trial_ids
import one.alf.io as alfio
from one.alf.exceptions import ALFObjectNotFound
from one.alf.spec import is_session_path, is_uuid_string


[docs] def find_nearest(array, value): array = np.asarray(array) idx = (np.abs(array - value)).argmin() return idx
[docs] class MotionAlignment: roi = {'left': ((800, 1020), (233, 1096)), 'right': ((426, 510), (104, 545)), 'body': ((402, 481), (31, 103))} def __init__(self, eid=None, one=None, log=logging.getLogger(__name__), **kwargs): self.one = one or ONE() self.eid = eid self.session_path = kwargs.pop('session_path', None) or self.one.eid2path(eid) self.ref = self.one.dict2ref(self.one.path2ref(self.session_path)) self.log = log self.trials = self.wheel = self.camera_times = None raw_cam_path = self.session_path.joinpath('raw_video_data') camera_path = list(raw_cam_path.glob('_iblrig_*Camera.raw.*')) self.video_paths = {vidio.label_from_path(x): x for x in camera_path} self.data = Bunch() self.alignment = Bunch()
[docs] def align_all_trials(self, side='all'): """Align all wheel motion for all trials""" if self.trials is None: self.load_data() if side == 'all': side = self.video_paths.keys() if not isinstance(side, str): # Try to iterate over sides [self.align_all_trials(s) for s in side] if side not in self.video_paths: raise ValueError(f'{side} camera video file not found') # Align each trial sequentially for i in np.arange(self.trials['intervals'].shape[0]): self.align_motion(i, display=False)
[docs] @staticmethod def set_roi(video_path): """Manually set the ROIs for a given set of videos TODO Improve docstring TODO A method for setting ROIs by label """ frame = vidio.get_video_frame(str(video_path), 0) def line_select_callback(eclick, erelease): """ Callback for line selection. *eclick* and *erelease* are the press and release events. """ x1, y1 = eclick.xdata, eclick.ydata x2, y2 = erelease.xdata, erelease.ydata print("(%3.2f, %3.2f) --> (%3.2f, %3.2f)" % (x1, y1, x2, y2)) return np.array([[x1, x2], [y1, y2]]) plt.imshow(frame) roi = RectangleSelector(plt.gca(), line_select_callback, drawtype='box', useblit=True, button=[1, 3], # don't use middle button minspanx=5, minspany=5, spancoords='pixels', interactive=True) plt.show() ((x1, x2, *_), (y1, *_, y2)) = roi.corners col = np.arange(round(x1), round(x2), dtype=int) row = np.arange(round(y1), round(y2), dtype=int) return col, row
[docs] def load_data(self, download=False): """ Load wheel, trial and camera timestamp data :return: wheel, trials """ if download: self.data.wheel = self.one.load_object(self.eid, 'wheel') self.data.trials = self.one.load_object(self.eid, 'trials') cam = self.one.load(self.eid, ['camera.times'], dclass_output=True) self.data.camera_times = {vidio.label_from_path(url): ts for ts, url in zip(cam.data, cam.url)} else: alf_path = self.session_path / 'alf' wheel_path = next(alf_path.rglob('*wheel.timestamps*')).parent self.data.wheel = alfio.load_object(wheel_path, 'wheel', short_keys=True) trials_path = next(alf_path.rglob('*trials.table*')).parent self.data.trials = alfio.load_object(trials_path, 'trials') self.data.camera_times = {vidio.label_from_path(x): alfio.load_file_content(x) for x in alf_path.rglob('*Camera.times*')} assert all(x is not None for x in self.data.values())
def _set_eid_or_path(self, session_path_or_eid): """Parse a given eID or session path If a session UUID is given, resolves and stores the local path and vice versa :param session_path_or_eid: A session eid or path :return: """ self.eid = None if is_uuid_string(str(session_path_or_eid)): self.eid = session_path_or_eid # Try to set session_path if data is found locally self.session_path = self.one.eid2path(self.eid) elif is_session_path(session_path_or_eid): self.session_path = Path(session_path_or_eid) if self.one is not None: self.eid = self.one.path2eid(self.session_path) if not self.eid: self.log.warning('Failed to determine eID from session path') else: self.log.error('Cannot run alignment: an experiment uuid or session path is required') raise ValueError("'session' must be a valid session path or uuid")
[docs] def align_motion(self, period=(-np.inf, np.inf), side='left', sd_thresh=10, display=False): """ Align video to the wheel using cross-correlation of the video motion signal and the rotary encoder. Parameters ---------- period : (float, float) The time period over which to do the alignment. side : {'left', 'right'} With which camera to perform the alignment. sd_thresh : float For plotting where the motion energy goes above this standard deviation threshold. display : bool When true, displays the aligned wheel motion energy along with the rotary encoder signal. Returns ------- int Frame offset, i.e. by how many frames the video was shifted to match the rotary encoder signal. Negative values mean the video was shifted backwards with respect to the wheel timestamps. float The peak cross-correlation. numpy.ndarray The motion energy used in the cross-correlation, i.e. the frame difference for the period given. """ # Get data samples within period wheel = self.data['wheel'] self.alignment.label = side self.alignment.to_mask = lambda ts: np.logical_and(ts >= period[0], ts <= period[1]) camera_times = self.data['camera_times'][side] cam_mask = self.alignment.to_mask(camera_times) frame_numbers, = np.where(cam_mask) if frame_numbers.size == 0: raise ValueError('No frames during given period') # Motion Energy camera_path = self.video_paths[side] roi = (*[slice(*r) for r in self.roi[side]], 0) try: # TODO Add function arg to make grayscale self.alignment.frames = vidio.get_video_frames_preload(camera_path, frame_numbers, mask=roi) assert self.alignment.frames.size != 0 except AssertionError: self.log.error('Failed to open video') return None, None, None self.alignment.df, stDev = video.motion_energy(self.alignment.frames, 2) self.alignment.period = period # For plotting # Calculate rotary encoder velocity trace x = camera_times[cam_mask] Fs = 1000 pos, t = wh.interpolate_position(wheel.timestamps, wheel.position, freq=Fs) v, _ = wh.velocity_filtered(pos, Fs) interp_mask = self.alignment.to_mask(t) # Convert to normalized speed xs = np.unique([find_nearest(t[interp_mask], ts) for ts in x]) vs = np.abs(v[interp_mask][xs]) vs = (vs - np.min(vs)) / (np.max(vs) - np.min(vs)) # FIXME This can be used as a goodness of fit measure USE_CV2 = False if USE_CV2: # convert from numpy format to openCV format dfCV = np.float32(self.alignment.df.reshape((-1, 1))) reCV = np.float32(vs.reshape((-1, 1))) # perform cross correlation resultCv = cv2.matchTemplate(dfCV, reCV, cv2.TM_CCORR_NORMED) # convert result back to numpy array xcorr = np.asarray(resultCv) else: xcorr = signal.correlate(self.alignment.df, vs) # Cross correlate wheel speed trace with the motion energy CORRECTION = 2 self.alignment.c = max(xcorr) self.alignment.xcorr = np.argmax(xcorr) self.alignment.dt_i = self.alignment.xcorr - xs.size + CORRECTION self.log.info(f'{side} camera, adjusted by {self.alignment.dt_i} frames') if display: # Plot the motion energy fig, ax = plt.subplots(2, 1, sharex='all') y = np.pad(self.alignment.df, 1, 'edge') ax[0].plot(x, y, '-x', label='wheel motion energy') thresh = stDev > sd_thresh ax[0].vlines(x[np.array(np.pad(thresh, 1, 'constant', constant_values=False))], 0, 1, linewidth=0.5, linestyle=':', label=f'>{sd_thresh} s.d. diff') ax[1].plot(t[interp_mask], np.abs(v[interp_mask])) # Plot other stuff dt = np.diff(camera_times[[0, np.abs(self.alignment.dt_i)]]) fps = 1 / np.diff(camera_times).mean() ax[0].plot(t[interp_mask][xs] - dt, vs, 'r-x', label='velocity (shifted)') ax[0].set_title('normalized motion energy, %s camera, %.0f fps' % (side, fps)) ax[0].set_ylabel('rate of change (a.u.)') ax[0].legend() ax[1].set_ylabel('wheel speed (rad / s)') ax[1].set_xlabel('Time (s)') title = f'{self.ref}, from {period[0]:.1f}s - {period[1]:.1f}s' fig.suptitle(title, fontsize=16) fig.set_size_inches(19.2, 9.89) return self.alignment.dt_i, self.alignment.c, self.alignment.df
[docs] def plot_alignment(self, energy=True, save=False): if not self.alignment: self.log.error('No alignment data, run `align_motion` first') return # Change backend based on save flag backend = matplotlib.get_backend().lower() if (save and backend != 'agg') or (not save and backend == 'agg'): new_backend = 'Agg' if save else 'Qt5Agg' self.log.warning('Switching backend from %s to %s', backend, new_backend) matplotlib.use(new_backend) from matplotlib import pyplot as plt # Main animated plots fig, axes = plt.subplots(nrows=2) title = f'{self.ref}' # ', from {period[0]:.1f}s - {period[1]:.1f}s' fig.suptitle(title, fontsize=16) wheel = self.data['wheel'] wheel_mask = self.alignment['to_mask'](wheel.timestamps) ts = self.data['camera_times'][self.alignment['label']] frame_numbers, = np.where(self.alignment['to_mask'](ts)) if energy: self.alignment['frames'] = video.frame_diffs(self.alignment['frames'], 2) frame_numbers = frame_numbers[1:-1] data = {'frame_ids': frame_numbers} def init_plot(): """ Plot the wheel data for the current trial :return: None """ data['im'] = axes[0].imshow(self.alignment['frames'][0]) axes[0].axis('off') axes[0].set_title(f'adjusted by {self.alignment["dt_i"]} frames') # Plot the wheel position ax = axes[1] ax.clear() ax.plot(wheel.timestamps[wheel_mask], wheel.position[wheel_mask], '-x') ts_0 = frame_numbers[0] data['idx_0'] = ts_0 - self.alignment['dt_i'] ts_0 = ts[ts_0 + self.alignment['dt_i']] data['ln'] = ax.axvline(x=ts_0, color='k') ax.set_xlim([ts_0 - (3 / 2), ts_0 + (3 / 2)]) data['frame_num'] = 0 mkr = find_nearest(wheel.timestamps[wheel_mask], ts_0) data['marker'], = ax.plot(wheel.timestamps[wheel_mask][mkr], wheel.position[wheel_mask][mkr], 'r-x') ax.set_ylabel('Wheel position (rad))') ax.set_xlabel('Time (s))') return def animate(i): """ Callback for figure animation. Sets image data for current frame and moves pointer along axis :param i: unused; the current time step of the calling method :return: None """ if i < 0: data['frame_num'] -= 1 if data['frame_num'] < 0: data['frame_num'] = len(self.alignment['frames']) - 1 else: data['frame_num'] += 1 if data['frame_num'] >= len(self.alignment['frames']): data['frame_num'] = 0 i = data['frame_num'] # NB: This is index for current trial's frame list frame = self.alignment['frames'][i] t_x = ts[data['idx_0'] + i] data['ln'].set_xdata([t_x, t_x]) axes[1].set_xlim([t_x - (3 / 2), t_x + (3 / 2)]) data['im'].set_data(frame) mkr = find_nearest(wheel.timestamps[wheel_mask], t_x) data['marker'].set_data(wheel.timestamps[wheel_mask][mkr], wheel.position[wheel_mask][mkr]) return data['im'], data['ln'], data['marker'] anim = animation.FuncAnimation(fig, animate, init_func=init_plot, frames=(range(len(self.alignment.df)) if save else cycle(range(60))), interval=20, blit=False, repeat=not save, cache_frame_data=False) anim.running = False def process_key(event): """ Callback for key presses. :param event: a figure key_press_event :return: None """ if event.key.isspace(): if anim.running: anim.event_source.stop() else: anim.event_source.start() anim.running = ~anim.running elif event.key == 'right': if anim.running: anim.event_source.stop() anim.running = False animate(1) fig.canvas.draw() elif event.key == 'left': if anim.running: anim.event_source.stop() anim.running = False animate(-1) fig.canvas.draw() fig.canvas.mpl_connect('key_press_event', process_key) # init_plot() # while True: # animate(0) if save: filename = '%s_%c.mp4' % (self.ref, self.alignment['label'][0]) if isinstance(save, (str, Path)): filename = Path(save).joinpath(filename) self.log.info(f'Saving to {filename}') # Set up formatting for the movie files Writer = animation.writers['ffmpeg'] writer = Writer(fps=24, metadata=dict(artist='Miles Wells'), bitrate=1800) anim.save(str(filename), writer=writer) else: plt.show()
[docs] class MotionAlignmentFullSession: def __init__(self, session_path, label, **kwargs): """ Class to extract camera times using video motion energy wheel alignment :param session_path: path of the session :param label: video label, only 'left' and 'right' videos are supported :param kwargs: threshold - the threshold to apply when identifying frames with artefacts (default 20) upload - whether to upload summary figure to alyx (default False) twin - the window length used when computing the shifts between the wheel and video nprocesses - the number of CPU processes to use sync - the type of sync scheme used (options 'nidq' or 'bpod') location - whether the code is being run on SDSC or not (options 'SDSC' or None) """ self.session_path = session_path self.label = label self.threshold = kwargs.get('threshold', 20) self.upload = kwargs.get('upload', False) self.twin = kwargs.get('twin', 150) self.nprocess = kwargs.get('nprocess', int(cpu_count() - cpu_count() / 4)) self.load_data(sync=kwargs.get('sync', 'nidq'), location=kwargs.get('location', None)) self.roi, self.mask = self.get_roi_mask() if self.upload: self.one = ONE(mode='remote') self.one.alyx.authenticate() self.eid = self.one.path2eid(self.session_path)
[docs] def load_data(self, sync='nidq', location=None): """ Loads relevant data from disk to perform motion alignment :param sync: type of sync used, 'nidq' or 'bpod' :param location: where the code is being run, if location='SDSC', the dataset uuids are removed when loading the data :return: """ def fix_keys(alf_object): """ Given an alf object removes the dataset uuid from the keys :param alf_object: :return: """ ob = Bunch() for key in alf_object.keys(): vals = alf_object[key] ob[key.split('.')[0]] = vals return ob alf_path = self.session_path.joinpath('alf') wheel_path = next(alf_path.rglob('*wheel.timestamps*')).parent wheel = (fix_keys(alfio.load_object(wheel_path, 'wheel')) if location == 'SDSC' else alfio.load_object(wheel_path, 'wheel')) self.wheel_timestamps = wheel.timestamps # Compute interpolated wheel position and wheel times wheel_pos, self.wheel_time = wh.interpolate_position(wheel.timestamps, wheel.position, freq=1000) # Compute wheel velocity self.wheel_vel, _ = wh.velocity_filtered(wheel_pos, 1000) # Load in original camera times self.camera_times = alfio.load_file_content(next(alf_path.rglob(f'_ibl_{self.label}Camera.times*.npy'))) self.camera_path = str(next(self.session_path.joinpath('raw_video_data').glob(f'_iblrig_{self.label}Camera.raw*.mp4'))) self.camera_meta = vidio.get_video_meta(self.camera_path) # TODO should read in the description file to get the correct sync location if sync == 'nidq': # If the sync is 'nidq' we read in the camera ttls from the spikeglx sync object sync, chmap = get_sync_and_chn_map(self.session_path, sync_collection='raw_ephys_data') sr = get_sync_fronts(sync, chmap[f'{self.label}_camera']) self.ttls = sr.times[::2] else: # Otherwise we assume the sync is 'bpod' and we read in the camera ttls from the raw bpod data cam_extractor = cam.CameraTimestampsBpod(session_path=self.session_path) cam_extractor.bpod_trials = raw.load_data(self.session_path, task_collection='raw_behavior_data') self.ttls = cam_extractor._times_from_bpod() # Check if the ttl and video sizes match up self.tdiff = self.ttls.size - self.camera_meta['length'] if self.tdiff < 0: # In this case there are fewer ttls than camera frames. This is not ideal, for now we pad the ttls with # nans but if this is too many we reject the wheel alignment based on the qc self.ttl_times = self.ttls self.times = np.r_[self.ttl_times, np.full((np.abs(self.tdiff)), np.nan)] self.short_flag = True elif self.tdiff > 0: # In this case there are more ttls than camera frames. This happens often, for now we remove the first # tdiff ttls from the ttls self.ttl_times = self.ttls[self.tdiff:] self.times = self.ttls[self.tdiff:] self.short_flag = False # Compute the frame rate of the camera self.frate = round(1 / np.nanmedian(np.diff(self.ttl_times))) # We attempt to load in some behavior data (trials and dlc). This is only needed for the summary plots, having # trial aligned paw velocity (from the dlc) is a nice sanity check to make sure the alignment went well try: self.trials = alfio.load_file_content(next(alf_path.rglob('_ibl_trials.table*.pqt'))) self.dlc = alfio.load_file_content(next(alf_path.rglob(f'_ibl_{self.label}Camera.dlc*.pqt'))) self.dlc = likelihood_threshold(self.dlc) self.behavior = True except (ALFObjectNotFound, StopIteration): self.behavior = False # Load in a single frame that we will use for the summary plot self.frame_example = vidio.get_video_frames_preload(self.camera_path, np.arange(10, 11), mask=np.s_[:, :, 0])
[docs] def get_roi_mask(self): """ Compute the region of interest mask for a given camera. This corresponds to a box in the video that we will use to compute the wheel motion energy :return: """ if self.label == 'right': roi = ((450, 512), (120, 200)) else: roi = ((900, 1024), (850, 1010)) roi_mask = (*[slice(*r) for r in roi], 0) return roi, roi_mask
[docs] def find_contaminated_frames(self, video_frames, thresold=20, normalise=True): """ Finds frames in the video that have artefacts such as the mouse's paw or a human hand. In order to determine frames with contamination an Otsu thresholding is applied to each frame to detect the artefact from the background image :param video_frames: np array of video frames (nframes, nwidth, nheight) :param thresold: threshold to differentiate artefact from background :param normalise: whether to normalise the threshold values for each frame to the baseline :return: mask of frames that are contaminated """ high = np.zeros((video_frames.shape[0])) # Iterate through each frame and compute and store the otsu threshold value for each frame for idx, frame in enumerate(video_frames): ret, _ = cv2.threshold(cv2.GaussianBlur(frame, (5, 5), 0), 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) high[idx] = ret # If normalise is True, we divide the threshold values for each frame by the minimum value if normalise: high -= np.min(high) # Identify the frames that have a threshold value greater than the specified threshold cutoff contaminated_frames = np.where(high > thresold)[0] return contaminated_frames
[docs] def compute_motion_energy(self, first, last, wg, iw): """ Computes the video motion energy for frame indexes between first and last. This function is written to be run in a parallel fashion jusing joblib.parallel :param first: first frame index of frame interval to consider :param last: last frame index of frame interval to consider :param wg: WindowGenerator :param iw: iteration of the WindowGenerator :return: """ if iw == wg.nwin - 1: return # Open the video and read in the relvant video frames between first idx and last idx cap = cv2.VideoCapture(self.camera_path) frames = vidio.get_video_frames_preload(cap, np.arange(first, last), mask=self.mask) # Identify if any of the frames have artefacts in them idx = self.find_contaminated_frames(frames, self.threshold) # If some of the frames are contaminated we find all the continuous intervals of contamination # and set the value for contaminated pixels for these frames to the average of the first frame before and after # this contamination interval if len(idx) != 0: before_status = False after_status = False counter = 0 n_frames = 200 # If it is the first frame that is contaminated, we need to read in a bit more of the video to find a # frame prior to contamination. We attempt this 20 times, after that we just take the value for the first # frame while np.any(idx == 0) and counter < 20 and iw != 0: n_before_offset = (counter + 1) * n_frames first -= n_frames extra_frames = vidio.get_video_frames_preload(cap, frame_numbers=np.arange(first - n_frames, first), mask=self.mask) frames = np.concatenate([extra_frames, frames], axis=0) idx = self.find_contaminated_frames(frames, self.threshold) before_status = True counter += 1 if counter > 0: print(f'In before: {counter}') counter = 0 # If it is the last frame that is contaminated, we need to read in a bit more of the video to find a # frame after the contamination. We attempt this 20 times, after that we just take the value for the last # frame while np.any(idx == frames.shape[0] - 1) and counter < 20 and iw != wg.nwin - 1: n_after_offset = (counter + 1) * n_frames last += n_frames extra_frames = vidio.get_video_frames_preload(cap, frame_numbers=np.arange(last, last + n_frames), mask=self.mask) frames = np.concatenate([frames, extra_frames], axis=0) idx = self.find_contaminated_frames(frames, self.threshold) after_status = True counter += 1 if counter > 0: print(f'In after: {counter}') # We find all the continuous intervals that contain contamination and fix the affected pixels # by taking the average value of the frame prior and after contamination intervals = np.split(idx, np.where(np.diff(idx) != 1)[0] + 1) for ints in intervals: if len(ints) > 0 and ints[0] == 0: ints = ints[1:] if len(ints) > 0 and ints[-1] == frames.shape[0] - 1: ints = ints[:-1] th_all = np.zeros_like(frames[0]) # We find all affected pixels for idx in ints: img = np.copy(frames[idx]) blur = cv2.GaussianBlur(img, (5, 5), 0) ret, th = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) th = cv2.GaussianBlur(th, (5, 5), 10) th_all += th # Compute the average image of the frame prior and after the interval vals = np.mean(np.dstack([frames[ints[0] - 1], frames[ints[-1] + 1]]), axis=-1) # For each frame set the affected pixels to the value of the clean average image for idx in ints: img = frames[idx] img[th_all > 0] = vals[th_all > 0] # If we have read in extra video frames we need to cut these off and make sure we only # consider the frames between the interval first and last given as args if before_status: frames = frames[n_before_offset:] if after_status: frames = frames[:(-1 * n_after_offset)] # Once the frames have been cleaned we compute the motion energy between frames frame_me, _ = video.motion_energy(frames, diff=2, normalize=False) cap.release() return frame_me[2:]
[docs] def compute_shifts(self, times, me, first, last, iw, wg): """ Compute the cross-correlation between the video motion energy and the wheel velocity to find the mismatch between the camera ttls and the video frames. This function is written to run in a parallel manner using joblib.parallel :param times: the times of the video frames across the whole session (ttls) :param me: the video motion energy computed across the whole session :param first: first time idx to consider :param last: last time idx to consider :param wg: WindowGenerator :param iw: iteration of the WindowGenerator :return: """ # If we are in the last window we exit if iw == wg.nwin - 1: return np.nan, np.nan # Find the time interval we are interested in t_first = times[first] t_last = times[last] # If both times during this interval are nan exit if np.isnan(t_last) and np.isnan(t_first): return np.nan, np.nan # If only the last time is nan, we find the last non nan time value elif np.isnan(t_last): t_last = times[np.where(~np.isnan(times))[0][-1]] # Find the mask of timepoints that fall in this interval mask = np.logical_and(times >= t_first, times <= t_last) # Restrict the video motion energy to this interval and normalise the values align_me = me[np.where(mask)[0]] align_me = (align_me - np.nanmin(align_me)) / (np.nanmax(align_me) - np.nanmin(align_me)) # Find closest timepoints in wheel that match the time interval wh_mask = np.logical_and(self.wheel_time >= t_first, self.wheel_time <= t_last) if np.sum(wh_mask) == 0: return np.nan, np.nan # Find the mask for the wheel times xs = np.searchsorted(self.wheel_time[wh_mask], times[mask]) xs[xs == np.sum(wh_mask)] = np.sum(wh_mask) - 1 # Convert to normalized speed vs = np.abs(self.wheel_vel[wh_mask][xs]) vs = (vs - np.min(vs)) / (np.max(vs) - np.min(vs)) # Account for nan values in the video motion energy isnan = np.isnan(align_me) if np.sum(isnan) > 0: where_nan = np.where(isnan)[0] assert where_nan[0] == 0 assert where_nan[-1] == np.sum(isnan) - 1 if np.all(isnan): return np.nan, np.nan # Compute the cross correlation between the video motion energy and the wheel speed xcorr = signal.correlate(align_me[~isnan], vs[~isnan]) # The max value of the cross correlation indicates the shift that needs to be applied # The +2 comes from the fact that the video motion energy was computed from the difference between frames shift = np.nanargmax(xcorr) - align_me[~isnan].size + 2 return shift, t_first + (t_last - t_first) / 2
[docs] def clean_shifts(self, x, n=1): """ Removes artefacts from the computed shifts across time. We assume that the shifts should never increase over time and that the jump between consecutive shifts shouldn't be greater than 1 :param x: computed shifts :param n: condition to apply :return: """ y = x.copy() dy = np.diff(y, prepend=y[0]) while True: pos = np.where(dy == 1)[0] if n == 1 else np.where(dy > 2)[0] # added frames: this doesn't make sense and this is noise if pos.size == 0: break neg = np.where(dy == -1)[0] if n == 1 else np.where(dy < -2)[0] if len(pos) > len(neg): neg = np.append(neg, dy.size - 1) iss = np.minimum(np.searchsorted(neg, pos), neg.size - 1) imin = np.argmin(np.minimum(np.abs(pos - neg[iss - 1]), np.abs(pos - neg[iss]))) idx = np.max([0, iss[imin] - 1]) ineg = neg[idx:iss[imin] + 1] ineg = ineg[np.argmin(np.abs(pos[imin] - ineg))] dy[pos[imin]] = 0 dy[ineg] = 0 return np.cumsum(dy) + y[0]
[docs] def qc_shifts(self, shifts, shifts_filt): """ Compute qc values for the wheel alignment. We consider 4 things 1. The number of camera ttl values that are missing (when we have less ttls than video frames) 2. The number of shifts that have nan values, this means the video motion energy computation 3. The number of large jumps (>10) between the computed shifts 4. The number of jumps (>1) between the shifts after they have been cleaned :param shifts: np.array of shifts over session :param shifts_filt: np.array of shifts after being cleaned over session :return: """ ttl_per = (np.abs(self.tdiff) / self.camera_meta['length']) * 100 if self.tdiff < 0 else 0 nan_per = (np.sum(np.isnan(shifts_filt)) / shifts_filt.size) * 100 shifts_sum = np.where(np.abs(np.diff(shifts)) > 10)[0].size shifts_filt_sum = np.where(np.abs(np.diff(shifts_filt)) > 1)[0].size qc = dict() qc['ttl_per'] = ttl_per qc['nan_per'] = nan_per qc['shifts_sum'] = shifts_sum qc['shifts_filt_sum'] = shifts_filt_sum qc_outcome = True # If more than 10% of ttls are missing we don't get new times if ttl_per > 10: qc_outcome = False # If too many of the shifts are nans it means the alignment is not accurate if nan_per > 40: qc_outcome = False # If there are too many artefacts could be errors if shifts_sum > 60: qc_outcome = False # If there are jumps > 1 in the filtered shifts then there is a problem if shifts_filt_sum > 0: qc_outcome = False return qc, qc_outcome
[docs] def extract_times(self, shifts_filt, t_shifts): """ Extracts new camera times after applying the computed shifts across the session :param shifts_filt: filtered shifts computed across session :param t_shifts: time point of computed shifts :return: """ # Compute the interpolation function to apply to the ttl times t_new = t_shifts - (shifts_filt * 1 / self.frate) fcn = interpolate.interp1d(t_shifts, t_new, fill_value="extrapolate") # Apply the function and get out new times new_times = fcn(self.ttl_times) # If we are missing ttls then interpolate and append the correct number at the end if self.tdiff < 0: to_app = (np.arange(np.abs(self.tdiff), ) + 1) / self.frate + new_times[-1] new_times = np.r_[new_times, to_app] return new_times
[docs] @staticmethod def single_cluster_raster(spike_times, events, trial_idx, dividers, colors, labels, weights=None, fr=True, norm=False, axs=None): """ Compute and plot trial aligned spike rasters and psth :param spike_times: times of variable :param events: trial times to align to :param trial_idx: trial idx to sort by :param dividers: :param colors: :param labels: :param weights: :param fr: :param norm: :param axs: :return: """ pre_time = 0.4 post_time = 1 raster_bin = 0.01 psth_bin = 0.05 raster, t_raster = bin_spikes( spike_times, events, pre_time=pre_time, post_time=post_time, bin_size=raster_bin, weights=weights) psth, t_psth = bin_spikes( spike_times, events, pre_time=pre_time, post_time=post_time, bin_size=psth_bin, weights=weights) if fr: psth = psth / psth_bin if norm: psth = psth - np.repeat(psth[:, 0][:, np.newaxis], psth.shape[1], axis=1) raster = raster - np.repeat(raster[:, 0][:, np.newaxis], raster.shape[1], axis=1) dividers = [0] + dividers + [len(trial_idx)] if axs is None: fig, axs = plt.subplots(2, 1, figsize=(4, 6), gridspec_kw={'height_ratios': [1, 3], 'hspace': 0}, sharex=True) else: fig = axs[0].get_figure() label, lidx = np.unique(labels, return_index=True) label_pos = [] for lab, lid in zip(label, lidx): idx = np.where(np.array(labels) == lab)[0] for iD in range(len(idx)): if iD == 0: t_ids = trial_idx[dividers[idx[iD]] + 1:dividers[idx[iD] + 1] + 1] t_ints = dividers[idx[iD] + 1] - dividers[idx[iD]] else: t_ids = np.r_[t_ids, trial_idx[dividers[idx[iD]] + 1:dividers[idx[iD] + 1] + 1]] t_ints = np.r_[t_ints, dividers[idx[iD] + 1] - dividers[idx[iD]]] psth_div = np.nanmean(psth[t_ids], axis=0) std_div = np.nanstd(psth[t_ids], axis=0) / np.sqrt(len(t_ids)) axs[0].fill_between(t_psth, psth_div - std_div, psth_div + std_div, alpha=0.4, color=colors[lid]) axs[0].plot(t_psth, psth_div, alpha=1, color=colors[lid]) lab_max = idx[np.argmax(t_ints)] label_pos.append((dividers[lab_max + 1] - dividers[lab_max]) / 2 + dividers[lab_max]) axs[1].imshow(raster[trial_idx], cmap='binary', origin='lower', extent=[np.min(t_raster), np.max(t_raster), 0, len(trial_idx)], aspect='auto') width = raster_bin * 4 for iD in range(len(dividers) - 1): axs[1].fill_between([post_time + raster_bin / 2, post_time + raster_bin / 2 + width], [dividers[iD + 1], dividers[iD + 1]], [dividers[iD], dividers[iD]], color=colors[iD]) axs[1].set_xlim([-1 * pre_time, post_time + raster_bin / 2 + width]) secax = axs[1].secondary_yaxis('right') secax.set_yticks(label_pos) secax.set_yticklabels(label, rotation=90, rotation_mode='anchor', ha='center') for ic, c in enumerate(np.array(colors)[lidx]): secax.get_yticklabels()[ic].set_color(c) axs[0].axvline(0, *axs[0].get_ylim(), c='k', ls='--', zorder=10) # TODO this doesn't always work axs[1].axvline(0, *axs[1].get_ylim(), c='k', ls='--', zorder=10) return fig, axs
[docs] def plot_with_behavior(self): """ Makes a summary figure of the alignment when behaviour data is available :return: """ self.dlc = likelihood_threshold(self.dlc) trial_idx, dividers = find_trial_ids(self.trials, sort='side') feature_ext = get_speed(self.dlc, self.camera_times, self.label, feature='paw_r') feature_new = get_speed(self.dlc, self.new_times, self.label, feature='paw_r') fig = plt.figure() fig.set_size_inches(15, 9) gs = gridspec.GridSpec(1, 5, figure=fig, width_ratios=[4, 1, 1, 1, 3], wspace=0.3, hspace=0.5) gs0 = gridspec.GridSpecFromSubplotSpec(3, 1, subplot_spec=gs[0, 0]) ax01 = fig.add_subplot(gs0[0, 0]) ax02 = fig.add_subplot(gs0[1, 0]) ax03 = fig.add_subplot(gs0[2, 0]) gs1 = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=gs[0, 1], height_ratios=[1, 3]) ax11 = fig.add_subplot(gs1[0, 0]) ax12 = fig.add_subplot(gs1[1, 0]) gs2 = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=gs[0, 2], height_ratios=[1, 3]) ax21 = fig.add_subplot(gs2[0, 0]) ax22 = fig.add_subplot(gs2[1, 0]) gs3 = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=gs[0, 3], height_ratios=[1, 3]) ax31 = fig.add_subplot(gs3[0, 0]) ax32 = fig.add_subplot(gs3[1, 0]) gs4 = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=gs[0, 4]) ax41 = fig.add_subplot(gs4[0, 0]) ax42 = fig.add_subplot(gs4[1, 0]) ax01.plot(self.t_shifts, self.shifts, label='shifts') ax01.plot(self.t_shifts, self.shifts_filt, label='shifts_filt') ax01.set_ylim(np.min(self.shifts_filt) - 10, np.max(self.shifts_filt) + 10) ax01.legend() ax01.set_ylabel('Frames') ax01.set_xlabel('Time in session') xs = np.searchsorted(self.ttl_times, self.t_shifts) ttl_diff = (self.times - self.camera_times)[xs] * self.camera_meta['fps'] ax02.plot(self.t_shifts, ttl_diff, label='extracted - ttl') ax02.set_ylim(np.min(ttl_diff) - 10, np.max(ttl_diff) + 10) ax02.legend() ax02.set_ylabel('Frames') ax02.set_xlabel('Time in session') ax03.plot(self.camera_times, (self.camera_times - self.new_times) * self.camera_meta['fps'], 'k', label='extracted - new') ax03.legend() ax03.set_ylim(-5, 5) ax03.set_ylabel('Frames') ax03.set_xlabel('Time in session') self.single_cluster_raster(self.wheel_timestamps, self.trials['firstMovement_times'].values, trial_idx, dividers, ['g', 'y'], ['left', 'right'], weights=self.wheel_vel, fr=False, axs=[ax11, ax12]) ax11.sharex(ax12) ax11.set_ylabel('Wheel velocity') ax11.set_title('Wheel') ax12.set_xlabel('Time from first move') self.single_cluster_raster(self.camera_times, self.trials['firstMovement_times'].values, trial_idx, dividers, ['g', 'y'], ['left', 'right'], weights=feature_ext, fr=False, axs=[ax21, ax22]) ax21.sharex(ax22) ax21.set_ylabel('Paw r velocity') ax21.set_title('Extracted times') ax22.set_xlabel('Time from first move') self.single_cluster_raster(self.new_times, self.trials['firstMovement_times'].values, trial_idx, dividers, ['g', 'y'], ['left', 'right'], weights=feature_new, fr=False, axs=[ax31, ax32]) ax31.sharex(ax32) ax31.set_ylabel('Paw r velocity') ax31.set_title('New times') ax32.set_xlabel('Time from first move') ax41.imshow(self.frame_example[0]) rect = matplotlib.patches.Rectangle((self.roi[1][1], self.roi[0][0]), self.roi[1][0] - self.roi[1][1], self.roi[0][1] - self.roi[0][0], linewidth=4, edgecolor='g', facecolor='none') ax41.add_patch(rect) ax42.plot(self.all_me) return fig
[docs] def plot_without_behavior(self): """ Makes a summary figure of the alignment when behaviour data is not available :return: """ fig = plt.figure() fig.set_size_inches(7, 7) gs = gridspec.GridSpec(1, 2, figure=fig) gs0 = gridspec.GridSpecFromSubplotSpec(3, 1, subplot_spec=gs[0, 0]) ax01 = fig.add_subplot(gs0[0, 0]) ax02 = fig.add_subplot(gs0[1, 0]) ax03 = fig.add_subplot(gs0[2, 0]) gs1 = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=gs[0, 1]) ax04 = fig.add_subplot(gs1[0, 0]) ax05 = fig.add_subplot(gs1[1, 0]) ax01.plot(self.t_shifts, self.shifts, label='shifts') ax01.plot(self.t_shifts, self.shifts_filt, label='shifts_filt') ax01.set_ylim(np.min(self.shifts_filt) - 10, np.max(self.shifts_filt) + 10) ax01.legend() ax01.set_ylabel('Frames') ax01.set_xlabel('Time in session') xs = np.searchsorted(self.ttl_times, self.t_shifts) ttl_diff = (self.times - self.camera_times)[xs] * self.camera_meta['fps'] ax02.plot(self.t_shifts, ttl_diff, label='extracted - ttl') ax02.set_ylim(np.min(ttl_diff) - 10, np.max(ttl_diff) + 10) ax02.legend() ax02.set_ylabel('Frames') ax02.set_xlabel('Time in session') ax03.plot(self.camera_times, (self.camera_times - self.new_times) * self.camera_meta['fps'], 'k', label='extracted - new') ax03.legend() ax03.set_ylim(-5, 5) ax03.set_ylabel('Frames') ax03.set_xlabel('Time in session') ax04.imshow(self.frame_example[0]) rect = matplotlib.patches.Rectangle((self.roi[1][1], self.roi[0][0]), self.roi[1][0] - self.roi[1][1], self.roi[0][1] - self.roi[0][0], linewidth=4, edgecolor='g', facecolor='none') ax04.add_patch(rect) ax05.plot(self.all_me) return fig
[docs] def process(self): """ Main function used to apply the video motion wheel alignment to the camera times. This function does the following 1. Computes the video motion energy across the whole session (computed in windows and parallelised) 2. Computes the shift that should be applied to the camera times across the whole session by computing the cross correlation between the video motion energy and the wheel speed (computed in overlapping windows and parallelised) 3. Removes artefacts from the computed shifts 4. Computes the qc for the wheel alignment 5. Extracts the new camera times using the shifts computed from the video wheel alignment 6. If upload is True, creates a summary plot of the alignment and uploads the figure to the relevant session on alyx :return: """ # Compute the motion energy of the wheel for the whole video wg = WindowGenerator(self.camera_meta['length'], 5000, 4) out = Parallel(n_jobs=self.nprocess)( delayed(self.compute_motion_energy)(first, last, wg, iw) for iw, (first, last) in enumerate(wg.firstlast)) # Concatenate the motion energy into one big array self.all_me = np.array([]) for vals in out[:-1]: self.all_me = np.r_[self.all_me, vals] toverlap = self.twin - 1 all_me = np.r_[np.full((int(self.camera_meta['fps'] * toverlap)), np.nan), self.all_me] to_app = self.times[0] - ((np.arange(int(self.camera_meta['fps'] * toverlap), ) + 1) / self.frate)[::-1] times = np.r_[to_app, self.times] wg = WindowGenerator(all_me.size - 1, int(self.camera_meta['fps'] * self.twin), int(self.camera_meta['fps'] * toverlap)) out = Parallel(n_jobs=1)(delayed(self.compute_shifts)(times, all_me, first, last, iw, wg) for iw, (first, last) in enumerate(wg.firstlast)) self.shifts = np.array([]) self.t_shifts = np.array([]) for vals in out[:-1]: self.shifts = np.r_[self.shifts, vals[0]] self.t_shifts = np.r_[self.t_shifts, vals[1]] idx = np.bitwise_and(self.t_shifts >= self.ttl_times[0], self.t_shifts < self.ttl_times[-1]) self.shifts = self.shifts[idx] self.t_shifts = self.t_shifts[idx] shifts_filt = ndimage.percentile_filter(self.shifts, 80, 120) shifts_filt = self.clean_shifts(shifts_filt, n=1) self.shifts_filt = self.clean_shifts(shifts_filt, n=2) self.qc, self.qc_outcome = self.qc_shifts(self.shifts, self.shifts_filt) self.new_times = self.extract_times(self.shifts_filt, self.t_shifts) if self.upload: fig = self.plot_with_behavior() if self.behavior else self.plot_without_behavior() save_fig_path = Path(self.session_path.joinpath('snapshot', 'video', f'video_wheel_alignment_{self.label}.png')) save_fig_path.parent.mkdir(exist_ok=True, parents=True) fig.savefig(save_fig_path) snp = ReportSnapshot(self.session_path, self.eid, content_type='session', one=self.one) snp.outputs = [save_fig_path] snp.register_images(widths=['orig']) plt.close(fig) return self.new_times