Source code for brainbox.behavior.training

"""Computing and testing IBL training status criteria.

For an in-depth description of each training status, see `Appendix 2`_ of the IBL Protocol For Mice
Training.

.. _Appendix 2: https://figshare.com/articles/preprint/A_standardized_and_reproducible_method_to_\
measure_decision-making_in_mice_Appendix_2_IBL_protocol_for_mice_training/11634729

Examples
--------
Plot the psychometric curve for a given session.

>>> trials = ONE().load_object(eid, 'trials')
>>> fix, ax = plot_psychometric(trials)

Compute 'response times', defined as the duration of open-loop for each contrast.

>>> reaction_time, contrasts, n_contrasts = compute_reaction_time(trials)

Compute 'reaction times', defined as the time between go cue and first detected movement.
NB: These may be negative!

>>> reaction_time, contrasts, n_contrasts = compute_reaction_time(
...     trials, stim_on_type='goCue_times', stim_off_type='firstMovement_times')

Compute 'response times', defined as the time between first detected movement and response.

>>> reaction_time, contrasts, n_contrasts = compute_reaction_time(
...     trials, stim_on_type='firstMovement_times', stim_off_type='response_times')

Compute 'movement times', defined as the time between last detected movement and response threshold.

>>> import brainbox.behavior.wheel as wh
>>> wheel_moves = ONE().load_object(eid, 'wheeMoves')
>>> trials['lastMovement_times'] = wh.get_movement_onset(wheel_moves.intervals, trial_data.response_times)
>>> reaction_time, contrasts, n_contrasts = compute_reaction_time(
...     trials, stim_on_type='lastMovement_times', stim_off_type='response_times')

"""
import logging
import datetime
import re
from enum import IntFlag, auto, unique

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from scipy.stats import bootstrap
from iblutil.util import Bunch
from one.api import ONE
from one.alf.io import AlfBunch
from one.alf.exceptions import ALFObjectNotFound
import psychofit as psy

_logger = logging.getLogger('ibllib')

TRIALS_KEYS = ['contrastLeft',
               'contrastRight',
               'feedbackType',
               'probabilityLeft',
               'choice',
               'response_times',
               'stimOn_times']
"""list of str: The required keys in the trials object for computing training status."""


[docs] @unique class TrainingStatus(IntFlag): """Standard IBL training criteria. Enumeration allows for comparisons between training status. Examples -------- >>> status = 'ready4delay' ... assert TrainingStatus[status.upper()] is TrainingStatus.READY4DELAY ... assert TrainingStatus[status.upper()] not in TrainingStatus.FAILED, 'Subject failed training' ... assert TrainingStatus[status.upper()] >= TrainingStatus.TRAINED, 'Subject untrained' ... assert TrainingStatus[status.upper()] > TrainingStatus.IN_TRAINING, 'Subject untrained' ... assert TrainingStatus[status.upper()] in ~TrainingStatus.FAILED, 'Subject untrained' ... assert TrainingStatus[status.upper()] in TrainingStatus.TRAINED ^ TrainingStatus.READY Get the next training status >>> next(member for member in sorted(TrainingStatus) if member > TrainingStatus[status.upper()]) <TrainingStatus.READY4RECORDING: 128> Notes ----- - ~TrainingStatus.TRAINED means any status but trained 1a or trained 1b. - A subject may acheive both TRAINED_1A and TRAINED_1B within a single session, therefore it is possible to have skipped the TRAINED_1A session status. """ UNTRAINABLE = auto() UNBIASABLE = auto() IN_TRAINING = auto() TRAINED_1A = auto() TRAINED_1B = auto() READY4EPHYSRIG = auto() READY4DELAY = auto() READY4RECORDING = auto() # Compound training statuses for convenience FAILED = UNTRAINABLE | UNBIASABLE READY = READY4EPHYSRIG | READY4DELAY | READY4RECORDING TRAINED = TRAINED_1A | TRAINED_1B
[docs] def get_lab_training_status(lab, date=None, details=True, one=None): """ Computes the training status of all alive and water restricted subjects in a specified lab. The response are printed to std out. Parameters ---------- lab : str Lab name (must match the name registered on Alyx). date : str The ISO date from which to compute training status. If not specified will compute from the latest date with available data. Format should be 'YYYY-MM-DD'. details : bool Whether to display all information about training status computation e.g. performance, number of trials, psychometric fit parameters. one : one.api.OneAlyx An instance of ONE. """ one = one or ONE() subj_lab = one.alyx.rest('subjects', 'list', lab=lab, alive=True, water_restricted=True) subjects = [subj['nickname'] for subj in subj_lab] for subj in subjects: get_subject_training_status(subj, date=date, details=details, one=one)
[docs] def get_subject_training_status(subj, date=None, details=True, one=None): """ Computes the training status of specified subject and prints results to std out. Parameters ---------- subj : str Subject nickname (must match the name registered on Alyx). date : str The ISO date from which to compute training status. If not specified will compute from the latest date with available data. Format should be 'YYYY-MM-DD'. details : bool Whether to display all information about training status computation e.g. performance, number of trials, psychometric fit parameters. one : one.api.OneAlyx An instance of ONE. """ one = one or ONE() trials, task_protocol, ephys_sess, n_delay = get_sessions(subj, date=date, one=one) if not trials: return sess_dates = list(trials.keys()) status, info = get_training_status(trials, task_protocol, ephys_sess, n_delay) if details: if np.any(info.get('psych')): display_status(subj, sess_dates, status, perf_easy=info.perf_easy, n_trials=info.n_trials, psych=info.psych, rt=info.rt) elif np.any(info.get('psych_20')): display_status(subj, sess_dates, status, perf_easy=info.perf_easy, n_trials=info.n_trials, psych_20=info.psych_20, psych_80=info.psych_80, rt=info.rt) else: display_status(subj, sess_dates, status)
[docs] def get_sessions(subj, date=None, one=None): """ Download and load in training data for a specified subject. If a date is given it will load data from the three (or as many as are available) previous sessions up to the specified date. If not it will load data from the last three training sessions that have data available. Parameters ---------- subj : str Subject nickname (must match the name registered on Alyx). date : str The ISO date from which to compute training status. If not specified will compute from the latest date with available data. Format should be 'YYYY-MM-DD'. one : one.api.OneAlyx An instance of ONE. Returns ------- iblutil.util.Bunch Dictionary of trials objects where each key is the ISO session date string. list of str List of the task protocol used for each of the sessions. list of str List of ISO date strings where training was conducted on ephys rig. Empty list if all sessions on training rig. n_delay : int Number of sessions on ephys rig that had delay prior to starting session > 15min. Returns 0 if no sessions detected. """ one = one or ONE() if date is None: # compute from yesterday specified_date = (datetime.date.today() - datetime.timedelta(days=1)) latest_sess = specified_date.strftime("%Y-%m-%d") latest_minus_week = (datetime.date.today() - datetime.timedelta(days=8)).strftime("%Y-%m-%d") else: # compute from the date specified specified_date = datetime.datetime.strptime(date, '%Y-%m-%d') latest_minus_week = (specified_date - datetime.timedelta(days=7)).strftime("%Y-%m-%d") latest_sess = date sessions = one.alyx.rest('sessions', 'list', subject=subj, date_range=[latest_minus_week, latest_sess], dataset_types='trials.goCueTrigger_times') # If not enough sessions in the last week, then just fetch them all if len(sessions) < 3: specified_date_plus = (specified_date + datetime.timedelta(days=1)).strftime("%Y-%m-%d") django_query = 'start_time__lte,' + specified_date_plus sessions = one.alyx.rest('sessions', 'list', subject=subj, dataset_types='trials.goCueTrigger_times', django=django_query) # If still 0 sessions then return with warning if len(sessions) == 0: _logger.warning(f"No training sessions detected for {subj}") return [None] * 4 trials = Bunch() task_protocol = [] sess_dates = [] if len(sessions) < 3: for n, _ in enumerate(sessions): try: trials_ = one.load_object(sessions[n]['url'].split('/')[-1], 'trials') except ALFObjectNotFound: trials_ = None if trials_: task_protocol.append(re.search('tasks_(.*)Choice', sessions[n]['task_protocol']).group(1)) sess_dates.append(sessions[n]['start_time'][:10]) trials[sessions[n]['start_time'][:10]] = trials_ else: n = 0 while len(trials) < 3: print(sessions[n]['url'].split('/')[-1]) try: trials_ = one.load_object(sessions[n]['url'].split('/')[-1], 'trials') except ALFObjectNotFound: trials_ = None if trials_: task_protocol.append(re.search('tasks_(.*)Choice', sessions[n]['task_protocol']).group(1)) sess_dates.append(sessions[n]['start_time'][:10]) trials[sessions[n]['start_time'][:10]] = trials_ n += 1 if not np.any(np.array(task_protocol) == 'training'): ephys_sess = one.alyx.rest('sessions', 'list', subject=subj, date_range=[sess_dates[-1], sess_dates[0]], django='json__PYBPOD_BOARD__icontains,ephys') if len(ephys_sess) > 0: ephys_sess_dates = [sess['start_time'][:10] for sess in ephys_sess] n_delay = len(one.alyx.rest('sessions', 'list', subject=subj, date_range=[sess_dates[-1], sess_dates[0]], django='json__SESSION_START_DELAY_SEC__gte,900')) else: ephys_sess_dates = [] n_delay = 0 else: ephys_sess_dates = [] n_delay = 0 return trials, task_protocol, ephys_sess_dates, n_delay
[docs] def get_training_status(trials, task_protocol, ephys_sess_dates, n_delay): """ Compute training status of a subject from consecutive training datasets. For IBL, training status is calculated using trials from the last three consecutive sessions. Parameters ---------- trials : dict of str Dictionary of trials objects where each key is the ISO session date string. task_protocol : list of str Task protocol used for each training session in `trials`, can be 'training', 'biased' or 'ephys'. ephys_sess_dates : list of str List of ISO date strings where training was conducted on ephys rig. Empty list if all sessions on training rig. n_delay : int Number of sessions on ephys rig that had delay prior to starting session > 15min. Returns 0 if no sessions detected. Returns ------- str Training status of the subject. iblutil.util.Bunch Bunch containing performance metrics that decide training status i.e. performance on easy trials, number of trials, psychometric fit parameters, reaction time. """ info = Bunch() trials_all = concatenate_trials(trials) # Case when all sessions are trainingChoiceWorld if np.all(np.array(task_protocol) == 'training'): signed_contrast = get_signed_contrast(trials_all) (info.perf_easy, info.n_trials, info.psych, info.rt) = compute_training_info(trials, trials_all) if not np.any(signed_contrast == 0): status = 'in training' else: if criterion_1b(info.psych, info.n_trials, info.perf_easy, info.rt): status = 'trained 1b' elif criterion_1a(info.psych, info.n_trials, info.perf_easy): status = 'trained 1a' else: status = 'in training' return status, info # Case when there are < 3 biasedChoiceWorld sessions after reaching trained_1b criterion if ~np.all(np.array(task_protocol) == 'training') and \ np.any(np.array(task_protocol) == 'training'): status = 'trained 1b' (info.perf_easy, info.n_trials, info.psych, info.rt) = compute_training_info(trials, trials_all) return status, info # Case when there is biasedChoiceWorld or ephysChoiceWorld in last three sessions if not np.any(np.array(task_protocol) == 'training'): (info.perf_easy, info.n_trials, info.psych_20, info.psych_80, info.rt) = compute_bias_info(trials, trials_all) # We are still on training rig and so all sessions should be biased if len(ephys_sess_dates) == 0: assert np.all(np.array(task_protocol) == 'biased') if criterion_ephys(info.psych_20, info.psych_80, info.n_trials, info.perf_easy, info.rt): status = 'ready4ephysrig' else: status = 'trained 1b' elif len(ephys_sess_dates) < 3: assert all(date in trials for date in ephys_sess_dates) perf_ephys_easy = np.array([compute_performance_easy(trials[k]) for k in ephys_sess_dates]) n_ephys_trials = np.array([compute_n_trials(trials[k]) for k in ephys_sess_dates]) if criterion_delay(n_ephys_trials, perf_ephys_easy): status = 'ready4delay' else: status = 'ready4ephysrig' elif len(ephys_sess_dates) >= 3: if n_delay > 0 and \ criterion_ephys(info.psych_20, info.psych_80, info.n_trials, info.perf_easy, info.rt): status = 'ready4recording' elif criterion_delay(info.n_trials, info.perf_easy): status = 'ready4delay' else: status = 'ready4ephysrig' return status, info
[docs] def display_status(subj, sess_dates, status, perf_easy=None, n_trials=None, psych=None, psych_20=None, psych_80=None, rt=None): """ Display training status of subject to terminal. Parameters ---------- subj : str Subject nickname (must match the name registered on Alyx). sess_dates : list of str ISO date strings of training sessions used to determine training status. status : str Training status of subject. perf_easy : numpy.array Proportion of correct high contrast trials for each training session. n_trials : numpy.array Total number of trials for each training session. psych : numpy.array Psychometric parameters fit to data from all training sessions - bias, threshold, lapse high, lapse low. psych_20 : numpy.array The fit psychometric parameters for the blocks where probability of a left stimulus is 0.2. psych_80 : numpy.array The fit psychometric parameters for the blocks where probability of a left stimulus is 0.8. rt : float The median response time for zero contrast trials across all training sessions. NaN indicates no zero contrast stimuli in training sessions. """ if perf_easy is None: print(f"\n{subj} : {status} \nSession dates=[{sess_dates[0]}, {sess_dates[1]}, " f"{sess_dates[2]}]") elif psych_20 is None: print(f"\n{subj} : {status} \nSession dates={[x for x in sess_dates]}, " f"Perf easy={[np.around(pe,2) for pe in perf_easy]}, " f"N trials={[nt for nt in n_trials]} " f"\nPsych fit over last 3 sessions: " f"bias={np.around(psych[0],2)}, thres={np.around(psych[1],2)}, " f"lapse_low={np.around(psych[2],2)}, lapse_high={np.around(psych[3],2)} " f"\nMedian reaction time at 0 contrast over last 3 sessions = " f"{np.around(rt,2)}") else: print(f"\n{subj} : {status} \nSession dates={[x for x in sess_dates]}, " f"Perf easy={[np.around(pe,2) for pe in perf_easy]}, " f"N trials={[nt for nt in n_trials]} " f"\nPsych fit over last 3 sessions (20): " f"bias={np.around(psych_20[0],2)}, thres={np.around(psych_20[1],2)}, " f"lapse_low={np.around(psych_20[2],2)}, lapse_high={np.around(psych_20[3],2)} " f"\nPsych fit over last 3 sessions (80): bias={np.around(psych_80[0],2)}, " f"thres={np.around(psych_80[1],2)}, lapse_low={np.around(psych_80[2],2)}, " f"lapse_high={np.around(psych_80[3],2)} " f"\nMedian reaction time at 0 contrast over last 3 sessions = " f"{np.around(rt, 2)}")
[docs] def concatenate_trials(trials): """ Concatenate trials from different training sessions. Parameters ---------- trials : dict of str Dictionary of trials objects where each key is the ISO session date string. Returns ------- one.alf.io.AlfBunch Trials object with data concatenated over three training sessions. """ trials_all = AlfBunch() for k in TRIALS_KEYS: trials_all[k] = np.concatenate(list(trials[kk][k] for kk in trials.keys())) return trials_all
[docs] def compute_training_info(trials, trials_all): """ Compute all relevant performance metrics for when subject is on trainingChoiceWorld. Parameters ---------- trials : dict of str Dictionary of trials objects where each key is the ISO session date string. trials_all : one.alf.io.AlfBunch Trials object with data concatenated over three training sessions. Returns ------- numpy.array Proportion of correct high contrast trials for each session. numpy.array Total number of trials for each training session. numpy.array Array of psychometric parameters fit to `all_trials` - bias, threshold, lapse high, lapse low. float The median response time for all zero-contrast trials across all sessions. Returns NaN if no trials zero-contrast trials). """ signed_contrast = get_signed_contrast(trials_all) perf_easy = np.array([compute_performance_easy(trials[k]) for k in trials.keys()]) n_trials = np.array([compute_n_trials(trials[k]) for k in trials.keys()]) psych = compute_psychometric(trials_all, signed_contrast=signed_contrast) rt = compute_median_reaction_time(trials_all, contrast=0, signed_contrast=signed_contrast) return perf_easy, n_trials, psych, rt
[docs] def compute_bias_info(trials, trials_all): """ Compute all relevant performance metrics for when subject is on biasedChoiceWorld :param trials: dict containing trials objects from three consecutive training sessions, keys are session dates :type trials: Bunch :param trials_all: trials object with data concatenated over three training sessions :type trials_all: Bunch :returns: - perf_easy - performance of easy trials for each session - n_trials - number of trials in each session - psych_20 - parameters for psychometric curve fit to trials in 20 block over all sessions - psych_80 - parameters for psychometric curve fit to trials in 80 block over all sessions - rt - median reaction time for zero contrast stimuli over all sessions """ signed_contrast = get_signed_contrast(trials_all) perf_easy = np.array([compute_performance_easy(trials[k]) for k in trials.keys()]) n_trials = np.array([compute_n_trials(trials[k]) for k in trials.keys()]) psych_20 = compute_psychometric(trials_all, signed_contrast=signed_contrast, block=0.2) psych_80 = compute_psychometric(trials_all, signed_contrast=signed_contrast, block=0.8) rt = compute_median_reaction_time(trials_all, contrast=0, signed_contrast=signed_contrast) return perf_easy, n_trials, psych_20, psych_80, rt
[docs] def get_signed_contrast(trials): """ Compute signed contrast from trials object :param trials: trials object that must contain contrastLeft and contrastRight keys :type trials: dict returns: array of signed contrasts in percent, where -ve values are on the left """ # Replace NaNs with zeros, stack and take the difference contrast = np.nan_to_num(np.c_[trials['contrastLeft'], trials['contrastRight']]) return np.diff(contrast).flatten() * 100
[docs] def compute_performance_easy(trials): """ Compute performance on easy trials (stimulus >= 50 %) from trials object :param trials: trials object that must contain contrastLeft, contrastRight and feedbackType keys :type trials: dict returns: float containing performance on easy contrast trials """ signed_contrast = get_signed_contrast(trials) easy_trials = np.where(np.abs(signed_contrast) >= 50)[0] return np.sum(trials['feedbackType'][easy_trials] == 1) / easy_trials.shape[0]
[docs] def compute_performance(trials, signed_contrast=None, block=None, prob_right=False): """ Compute performance on all trials at each contrast level from trials object :param trials: trials object that must contain contrastLeft, contrastRight and feedbackType keys :type trials: dict returns: float containing performance on easy contrast trials """ if signed_contrast is None: signed_contrast = get_signed_contrast(trials) if block is None: block_idx = np.full(trials.probabilityLeft.shape, True, dtype=bool) else: block_idx = trials.probabilityLeft == block if not np.any(block_idx): return np.nan * np.zeros(3) contrasts, n_contrasts = np.unique(signed_contrast[block_idx], return_counts=True) if not prob_right: correct = trials.feedbackType == 1 performance = np.vectorize(lambda x: np.mean(correct[(x == signed_contrast) & block_idx]))(contrasts) else: rightward = trials.choice == -1 # Calculate the proportion rightward for each contrast type performance = np.vectorize(lambda x: np.mean(rightward[(x == signed_contrast) & block_idx]))(contrasts) return performance, contrasts, n_contrasts
[docs] def compute_n_trials(trials): """ Compute number of trials in trials object :param trials: trials object :type trials: dict returns: int containing number of trials in session """ return trials['choice'].shape[0]
[docs] def compute_psychometric(trials, signed_contrast=None, block=None, plotting=False, compute_ci=False, alpha=.032): """ Compute psychometric fit parameters for trials object. Parameters ---------- trials : one.alf.io.AlfBunch An ALF trials object containing the keys {'probabilityLeft', 'contrastLeft', 'contrastRight', 'feedbackType', 'choice', 'response_times', 'stimOn_times'}. signed_contrast : numpy.array An array of signed contrasts in percent the length of trials, where left contrasts are -ve. If None, these are computed from the trials object. block : float The block type to compute. If None, all trials are included, otherwise only trials where probabilityLeft matches this value are included. For biasedChoiceWorld, the probabilityLeft set is {0.5, 0.2, 0.8}. plotting : bool Which set of psychofit model parameters to use (see notes). compute_ci : bool If true, computes and returns the confidence intervals for response at each contrast. alpha : float, default=0.032 Significance level for confidence interval. Must be in (0, 1). If `compute_ci` is false, this value is ignored. Returns ------- numpy.array Array of psychometric fit parameters - bias, threshold, lapse high, lapse low. (tuple of numpy.array) If `compute_ci` is true, a tuple of See Also -------- statsmodels.stats.proportion.proportion_confint - The function used to compute confidence interval. psychofit.mle_fit_psycho - The function used to fit the psychometric parameters. Notes ----- The psychofit starting parameters and model constraints used for the fit when computing the training status (e.g. trained_1a, etc.) are sub-optimal and can produce a poor fit. To keep the precise criteria the same for all subjects, these parameters have not changed. To produce a better fit for plotting purposes, or to calculate the training status in a manner inconsistent with the IBL training pipeline, use plotting=True. """ if signed_contrast is None: signed_contrast = get_signed_contrast(trials) if block is None: block_idx = np.full(trials.probabilityLeft.shape, True, dtype=bool) else: block_idx = trials.probabilityLeft == block if not np.any(block_idx): return np.nan * np.zeros(4) prob_choose_right, contrasts, n_contrasts = compute_performance( trials, signed_contrast=signed_contrast, block=block, prob_right=True) if plotting: # These starting parameters and constraints tend to produce a better fit, and are therefore # used for plotting. psych, _ = psy.mle_fit_psycho( np.vstack([contrasts, n_contrasts, prob_choose_right]), P_model='erf_psycho_2gammas', parstart=np.array([0., 40., 0.1, 0.1]), parmin=np.array([-50., 10., 0., 0.]), parmax=np.array([50., 50., 0.2, 0.2]), nfits=10) else: # These starting parameters and constraints are not ideal but are still used for computing # the training status for consistency. psych, _ = psy.mle_fit_psycho( np.vstack([contrasts, n_contrasts, prob_choose_right]), P_model='erf_psycho_2gammas', parstart=np.array([np.mean(contrasts), 20., 0.05, 0.05]), parmin=np.array([np.min(contrasts), 0., 0., 0.]), parmax=np.array([np.max(contrasts), 100., 1, 1])) if compute_ci: import statsmodels.stats.proportion as smp # noqa # choice == -1 means contrast on right hand side n_right = np.vectorize(lambda x: np.sum(trials['choice'][(x == signed_contrast) & block_idx] == -1))(contrasts) ci = smp.proportion_confint(n_right, n_contrasts, alpha=alpha, method='normal') - prob_choose_right return psych, ci else: return psych
[docs] def compute_median_reaction_time(trials, stim_on_type='stimOn_times', contrast=None, signed_contrast=None): """ Compute median response time on zero contrast trials from trials object Parameters ---------- trials : one.alf.io.AlfBunch An ALF trials object containing the keys {'probabilityLeft', 'contrastLeft', 'contrastRight', 'feedbackType', 'choice', 'response_times', 'stimOn_times'}. stim_on_type : str, default='stimOn_times' The trials key to use when calculating the response times. The difference between this and 'feedback_times' is used (see notes). contrast : float If None, the median response time is calculated for all trials, regardless of contrast, otherwise only trials where the matching signed percent contrast was presented are used. signed_contrast : numpy.array An array of signed contrasts in percent the length of trials, where left contrasts are -ve. If None, these are computed from the trials object. Returns ------- float The median response time for trials with `contrast` (returns NaN if no trials matching `contrast` in trials object). Notes ----- - The `stim_on_type` is 'stimOn_times' by default, however for IBL rig data, the photodiode is sometimes not calibrated properly which can lead to inaccurate (or absent, i.e. NaN) stim on times. Therefore, it is sometimes more accurate to use the 'stimOnTrigger_times' (the time of the stimulus onset command), if available, or the 'goCue_times' (the time of the soundcard output TTL when the audio go cue is played) or the 'goCueTrigger_times' (the time of the audio go cue command). - The response/reaction time here is defined as the time between stim on and feedback, i.e. the entire open-loop trial duration. """ if signed_contrast is None: signed_contrast = get_signed_contrast(trials) if contrast is None: contrast_idx = np.full(trials.probabilityLeft.shape, True, dtype=bool) else: contrast_idx = signed_contrast == contrast if np.any(contrast_idx): reaction_time = np.nanmedian((trials.response_times - trials[stim_on_type]) [contrast_idx]) else: reaction_time = np.nan return reaction_time
[docs] def compute_reaction_time(trials, stim_on_type='stimOn_times', stim_off_type='response_times', signed_contrast=None, block=None, compute_ci=False, alpha=0.32): """ Compute median response time for all contrasts. Parameters ---------- trials : one.alf.io.AlfBunch An ALF trials object containing the keys {'probabilityLeft', 'contrastLeft', 'contrastRight', 'feedbackType', 'choice', 'response_times', 'stimOn_times'}. stim_on_type : str, default='stimOn_times' The trials key to use when calculating the response times. The difference between this and `stim_off_type` is used (see notes). stim_off_type : str, default='response_times' The trials key to use when calculating the response times. The difference between this and `stim_on_type` is used (see notes). signed_contrast : numpy.array An array of signed contrasts in percent the length of trials, where left contrasts are -ve. If None, these are computed from the trials object. block : float The block type to compute. If None, all trials are included, otherwise only trials where probabilityLeft matches this value are included. For biasedChoiceWorld, the probabilityLeft set is {0.5, 0.2, 0.8}. compute_ci : bool If true, computes and returns the confidence intervals for response time at each contrast. alpha : float, default=0.32 Significance level for confidence interval. Must be in (0, 1). If `compute_ci` is false, this value is ignored. Returns ------- numpy.array The median response times for each unique signed contrast. numpy.array The set of unique signed contrasts. numpy.array The number of trials for each unique signed contrast. (numpy.array) If `compute_ci` is true, an array of confidence intervals is return in the shape (n_trials, 2). Notes ----- - The response/reaction time by default is the time between stim on and response, i.e. the entire open-loop trial duration. One could use 'stimOn_times' and 'firstMovement_times' to get the true reaction time, or 'firstMovement_times' and 'response_times' to get the true response times, or calculate the last movement onset times and calculate the true movement times. See module examples for how to calculate this. See Also -------- scipy.stats.bootstrap - the function used to compute the confidence interval. """ if signed_contrast is None: signed_contrast = get_signed_contrast(trials) if block is None: block_idx = np.full(trials.probabilityLeft.shape, True, dtype=bool) else: block_idx = trials.probabilityLeft == block contrasts, n_contrasts = np.unique(signed_contrast[block_idx], return_counts=True) reaction_time = np.vectorize( lambda x: np.nanmedian((trials[stim_off_type] - trials[stim_on_type])[(x == signed_contrast) & block_idx]), otypes=[float] )(contrasts) if compute_ci: ci = np.full((contrasts.size, 2), np.nan) for i, x in enumerate(contrasts): data = (trials[stim_off_type] - trials[stim_on_type])[(x == signed_contrast) & block_idx] bt = bootstrap((data,), np.nanmedian, confidence_level=1 - alpha) ci[i, 0] = bt.confidence_interval.low ci[i, 1] = bt.confidence_interval.high return reaction_time, contrasts, n_contrasts, ci else: return reaction_time, contrasts, n_contrasts,
[docs] def criterion_1a(psych, n_trials, perf_easy): """ Returns bool indicating whether criteria for status 'trained_1a' are met. Criteria -------- - Bias is less than 16 - Threshold is less than 19 - Lapse rate on both sides is less than 0.2 - The total number of trials is greater than 200 for each session - Performance on easy contrasts > 80% for all sessions Parameters ---------- psych : numpy.array The fit psychometric parameters three consecutive sessions. Parameters are bias, threshold, lapse high, lapse low. n_trials : numpy.array of int The number for trials for each session. perf_easy : numpy.array of float The proportion of correct high contrast trials for each session. Returns ------- bool True if the criteria are met for 'trained_1a'. Notes ----- The parameter thresholds chosen here were originally determined by averaging the parameter fits for a number of sessions determined to be of 'good' performance by an experimenter. """ criterion = (abs(psych[0]) < 16 and psych[1] < 19 and psych[2] < 0.2 and psych[3] < 0.2 and np.all(n_trials > 200) and np.all(perf_easy > 0.8)) return criterion
[docs] def criterion_1b(psych, n_trials, perf_easy, rt): """ Returns bool indicating whether criteria for trained_1b are met. Criteria -------- - Bias is less than 10 - Threshold is less than 20 (see notes) - Lapse rate on both sides is less than 0.1 - The total number of trials is greater than 400 for each session - Performance on easy contrasts > 90% for all sessions - The median response time across all zero contrast trials is less than 2 seconds Parameters ---------- psych : numpy.array The fit psychometric parameters three consecutive sessions. Parameters are bias, threshold, lapse high, lapse low. n_trials : numpy.array of int The number for trials for each session. perf_easy : numpy.array of float The proportion of correct high contrast trials for each session. rt : float The median response time for zero contrast trials. Returns ------- bool True if the criteria are met for 'trained_1b'. Notes ----- The parameter thresholds chosen here were originally chosen to be slightly stricter than 1a, however it was decided to use round numbers so that readers would not assume a level of precision that isn't there (remember, these parameters were not chosen with any rigor). This regrettably means that the maximum threshold fit for 1b is greater than for 1a, meaning the slope of the psychometric curve may be slightly less steep than 1a. """ criterion = (abs(psych[0]) < 10 and psych[1] < 20 and psych[2] < 0.1 and psych[3] < 0.1 and np.all(n_trials > 400) and np.all(perf_easy > 0.9) and rt < 2) return criterion
[docs] def criterion_ephys(psych_20, psych_80, n_trials, perf_easy, rt): """ Returns bool indicating whether criteria for ready4ephysrig or ready4recording are met. NB: The difference between these two is whether the sessions were acquired ot a recording rig with a delay before the first trial. Neither of these two things are tested here. Criteria -------- - Lapse on both sides < 0.1 for both bias blocks - Bias shift between blocks > 5 - Total number of trials > 400 for all sessions - Performance on easy contrasts > 90% for all sessions - Median response time for zero contrast stimuli < 2 seconds Parameters ---------- psych_20 : numpy.array The fit psychometric parameters for the blocks where probability of a left stimulus is 0.2. Parameters are bias, threshold, lapse high, lapse low. psych_80 : numpy.array The fit psychometric parameters for the blocks where probability of a left stimulus is 0.8. Parameters are bias, threshold, lapse high, lapse low. n_trials : numpy.array The number of trials for each session (typically three consecutive sessions). perf_easy : numpy.array The proportion of correct high contrast trials for each session (typically three consecutive sessions). rt : float The median response time for zero contrast trials. Returns ------- bool True if subject passes the ready4ephysrig or ready4recording criteria. """ criterion = (np.all(np.r_[psych_20[2:4], psych_80[2:4]] < 0.1) and # lapse psych_80[0] - psych_20[0] > 5 and np.all(n_trials > 400) and # bias shift and n trials np.all(perf_easy > 0.9) and rt < 2) # overall performance and response times return criterion
[docs] def criterion_delay(n_trials, perf_easy): """ Returns bool indicating whether criteria for 'ready4delay' is met. Criteria -------- - Total number of trials for any of the sessions is greater than 400 - Performance on easy contrasts is greater than 90% for any of the sessions Parameters ---------- n_trials : numpy.array of int The number of trials for each session (typically three consecutive sessions). perf_easy : numpy.array The proportion of correct high contrast trials for each session (typically three consecutive sessions). Returns ------- bool True if subject passes the 'ready4delay' criteria. """ criterion = np.any(n_trials > 400) and np.any(perf_easy > 0.9) return criterion
[docs] def plot_psychometric(trials, ax=None, title=None, plot_ci=False, ci_alpha=0.032, **kwargs): """ Function to plot psychometric curve plots a la datajoint webpage. Parameters ---------- trials : one.alf.io.AlfBunch An ALF trials object containing the keys {'probabilityLeft', 'contrastLeft', 'contrastRight', 'feedbackType', 'choice', 'response_times', 'stimOn_times'}. ax : matplotlib.pyplot.Axes An axis object to plot onto. title : str An optional plot title. plot_ci : bool If true, computes and plots the confidence intervals for response at each contrast. ci_alpha : float, default=0.032 Significance level for confidence interval. Must be in (0, 1). If `plot_ci` is false, this value is ignored. **kwargs If `ax` is None, these arguments are passed to matplotlib.pyplot.subplots. Returns ------- matplotlib.pyplot.Figure The figure handle containing the plot. matplotlib.pyplot.Axes The plotted axes. See Also -------- statsmodels.stats.proportion.proportion_confint - The function used to compute confidence interval. psychofit.mle_fit_psycho - The function used to fit the psychometric parameters. psychofit.erf_psycho_2gammas - The function used to transform contrast to response probability using the fit parameters. """ signed_contrast = get_signed_contrast(trials) contrasts_fit = np.arange(-100, 100) prob_right_50, contrasts_50, _ = compute_performance(trials, signed_contrast=signed_contrast, block=0.5, prob_right=True) out_50 = compute_psychometric(trials, signed_contrast=signed_contrast, block=0.5, plotting=True, compute_ci=plot_ci, alpha=ci_alpha) pars_50 = out_50[0] if plot_ci else out_50 prob_right_fit_50 = psy.erf_psycho_2gammas(pars_50, contrasts_fit) prob_right_20, contrasts_20, _ = compute_performance(trials, signed_contrast=signed_contrast, block=0.2, prob_right=True) out_20 = compute_psychometric(trials, signed_contrast=signed_contrast, block=0.2, plotting=True, compute_ci=plot_ci, alpha=ci_alpha) pars_20 = out_20[0] if plot_ci else out_20 prob_right_fit_20 = psy.erf_psycho_2gammas(pars_20, contrasts_fit) prob_right_80, contrasts_80, _ = compute_performance(trials, signed_contrast=signed_contrast, block=0.8, prob_right=True) out_80 = compute_psychometric(trials, signed_contrast=signed_contrast, block=0.8, plotting=True, compute_ci=plot_ci, alpha=ci_alpha) pars_80 = out_80[0] if plot_ci else out_80 prob_right_fit_80 = psy.erf_psycho_2gammas(pars_80, contrasts_fit) cmap = sns.diverging_palette(20, 220, n=3, center='dark') if not ax: fig, ax = plt.subplots(**kwargs) else: fig = plt.gcf() fit_50 = ax.plot(contrasts_fit, prob_right_fit_50, color=cmap[1]) data_50 = ax.scatter(contrasts_50, prob_right_50, color=cmap[1]) fit_20 = ax.plot(contrasts_fit, prob_right_fit_20, color=cmap[0]) data_20 = ax.scatter(contrasts_20, prob_right_20, color=cmap[0]) fit_80 = ax.plot(contrasts_fit, prob_right_fit_80, color=cmap[2]) data_80 = ax.scatter(contrasts_80, prob_right_80, color=cmap[2]) if plot_ci: errbar_50 = np.c_[np.abs(out_50[1][0]), np.abs(out_50[1][1])].T errbar_20 = np.c_[np.abs(out_20[1][0]), np.abs(out_20[1][1])].T errbar_80 = np.c_[np.abs(out_80[1][0]), np.abs(out_80[1][1])].T ax.errorbar(contrasts_50, prob_right_50, yerr=errbar_50, ecolor=cmap[1], fmt='none', capsize=5, alpha=0.4) ax.errorbar(contrasts_20, prob_right_20, yerr=errbar_20, ecolor=cmap[0], fmt='none', capsize=5, alpha=0.4) ax.errorbar(contrasts_80, prob_right_80, yerr=errbar_80, ecolor=cmap[2], fmt='none', capsize=5, alpha=0.4) ax.legend([fit_50[0], data_50, fit_20[0], data_20, fit_80[0], data_80], ['p_left=0.5 fit', 'p_left=0.5 data', 'p_left=0.2 fit', 'p_left=0.2 data', 'p_left=0.8 fit', 'p_left=0.8 data'], loc='upper left') ax.set_ylim(-0.05, 1.05) ax.set_ylabel('Probability choosing right') ax.set_xlabel('Contrasts') if title: ax.set_title(title) return fig, ax
[docs] def plot_reaction_time(trials, ax=None, title=None, plot_ci=False, ci_alpha=0.32, **kwargs): """ Function to plot reaction time against contrast a la datajoint webpage. The reaction times are plotted individually for the following three blocks: {0.5, 0.2, 0.8}. Parameters ---------- trials : one.alf.io.AlfBunch An ALF trials object containing the keys {'probabilityLeft', 'contrastLeft', 'contrastRight', 'feedbackType', 'choice', 'response_times', 'stimOn_times'}. ax : matplotlib.pyplot.Axes An axis object to plot onto. title : str An optional plot title. plot_ci : bool If true, computes and plots the confidence intervals for response at each contrast. ci_alpha : float, default=0.32 Significance level for confidence interval. Must be in (0, 1). If `plot_ci` is false, this value is ignored. **kwargs If `ax` is None, these arguments are passed to matplotlib.pyplot.subplots. Returns ------- matplotlib.pyplot.Figure The figure handle containing the plot. matplotlib.pyplot.Axes The plotted axes. See Also -------- scipy.stats.bootstrap - the function used to compute the confidence interval. """ signed_contrast = get_signed_contrast(trials) out_50 = compute_reaction_time(trials, signed_contrast=signed_contrast, block=0.5, compute_ci=plot_ci, alpha=ci_alpha) out_20 = compute_reaction_time(trials, signed_contrast=signed_contrast, block=0.2, compute_ci=plot_ci, alpha=ci_alpha) out_80 = compute_reaction_time(trials, signed_contrast=signed_contrast, block=0.8, compute_ci=plot_ci, alpha=ci_alpha) cmap = sns.diverging_palette(20, 220, n=3, center='dark') if not ax: fig, ax = plt.subplots(**kwargs) else: fig = plt.gcf() data_50 = ax.plot(out_50[1], out_50[0], '-o', color=cmap[1]) data_20 = ax.plot(out_20[1], out_20[0], '-o', color=cmap[0]) data_80 = ax.plot(out_80[1], out_80[0], '-o', color=cmap[2]) if plot_ci: errbar_50 = np.c_[out_50[0] - out_50[3][:, 0], out_50[3][:, 1] - out_50[0]].T errbar_20 = np.c_[out_20[0] - out_20[3][:, 0], out_20[3][:, 1] - out_20[0]].T errbar_80 = np.c_[out_80[0] - out_80[3][:, 0], out_80[3][:, 1] - out_80[0]].T ax.errorbar(out_50[1], out_50[0], yerr=errbar_50, ecolor=cmap[1], fmt='none', capsize=5, alpha=0.4) ax.errorbar(out_20[1], out_20[0], yerr=errbar_20, ecolor=cmap[0], fmt='none', capsize=5, alpha=0.4) ax.errorbar(out_80[1], out_80[0], yerr=errbar_80, ecolor=cmap[2], fmt='none', capsize=5, alpha=0.4) ax.legend([data_50[0], data_20[0], data_80[0]], ['p_left=0.5 data', 'p_left=0.2 data', 'p_left=0.8 data'], loc='upper left') ax.set_ylabel('Reaction time (s)') ax.set_xlabel('Contrasts') if title: ax.set_title(title) return fig, ax
[docs] def plot_reaction_time_over_trials(trials, stim_on_type='stimOn_times', ax=None, title=None, **kwargs): """ Function to plot reaction time with trial number a la datajoint webpage. Parameters ---------- trials : one.alf.io.AlfBunch An ALF trials object containing the keys {'probabilityLeft', 'contrastLeft', 'contrastRight', 'feedbackType', 'choice', 'response_times', 'stimOn_times'}. stim_on_type : str, default='stimOn_times' The trials key to use when calculating the response times. The difference between this and 'feedback_times' is used (see notes for `compute_median_reaction_time`). ax : matplotlib.pyplot.Axes An axis object to plot onto. title : str An optional plot title. **kwargs If `ax` is None, these arguments are passed to matplotlib.pyplot.subplots. Returns ------- matplotlib.pyplot.Figure The figure handle containing the plot. matplotlib.pyplot.Axes The plotted axes. """ reaction_time = pd.DataFrame() reaction_time['reaction_time'] = trials.response_times - trials[stim_on_type] reaction_time.index = reaction_time.index + 1 reaction_time_rolled = reaction_time['reaction_time'].rolling(window=10).median() reaction_time_rolled = reaction_time_rolled.where((pd.notnull(reaction_time_rolled)), None) reaction_time = reaction_time.where((pd.notnull(reaction_time)), None) if not ax: fig, ax = plt.subplots(**kwargs) else: fig = plt.gcf() ax.scatter(np.arange(len(reaction_time.values)), reaction_time.values, s=16, color='darkgray') ax.plot(np.arange(len(reaction_time_rolled.values)), reaction_time_rolled.values, color='k', linewidth=2) ax.set_yscale('log') ax.set_ylim(0.1, 100) ax.yaxis.set_major_formatter(matplotlib.ticker.ScalarFormatter()) ax.set_ylabel('Reaction time (s)') ax.set_xlabel('Trial number') if title: ax.set_title(title) return fig, ax
[docs] def query_criterion(subject, status, from_status=None, one=None, validate=True): """Get the session for which a given training criterion was met. Parameters ---------- subject : str The subject name. status : str The training status to query for. from_status : str, optional Count number of sessions and days from reaching `from_status` to `status`. one : one.api.OneAlyx, optional An instance of ONE. validate : bool If true, check if status in TrainingStatus enumeration. Set to false for non-standard training pipelines. Returns ------- str The eID of the first session where this training status was reached. int The number of sessions it took to reach `status` (optionally from reaching `from_status`). int The number of days it tool to reach `status` (optionally from reaching `from_status`). """ if validate: status = status.lower().replace(' ', '_') try: status = TrainingStatus[status.upper().replace(' ', '_')].name.lower() except KeyError as ex: raise ValueError( f'Unknown status "{status}". For non-standard training protocols set validate=False' ) from ex one = one or ONE() subject_json = one.alyx.rest('subjects', 'read', id=subject)['json'] if not (criteria := subject_json.get('trained_criteria')) or status not in criteria: return None, None, None to_date, eid = criteria[status] from_date, _ = criteria.get(from_status, (None, None)) eids, det = one.search(subject=subject, date_range=[from_date, to_date], details=True) if len(eids) == 0: return eid, None, None delta_date = det[0]['date'] - det[-1]['date'] return eid, len(eids), delta_date.days