"""Computing and testing IBL training status criteria.
For an in-depth description of each training status, see `Appendix 2`_ of the IBL Protocol For Mice
Training.
.. _Appendix 2: https://figshare.com/articles/preprint/A_standardized_and_reproducible_method_to_\
measure_decision-making_in_mice_Appendix_2_IBL_protocol_for_mice_training/11634729
Examples
--------
Plot the psychometric curve for a given session.
>>> trials = ONE().load_object(eid, 'trials')
>>> fix, ax = plot_psychometric(trials)
Compute 'response times', defined as the duration of open-loop for each contrast.
>>> reaction_time, contrasts, n_contrasts = compute_reaction_time(trials)
Compute 'reaction times', defined as the time between go cue and first detected movement.
NB: These may be negative!
>>> reaction_time, contrasts, n_contrasts = compute_reaction_time(
... trials, stim_on_type='goCue_times', stim_off_type='firstMovement_times')
Compute 'response times', defined as the time between first detected movement and response.
>>> reaction_time, contrasts, n_contrasts = compute_reaction_time(
... trials, stim_on_type='firstMovement_times', stim_off_type='response_times')
Compute 'movement times', defined as the time between last detected movement and response threshold.
>>> import brainbox.behavior.wheel as wh
>>> wheel_moves = ONE().load_object(eid, 'wheeMoves')
>>> trials['lastMovement_times'] = wh.get_movement_onset(wheel_moves.intervals, trial_data.response_times)
>>> reaction_time, contrasts, n_contrasts = compute_reaction_time(
... trials, stim_on_type='lastMovement_times', stim_off_type='response_times')
"""
import logging
import datetime
import re
from enum import IntFlag, auto, unique
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from scipy.stats import bootstrap
from iblutil.util import Bunch
from one.api import ONE
from one.alf.io import AlfBunch
from one.alf.exceptions import ALFObjectNotFound
import psychofit as psy
_logger = logging.getLogger('ibllib')
TRIALS_KEYS = ['contrastLeft',
'contrastRight',
'feedbackType',
'probabilityLeft',
'choice',
'response_times',
'stimOn_times']
"""list of str: The required keys in the trials object for computing training status."""
[docs]
@unique
class TrainingStatus(IntFlag):
"""Standard IBL training criteria.
Enumeration allows for comparisons between training status.
Examples
--------
>>> status = 'ready4delay'
... assert TrainingStatus[status.upper()] is TrainingStatus.READY4DELAY
... assert TrainingStatus[status.upper()] not in TrainingStatus.FAILED, 'Subject failed training'
... assert TrainingStatus[status.upper()] >= TrainingStatus.TRAINED, 'Subject untrained'
... assert TrainingStatus[status.upper()] > TrainingStatus.IN_TRAINING, 'Subject untrained'
... assert TrainingStatus[status.upper()] in ~TrainingStatus.FAILED, 'Subject untrained'
... assert TrainingStatus[status.upper()] in TrainingStatus.TRAINED ^ TrainingStatus.READY
Get the next training status
>>> next(member for member in sorted(TrainingStatus) if member > TrainingStatus[status.upper()])
<TrainingStatus.READY4RECORDING: 128>
Notes
-----
- ~TrainingStatus.TRAINED means any status but trained 1a or trained 1b.
- A subject may acheive both TRAINED_1A and TRAINED_1B within a single session, therefore it
is possible to have skipped the TRAINED_1A session status.
"""
UNTRAINABLE = auto()
UNBIASABLE = auto()
IN_TRAINING = auto()
TRAINED_1A = auto()
TRAINED_1B = auto()
READY4EPHYSRIG = auto()
READY4DELAY = auto()
READY4RECORDING = auto()
# Compound training statuses for convenience
FAILED = UNTRAINABLE | UNBIASABLE
READY = READY4EPHYSRIG | READY4DELAY | READY4RECORDING
TRAINED = TRAINED_1A | TRAINED_1B
[docs]
def get_lab_training_status(lab, date=None, details=True, one=None):
"""
Computes the training status of all alive and water restricted subjects in a specified lab.
The response are printed to std out.
Parameters
----------
lab : str
Lab name (must match the name registered on Alyx).
date : str
The ISO date from which to compute training status. If not specified will compute from the
latest date with available data. Format should be 'YYYY-MM-DD'.
details : bool
Whether to display all information about training status computation e.g. performance,
number of trials, psychometric fit parameters.
one : one.api.OneAlyx
An instance of ONE.
"""
one = one or ONE()
subj_lab = one.alyx.rest('subjects', 'list', lab=lab, alive=True, water_restricted=True)
subjects = [subj['nickname'] for subj in subj_lab]
for subj in subjects:
get_subject_training_status(subj, date=date, details=details, one=one)
[docs]
def get_subject_training_status(subj, date=None, details=True, one=None):
"""
Computes the training status of specified subject and prints results to std out.
Parameters
----------
subj : str
Subject nickname (must match the name registered on Alyx).
date : str
The ISO date from which to compute training status. If not specified will compute from the
latest date with available data. Format should be 'YYYY-MM-DD'.
details : bool
Whether to display all information about training status computation e.g. performance,
number of trials, psychometric fit parameters.
one : one.api.OneAlyx
An instance of ONE.
"""
one = one or ONE()
trials, task_protocol, ephys_sess, n_delay = get_sessions(subj, date=date, one=one)
if not trials:
return
sess_dates = list(trials.keys())
status, info = get_training_status(trials, task_protocol, ephys_sess, n_delay)
if details:
if np.any(info.get('psych')):
display_status(subj, sess_dates, status, perf_easy=info.perf_easy,
n_trials=info.n_trials, psych=info.psych, rt=info.rt)
elif np.any(info.get('psych_20')):
display_status(subj, sess_dates, status, perf_easy=info.perf_easy,
n_trials=info.n_trials, psych_20=info.psych_20, psych_80=info.psych_80,
rt=info.rt)
else:
display_status(subj, sess_dates, status)
[docs]
def get_sessions(subj, date=None, one=None):
"""
Download and load in training data for a specified subject. If a date is given it will load
data from the three (or as many as are available) previous sessions up to the specified date.
If not it will load data from the last three training sessions that have data available.
Parameters
----------
subj : str
Subject nickname (must match the name registered on Alyx).
date : str
The ISO date from which to compute training status. If not specified will compute from the
latest date with available data. Format should be 'YYYY-MM-DD'.
one : one.api.OneAlyx
An instance of ONE.
Returns
-------
iblutil.util.Bunch
Dictionary of trials objects where each key is the ISO session date string.
list of str
List of the task protocol used for each of the sessions.
list of str
List of ISO date strings where training was conducted on ephys rig. Empty list if all
sessions on training rig.
n_delay : int
Number of sessions on ephys rig that had delay prior to starting session > 15min.
Returns 0 if no sessions detected.
"""
one = one or ONE()
if date is None:
# compute from yesterday
specified_date = (datetime.date.today() - datetime.timedelta(days=1))
latest_sess = specified_date.strftime("%Y-%m-%d")
latest_minus_week = (datetime.date.today() -
datetime.timedelta(days=8)).strftime("%Y-%m-%d")
else:
# compute from the date specified
specified_date = datetime.datetime.strptime(date, '%Y-%m-%d')
latest_minus_week = (specified_date - datetime.timedelta(days=7)).strftime("%Y-%m-%d")
latest_sess = date
sessions = one.alyx.rest('sessions', 'list', subject=subj, date_range=[latest_minus_week,
latest_sess], dataset_types='trials.goCueTrigger_times')
# If not enough sessions in the last week, then just fetch them all
if len(sessions) < 3:
specified_date_plus = (specified_date + datetime.timedelta(days=1)).strftime("%Y-%m-%d")
django_query = 'start_time__lte,' + specified_date_plus
sessions = one.alyx.rest('sessions', 'list', subject=subj,
dataset_types='trials.goCueTrigger_times', django=django_query)
# If still 0 sessions then return with warning
if len(sessions) == 0:
_logger.warning(f"No training sessions detected for {subj}")
return [None] * 4
trials = Bunch()
task_protocol = []
sess_dates = []
if len(sessions) < 3:
for n, _ in enumerate(sessions):
try:
trials_ = one.load_object(sessions[n]['url'].split('/')[-1], 'trials')
except ALFObjectNotFound:
trials_ = None
if trials_:
task_protocol.append(re.search('tasks_(.*)Choice',
sessions[n]['task_protocol']).group(1))
sess_dates.append(sessions[n]['start_time'][:10])
trials[sessions[n]['start_time'][:10]] = trials_
else:
n = 0
while len(trials) < 3:
print(sessions[n]['url'].split('/')[-1])
try:
trials_ = one.load_object(sessions[n]['url'].split('/')[-1], 'trials')
except ALFObjectNotFound:
trials_ = None
if trials_:
task_protocol.append(re.search('tasks_(.*)Choice',
sessions[n]['task_protocol']).group(1))
sess_dates.append(sessions[n]['start_time'][:10])
trials[sessions[n]['start_time'][:10]] = trials_
n += 1
if not np.any(np.array(task_protocol) == 'training'):
ephys_sess = one.alyx.rest('sessions', 'list', subject=subj,
date_range=[sess_dates[-1], sess_dates[0]],
django='json__PYBPOD_BOARD__icontains,ephys')
if len(ephys_sess) > 0:
ephys_sess_dates = [sess['start_time'][:10] for sess in ephys_sess]
n_delay = len(one.alyx.rest('sessions', 'list', subject=subj,
date_range=[sess_dates[-1], sess_dates[0]],
django='json__SESSION_START_DELAY_SEC__gte,900'))
else:
ephys_sess_dates = []
n_delay = 0
else:
ephys_sess_dates = []
n_delay = 0
return trials, task_protocol, ephys_sess_dates, n_delay
[docs]
def get_training_status(trials, task_protocol, ephys_sess_dates, n_delay):
"""
Compute training status of a subject from consecutive training datasets.
For IBL, training status is calculated using trials from the last three consecutive sessions.
Parameters
----------
trials : dict of str
Dictionary of trials objects where each key is the ISO session date string.
task_protocol : list of str
Task protocol used for each training session in `trials`, can be 'training', 'biased' or
'ephys'.
ephys_sess_dates : list of str
List of ISO date strings where training was conducted on ephys rig. Empty list if all
sessions on training rig.
n_delay : int
Number of sessions on ephys rig that had delay prior to starting session > 15min.
Returns 0 if no sessions detected.
Returns
-------
str
Training status of the subject.
iblutil.util.Bunch
Bunch containing performance metrics that decide training status i.e. performance on easy
trials, number of trials, psychometric fit parameters, reaction time.
"""
info = Bunch()
trials_all = concatenate_trials(trials)
# Case when all sessions are trainingChoiceWorld
if np.all(np.array(task_protocol) == 'training'):
signed_contrast = get_signed_contrast(trials_all)
(info.perf_easy, info.n_trials,
info.psych, info.rt) = compute_training_info(trials, trials_all)
if not np.any(signed_contrast == 0):
status = 'in training'
else:
if criterion_1b(info.psych, info.n_trials, info.perf_easy, info.rt):
status = 'trained 1b'
elif criterion_1a(info.psych, info.n_trials, info.perf_easy):
status = 'trained 1a'
else:
status = 'in training'
return status, info
# Case when there are < 3 biasedChoiceWorld sessions after reaching trained_1b criterion
if ~np.all(np.array(task_protocol) == 'training') and \
np.any(np.array(task_protocol) == 'training'):
status = 'trained 1b'
(info.perf_easy, info.n_trials,
info.psych, info.rt) = compute_training_info(trials, trials_all)
return status, info
# Case when there is biasedChoiceWorld or ephysChoiceWorld in last three sessions
if not np.any(np.array(task_protocol) == 'training'):
(info.perf_easy, info.n_trials,
info.psych_20, info.psych_80,
info.rt) = compute_bias_info(trials, trials_all)
# We are still on training rig and so all sessions should be biased
if len(ephys_sess_dates) == 0:
assert np.all(np.array(task_protocol) == 'biased')
if criterion_ephys(info.psych_20, info.psych_80, info.n_trials, info.perf_easy,
info.rt):
status = 'ready4ephysrig'
else:
status = 'trained 1b'
elif len(ephys_sess_dates) < 3:
assert all(date in trials for date in ephys_sess_dates)
perf_ephys_easy = np.array([compute_performance_easy(trials[k]) for k in
ephys_sess_dates])
n_ephys_trials = np.array([compute_n_trials(trials[k]) for k in ephys_sess_dates])
if criterion_delay(n_ephys_trials, perf_ephys_easy):
status = 'ready4delay'
else:
status = 'ready4ephysrig'
elif len(ephys_sess_dates) >= 3:
if n_delay > 0 and \
criterion_ephys(info.psych_20, info.psych_80, info.n_trials, info.perf_easy,
info.rt):
status = 'ready4recording'
elif criterion_delay(info.n_trials, info.perf_easy):
status = 'ready4delay'
else:
status = 'ready4ephysrig'
return status, info
[docs]
def display_status(subj, sess_dates, status, perf_easy=None, n_trials=None, psych=None,
psych_20=None, psych_80=None, rt=None):
"""
Display training status of subject to terminal.
Parameters
----------
subj : str
Subject nickname (must match the name registered on Alyx).
sess_dates : list of str
ISO date strings of training sessions used to determine training status.
status : str
Training status of subject.
perf_easy : numpy.array
Proportion of correct high contrast trials for each training session.
n_trials : numpy.array
Total number of trials for each training session.
psych : numpy.array
Psychometric parameters fit to data from all training sessions - bias, threshold, lapse
high, lapse low.
psych_20 : numpy.array
The fit psychometric parameters for the blocks where probability of a left stimulus is 0.2.
psych_80 : numpy.array
The fit psychometric parameters for the blocks where probability of a left stimulus is 0.8.
rt : float
The median response time for zero contrast trials across all training sessions. NaN
indicates no zero contrast stimuli in training sessions.
"""
if perf_easy is None:
print(f"\n{subj} : {status} \nSession dates=[{sess_dates[0]}, {sess_dates[1]}, "
f"{sess_dates[2]}]")
elif psych_20 is None:
print(f"\n{subj} : {status} \nSession dates={[x for x in sess_dates]}, "
f"Perf easy={[np.around(pe,2) for pe in perf_easy]}, "
f"N trials={[nt for nt in n_trials]} "
f"\nPsych fit over last 3 sessions: "
f"bias={np.around(psych[0],2)}, thres={np.around(psych[1],2)}, "
f"lapse_low={np.around(psych[2],2)}, lapse_high={np.around(psych[3],2)} "
f"\nMedian reaction time at 0 contrast over last 3 sessions = "
f"{np.around(rt,2)}")
else:
print(f"\n{subj} : {status} \nSession dates={[x for x in sess_dates]}, "
f"Perf easy={[np.around(pe,2) for pe in perf_easy]}, "
f"N trials={[nt for nt in n_trials]} "
f"\nPsych fit over last 3 sessions (20): "
f"bias={np.around(psych_20[0],2)}, thres={np.around(psych_20[1],2)}, "
f"lapse_low={np.around(psych_20[2],2)}, lapse_high={np.around(psych_20[3],2)} "
f"\nPsych fit over last 3 sessions (80): bias={np.around(psych_80[0],2)}, "
f"thres={np.around(psych_80[1],2)}, lapse_low={np.around(psych_80[2],2)}, "
f"lapse_high={np.around(psych_80[3],2)} "
f"\nMedian reaction time at 0 contrast over last 3 sessions = "
f"{np.around(rt, 2)}")
[docs]
def concatenate_trials(trials):
"""
Concatenate trials from different training sessions.
Parameters
----------
trials : dict of str
Dictionary of trials objects where each key is the ISO session date string.
Returns
-------
one.alf.io.AlfBunch
Trials object with data concatenated over three training sessions.
"""
trials_all = AlfBunch()
for k in TRIALS_KEYS:
trials_all[k] = np.concatenate(list(trials[kk][k] for kk in trials.keys()))
return trials_all
[docs]
def compute_training_info(trials, trials_all):
"""
Compute all relevant performance metrics for when subject is on trainingChoiceWorld.
Parameters
----------
trials : dict of str
Dictionary of trials objects where each key is the ISO session date string.
trials_all : one.alf.io.AlfBunch
Trials object with data concatenated over three training sessions.
Returns
-------
numpy.array
Proportion of correct high contrast trials for each session.
numpy.array
Total number of trials for each training session.
numpy.array
Array of psychometric parameters fit to `all_trials` - bias, threshold, lapse high,
lapse low.
float
The median response time for all zero-contrast trials across all sessions. Returns NaN if
no trials zero-contrast trials).
"""
signed_contrast = get_signed_contrast(trials_all)
perf_easy = np.array([compute_performance_easy(trials[k]) for k in trials.keys()])
n_trials = np.array([compute_n_trials(trials[k]) for k in trials.keys()])
psych = compute_psychometric(trials_all, signed_contrast=signed_contrast)
rt = compute_median_reaction_time(trials_all, contrast=0, signed_contrast=signed_contrast)
return perf_easy, n_trials, psych, rt
[docs]
def compute_bias_info(trials, trials_all):
"""
Compute all relevant performance metrics for when subject is on biasedChoiceWorld
:param trials: dict containing trials objects from three consecutive training sessions,
keys are session dates
:type trials: Bunch
:param trials_all: trials object with data concatenated over three training sessions
:type trials_all: Bunch
:returns:
- perf_easy - performance of easy trials for each session
- n_trials - number of trials in each session
- psych_20 - parameters for psychometric curve fit to trials in 20 block over all sessions
- psych_80 - parameters for psychometric curve fit to trials in 80 block over all sessions
- rt - median reaction time for zero contrast stimuli over all sessions
"""
signed_contrast = get_signed_contrast(trials_all)
perf_easy = np.array([compute_performance_easy(trials[k]) for k in trials.keys()])
n_trials = np.array([compute_n_trials(trials[k]) for k in trials.keys()])
psych_20 = compute_psychometric(trials_all, signed_contrast=signed_contrast, block=0.2)
psych_80 = compute_psychometric(trials_all, signed_contrast=signed_contrast, block=0.8)
rt = compute_median_reaction_time(trials_all, contrast=0, signed_contrast=signed_contrast)
return perf_easy, n_trials, psych_20, psych_80, rt
[docs]
def get_signed_contrast(trials):
"""
Compute signed contrast from trials object
:param trials: trials object that must contain contrastLeft and contrastRight keys
:type trials: dict
returns: array of signed contrasts in percent, where -ve values are on the left
"""
# Replace NaNs with zeros, stack and take the difference
contrast = np.nan_to_num(np.c_[trials['contrastLeft'], trials['contrastRight']])
return np.diff(contrast).flatten() * 100
[docs]
def compute_n_trials(trials):
"""
Compute number of trials in trials object
:param trials: trials object
:type trials: dict
returns: int containing number of trials in session
"""
return trials['choice'].shape[0]
[docs]
def compute_psychometric(trials, signed_contrast=None, block=None, plotting=False, compute_ci=False, alpha=.032):
"""
Compute psychometric fit parameters for trials object.
Parameters
----------
trials : one.alf.io.AlfBunch
An ALF trials object containing the keys {'probabilityLeft', 'contrastLeft',
'contrastRight', 'feedbackType', 'choice', 'response_times', 'stimOn_times'}.
signed_contrast : numpy.array
An array of signed contrasts in percent the length of trials, where left contrasts are -ve.
If None, these are computed from the trials object.
block : float
The block type to compute. If None, all trials are included, otherwise only trials where
probabilityLeft matches this value are included. For biasedChoiceWorld, the
probabilityLeft set is {0.5, 0.2, 0.8}.
plotting : bool
Which set of psychofit model parameters to use (see notes).
compute_ci : bool
If true, computes and returns the confidence intervals for response at each contrast.
alpha : float, default=0.032
Significance level for confidence interval. Must be in (0, 1). If `compute_ci` is false,
this value is ignored.
Returns
-------
numpy.array
Array of psychometric fit parameters - bias, threshold, lapse high, lapse low.
(tuple of numpy.array)
If `compute_ci` is true, a tuple of
See Also
--------
statsmodels.stats.proportion.proportion_confint - The function used to compute confidence
interval.
psychofit.mle_fit_psycho - The function used to fit the psychometric parameters.
Notes
-----
The psychofit starting parameters and model constraints used for the fit when computing the
training status (e.g. trained_1a, etc.) are sub-optimal and can produce a poor fit. To keep
the precise criteria the same for all subjects, these parameters have not changed. To produce a
better fit for plotting purposes, or to calculate the training status in a manner inconsistent
with the IBL training pipeline, use plotting=True.
"""
if signed_contrast is None:
signed_contrast = get_signed_contrast(trials)
if block is None:
block_idx = np.full(trials.probabilityLeft.shape, True, dtype=bool)
else:
block_idx = trials.probabilityLeft == block
if not np.any(block_idx):
return np.nan * np.zeros(4)
prob_choose_right, contrasts, n_contrasts = compute_performance(
trials, signed_contrast=signed_contrast, block=block, prob_right=True)
if plotting:
# These starting parameters and constraints tend to produce a better fit, and are therefore
# used for plotting.
psych, _ = psy.mle_fit_psycho(
np.vstack([contrasts, n_contrasts, prob_choose_right]),
P_model='erf_psycho_2gammas',
parstart=np.array([0., 40., 0.1, 0.1]),
parmin=np.array([-50., 10., 0., 0.]),
parmax=np.array([50., 50., 0.2, 0.2]),
nfits=10)
else:
# These starting parameters and constraints are not ideal but are still used for computing
# the training status for consistency.
psych, _ = psy.mle_fit_psycho(
np.vstack([contrasts, n_contrasts, prob_choose_right]),
P_model='erf_psycho_2gammas',
parstart=np.array([np.mean(contrasts), 20., 0.05, 0.05]),
parmin=np.array([np.min(contrasts), 0., 0., 0.]),
parmax=np.array([np.max(contrasts), 100., 1, 1]))
if compute_ci:
import statsmodels.stats.proportion as smp # noqa
# choice == -1 means contrast on right hand side
n_right = np.vectorize(lambda x: np.sum(trials['choice'][(x == signed_contrast) & block_idx] == -1))(contrasts)
ci = smp.proportion_confint(n_right, n_contrasts, alpha=alpha, method='normal') - prob_choose_right
return psych, ci
else:
return psych
[docs]
def compute_reaction_time(trials, stim_on_type='stimOn_times', stim_off_type='response_times', signed_contrast=None, block=None,
compute_ci=False, alpha=0.32):
"""
Compute median response time for all contrasts.
Parameters
----------
trials : one.alf.io.AlfBunch
An ALF trials object containing the keys {'probabilityLeft', 'contrastLeft',
'contrastRight', 'feedbackType', 'choice', 'response_times', 'stimOn_times'}.
stim_on_type : str, default='stimOn_times'
The trials key to use when calculating the response times. The difference between this and
`stim_off_type` is used (see notes).
stim_off_type : str, default='response_times'
The trials key to use when calculating the response times. The difference between this and
`stim_on_type` is used (see notes).
signed_contrast : numpy.array
An array of signed contrasts in percent the length of trials, where left contrasts are -ve.
If None, these are computed from the trials object.
block : float
The block type to compute. If None, all trials are included, otherwise only trials where
probabilityLeft matches this value are included. For biasedChoiceWorld, the
probabilityLeft set is {0.5, 0.2, 0.8}.
compute_ci : bool
If true, computes and returns the confidence intervals for response time at each contrast.
alpha : float, default=0.32
Significance level for confidence interval. Must be in (0, 1). If `compute_ci` is false,
this value is ignored.
Returns
-------
numpy.array
The median response times for each unique signed contrast.
numpy.array
The set of unique signed contrasts.
numpy.array
The number of trials for each unique signed contrast.
(numpy.array)
If `compute_ci` is true, an array of confidence intervals is return in the shape (n_trials,
2).
Notes
-----
- The response/reaction time by default is the time between stim on and response, i.e. the
entire open-loop trial duration. One could use 'stimOn_times' and 'firstMovement_times' to
get the true reaction time, or 'firstMovement_times' and 'response_times' to get the true
response times, or calculate the last movement onset times and calculate the true movement
times. See module examples for how to calculate this.
See Also
--------
scipy.stats.bootstrap - the function used to compute the confidence interval.
"""
if signed_contrast is None:
signed_contrast = get_signed_contrast(trials)
if block is None:
block_idx = np.full(trials.probabilityLeft.shape, True, dtype=bool)
else:
block_idx = trials.probabilityLeft == block
contrasts, n_contrasts = np.unique(signed_contrast[block_idx], return_counts=True)
reaction_time = np.vectorize(
lambda x: np.nanmedian((trials[stim_off_type] - trials[stim_on_type])[(x == signed_contrast) & block_idx]),
otypes=[float]
)(contrasts)
if compute_ci:
ci = np.full((contrasts.size, 2), np.nan)
for i, x in enumerate(contrasts):
data = (trials[stim_off_type] - trials[stim_on_type])[(x == signed_contrast) & block_idx]
bt = bootstrap((data,), np.nanmedian, confidence_level=1 - alpha)
ci[i, 0] = bt.confidence_interval.low
ci[i, 1] = bt.confidence_interval.high
return reaction_time, contrasts, n_contrasts, ci
else:
return reaction_time, contrasts, n_contrasts,
[docs]
def criterion_1a(psych, n_trials, perf_easy):
"""
Returns bool indicating whether criteria for status 'trained_1a' are met.
Criteria
--------
- Bias is less than 16
- Threshold is less than 19
- Lapse rate on both sides is less than 0.2
- The total number of trials is greater than 200 for each session
- Performance on easy contrasts > 80% for all sessions
Parameters
----------
psych : numpy.array
The fit psychometric parameters three consecutive sessions. Parameters are bias, threshold,
lapse high, lapse low.
n_trials : numpy.array of int
The number for trials for each session.
perf_easy : numpy.array of float
The proportion of correct high contrast trials for each session.
Returns
-------
bool
True if the criteria are met for 'trained_1a'.
Notes
-----
The parameter thresholds chosen here were originally determined by averaging the parameter fits
for a number of sessions determined to be of 'good' performance by an experimenter.
"""
criterion = (abs(psych[0]) < 16 and psych[1] < 19 and psych[2] < 0.2 and psych[3] < 0.2 and
np.all(n_trials > 200) and np.all(perf_easy > 0.8))
return criterion
[docs]
def criterion_1b(psych, n_trials, perf_easy, rt):
"""
Returns bool indicating whether criteria for trained_1b are met.
Criteria
--------
- Bias is less than 10
- Threshold is less than 20 (see notes)
- Lapse rate on both sides is less than 0.1
- The total number of trials is greater than 400 for each session
- Performance on easy contrasts > 90% for all sessions
- The median response time across all zero contrast trials is less than 2 seconds
Parameters
----------
psych : numpy.array
The fit psychometric parameters three consecutive sessions. Parameters are bias, threshold,
lapse high, lapse low.
n_trials : numpy.array of int
The number for trials for each session.
perf_easy : numpy.array of float
The proportion of correct high contrast trials for each session.
rt : float
The median response time for zero contrast trials.
Returns
-------
bool
True if the criteria are met for 'trained_1b'.
Notes
-----
The parameter thresholds chosen here were originally chosen to be slightly stricter than 1a,
however it was decided to use round numbers so that readers would not assume a level of
precision that isn't there (remember, these parameters were not chosen with any rigor). This
regrettably means that the maximum threshold fit for 1b is greater than for 1a, meaning the
slope of the psychometric curve may be slightly less steep than 1a.
"""
criterion = (abs(psych[0]) < 10 and psych[1] < 20 and psych[2] < 0.1 and psych[3] < 0.1 and
np.all(n_trials > 400) and np.all(perf_easy > 0.9) and rt < 2)
return criterion
[docs]
def criterion_ephys(psych_20, psych_80, n_trials, perf_easy, rt):
"""
Returns bool indicating whether criteria for ready4ephysrig or ready4recording are met.
NB: The difference between these two is whether the sessions were acquired ot a recording rig
with a delay before the first trial. Neither of these two things are tested here.
Criteria
--------
- Lapse on both sides < 0.1 for both bias blocks
- Bias shift between blocks > 5
- Total number of trials > 400 for all sessions
- Performance on easy contrasts > 90% for all sessions
- Median response time for zero contrast stimuli < 2 seconds
Parameters
----------
psych_20 : numpy.array
The fit psychometric parameters for the blocks where probability of a left stimulus is 0.2.
Parameters are bias, threshold, lapse high, lapse low.
psych_80 : numpy.array
The fit psychometric parameters for the blocks where probability of a left stimulus is 0.8.
Parameters are bias, threshold, lapse high, lapse low.
n_trials : numpy.array
The number of trials for each session (typically three consecutive sessions).
perf_easy : numpy.array
The proportion of correct high contrast trials for each session (typically three
consecutive sessions).
rt : float
The median response time for zero contrast trials.
Returns
-------
bool
True if subject passes the ready4ephysrig or ready4recording criteria.
"""
criterion = (np.all(np.r_[psych_20[2:4], psych_80[2:4]] < 0.1) and # lapse
psych_80[0] - psych_20[0] > 5 and np.all(n_trials > 400) and # bias shift and n trials
np.all(perf_easy > 0.9) and rt < 2) # overall performance and response times
return criterion
[docs]
def criterion_delay(n_trials, perf_easy):
"""
Returns bool indicating whether criteria for 'ready4delay' is met.
Criteria
--------
- Total number of trials for any of the sessions is greater than 400
- Performance on easy contrasts is greater than 90% for any of the sessions
Parameters
----------
n_trials : numpy.array of int
The number of trials for each session (typically three consecutive sessions).
perf_easy : numpy.array
The proportion of correct high contrast trials for each session (typically three
consecutive sessions).
Returns
-------
bool
True if subject passes the 'ready4delay' criteria.
"""
criterion = np.any(n_trials > 400) and np.any(perf_easy > 0.9)
return criterion
[docs]
def plot_psychometric(trials, ax=None, title=None, plot_ci=False, ci_alpha=0.032, **kwargs):
"""
Function to plot psychometric curve plots a la datajoint webpage.
Parameters
----------
trials : one.alf.io.AlfBunch
An ALF trials object containing the keys {'probabilityLeft', 'contrastLeft',
'contrastRight', 'feedbackType', 'choice', 'response_times', 'stimOn_times'}.
ax : matplotlib.pyplot.Axes
An axis object to plot onto.
title : str
An optional plot title.
plot_ci : bool
If true, computes and plots the confidence intervals for response at each contrast.
ci_alpha : float, default=0.032
Significance level for confidence interval. Must be in (0, 1). If `plot_ci` is false,
this value is ignored.
**kwargs
If `ax` is None, these arguments are passed to matplotlib.pyplot.subplots.
Returns
-------
matplotlib.pyplot.Figure
The figure handle containing the plot.
matplotlib.pyplot.Axes
The plotted axes.
See Also
--------
statsmodels.stats.proportion.proportion_confint - The function used to compute confidence
interval.
psychofit.mle_fit_psycho - The function used to fit the psychometric parameters.
psychofit.erf_psycho_2gammas - The function used to transform contrast to response probability
using the fit parameters.
"""
signed_contrast = get_signed_contrast(trials)
contrasts_fit = np.arange(-100, 100)
prob_right_50, contrasts_50, _ = compute_performance(trials, signed_contrast=signed_contrast, block=0.5, prob_right=True)
out_50 = compute_psychometric(trials, signed_contrast=signed_contrast, block=0.5, plotting=True,
compute_ci=plot_ci, alpha=ci_alpha)
pars_50 = out_50[0] if plot_ci else out_50
prob_right_fit_50 = psy.erf_psycho_2gammas(pars_50, contrasts_fit)
prob_right_20, contrasts_20, _ = compute_performance(trials, signed_contrast=signed_contrast, block=0.2, prob_right=True)
out_20 = compute_psychometric(trials, signed_contrast=signed_contrast, block=0.2, plotting=True,
compute_ci=plot_ci, alpha=ci_alpha)
pars_20 = out_20[0] if plot_ci else out_20
prob_right_fit_20 = psy.erf_psycho_2gammas(pars_20, contrasts_fit)
prob_right_80, contrasts_80, _ = compute_performance(trials, signed_contrast=signed_contrast, block=0.8, prob_right=True)
out_80 = compute_psychometric(trials, signed_contrast=signed_contrast, block=0.8, plotting=True,
compute_ci=plot_ci, alpha=ci_alpha)
pars_80 = out_80[0] if plot_ci else out_80
prob_right_fit_80 = psy.erf_psycho_2gammas(pars_80, contrasts_fit)
cmap = sns.diverging_palette(20, 220, n=3, center='dark')
if not ax:
fig, ax = plt.subplots(**kwargs)
else:
fig = plt.gcf()
fit_50 = ax.plot(contrasts_fit, prob_right_fit_50, color=cmap[1])
data_50 = ax.scatter(contrasts_50, prob_right_50, color=cmap[1])
fit_20 = ax.plot(contrasts_fit, prob_right_fit_20, color=cmap[0])
data_20 = ax.scatter(contrasts_20, prob_right_20, color=cmap[0])
fit_80 = ax.plot(contrasts_fit, prob_right_fit_80, color=cmap[2])
data_80 = ax.scatter(contrasts_80, prob_right_80, color=cmap[2])
if plot_ci:
errbar_50 = np.c_[np.abs(out_50[1][0]), np.abs(out_50[1][1])].T
errbar_20 = np.c_[np.abs(out_20[1][0]), np.abs(out_20[1][1])].T
errbar_80 = np.c_[np.abs(out_80[1][0]), np.abs(out_80[1][1])].T
ax.errorbar(contrasts_50, prob_right_50, yerr=errbar_50, ecolor=cmap[1], fmt='none', capsize=5, alpha=0.4)
ax.errorbar(contrasts_20, prob_right_20, yerr=errbar_20, ecolor=cmap[0], fmt='none', capsize=5, alpha=0.4)
ax.errorbar(contrasts_80, prob_right_80, yerr=errbar_80, ecolor=cmap[2], fmt='none', capsize=5, alpha=0.4)
ax.legend([fit_50[0], data_50, fit_20[0], data_20, fit_80[0], data_80],
['p_left=0.5 fit', 'p_left=0.5 data', 'p_left=0.2 fit', 'p_left=0.2 data', 'p_left=0.8 fit', 'p_left=0.8 data'],
loc='upper left')
ax.set_ylim(-0.05, 1.05)
ax.set_ylabel('Probability choosing right')
ax.set_xlabel('Contrasts')
if title:
ax.set_title(title)
return fig, ax
[docs]
def plot_reaction_time(trials, ax=None, title=None, plot_ci=False, ci_alpha=0.32, **kwargs):
"""
Function to plot reaction time against contrast a la datajoint webpage.
The reaction times are plotted individually for the following three blocks: {0.5, 0.2, 0.8}.
Parameters
----------
trials : one.alf.io.AlfBunch
An ALF trials object containing the keys {'probabilityLeft', 'contrastLeft',
'contrastRight', 'feedbackType', 'choice', 'response_times', 'stimOn_times'}.
ax : matplotlib.pyplot.Axes
An axis object to plot onto.
title : str
An optional plot title.
plot_ci : bool
If true, computes and plots the confidence intervals for response at each contrast.
ci_alpha : float, default=0.32
Significance level for confidence interval. Must be in (0, 1). If `plot_ci` is false,
this value is ignored.
**kwargs
If `ax` is None, these arguments are passed to matplotlib.pyplot.subplots.
Returns
-------
matplotlib.pyplot.Figure
The figure handle containing the plot.
matplotlib.pyplot.Axes
The plotted axes.
See Also
--------
scipy.stats.bootstrap - the function used to compute the confidence interval.
"""
signed_contrast = get_signed_contrast(trials)
out_50 = compute_reaction_time(trials, signed_contrast=signed_contrast, block=0.5, compute_ci=plot_ci, alpha=ci_alpha)
out_20 = compute_reaction_time(trials, signed_contrast=signed_contrast, block=0.2, compute_ci=plot_ci, alpha=ci_alpha)
out_80 = compute_reaction_time(trials, signed_contrast=signed_contrast, block=0.8, compute_ci=plot_ci, alpha=ci_alpha)
cmap = sns.diverging_palette(20, 220, n=3, center='dark')
if not ax:
fig, ax = plt.subplots(**kwargs)
else:
fig = plt.gcf()
data_50 = ax.plot(out_50[1], out_50[0], '-o', color=cmap[1])
data_20 = ax.plot(out_20[1], out_20[0], '-o', color=cmap[0])
data_80 = ax.plot(out_80[1], out_80[0], '-o', color=cmap[2])
if plot_ci:
errbar_50 = np.c_[out_50[0] - out_50[3][:, 0], out_50[3][:, 1] - out_50[0]].T
errbar_20 = np.c_[out_20[0] - out_20[3][:, 0], out_20[3][:, 1] - out_20[0]].T
errbar_80 = np.c_[out_80[0] - out_80[3][:, 0], out_80[3][:, 1] - out_80[0]].T
ax.errorbar(out_50[1], out_50[0], yerr=errbar_50, ecolor=cmap[1], fmt='none', capsize=5, alpha=0.4)
ax.errorbar(out_20[1], out_20[0], yerr=errbar_20, ecolor=cmap[0], fmt='none', capsize=5, alpha=0.4)
ax.errorbar(out_80[1], out_80[0], yerr=errbar_80, ecolor=cmap[2], fmt='none', capsize=5, alpha=0.4)
ax.legend([data_50[0], data_20[0], data_80[0]],
['p_left=0.5 data', 'p_left=0.2 data', 'p_left=0.8 data'],
loc='upper left')
ax.set_ylabel('Reaction time (s)')
ax.set_xlabel('Contrasts')
if title:
ax.set_title(title)
return fig, ax
[docs]
def plot_reaction_time_over_trials(trials, stim_on_type='stimOn_times', ax=None, title=None, **kwargs):
"""
Function to plot reaction time with trial number a la datajoint webpage.
Parameters
----------
trials : one.alf.io.AlfBunch
An ALF trials object containing the keys {'probabilityLeft', 'contrastLeft',
'contrastRight', 'feedbackType', 'choice', 'response_times', 'stimOn_times'}.
stim_on_type : str, default='stimOn_times'
The trials key to use when calculating the response times. The difference between this and
'feedback_times' is used (see notes for `compute_median_reaction_time`).
ax : matplotlib.pyplot.Axes
An axis object to plot onto.
title : str
An optional plot title.
**kwargs
If `ax` is None, these arguments are passed to matplotlib.pyplot.subplots.
Returns
-------
matplotlib.pyplot.Figure
The figure handle containing the plot.
matplotlib.pyplot.Axes
The plotted axes.
"""
reaction_time = pd.DataFrame()
reaction_time['reaction_time'] = trials.response_times - trials[stim_on_type]
reaction_time.index = reaction_time.index + 1
reaction_time_rolled = reaction_time['reaction_time'].rolling(window=10).median()
reaction_time_rolled = reaction_time_rolled.where((pd.notnull(reaction_time_rolled)), None)
reaction_time = reaction_time.where((pd.notnull(reaction_time)), None)
if not ax:
fig, ax = plt.subplots(**kwargs)
else:
fig = plt.gcf()
ax.scatter(np.arange(len(reaction_time.values)), reaction_time.values, s=16, color='darkgray')
ax.plot(np.arange(len(reaction_time_rolled.values)), reaction_time_rolled.values, color='k', linewidth=2)
ax.set_yscale('log')
ax.set_ylim(0.1, 100)
ax.yaxis.set_major_formatter(matplotlib.ticker.ScalarFormatter())
ax.set_ylabel('Reaction time (s)')
ax.set_xlabel('Trial number')
if title:
ax.set_title(title)
return fig, ax
[docs]
def query_criterion(subject, status, from_status=None, one=None, validate=True):
"""Get the session for which a given training criterion was met.
Parameters
----------
subject : str
The subject name.
status : str
The training status to query for.
from_status : str, optional
Count number of sessions and days from reaching `from_status` to `status`.
one : one.api.OneAlyx, optional
An instance of ONE.
validate : bool
If true, check if status in TrainingStatus enumeration. Set to false for non-standard
training pipelines.
Returns
-------
str
The eID of the first session where this training status was reached.
int
The number of sessions it took to reach `status` (optionally from reaching `from_status`).
int
The number of days it tool to reach `status` (optionally from reaching `from_status`).
"""
if validate:
status = status.lower().replace(' ', '_')
try:
status = TrainingStatus[status.upper().replace(' ', '_')].name.lower()
except KeyError as ex:
raise ValueError(
f'Unknown status "{status}". For non-standard training protocols set validate=False'
) from ex
one = one or ONE()
subject_json = one.alyx.rest('subjects', 'read', id=subject)['json']
if not (criteria := subject_json.get('trained_criteria')) or status not in criteria:
return None, None, None
to_date, eid = criteria[status]
from_date, _ = criteria.get(from_status, (None, None))
eids, det = one.search(subject=subject, date_range=[from_date, to_date], details=True)
if len(eids) == 0:
return eid, None, None
delta_date = det[0]['date'] - det[-1]['date']
return eid, len(eids), delta_date.days