"""Computing and testing IBL training status criteria.For an in-depth description of each training status, see `Appendix 2`_ of the IBL Protocol For MiceTraining... _Appendix 2: https://figshare.com/articles/preprint/A_standardized_and_reproducible_method_to_\measure_decision-making_in_mice_Appendix_2_IBL_protocol_for_mice_training/11634729Examples--------Plot the psychometric curve for a given session.>>> trials = ONE().load_object(eid, 'trials')>>> fix, ax = plot_psychometric(trials)Compute 'response times', defined as the duration of open-loop for each contrast.>>> reaction_time, contrasts, n_contrasts = compute_reaction_time(trials)Compute 'reaction times', defined as the time between go cue and first detected movement.NB: These may be negative!>>> reaction_time, contrasts, n_contrasts = compute_reaction_time(... trials, stim_on_type='goCue_times', stim_off_type='firstMovement_times')Compute 'response times', defined as the time between first detected movement and response.>>> reaction_time, contrasts, n_contrasts = compute_reaction_time(... trials, stim_on_type='firstMovement_times', stim_off_type='response_times')Compute 'movement times', defined as the time between last detected movement and response threshold.>>> import brainbox.behavior.wheel as wh>>> wheel_moves = ONE().load_object(eid, 'wheeMoves')>>> trials['lastMovement_times'] = wh.get_movement_onset(wheel_moves.intervals, trial_data.response_times)>>> reaction_time, contrasts, n_contrasts = compute_reaction_time(... trials, stim_on_type='lastMovement_times', stim_off_type='response_times')"""importloggingimportdatetimeimportrefromenumimportIntFlag,auto,uniqueimportnumpyasnpimportmatplotlibimportmatplotlib.pyplotaspltimportseabornassnsimportpandasaspdfromscipy.statsimportbootstrapfromiblutil.utilimportBunchfromone.apiimportONEfromone.alf.ioimportAlfBunchfromone.alf.exceptionsimportALFObjectNotFoundimportpsychofitaspsy_logger=logging.getLogger('ibllib')TRIALS_KEYS=['contrastLeft','contrastRight','feedbackType','probabilityLeft','choice','response_times','stimOn_times']"""list of str: The required keys in the trials object for computing training status."""
[docs]@uniqueclassTrainingStatus(IntFlag):"""Standard IBL training criteria. Enumeration allows for comparisons between training status. Examples -------- >>> status = 'ready4delay' ... assert TrainingStatus[status.upper()] is TrainingStatus.READY4DELAY ... assert TrainingStatus[status.upper()] not in TrainingStatus.FAILED, 'Subject failed training' ... assert TrainingStatus[status.upper()] >= TrainingStatus.TRAINED, 'Subject untrained' ... assert TrainingStatus[status.upper()] > TrainingStatus.IN_TRAINING, 'Subject untrained' ... assert TrainingStatus[status.upper()] in ~TrainingStatus.FAILED, 'Subject untrained' ... assert TrainingStatus[status.upper()] in TrainingStatus.TRAINED ^ TrainingStatus.READY Get the next training status >>> next(member for member in sorted(TrainingStatus) if member > TrainingStatus[status.upper()]) <TrainingStatus.READY4RECORDING: 128> Notes ----- - ~TrainingStatus.TRAINED means any status but trained 1a or trained 1b. - A subject may acheive both TRAINED_1A and TRAINED_1B within a single session, therefore it is possible to have skipped the TRAINED_1A session status. """UNTRAINABLE=auto()UNBIASABLE=auto()IN_TRAINING=auto()TRAINED_1A=auto()TRAINED_1B=auto()READY4EPHYSRIG=auto()READY4DELAY=auto()READY4RECORDING=auto()# Compound training statuses for convenienceFAILED=UNTRAINABLE|UNBIASABLEREADY=READY4EPHYSRIG|READY4DELAY|READY4RECORDINGTRAINED=TRAINED_1A|TRAINED_1B
[docs]defget_lab_training_status(lab,date=None,details=True,one=None):""" Computes the training status of all alive and water restricted subjects in a specified lab. The response are printed to std out. Parameters ---------- lab : str Lab name (must match the name registered on Alyx). date : str The ISO date from which to compute training status. If not specified will compute from the latest date with available data. Format should be 'YYYY-MM-DD'. details : bool Whether to display all information about training status computation e.g. performance, number of trials, psychometric fit parameters. one : one.api.OneAlyx An instance of ONE. """one=oneorONE()subj_lab=one.alyx.rest('subjects','list',lab=lab,alive=True,water_restricted=True)subjects=[subj['nickname']forsubjinsubj_lab]forsubjinsubjects:get_subject_training_status(subj,date=date,details=details,one=one)
[docs]defget_subject_training_status(subj,date=None,details=True,one=None):""" Computes the training status of specified subject and prints results to std out. Parameters ---------- subj : str Subject nickname (must match the name registered on Alyx). date : str The ISO date from which to compute training status. If not specified will compute from the latest date with available data. Format should be 'YYYY-MM-DD'. details : bool Whether to display all information about training status computation e.g. performance, number of trials, psychometric fit parameters. one : one.api.OneAlyx An instance of ONE. """one=oneorONE()trials,task_protocol,ephys_sess,n_delay=get_sessions(subj,date=date,one=one)ifnottrials:returnsess_dates=list(trials.keys())status,info,_=get_training_status(trials,task_protocol,ephys_sess,n_delay)ifdetails:ifnp.any(info.get('psych')):display_status(subj,sess_dates,status,perf_easy=info.perf_easy,n_trials=info.n_trials,psych=info.psych,rt=info.rt)elifnp.any(info.get('psych_20')):display_status(subj,sess_dates,status,perf_easy=info.perf_easy,n_trials=info.n_trials,psych_20=info.psych_20,psych_80=info.psych_80,rt=info.rt)else:display_status(subj,sess_dates,status)
[docs]defget_sessions(subj,date=None,one=None):""" Download and load in training data for a specified subject. If a date is given it will load data from the three (or as many as are available) previous sessions up to the specified date. If not it will load data from the last three training sessions that have data available. Parameters ---------- subj : str Subject nickname (must match the name registered on Alyx). date : str The ISO date from which to compute training status. If not specified will compute from the latest date with available data. Format should be 'YYYY-MM-DD'. one : one.api.OneAlyx An instance of ONE. Returns ------- iblutil.util.Bunch Dictionary of trials objects where each key is the ISO session date string. list of str List of the task protocol used for each of the sessions. list of str List of ISO date strings where training was conducted on ephys rig. Empty list if all sessions on training rig. n_delay : int Number of sessions on ephys rig that had delay prior to starting session > 15min. Returns 0 if no sessions detected. """one=oneorONE()ifdateisNone:# compute from yesterdayspecified_date=(datetime.date.today()-datetime.timedelta(days=1))latest_sess=specified_date.strftime("%Y-%m-%d")latest_minus_week=(datetime.date.today()-datetime.timedelta(days=8)).strftime("%Y-%m-%d")else:# compute from the date specifiedspecified_date=datetime.datetime.strptime(date,'%Y-%m-%d')latest_minus_week=(specified_date-datetime.timedelta(days=7)).strftime("%Y-%m-%d")latest_sess=datesessions=one.alyx.rest('sessions','list',subject=subj,date_range=[latest_minus_week,latest_sess],dataset_types='trials.goCueTrigger_times')# If not enough sessions in the last week, then just fetch them alliflen(sessions)<3:specified_date_plus=(specified_date+datetime.timedelta(days=1)).strftime("%Y-%m-%d")django_query='start_time__lte,'+specified_date_plussessions=one.alyx.rest('sessions','list',subject=subj,dataset_types='trials.goCueTrigger_times',django=django_query)# If still 0 sessions then return with warningiflen(sessions)==0:_logger.warning(f"No training sessions detected for {subj}")return[None]*4trials=Bunch()task_protocol=[]sess_dates=[]iflen(sessions)<3:forn,_inenumerate(sessions):try:trials_=one.load_object(sessions[n]['url'].split('/')[-1],'trials')exceptALFObjectNotFound:trials_=Noneiftrials_:task_protocol.append(re.search('tasks_(.*)Choice',sessions[n]['task_protocol']).group(1))sess_dates.append(sessions[n]['start_time'][:10])trials[sessions[n]['start_time'][:10]]=trials_else:n=0whilelen(trials)<3:print(sessions[n]['url'].split('/')[-1])try:trials_=one.load_object(sessions[n]['url'].split('/')[-1],'trials')exceptALFObjectNotFound:trials_=Noneiftrials_:task_protocol.append(re.search('tasks_(.*)Choice',sessions[n]['task_protocol']).group(1))sess_dates.append(sessions[n]['start_time'][:10])trials[sessions[n]['start_time'][:10]]=trials_n+=1ifnotnp.any(np.array(task_protocol)=='training'):ephys_sess=one.alyx.rest('sessions','list',subject=subj,date_range=[sess_dates[-1],sess_dates[0]],django='location__name__icontains,ephys')iflen(ephys_sess)>0:ephys_sess_dates=[sess['start_time'][:10]forsessinephys_sess]n_delay=len(one.alyx.rest('sessions','list',subject=subj,date_range=[sess_dates[-1],sess_dates[0]],django='json__SESSION_DELAY_START__gte,900'))else:ephys_sess_dates=[]n_delay=0else:ephys_sess_dates=[]n_delay=0returntrials,task_protocol,ephys_sess_dates,n_delay
[docs]defget_training_status(trials,task_protocol,ephys_sess_dates,n_delay):""" Compute training status of a subject from consecutive training datasets. For IBL, training status is calculated using trials from the last three consecutive sessions. Parameters ---------- trials : dict of str Dictionary of trials objects where each key is the ISO session date string. task_protocol : list of str Task protocol used for each training session in `trials`, can be 'training', 'biased' or 'ephys'. ephys_sess_dates : list of str List of ISO date strings where training was conducted on ephys rig. Empty list if all sessions on training rig. n_delay : int Number of sessions on ephys rig that had delay prior to starting session > 15min. Returns 0 if no sessions detected. Returns ------- str Training status of the subject. iblutil.util.Bunch Bunch containing performance metrics that decide training status i.e. performance on easy trials, number of trials, psychometric fit parameters, reaction time. """info=Bunch()trials_all=concatenate_trials(trials)info.session_dates=list(trials.keys())info.protocols=[pforpintask_protocol]# Case when all sessions are trainingChoiceWorldifnp.all(np.array(task_protocol)=='training'):signed_contrast=np.unique(get_signed_contrast(trials_all))(info.perf_easy,info.n_trials,info.psych,info.rt)=compute_training_info(trials,trials_all)pass_criteria,criteria=criterion_1b(info.psych,info.n_trials,info.perf_easy,info.rt,signed_contrast)ifpass_criteria:failed_criteria=Bunch()failed_criteria['NBiased']={'val':info.protocols,'pass':False}failed_criteria['Criteria']={'val':'ready4ephysrig','pass':False}status='trained 1b'else:failed_criteria=criteriapass_criteria,criteria=criterion_1a(info.psych,info.n_trials,info.perf_easy,signed_contrast)ifpass_criteria:status='trained 1a'else:failed_criteria=criteriastatus='in training'returnstatus,info,failed_criteria# Case when there are < 3 biasedChoiceWorld sessions after reaching trained_1b criterionif~np.all(np.array(task_protocol)=='training')and \
np.any(np.array(task_protocol)=='training'):status='trained 1b'(info.perf_easy,info.n_trials,info.psych,info.rt)=compute_training_info(trials,trials_all)criteria=Bunch()criteria['NBiased']={'val':info.protocols,'pass':False}criteria['Criteria']={'val':'ready4ephysrig','pass':False}returnstatus,info,criteria# Case when there is biasedChoiceWorld or ephysChoiceWorld in last three sessionsifnotnp.any(np.array(task_protocol)=='training'):(info.perf_easy,info.n_trials,info.psych_20,info.psych_80,info.rt)=compute_bias_info(trials,trials_all)n_ephys=len(ephys_sess_dates)info.n_ephys=n_ephysinfo.n_delay=n_delay# Criterion recordingpass_criteria,criteria=criteria_recording(n_ephys,n_delay,info.psych_20,info.psych_80,info.n_trials,info.perf_easy,info.rt)ifpass_criteria:# Here the criteria doesn't actually fail but we have no other criteria to meet so we return thisfailed_criteria=criteriastatus='ready4recording'else:failed_criteria=criteriaassertall(dateintrialsfordateinephys_sess_dates)perf_ephys_easy=np.array([compute_performance_easy(trials[k])forkinephys_sess_dates])n_ephys_trials=np.array([compute_n_trials(trials[k])forkinephys_sess_dates])pass_criteria,criteria=criterion_delay(n_ephys_trials,perf_ephys_easy,n_ephys=n_ephys)ifpass_criteria:status='ready4delay'else:failed_criteria=criteriapass_criteria,criteria=criterion_ephys(info.psych_20,info.psych_80,info.n_trials,info.perf_easy,info.rt)ifpass_criteria:status='ready4ephysrig'else:failed_criteria=criteriastatus='trained 1b'returnstatus,info,failed_criteria
[docs]defdisplay_status(subj,sess_dates,status,perf_easy=None,n_trials=None,psych=None,psych_20=None,psych_80=None,rt=None):""" Display training status of subject to terminal. Parameters ---------- subj : str Subject nickname (must match the name registered on Alyx). sess_dates : list of str ISO date strings of training sessions used to determine training status. status : str Training status of subject. perf_easy : numpy.array Proportion of correct high contrast trials for each training session. n_trials : numpy.array Total number of trials for each training session. psych : numpy.array Psychometric parameters fit to data from all training sessions - bias, threshold, lapse high, lapse low. psych_20 : numpy.array The fit psychometric parameters for the blocks where probability of a left stimulus is 0.2. psych_80 : numpy.array The fit psychometric parameters for the blocks where probability of a left stimulus is 0.8. rt : float The median response time for zero contrast trials across all training sessions. NaN indicates no zero contrast stimuli in training sessions. """ifperf_easyisNone:print(f"\n{subj} : {status}\nSession dates=[{sess_dates[0]}, {sess_dates[1]}, "f"{sess_dates[2]}]")elifpsych_20isNone:print(f"\n{subj} : {status}\nSession dates={[xforxinsess_dates]}, "f"Perf easy={[np.around(pe,2)forpeinperf_easy]}, "f"N trials={[ntforntinn_trials]} "f"\nPsych fit over last 3 sessions: "f"bias={np.around(psych[0],2)}, thres={np.around(psych[1],2)}, "f"lapse_low={np.around(psych[2],2)}, lapse_high={np.around(psych[3],2)} "f"\nMedian reaction time at 0 contrast over last 3 sessions = "f"{np.around(rt,2)}")else:print(f"\n{subj} : {status}\nSession dates={[xforxinsess_dates]}, "f"Perf easy={[np.around(pe,2)forpeinperf_easy]}, "f"N trials={[ntforntinn_trials]} "f"\nPsych fit over last 3 sessions (20): "f"bias={np.around(psych_20[0],2)}, thres={np.around(psych_20[1],2)}, "f"lapse_low={np.around(psych_20[2],2)}, lapse_high={np.around(psych_20[3],2)} "f"\nPsych fit over last 3 sessions (80): bias={np.around(psych_80[0],2)}, "f"thres={np.around(psych_80[1],2)}, lapse_low={np.around(psych_80[2],2)}, "f"lapse_high={np.around(psych_80[3],2)} "f"\nMedian reaction time at 0 contrast over last 3 sessions = "f"{np.around(rt,2)}")
[docs]defconcatenate_trials(trials):""" Concatenate trials from different training sessions. Parameters ---------- trials : dict of str Dictionary of trials objects where each key is the ISO session date string. Returns ------- one.alf.io.AlfBunch Trials object with data concatenated over three training sessions. """trials_all=AlfBunch()forkinTRIALS_KEYS:trials_all[k]=np.concatenate(list(trials[kk][k]forkkintrials.keys()))returntrials_all
[docs]defcompute_training_info(trials,trials_all):""" Compute all relevant performance metrics for when subject is on trainingChoiceWorld. Parameters ---------- trials : dict of str Dictionary of trials objects where each key is the ISO session date string. trials_all : one.alf.io.AlfBunch Trials object with data concatenated over three training sessions. Returns ------- numpy.array Proportion of correct high contrast trials for each session. numpy.array Total number of trials for each training session. numpy.array Array of psychometric parameters fit to `all_trials` - bias, threshold, lapse high, lapse low. float The median response time for all zero-contrast trials across all sessions. Returns NaN if no trials zero-contrast trials). """signed_contrast=get_signed_contrast(trials_all)perf_easy=np.array([compute_performance_easy(trials[k])forkintrials.keys()])n_trials=np.array([compute_n_trials(trials[k])forkintrials.keys()])psych=compute_psychometric(trials_all,signed_contrast=signed_contrast)rt=compute_median_reaction_time(trials_all,contrast=0,signed_contrast=signed_contrast)returnperf_easy,n_trials,psych,rt
[docs]defcompute_bias_info(trials,trials_all):""" Compute all relevant performance metrics for when subject is on biasedChoiceWorld :param trials: dict containing trials objects from three consecutive training sessions, keys are session dates :type trials: Bunch :param trials_all: trials object with data concatenated over three training sessions :type trials_all: Bunch :returns: - perf_easy - performance of easy trials for each session - n_trials - number of trials in each session - psych_20 - parameters for psychometric curve fit to trials in 20 block over all sessions - psych_80 - parameters for psychometric curve fit to trials in 80 block over all sessions - rt - median reaction time for zero contrast stimuli over all sessions """signed_contrast=get_signed_contrast(trials_all)perf_easy=np.array([compute_performance_easy(trials[k])forkintrials.keys()])n_trials=np.array([compute_n_trials(trials[k])forkintrials.keys()])psych_20=compute_psychometric(trials_all,signed_contrast=signed_contrast,block=0.2)psych_80=compute_psychometric(trials_all,signed_contrast=signed_contrast,block=0.8)rt=compute_median_reaction_time(trials_all,contrast=0,signed_contrast=signed_contrast)returnperf_easy,n_trials,psych_20,psych_80,rt
[docs]defget_signed_contrast(trials):""" Compute signed contrast from trials object :param trials: trials object that must contain contrastLeft and contrastRight keys :type trials: dict returns: array of signed contrasts in percent, where -ve values are on the left """# Replace NaNs with zeros, stack and take the differencecontrast=np.nan_to_num(np.c_[trials['contrastLeft'],trials['contrastRight']])returnnp.diff(contrast).flatten()*100
[docs]defcompute_performance_easy(trials):""" Compute performance on easy trials (stimulus >= 50 %) from trials object :param trials: trials object that must contain contrastLeft, contrastRight and feedbackType keys :type trials: dict returns: float containing performance on easy contrast trials """signed_contrast=get_signed_contrast(trials)easy_trials=np.where(np.abs(signed_contrast)>=50)[0]returnnp.sum(trials['feedbackType'][easy_trials]==1)/easy_trials.shape[0]
[docs]defcompute_performance(trials,signed_contrast=None,block=None,prob_right=False):""" Compute performance on all trials at each contrast level from trials object :param trials: trials object that must contain contrastLeft, contrastRight and feedbackType keys :type trials: dict returns: float containing performance on easy contrast trials """ifsigned_contrastisNone:signed_contrast=get_signed_contrast(trials)ifblockisNone:block_idx=np.full(trials.probabilityLeft.shape,True,dtype=bool)else:block_idx=trials.probabilityLeft==blockifnotnp.any(block_idx):returnnp.nan*np.zeros(3)contrasts,n_contrasts=np.unique(signed_contrast[block_idx],return_counts=True)ifnotprob_right:correct=trials.feedbackType==1performance=np.vectorize(lambdax:np.mean(correct[(x==signed_contrast)&block_idx]))(contrasts)else:rightward=trials.choice==-1# Calculate the proportion rightward for each contrast typeperformance=np.vectorize(lambdax:np.mean(rightward[(x==signed_contrast)&block_idx]))(contrasts)returnperformance,contrasts,n_contrasts
[docs]defcompute_n_trials(trials):""" Compute number of trials in trials object :param trials: trials object :type trials: dict returns: int containing number of trials in session """returntrials['choice'].shape[0]
[docs]defcompute_psychometric(trials,signed_contrast=None,block=None,plotting=False,compute_ci=False,alpha=.032):""" Compute psychometric fit parameters for trials object. Parameters ---------- trials : one.alf.io.AlfBunch An ALF trials object containing the keys {'probabilityLeft', 'contrastLeft', 'contrastRight', 'feedbackType', 'choice', 'response_times', 'stimOn_times'}. signed_contrast : numpy.array An array of signed contrasts in percent the length of trials, where left contrasts are -ve. If None, these are computed from the trials object. block : float The block type to compute. If None, all trials are included, otherwise only trials where probabilityLeft matches this value are included. For biasedChoiceWorld, the probabilityLeft set is {0.5, 0.2, 0.8}. plotting : bool Which set of psychofit model parameters to use (see notes). compute_ci : bool If true, computes and returns the confidence intervals for response at each contrast. alpha : float, default=0.032 Significance level for confidence interval. Must be in (0, 1). If `compute_ci` is false, this value is ignored. Returns ------- numpy.array Array of psychometric fit parameters - bias, threshold, lapse high, lapse low. (tuple of numpy.array) If `compute_ci` is true, a tuple of See Also -------- statsmodels.stats.proportion.proportion_confint - The function used to compute confidence interval. psychofit.mle_fit_psycho - The function used to fit the psychometric parameters. Notes ----- The psychofit starting parameters and model constraints used for the fit when computing the training status (e.g. trained_1a, etc.) are sub-optimal and can produce a poor fit. To keep the precise criteria the same for all subjects, these parameters have not changed. To produce a better fit for plotting purposes, or to calculate the training status in a manner inconsistent with the IBL training pipeline, use plotting=True. """ifsigned_contrastisNone:signed_contrast=get_signed_contrast(trials)ifblockisNone:block_idx=np.full(trials.probabilityLeft.shape,True,dtype=bool)else:block_idx=trials.probabilityLeft==blockifnotnp.any(block_idx):returnnp.nan*np.zeros(4)prob_choose_right,contrasts,n_contrasts=compute_performance(trials,signed_contrast=signed_contrast,block=block,prob_right=True)ifplotting:# These starting parameters and constraints tend to produce a better fit, and are therefore# used for plotting.psych,_=psy.mle_fit_psycho(np.vstack([contrasts,n_contrasts,prob_choose_right]),P_model='erf_psycho_2gammas',parstart=np.array([0.,40.,0.1,0.1]),parmin=np.array([-50.,10.,0.,0.]),parmax=np.array([50.,50.,0.2,0.2]),nfits=10)else:# These starting parameters and constraints are not ideal but are still used for computing# the training status for consistency.psych,_=psy.mle_fit_psycho(np.vstack([contrasts,n_contrasts,prob_choose_right]),P_model='erf_psycho_2gammas',parstart=np.array([np.mean(contrasts),20.,0.05,0.05]),parmin=np.array([np.min(contrasts),0.,0.,0.]),parmax=np.array([np.max(contrasts),100.,1,1]))ifcompute_ci:importstatsmodels.stats.proportionassmp# noqa# choice == -1 means contrast on right hand siden_right=np.vectorize(lambdax:np.sum(trials['choice'][(x==signed_contrast)&block_idx]==-1))(contrasts)ci=smp.proportion_confint(n_right,n_contrasts,alpha=alpha,method='normal')-prob_choose_rightreturnpsych,cielse:returnpsych
[docs]defcompute_median_reaction_time(trials,stim_on_type='stimOn_times',contrast=None,signed_contrast=None):""" Compute median response time on zero contrast trials from trials object Parameters ---------- trials : one.alf.io.AlfBunch An ALF trials object containing the keys {'probabilityLeft', 'contrastLeft', 'contrastRight', 'feedbackType', 'choice', 'response_times', 'stimOn_times'}. stim_on_type : str, default='stimOn_times' The trials key to use when calculating the response times. The difference between this and 'feedback_times' is used (see notes). contrast : float If None, the median response time is calculated for all trials, regardless of contrast, otherwise only trials where the matching signed percent contrast was presented are used. signed_contrast : numpy.array An array of signed contrasts in percent the length of trials, where left contrasts are -ve. If None, these are computed from the trials object. Returns ------- float The median response time for trials with `contrast` (returns NaN if no trials matching `contrast` in trials object). Notes ----- - The `stim_on_type` is 'stimOn_times' by default, however for IBL rig data, the photodiode is sometimes not calibrated properly which can lead to inaccurate (or absent, i.e. NaN) stim on times. Therefore, it is sometimes more accurate to use the 'stimOnTrigger_times' (the time of the stimulus onset command), if available, or the 'goCue_times' (the time of the soundcard output TTL when the audio go cue is played) or the 'goCueTrigger_times' (the time of the audio go cue command). - The response/reaction time here is defined as the time between stim on and feedback, i.e. the entire open-loop trial duration. """ifsigned_contrastisNone:signed_contrast=get_signed_contrast(trials)ifcontrastisNone:contrast_idx=np.full(trials.probabilityLeft.shape,True,dtype=bool)else:contrast_idx=signed_contrast==contrastifnp.any(contrast_idx):reaction_time=np.nanmedian((trials.response_times-trials[stim_on_type])[contrast_idx])else:reaction_time=np.nanreturnreaction_time
[docs]defcompute_reaction_time(trials,stim_on_type='stimOn_times',stim_off_type='response_times',signed_contrast=None,block=None,compute_ci=False,alpha=0.32):""" Compute median response time for all contrasts. Parameters ---------- trials : one.alf.io.AlfBunch An ALF trials object containing the keys {'probabilityLeft', 'contrastLeft', 'contrastRight', 'feedbackType', 'choice', 'response_times', 'stimOn_times'}. stim_on_type : str, default='stimOn_times' The trials key to use when calculating the response times. The difference between this and `stim_off_type` is used (see notes). stim_off_type : str, default='response_times' The trials key to use when calculating the response times. The difference between this and `stim_on_type` is used (see notes). signed_contrast : numpy.array An array of signed contrasts in percent the length of trials, where left contrasts are -ve. If None, these are computed from the trials object. block : float The block type to compute. If None, all trials are included, otherwise only trials where probabilityLeft matches this value are included. For biasedChoiceWorld, the probabilityLeft set is {0.5, 0.2, 0.8}. compute_ci : bool If true, computes and returns the confidence intervals for response time at each contrast. alpha : float, default=0.32 Significance level for confidence interval. Must be in (0, 1). If `compute_ci` is false, this value is ignored. Returns ------- numpy.array The median response times for each unique signed contrast. numpy.array The set of unique signed contrasts. numpy.array The number of trials for each unique signed contrast. (numpy.array) If `compute_ci` is true, an array of confidence intervals is return in the shape (n_trials, 2). Notes ----- - The response/reaction time by default is the time between stim on and response, i.e. the entire open-loop trial duration. One could use 'stimOn_times' and 'firstMovement_times' to get the true reaction time, or 'firstMovement_times' and 'response_times' to get the true response times, or calculate the last movement onset times and calculate the true movement times. See module examples for how to calculate this. See Also -------- scipy.stats.bootstrap - the function used to compute the confidence interval. """ifsigned_contrastisNone:signed_contrast=get_signed_contrast(trials)ifblockisNone:block_idx=np.full(trials.probabilityLeft.shape,True,dtype=bool)else:block_idx=trials.probabilityLeft==blockcontrasts,n_contrasts=np.unique(signed_contrast[block_idx],return_counts=True)reaction_time=np.vectorize(lambdax:np.nanmedian((trials[stim_off_type]-trials[stim_on_type])[(x==signed_contrast)&block_idx]),otypes=[float])(contrasts)ifcompute_ci:ci=np.full((contrasts.size,2),np.nan)fori,xinenumerate(contrasts):data=(trials[stim_off_type]-trials[stim_on_type])[(x==signed_contrast)&block_idx]bt=bootstrap((data,),np.nanmedian,confidence_level=1-alpha)ci[i,0]=bt.confidence_interval.lowci[i,1]=bt.confidence_interval.highreturnreaction_time,contrasts,n_contrasts,cielse:returnreaction_time,contrasts,n_contrasts,
[docs]defcriterion_1a(psych,n_trials,perf_easy,signed_contrast):""" Returns bool indicating whether criteria for status 'trained_1a' are met. Criteria -------- - Bias is less than 16 - Threshold is less than 19 - Lapse rate on both sides is less than 0.2 - The total number of trials is greater than 200 for each session - Performance on easy contrasts > 80% for all sessions - Zero contrast trials must be present Parameters ---------- psych : numpy.array The fit psychometric parameters three consecutive sessions. Parameters are bias, threshold, lapse high, lapse low. n_trials : numpy.array of int The number for trials for each session. perf_easy : numpy.array of float The proportion of correct high contrast trials for each session. signed_contrast: numpy.array Unique list of contrasts displayed Returns ------- bool True if the criteria are met for 'trained_1a'. Bunch Bunch containing breakdown of the passing/ failing critieria Notes ----- The parameter thresholds chosen here were originally determined by averaging the parameter fits for a number of sessions determined to be of 'good' performance by an experimenter. """criteria=Bunch()criteria['Zero_contrast']={'val':signed_contrast,'pass':np.any(signed_contrast==0)}criteria['LapseLow_50']={'val':psych[2],'pass':psych[2]<0.2}criteria['LapseHigh_50']={'val':psych[3],'pass':psych[3]<0.2}criteria['Bias']={'val':psych[0],'pass':abs(psych[0])<16}criteria['Threshold']={'val':psych[1],'pass':psych[1]<19}criteria['N_trials']={'val':n_trials,'pass':np.all(n_trials>200)}criteria['Perf_easy']={'val':perf_easy,'pass':np.all(perf_easy>0.8)}passing=np.all([v['pass']fork,vincriteria.items()])criteria['Criteria']={'val':'trained_1a','pass':passing}returnpassing,criteria
[docs]defcriterion_1b(psych,n_trials,perf_easy,rt,signed_contrast):""" Returns bool indicating whether criteria for trained_1b are met. Criteria -------- - Bias is less than 10 - Threshold is less than 20 (see notes) - Lapse rate on both sides is less than 0.1 - The total number of trials is greater than 400 for each session - Performance on easy contrasts > 90% for all sessions - The median response time across all zero contrast trials is less than 2 seconds - Zero contrast trials must be present Parameters ---------- psych : numpy.array The fit psychometric parameters three consecutive sessions. Parameters are bias, threshold, lapse high, lapse low. n_trials : numpy.array of int The number for trials for each session. perf_easy : numpy.array of float The proportion of correct high contrast trials for each session. rt : float The median response time for zero contrast trials. signed_contrast: numpy.array Unique list of contrasts displayed Returns ------- bool True if the criteria are met for 'trained_1b'. Bunch Bunch containing breakdown of the passing/ failing critieria Notes ----- The parameter thresholds chosen here were originally chosen to be slightly stricter than 1a, however it was decided to use round numbers so that readers would not assume a level of precision that isn't there (remember, these parameters were not chosen with any rigor). This regrettably means that the maximum threshold fit for 1b is greater than for 1a, meaning the slope of the psychometric curve may be slightly less steep than 1a. """criteria=Bunch()criteria['Zero_contrast']={'val':signed_contrast,'pass':np.any(signed_contrast==0)}criteria['LapseLow_50']={'val':psych[2],'pass':psych[2]<0.1}criteria['LapseHigh_50']={'val':psych[3],'pass':psych[3]<0.1}criteria['Bias']={'val':psych[0],'pass':abs(psych[0])<10}criteria['Threshold']={'val':psych[1],'pass':psych[1]<20}criteria['N_trials']={'val':n_trials,'pass':np.all(n_trials>400)}criteria['Perf_tasy']={'val':perf_easy,'pass':np.all(perf_easy>0.9)}criteria['Reaction_time']={'val':rt,'pass':rt<2}passing=np.all([v['pass']fork,vincriteria.items()])criteria['Criteria']={'val':'trained_1b','pass':passing}returnpassing,criteria
[docs]defcriterion_ephys(psych_20,psych_80,n_trials,perf_easy,rt):""" Returns bool indicating whether criteria for ready4ephysrig are met. Criteria -------- - Lapse on both sides < 0.1 for both bias blocks - Bias shift between blocks > 5 - Total number of trials > 400 for all sessions - Performance on easy contrasts > 90% for all sessions - Median response time for zero contrast stimuli < 2 seconds Parameters ---------- psych_20 : numpy.array The fit psychometric parameters for the blocks where probability of a left stimulus is 0.2. Parameters are bias, threshold, lapse high, lapse low. psych_80 : numpy.array The fit psychometric parameters for the blocks where probability of a left stimulus is 0.8. Parameters are bias, threshold, lapse high, lapse low. n_trials : numpy.array The number of trials for each session (typically three consecutive sessions). perf_easy : numpy.array The proportion of correct high contrast trials for each session (typically three consecutive sessions). rt : float The median response time for zero contrast trials. Returns ------- bool True if subject passes the ready4ephysrig criteria. Bunch Bunch containing breakdown of the passing/ failing critieria """criteria=Bunch()criteria['LapseLow_80']={'val':psych_80[2],'pass':psych_80[2]<0.1}criteria['LapseHigh_80']={'val':psych_80[3],'pass':psych_80[3]<0.1}criteria['LapseLow_20']={'val':psych_20[2],'pass':psych_20[2]<0.1}criteria['LapseHigh_20']={'val':psych_20[3],'pass':psych_20[3]<0.1}criteria['Bias_shift']={'val':psych_80[0]-psych_20[0],'pass':psych_80[0]-psych_20[0]>5}criteria['N_trials']={'val':n_trials,'pass':np.all(n_trials>400)}criteria['Perf_easy']={'val':perf_easy,'pass':np.all(perf_easy>0.9)}criteria['Reaction_time']={'val':rt,'pass':rt<2}passing=np.all([v['pass']fork,vincriteria.items()])criteria['Criteria']={'val':'ready4ephysrig','pass':passing}returnpassing,criteria
[docs]defcriterion_delay(n_trials,perf_easy,n_ephys=1):""" Returns bool indicating whether criteria for 'ready4delay' is met. Criteria -------- - At least one session on an ephys rig - Total number of trials for any of the sessions is greater than 400 - Performance on easy contrasts is greater than 90% for any of the sessions Parameters ---------- n_trials : numpy.array of int The number of trials for each session (typically three consecutive sessions). perf_easy : numpy.array The proportion of correct high contrast trials for each session (typically three consecutive sessions). Returns ------- bool True if subject passes the 'ready4delay' criteria. Bunch Bunch containing breakdown of the passing/ failing critieria """criteria=Bunch()criteria['N_ephys']={'val':n_ephys,'pass':n_ephys>0}criteria['N_trials']={'val':n_trials,'pass':np.any(n_trials>400)}criteria['Perf_easy']={'val':perf_easy,'pass':np.any(perf_easy>0.9)}passing=np.all([v['pass']fork,vincriteria.items()])criteria['Criteria']={'val':'ready4delay','pass':passing}returnpassing,criteria
[docs]defcriteria_recording(n_ephys,delay,psych_20,psych_80,n_trials,perf_easy,rt):""" Returns bool indicating whether criteria for ready4recording are met. Criteria -------- - At least 3 ephys sessions - Delay on any session > 0 - Lapse on both sides < 0.1 for both bias blocks - Bias shift between blocks > 5 - Total number of trials > 400 for all sessions - Performance on easy contrasts > 90% for all sessions - Median response time for zero contrast stimuli < 2 seconds Parameters ---------- psych_20 : numpy.array The fit psychometric parameters for the blocks where probability of a left stimulus is 0.2. Parameters are bias, threshold, lapse high, lapse low. psych_80 : numpy.array The fit psychometric parameters for the blocks where probability of a left stimulus is 0.8. Parameters are bias, threshold, lapse high, lapse low. n_trials : numpy.array The number of trials for each session (typically three consecutive sessions). perf_easy : numpy.array The proportion of correct high contrast trials for each session (typically three consecutive sessions). rt : float The median response time for zero contrast trials. Returns ------- bool True if subject passes the ready4recording criteria. Bunch Bunch containing breakdown of the passing/ failing critieria """_,criteria=criterion_ephys(psych_20,psych_80,n_trials,perf_easy,rt)criteria['N_ephys']={'val':n_ephys,'pass':n_ephys>=3}criteria['N_delay']={'val':delay,'pass':delay>0}passing=np.all([v['pass']fork,vincriteria.items()])criteria['Criteria']={'val':'ready4recording','pass':passing}returnpassing,criteria
[docs]defplot_psychometric(trials,ax=None,title=None,plot_ci=False,ci_alpha=0.032,**kwargs):""" Function to plot psychometric curve plots a la datajoint webpage. Parameters ---------- trials : one.alf.io.AlfBunch An ALF trials object containing the keys {'probabilityLeft', 'contrastLeft', 'contrastRight', 'feedbackType', 'choice', 'response_times', 'stimOn_times'}. ax : matplotlib.pyplot.Axes An axis object to plot onto. title : str An optional plot title. plot_ci : bool If true, computes and plots the confidence intervals for response at each contrast. ci_alpha : float, default=0.032 Significance level for confidence interval. Must be in (0, 1). If `plot_ci` is false, this value is ignored. **kwargs If `ax` is None, these arguments are passed to matplotlib.pyplot.subplots. Returns ------- matplotlib.pyplot.Figure The figure handle containing the plot. matplotlib.pyplot.Axes The plotted axes. See Also -------- statsmodels.stats.proportion.proportion_confint - The function used to compute confidence interval. psychofit.mle_fit_psycho - The function used to fit the psychometric parameters. psychofit.erf_psycho_2gammas - The function used to transform contrast to response probability using the fit parameters. """signed_contrast=get_signed_contrast(trials)contrasts_fit=np.arange(-100,100)prob_right_50,contrasts_50,_=compute_performance(trials,signed_contrast=signed_contrast,block=0.5,prob_right=True)out_50=compute_psychometric(trials,signed_contrast=signed_contrast,block=0.5,plotting=True,compute_ci=plot_ci,alpha=ci_alpha)pars_50=out_50[0]ifplot_cielseout_50prob_right_fit_50=psy.erf_psycho_2gammas(pars_50,contrasts_fit)prob_right_20,contrasts_20,_=compute_performance(trials,signed_contrast=signed_contrast,block=0.2,prob_right=True)out_20=compute_psychometric(trials,signed_contrast=signed_contrast,block=0.2,plotting=True,compute_ci=plot_ci,alpha=ci_alpha)pars_20=out_20[0]ifplot_cielseout_20prob_right_fit_20=psy.erf_psycho_2gammas(pars_20,contrasts_fit)prob_right_80,contrasts_80,_=compute_performance(trials,signed_contrast=signed_contrast,block=0.8,prob_right=True)out_80=compute_psychometric(trials,signed_contrast=signed_contrast,block=0.8,plotting=True,compute_ci=plot_ci,alpha=ci_alpha)pars_80=out_80[0]ifplot_cielseout_80prob_right_fit_80=psy.erf_psycho_2gammas(pars_80,contrasts_fit)cmap=sns.diverging_palette(20,220,n=3,center='dark')ifnotax:fig,ax=plt.subplots(**kwargs)else:fig=plt.gcf()fit_50=ax.plot(contrasts_fit,prob_right_fit_50,color=cmap[1])data_50=ax.scatter(contrasts_50,prob_right_50,color=cmap[1])fit_20=ax.plot(contrasts_fit,prob_right_fit_20,color=cmap[0])data_20=ax.scatter(contrasts_20,prob_right_20,color=cmap[0])fit_80=ax.plot(contrasts_fit,prob_right_fit_80,color=cmap[2])data_80=ax.scatter(contrasts_80,prob_right_80,color=cmap[2])ifplot_ci:errbar_50=np.c_[np.abs(out_50[1][0]),np.abs(out_50[1][1])].Terrbar_20=np.c_[np.abs(out_20[1][0]),np.abs(out_20[1][1])].Terrbar_80=np.c_[np.abs(out_80[1][0]),np.abs(out_80[1][1])].Tax.errorbar(contrasts_50,prob_right_50,yerr=errbar_50,ecolor=cmap[1],fmt='none',capsize=5,alpha=0.4)ax.errorbar(contrasts_20,prob_right_20,yerr=errbar_20,ecolor=cmap[0],fmt='none',capsize=5,alpha=0.4)ax.errorbar(contrasts_80,prob_right_80,yerr=errbar_80,ecolor=cmap[2],fmt='none',capsize=5,alpha=0.4)ax.legend([fit_50[0],data_50,fit_20[0],data_20,fit_80[0],data_80],['p_left=0.5 fit','p_left=0.5 data','p_left=0.2 fit','p_left=0.2 data','p_left=0.8 fit','p_left=0.8 data'],loc='upper left')ax.set_ylim(-0.05,1.05)ax.set_ylabel('Probability choosing right')ax.set_xlabel('Contrasts')iftitle:ax.set_title(title)returnfig,ax
[docs]defplot_reaction_time(trials,ax=None,title=None,plot_ci=False,ci_alpha=0.32,**kwargs):""" Function to plot reaction time against contrast a la datajoint webpage. The reaction times are plotted individually for the following three blocks: {0.5, 0.2, 0.8}. Parameters ---------- trials : one.alf.io.AlfBunch An ALF trials object containing the keys {'probabilityLeft', 'contrastLeft', 'contrastRight', 'feedbackType', 'choice', 'response_times', 'stimOn_times'}. ax : matplotlib.pyplot.Axes An axis object to plot onto. title : str An optional plot title. plot_ci : bool If true, computes and plots the confidence intervals for response at each contrast. ci_alpha : float, default=0.32 Significance level for confidence interval. Must be in (0, 1). If `plot_ci` is false, this value is ignored. **kwargs If `ax` is None, these arguments are passed to matplotlib.pyplot.subplots. Returns ------- matplotlib.pyplot.Figure The figure handle containing the plot. matplotlib.pyplot.Axes The plotted axes. See Also -------- scipy.stats.bootstrap - the function used to compute the confidence interval. """signed_contrast=get_signed_contrast(trials)out_50=compute_reaction_time(trials,signed_contrast=signed_contrast,block=0.5,compute_ci=plot_ci,alpha=ci_alpha)out_20=compute_reaction_time(trials,signed_contrast=signed_contrast,block=0.2,compute_ci=plot_ci,alpha=ci_alpha)out_80=compute_reaction_time(trials,signed_contrast=signed_contrast,block=0.8,compute_ci=plot_ci,alpha=ci_alpha)cmap=sns.diverging_palette(20,220,n=3,center='dark')ifnotax:fig,ax=plt.subplots(**kwargs)else:fig=plt.gcf()data_50=ax.plot(out_50[1],out_50[0],'-o',color=cmap[1])data_20=ax.plot(out_20[1],out_20[0],'-o',color=cmap[0])data_80=ax.plot(out_80[1],out_80[0],'-o',color=cmap[2])ifplot_ci:errbar_50=np.c_[out_50[0]-out_50[3][:,0],out_50[3][:,1]-out_50[0]].Terrbar_20=np.c_[out_20[0]-out_20[3][:,0],out_20[3][:,1]-out_20[0]].Terrbar_80=np.c_[out_80[0]-out_80[3][:,0],out_80[3][:,1]-out_80[0]].Tax.errorbar(out_50[1],out_50[0],yerr=errbar_50,ecolor=cmap[1],fmt='none',capsize=5,alpha=0.4)ax.errorbar(out_20[1],out_20[0],yerr=errbar_20,ecolor=cmap[0],fmt='none',capsize=5,alpha=0.4)ax.errorbar(out_80[1],out_80[0],yerr=errbar_80,ecolor=cmap[2],fmt='none',capsize=5,alpha=0.4)ax.legend([data_50[0],data_20[0],data_80[0]],['p_left=0.5 data','p_left=0.2 data','p_left=0.8 data'],loc='upper left')ax.set_ylabel('Reaction time (s)')ax.set_xlabel('Contrasts')iftitle:ax.set_title(title)returnfig,ax
[docs]defplot_reaction_time_over_trials(trials,stim_on_type='stimOn_times',ax=None,title=None,**kwargs):""" Function to plot reaction time with trial number a la datajoint webpage. Parameters ---------- trials : one.alf.io.AlfBunch An ALF trials object containing the keys {'probabilityLeft', 'contrastLeft', 'contrastRight', 'feedbackType', 'choice', 'response_times', 'stimOn_times'}. stim_on_type : str, default='stimOn_times' The trials key to use when calculating the response times. The difference between this and 'feedback_times' is used (see notes for `compute_median_reaction_time`). ax : matplotlib.pyplot.Axes An axis object to plot onto. title : str An optional plot title. **kwargs If `ax` is None, these arguments are passed to matplotlib.pyplot.subplots. Returns ------- matplotlib.pyplot.Figure The figure handle containing the plot. matplotlib.pyplot.Axes The plotted axes. """reaction_time=pd.DataFrame()reaction_time['reaction_time']=trials.response_times-trials[stim_on_type]reaction_time.index=reaction_time.index+1reaction_time_rolled=reaction_time['reaction_time'].rolling(window=10).median()reaction_time_rolled=reaction_time_rolled.where((pd.notnull(reaction_time_rolled)),None)reaction_time=reaction_time.where((pd.notnull(reaction_time)),None)ifnotax:fig,ax=plt.subplots(**kwargs)else:fig=plt.gcf()ax.scatter(np.arange(len(reaction_time.values)),reaction_time.values,s=16,color='darkgray')ax.plot(np.arange(len(reaction_time_rolled.values)),reaction_time_rolled.values,color='k',linewidth=2)ax.set_yscale('log')ax.set_ylim(0.1,100)ax.yaxis.set_major_formatter(matplotlib.ticker.ScalarFormatter())ax.set_ylabel('Reaction time (s)')ax.set_xlabel('Trial number')iftitle:ax.set_title(title)returnfig,ax
[docs]defquery_criterion(subject,status,from_status=None,one=None,validate=True):"""Get the session for which a given training criterion was met. Parameters ---------- subject : str The subject name. status : str The training status to query for. from_status : str, optional Count number of sessions and days from reaching `from_status` to `status`. one : one.api.OneAlyx, optional An instance of ONE. validate : bool If true, check if status in TrainingStatus enumeration. Set to false for non-standard training pipelines. Returns ------- str The eID of the first session where this training status was reached. int The number of sessions it took to reach `status` (optionally from reaching `from_status`). int The number of days it tool to reach `status` (optionally from reaching `from_status`). """ifvalidate:status=status.lower().replace(' ','_')try:status=TrainingStatus[status.upper().replace(' ','_')].name.lower()exceptKeyErrorasex:raiseValueError(f'Unknown status "{status}". For non-standard training protocols set validate=False')fromexone=oneorONE()subject_json=one.alyx.rest('subjects','read',id=subject)['json']ifnot(criteria:=subject_json.get('trained_criteria'))orstatusnotincriteria:returnNone,None,Noneto_date,eid=criteria[status]from_date,_=criteria.get(from_status,(None,None))eids,det=one.search(subject=subject,date_range=[from_date,to_date],details=True)iflen(eids)==0:returneid,None,Nonedelta_date=det[0]['date']-det[-1]['date']returneid,len(eids),delta_date.days