"""Set of functions to deal with dlc data."""
import logging
import pandas as pd
import warnings
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import scipy.interpolate as interpolate
from scipy.stats import zscore
from ibldsp.smooth import smooth_interpolate_savgol
from iblutil.numerical import bincount2D
import brainbox.behavior.wheel as bbox_wheel
logger = logging.getLogger('ibllib')
SAMPLING = {'left': 60,
'right': 150,
'body': 30}
RESOLUTION = {'left': 2,
'right': 1,
'body': 1}
T_BIN = 0.02 # sec
WINDOW_LEN = 2 # sec
WINDOW_LAG = -0.5 # sec
# For plotting we use a window around the event the data is aligned to WINDOW_LAG before and WINDOW_LEN after the event
[docs]
def plt_window(x):
return x + WINDOW_LAG, x + WINDOW_LEN
[docs]
def insert_idx(array, values):
idx = np.searchsorted(array, values, side="left")
# Choose lower index if insertion would be after last index or if lower index is closer
idx[idx == len(array)] -= 1
idx[np.where(abs(values - array[idx - 1]) < abs(values - array[idx]))] -= 1
# If 0 index was reduced, revert
idx[idx == -1] = 0
if np.all(idx == 0):
raise ValueError('Something is wrong, all values to insert are outside of the array.')
return idx
[docs]
def likelihood_threshold(dlc, threshold=0.9):
"""
Set dlc points with likelihood less than threshold to nan.
FIXME Add unit test.
:param dlc: dlc pqt object
:param threshold: likelihood threshold
:return:
"""
features = np.unique(['_'.join(x.split('_')[:-1]) for x in dlc.keys()])
for feat in features:
nan_fill = dlc[f'{feat}_likelihood'] < threshold
dlc.loc[nan_fill, (f'{feat}_x', f'{feat}_y')] = np.nan
return dlc
[docs]
def get_speed(dlc, dlc_t, camera, feature='paw_r'):
"""
FIXME Document and add unit test!
:param dlc: dlc pqt table
:param dlc_t: dlc time points
:param camera: camera type e.g 'left', 'right', 'body'
:param feature: dlc feature to compute speed over
:return:
"""
x = dlc[f'{feature}_x'] / RESOLUTION[camera]
y = dlc[f'{feature}_y'] / RESOLUTION[camera]
# get speed in px/sec [half res]
s = ((np.diff(x) ** 2 + np.diff(y) ** 2) ** .5) * SAMPLING[camera]
dt = np.diff(dlc_t)
tv = dlc_t[:-1] + dt / 2
# interpolate over original time scale
if tv.size > 1:
ifcn = interpolate.interp1d(tv, s, fill_value="extrapolate")
return ifcn(dlc_t)
[docs]
def get_speed_for_features(dlc, dlc_t, camera, features=['paw_r', 'paw_l', 'nose_tip']):
"""
Wrapper to compute speed for a number of dlc features and add them to dlc table
:param dlc: dlc pqt table
:param dlc_t: dlc time points
:param camera: camera type e.g 'left', 'right', 'body'
:param features: dlc features to compute speed for
:return:
"""
for feat in features:
dlc[f'{feat}_speed'] = get_speed(dlc, dlc_t, camera, feat)
return dlc
[docs]
def get_feature_event_times(dlc, dlc_t, features):
"""
Detect events from the dlc traces. Based on the standard deviation between frames
:param dlc: dlc pqt table
:param dlc_t: dlc times
:param features: features to consider
:return:
"""
for i, feat in enumerate(features):
f = dlc[feat]
threshold = np.nanstd(np.diff(f)) / 4
if i == 0:
events = np.where(np.abs(np.diff(f)) > threshold)[0]
else:
events = np.r_[events, np.where(np.abs(np.diff(f)) > threshold)[0]]
return dlc_t[np.unique(events)]
[docs]
def get_licks(dlc, dlc_t):
"""
Compute lick times from the tongue dlc points
:param dlc: dlc pqt table
:param dlc_t: dlc times
:return:
"""
lick_times = get_feature_event_times(dlc, dlc_t, ['tongue_end_l_x', 'tongue_end_l_y',
'tongue_end_r_x', 'tongue_end_r_y'])
return lick_times
[docs]
def get_sniffs(dlc, dlc_t):
"""
Compute sniff times from the nose tip
:param dlc: dlc pqt table
:param dlc_t: dlc times
:return:
"""
sniff_times = get_feature_event_times(dlc, dlc_t, ['nose_tip_y'])
return sniff_times
[docs]
def get_dlc_everything(dlc_cam, camera):
"""
Get out features of interest for dlc
:param dlc_cam: dlc object
:param camera: camera type e.g 'left', 'right'
:return:
"""
aligned = True
if dlc_cam.times.shape[0] != dlc_cam.dlc.shape[0]:
# logger warning and print out status of the qc, specific serializer django!
logger.warning('Dimension mismatch between dlc points and timestamps')
min_samps = min(dlc_cam.times.shape[0], dlc_cam.dlc.shape[0])
dlc_cam.times = dlc_cam.times[:min_samps]
dlc_cam.dlc = dlc_cam.dlc[:min_samps]
aligned = False
dlc_cam.dlc = likelihood_threshold(dlc_cam.dlc)
dlc_cam.dlc = get_speed_for_features(dlc_cam.dlc, dlc_cam.times, camera)
dlc_cam['licks'] = get_licks(dlc_cam.dlc, dlc_cam.times)
dlc_cam['sniffs'] = get_sniffs(dlc_cam.dlc, dlc_cam.times)
dlc_cam['aligned'] = aligned
return dlc_cam
[docs]
def get_pupil_diameter(dlc):
"""
Estimates pupil diameter by taking median of different computations.
The two most straightforward estimates: d1 = top - bottom, d2 = left - right
In addition, assume the pupil is a circle and estimate diameter from other pairs of points
:param dlc: dlc pqt table with pupil estimates, should be likelihood thresholded (e.g. at 0.9)
:return: np.array, pupil diameter estimate for each time point, shape (n_frames,)
"""
diameters = []
# Get the x,y coordinates of the four pupil points
top, bottom, left, right = [np.vstack((dlc[f'pupil_{point}_r_x'], dlc[f'pupil_{point}_r_y']))
for point in ['top', 'bottom', 'left', 'right']]
# First compute direct diameters
diameters.append(np.linalg.norm(top - bottom, axis=0))
diameters.append(np.linalg.norm(left - right, axis=0))
# For non-crossing edges, estimate diameter via circle assumption
for pair in [(top, left), (top, right), (bottom, left), (bottom, right)]:
diameters.append(np.linalg.norm(pair[0] - pair[1], axis=0) * 2 ** 0.5)
# Ignore all nan runtime warning
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=RuntimeWarning)
return np.nanmedian(diameters, axis=0)
[docs]
def get_smooth_pupil_diameter(diameter_raw, camera, std_thresh=5, nan_thresh=1):
"""
:param diameter_raw: np.array, raw pupil diameters, calculated from (thresholded) dlc traces
:param camera: str ('left', 'right'), which camera to run the smoothing for
:param std_thresh: threshold (in standard deviations) beyond which a point is labeled as an outlier
:param nan_thresh: threshold (in seconds) above which we will not interpolate nans, but keep them
(for long stretches interpolation may not be appropriate)
:return:
"""
# set framerate of camera
if camera == 'left':
fr = SAMPLING['left'] # set by hardware
window = 31 # works well empirically
elif camera == 'right':
fr = SAMPLING['right'] # set by hardware
window = 75 # works well empirically
else:
raise NotImplementedError("camera has to be 'left' or 'right")
# Raise error if too many NaN time points, in this case it doesn't make sense to interpolate
if np.mean(np.isnan(diameter_raw)) > 0.9:
raise ValueError(f"Raw pupil diameter for {camera} is too often NaN, cannot smooth.")
# run savitzy-golay filter on non-nan time points to denoise
diameter_smoothed = smooth_interpolate_savgol(diameter_raw, window=window, order=3, interp_kind='linear')
# find outliers and set them to nan
difference = diameter_raw - diameter_smoothed
outlier_thresh = std_thresh * np.nanstd(difference)
without_outliers = np.copy(diameter_raw)
without_outliers[(difference < -outlier_thresh) | (difference > outlier_thresh)] = np.nan
# run savitzy-golay filter again on (possibly reduced) non-nan timepoints to denoise
diameter_smoothed = smooth_interpolate_savgol(without_outliers, window=window, order=3, interp_kind='linear')
# don't interpolate long strings of nans
t = np.diff(np.isnan(without_outliers).astype(int))
begs = np.where(t == 1)[0]
ends = np.where(t == -1)[0]
if begs.shape[0] > ends.shape[0]:
begs = begs[:ends.shape[0]]
for b, e in zip(begs, ends):
if (e - b) > (fr * nan_thresh):
diameter_smoothed[(b + 1):(e + 1)] = np.nan # offset by 1 due to earlier diff
return diameter_smoothed
[docs]
def plot_trace_on_frame(frame, dlc_df, cam):
"""
Plots dlc traces as scatter plots on a frame of the video.
For left and right video also plots whisker pad and eye and tongue zoom.
:param frame: np.array, single video frame to plot on
:param dlc_df: pd.Dataframe, dlc traces with _x, _y and _likelihood info for each trace
:param cam: str, which camera to process ('left', 'right', 'body')
:returns: matplolib.axis
"""
# Define colors
colors = {'tail_start': '#636EFA',
'nose_tip': '#636EFA',
'paw_l': '#EF553B',
'paw_r': '#00CC96',
'pupil_bottom_r': '#AB63FA',
'pupil_left_r': '#FFA15A',
'pupil_right_r': '#19D3F3',
'pupil_top_r': '#FF6692',
'tongue_end_l': '#B6E880',
'tongue_end_r': '#FF97FF'}
# Threshold the dlc traces
dlc_df = likelihood_threshold(dlc_df)
# Features without tube
features = np.unique(['_'.join(x.split('_')[:-1]) for x in dlc_df.keys() if 'tube' not in x])
# Normalize the number of points across cameras
dlc_df_norm = pd.DataFrame()
for feat in features:
dlc_df_norm[f'{feat}_x'] = dlc_df[f'{feat}_x'][0::int(SAMPLING[cam] / 10)]
dlc_df_norm[f'{feat}_y'] = dlc_df[f'{feat}_y'][0::int(SAMPLING[cam] / 10)]
# Scatter
plt.scatter(dlc_df_norm[f'{feat}_x'], dlc_df_norm[f'{feat}_y'], alpha=0.05, s=2, label=feat, c=colors[feat])
plt.axis('off')
plt.imshow(frame, cmap='gray')
plt.tight_layout()
ax = plt.gca()
if cam == 'body':
plt.title(f'{cam.capitalize()} camera')
return ax
# For left and right cam plot whisker pad rectangle
# heuristic: square with side length half the distance between nose and pupil and anchored on midpoint
p_nose = np.array(dlc_df[['nose_tip_x', 'nose_tip_y']].mean())
p_pupil = np.array(dlc_df[['pupil_top_r_x', 'pupil_top_r_y']].mean())
p_anchor = np.mean([p_nose, p_pupil], axis=0)
dist = np.linalg.norm(p_nose - p_pupil)
rect = matplotlib.patches.Rectangle((int(p_anchor[0] - dist / 4), int(p_anchor[1])), int(dist / 2), int(dist / 3),
linewidth=1, edgecolor='lime', facecolor='none')
ax.add_patch(rect)
# Plot eye region zoom
inset_anchor = 0 if cam == 'right' else 0.5
ax_ins = ax.inset_axes([inset_anchor, -0.5, 0.5, 0.5])
ax_ins.imshow(frame, cmap='gray', origin="lower")
for feat in features:
ax_ins.scatter(dlc_df_norm[f'{feat}_x'], dlc_df_norm[f'{feat}_y'], alpha=1, s=0.001, label=feat, c=colors[feat])
ax_ins.set_xlim(int(p_pupil[0] - 33 * RESOLUTION[cam] / 2), int(p_pupil[0] + 33 * RESOLUTION[cam] / 2))
ax_ins.set_ylim(int(p_pupil[1] + 38 * RESOLUTION[cam] / 2), int(p_pupil[1] - 28 * RESOLUTION[cam] / 2))
ax_ins.axis('off')
# Plot tongue region zoom
p1 = np.array(dlc_df[['tube_top_x', 'tube_top_y']].mean())
p2 = np.array(dlc_df[['tube_bottom_x', 'tube_bottom_y']].mean())
p_tongue = np.nanmean([p1, p2], axis=0)
inset_anchor = 0 if cam == 'left' else 0.5
ax_ins = ax.inset_axes([inset_anchor, -0.5, 0.5, 0.5])
ax_ins.imshow(frame, cmap='gray', origin="upper")
for feat in features:
ax_ins.scatter(dlc_df_norm[f'{feat}_x'], dlc_df_norm[f'{feat}_y'], alpha=1, s=0.001, label=feat, c=colors[feat])
ax_ins.set_xlim(int(p_tongue[0] - 60 * RESOLUTION[cam] / 2), int(p_tongue[0] + 100 * RESOLUTION[cam] / 2))
ax_ins.set_ylim(int(p_tongue[1] + 60 * RESOLUTION[cam] / 2), int(p_tongue[1] - 100 * RESOLUTION[cam] / 2))
ax_ins.axis('off')
plt.title(f'{cam.capitalize()} camera')
return ax
[docs]
def plot_wheel_position(wheel_position, wheel_time, trials_df):
"""
Plots wheel position across trials, color by which side was chosen
:param wheel_position: np.array, interpolated wheel position
:param wheel_time: np.array, interpolated wheel timestamps
:param trials_df: pd.DataFrame, with column 'stimOn_times' (time of stimulus onset times for each trial)
:returns: matplotlib.axis
"""
# Interpolate wheel data
wheel_position, wheel_time = bbox_wheel.interpolate_position(wheel_time, wheel_position, freq=1 / T_BIN)
# Create a window around the stimulus onset
start_window, end_window = plt_window(trials_df['stimOn_times'])
# Translating the time window into an index window
start_idx = insert_idx(wheel_time, start_window)
end_idx = np.array(start_idx + int(WINDOW_LEN / T_BIN), dtype='int64')
# Getting the wheel position for each window, normalize to first value of each window
trials_df['wheel_position'] = [wheel_position[start_idx[w]: end_idx[w]] - wheel_position[start_idx[w]]
for w in range(len(start_idx))]
# Plotting
times = np.arange(len(trials_df['wheel_position'].iloc[0])) * T_BIN + WINDOW_LAG
for side, label, color in zip([-1, 1], ['right', 'left'], ['darkred', '#1f77b4']):
side_df = trials_df[trials_df['choice'] == side]
for idx in side_df.index:
plt.plot(times, side_df.loc[idx, 'wheel_position'], c=color, alpha=0.5, linewidth=0.05)
plt.plot(times, side_df['wheel_position'].mean(), c=color, linewidth=2, label=f'{label} turn')
plt.axvline(x=0, linestyle='--', c='k', label='stimOn')
plt.axhline(y=-0.26, linestyle='--', c='g', label='reward')
plt.axhline(y=0.26, linestyle='--', c='g', label='reward')
plt.ylim([-0.27, 0.27])
plt.xlabel('time [sec]')
plt.ylabel('wheel position diff to first value [rad]')
plt.legend(loc='center right')
plt.title('Wheel position trial avg\n(and individual trials)')
plt.tight_layout()
return plt.gca()
def _bin_window_licks(lick_times, trials_df):
"""
Helper function to bin and window the lick times and get them into trials df for plotting
:param lick_times: np.array, timestamps of lick events
:param trials_df: pd.DataFrame, with column 'feedback_times' (time of feedback for each trial)
:returns: pd.DataFrame with binned, windowed lick times for plotting
"""
# Bin the licks
lick_bins, bin_times, _ = bincount2D(lick_times, np.ones(len(lick_times)), T_BIN)
lick_bins = np.squeeze(lick_bins)
start_window, end_window = plt_window(trials_df['feedback_times'])
# Translating the time window into an index window
try:
start_idx = insert_idx(bin_times, start_window)
except ValueError:
logger.error('Lick time stamps are outside of the trials windows')
raise
end_idx = np.array(start_idx + int(WINDOW_LEN / T_BIN), dtype='int64')
# Get the binned licks for each window
trials_df['lick_bins'] = [lick_bins[start_idx[i]:end_idx[i]] for i in range(len(start_idx))]
# Remove windows that the exceed bins
trials_df['end_idx'] = end_idx
trials_df = trials_df[trials_df['end_idx'] <= len(lick_bins)]
return trials_df
[docs]
def plot_lick_hist(lick_times, trials_df):
"""
Plots histogramm of lick events aligned to feedback time, separate for correct and incorrect trials
:param lick_times: np.array, timestamps of lick events
:param trials_df: pd.DataFrame, with column 'feedback_times' (time of feedback for each trial) and
'feedbackType' (1 for correct, -1 for incorrect trials)
:returns: matplotlib axis
"""
licks_df = _bin_window_licks(lick_times, trials_df)
# Plot
times = np.arange(len(licks_df['lick_bins'].iloc[0])) * T_BIN + WINDOW_LAG
correct = licks_df[licks_df['feedbackType'] == 1]['lick_bins']
incorrect = licks_df[licks_df['feedbackType'] == -1]['lick_bins']
plt.plot(times, pd.DataFrame.from_dict(dict(zip(correct.index, correct.values))).mean(axis=1),
c='k', label='correct trial')
plt.plot(times, pd.DataFrame.from_dict(dict(zip(correct.index, incorrect.values))).mean(axis=1),
c='gray', label='incorrect trial')
plt.axvline(x=0, label='feedback', linestyle='--', c='purple')
plt.title('Lick events trial avg')
plt.xlabel('time [sec]')
plt.ylabel('lick events [a.u.]')
plt.legend(loc='lower right')
return plt.gca()
[docs]
def plot_lick_raster(lick_times, trials_df):
"""
Plots lick raster for correct trials
:param lick_times: np.array, timestamps of lick events
:param trials_df: pd.DataFrame, with column 'feedback_times' (time of feedback for each trial) and
feedbackType (1 for correct, -1 for incorrect trials)
:returns: matplotlib.axis
"""
licks_df = _bin_window_licks(lick_times, trials_df)
plt.imshow(list(licks_df[licks_df['feedbackType'] == 1]['lick_bins']), aspect='auto',
extent=[-0.5, 1.5, len(licks_df['lick_bins'].iloc[0]), 0], cmap='gray_r')
plt.xticks([-0.5, 0, 0.5, 1, 1.5])
plt.ylabel('trials')
plt.xlabel('time [sec]')
plt.axvline(x=0, label='feedback', linestyle='--', c='purple')
plt.title('Lick events per correct trial')
plt.tight_layout()
return plt.gca()
[docs]
def plot_motion_energy_hist(camera_dict, trials_df):
"""
Plots mean motion energy of given cameras, aligned to stimulus onset.
:param camera_dict: dict, one key for each camera to be plotted (e.g. 'left'), value is another dict with items
'motion_energy' (np.array, motion energy calculated from this camera) and
'times' (np.array, camera timestamps)
:param trials_df: pd.DataFrame, with column 'stimOn_times' (time of stimulus onset for each trial)
:returns: matplotlib.axis
"""
colors = {'left': '#bd7a98',
'right': '#2b6f39',
'body': '#035382'}
start_window, end_window = plt_window(trials_df['stimOn_times'])
missing_data = []
for cam in camera_dict.keys():
if (camera_dict[cam]['motion_energy'] is not None and len(camera_dict[cam]['motion_energy']) > 0
and camera_dict[cam]['times'] is not None and len(camera_dict[cam]['times']) > 0):
try:
motion_energy = zscore(camera_dict[cam]['motion_energy'], nan_policy='omit')
try:
start_idx = insert_idx(camera_dict[cam]['times'], start_window)
end_idx = np.array(start_idx + int(WINDOW_LEN * SAMPLING[cam]), dtype='int64')
me_all = [motion_energy[start_idx[i]:end_idx[i]] for i in range(len(start_idx))]
me_all = [m for m in me_all if len(m) > 0]
times = np.arange(len(me_all[0])) / SAMPLING[cam] + WINDOW_LAG
me_mean = np.mean(me_all, axis=0)
me_std = np.std(me_all, axis=0) / np.sqrt(len(me_all))
plt.plot(times, me_mean, label=f'{cam} cam', color=colors[cam], linewidth=2)
plt.fill_between(times, me_mean + me_std, me_mean - me_std, color=colors[cam], alpha=0.2)
except ValueError:
logger.error(f"{cam}Camera camera.times are outside of the trial windows")
missing_data.append(cam)
except AttributeError:
logger.warning(f"Cannot load motion energy and/or times data for {cam} camera")
missing_data.append(cam)
else:
logger.warning(f"Data missing or empty for motion energy and/or times data for {cam} camera")
missing_data.append(cam)
plt.xticks([-0.5, 0, 0.5, 1, 1.5])
plt.ylabel('z-scored motion energy [a.u.]')
plt.xlabel('time [sec]')
plt.axvline(x=0, label='stimOn', linestyle='--', c='k')
plt.legend(loc='lower right')
plt.title('Motion Energy trial avg\n(+/- std)')
if len(missing_data) > 0:
ax = plt.gca()
ax.text(.95, .35, f"Data incomplete for\n{' and '.join(missing_data)} camera", color='r', fontsize=10,
fontweight='bold', horizontalalignment='right', verticalalignment='center', transform=ax.transAxes)
return plt.gca()
[docs]
def plot_speed_hist(dlc_df, cam_times, trials_df, feature='paw_r', cam='left', legend=True):
"""
Plots speed histogram of a given dlc feature, aligned to stimulus onset, separate for correct and incorrect trials
:param dlc_df: pd.Dataframe, dlc traces with _x, _y and _likelihood info for each trace
:param cam_times: np.array, camera timestamps
:param trials_df: pd.DataFrame, with column 'stimOn_times' (time of stimulus onset for each trial)
:param feature: str, feature with trace in dlc_df for which to plot speed hist, default is 'paw_r'
:param cam: str, camera to use ('body', 'left', 'right') default is 'left'
:param legend: bool, whether to add legend to the plot, default is True
:returns: matplotlib.axis
"""
# Threshold the dlc traces
dlc_df = likelihood_threshold(dlc_df)
# For pre-GPIO sessions, remove the first few timestamps to match the number of frames
cam_times = cam_times[-len(dlc_df):]
if len(cam_times) != len(dlc_df):
raise ValueError("Camera times length and DLC length are inconsistent")
# Get speeds
speeds = get_speed(dlc_df, cam_times, camera=cam, feature=feature)
# Windows aligned to align_to
start_window, end_window = plt_window(trials_df['stimOn_times'])
start_idx = insert_idx(cam_times, start_window)
end_idx = np.array(start_idx + int(WINDOW_LEN * SAMPLING[cam]), dtype='int64')
# Add speeds to trials_df
trials_df[f'speed_{feature}'] = [speeds[start_idx[i]:end_idx[i]] for i in range(len(start_idx))]
# Plot
times = np.arange(len(trials_df[f'speed_{feature}'].iloc[0])) / SAMPLING[cam] + WINDOW_LAG
# Need to expand the series of lists into a dataframe first, for the nan skipping to work
correct = trials_df[trials_df['feedbackType'] == 1][f'speed_{feature}']
incorrect = trials_df[trials_df['feedbackType'] == -1][f'speed_{feature}']
plt.plot(times, pd.DataFrame.from_dict(dict(zip(correct.index, correct.values))).mean(axis=1),
c='k', label='correct trial')
plt.plot(times, pd.DataFrame.from_dict(dict(zip(incorrect.index, incorrect.values))).mean(axis=1),
c='gray', label='incorrect trial')
plt.axvline(x=0, label='stimOn', linestyle='--', c='r')
plt.title(f'{feature.capitalize()} speed trial avg\n({cam.upper()} cam)')
plt.xticks([-0.5, 0, 0.5, 1, 1.5])
plt.xlabel('time [sec]')
plt.ylabel('speed [px/sec]')
if legend:
plt.legend()
return plt.gca()
[docs]
def plot_pupil_diameter_hist(pupil_diameter, cam_times, trials_df, cam='left'):
"""
Plots histogram of pupil diameter aligned to simulus onset and feedback time.
:param pupil_diameter: np.array, (smoothed) pupil diameter estimate
:param cam_times: np.array, camera timestamps
:param trials_df: pd.DataFrame, with column 'stimOn_times' (time of stimulus onset for each trial) and
feedback_times (time of feedback for each trial)
:param cam: str, camera to use ('body', 'left', 'right') default is 'left'
:returns: matplotlib.axis
"""
for align_to, color in zip(['stimOn_times', 'feedback_times'], ['red', 'purple']):
start_window, end_window = plt_window(trials_df[align_to])
start_idx = insert_idx(cam_times, start_window)
end_idx = np.array(start_idx + int(WINDOW_LEN * SAMPLING[cam]), dtype='int64')
# Per trial norm
pupil_all = [zscore(list(pupil_diameter[start_idx[i]:end_idx[i]])) for i in range(len(start_idx))]
pupil_all_norm = [trial - trial[0] for trial in pupil_all]
pupil_mean = np.nanmean(pupil_all_norm, axis=0)
pupil_std = np.nanstd(pupil_all_norm, axis=0) / np.sqrt(len(pupil_all_norm))
times = np.arange(len(pupil_all_norm[0])) / SAMPLING[cam] + WINDOW_LAG
plt.plot(times, pupil_mean, label=align_to.split("_")[0], color=color)
plt.fill_between(times, pupil_mean + pupil_std, pupil_mean - pupil_std, color=color, alpha=0.5)
plt.axvline(x=0, linestyle='--', c='k')
plt.title(f'Pupil diameter trial avg\n({cam.upper()} cam)')
plt.xlabel('time [sec]')
plt.xticks([-0.5, 0, 0.5, 1, 1.5])
plt.ylabel('z-scored smoothed pupil diameter [px]')
plt.legend(loc='lower right', title='aligned to')