import logging
from pathlib import Path
from datetime import datetime
from itertools import chain
import numpy as np
import pandas as pd
from iblutil.numerical import ismember
import one.alf.io as alfio
from one.alf.exceptions import ALFObjectNotFound
import one.alf.path as alfiles
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from matplotlib.lines import Line2D
import seaborn as sns
import boto3
from botocore.exceptions import ProfileNotFound, ClientError
from ibllib.io.raw_data_loaders import load_bpod
from ibllib.oneibl.registration import _get_session_times
from ibllib.io.extractors.base import get_bpod_extractor_class
from ibllib.io.session_params import read_params
from ibllib.io.extractors.bpod_trials import get_bpod_extractor
from ibllib.plots.snapshot import ReportSnapshot
from brainbox.behavior import training
logger = logging.getLogger(__name__)
TRAINING_STATUS = {'untrainable': (-4, (0, 0, 0, 0)),
'unbiasable': (-3, (0, 0, 0, 0)),
'not_computed': (-2, (0, 0, 0, 0)),
'habituation': (-1, (0, 0, 0, 0)),
'in training': (0, (0, 0, 0, 0)),
'trained 1a': (1, (195, 90, 80, 255)),
'trained 1b': (2, (255, 153, 20, 255)),
'ready4ephysrig': (3, (28, 20, 255, 255)),
'ready4delay': (4, (117, 117, 117, 255)),
'ready4recording': (5, (20, 255, 91, 255))}
[docs]
def get_training_table_from_aws(lab, subject):
"""
If aws credentials exist on the local server download the latest training table from aws s3 private bucket
:param lab:
:param subject:
:return:
"""
try:
session = boto3.Session(profile_name='ibl_training')
except ProfileNotFound:
return
local_file_path = f'/mnt/s0/Data/Subjects/{subject}/training.csv'
dst_bucket_name = 'ibl-brain-wide-map-private'
try:
s3 = session.resource('s3')
bucket = s3.Bucket(name=dst_bucket_name)
bucket.download_file(f'resources/training/{lab}/{subject}/training.csv',
local_file_path)
df = pd.read_csv(local_file_path)
except ClientError:
return
return df
[docs]
def upload_training_table_to_aws(lab, subject):
"""
If aws credentials exist on the local server upload the training table to aws s3 private bucket
:param lab:
:param subject:
:return:
"""
try:
session = boto3.Session(profile_name='ibl_training')
except ProfileNotFound:
return
local_file_path = f'/mnt/s0/Data/Subjects/{subject}/training.csv'
dst_bucket_name = 'ibl-brain-wide-map-private'
try:
s3 = session.resource('s3')
bucket = s3.Bucket(name=dst_bucket_name)
bucket.upload_file(local_file_path,
f'resources/training/{lab}/{subject}/training.csv')
except (ClientError, FileNotFoundError):
return
[docs]
def save_path(subj_path):
return Path(subj_path).joinpath('training.csv')
[docs]
def save_dataframe(df, subj_path):
"""Save training dataframe to disk.
:param df: dataframe to save
:param subj_path: path to subject folder
:return:
"""
df.to_csv(save_path(subj_path), index=False)
[docs]
def load_existing_dataframe(subj_path):
"""Load training dataframe from disk, if dataframe doesn't exist returns None.
:param subj_path: path to subject folder
:return:
"""
df_location = save_path(subj_path)
if df_location.exists():
return pd.read_csv(df_location)
else:
df_location.parent.mkdir(exist_ok=True, parents=True)
return None
[docs]
def load_trials(sess_path, one, collections=None, force=True, mode='raise'):
"""
Load trials data for session. First attempts to load from local session path, if this fails will attempt to download via ONE,
if this also fails, will then attempt to re-extract locally
:param sess_path: session path
:param one: ONE instance
:param force: when True and if the session trials can't be found, will attempt to re-extract from the disk
:param mode: 'raise' or 'warn', if 'raise', will error when forcing re-extraction of past sessions
:return:
"""
try:
# try and load all trials that are found locally in the session path locally
if collections is None:
trial_locations = list(sess_path.rglob('_ibl_trials.goCueTrigger_times.*npy'))
else:
trial_locations = [Path(sess_path).joinpath(c, '_ibl_trials.goCueTrigger_times.*npy') for c in collections]
if len(trial_locations) > 1:
trial_dict = {}
for i, loc in enumerate(trial_locations):
trial_dict[i] = alfio.load_object(loc.parent, 'trials', short_keys=True)
trials = training.concatenate_trials(trial_dict)
elif len(trial_locations) == 1:
trials = alfio.load_object(trial_locations[0].parent, 'trials', short_keys=True)
else:
raise ALFObjectNotFound
if 'probabilityLeft' not in trials.keys():
raise ALFObjectNotFound
except ALFObjectNotFound:
# Next try and load all trials data through ONE
try:
if not force:
return None
eid = one.path2eid(sess_path)
if collections is None:
trial_collections = one.list_datasets(eid, '_ibl_trials.goCueTrigger_times.npy')
if len(trial_collections) > 0:
trial_collections = ['/'.join(c.split('/')[:-1]) for c in trial_collections]
else:
trial_collections = collections
if len(trial_collections) > 1:
trial_dict = {}
for i, collection in enumerate(trial_collections):
trial_dict[i] = one.load_object(eid, 'trials', collection=collection)
trials = training.concatenate_trials(trial_dict)
elif len(trial_collections) == 1:
trials = one.load_object(eid, 'trials', collection=trial_collections[0])
else:
raise ALFObjectNotFound
if 'probabilityLeft' not in trials.keys():
raise ALFObjectNotFound
except Exception:
# Finally try to re-extract the trials data locally
try:
raw_collections, _ = get_data_collection(sess_path)
if len(raw_collections) == 0:
return None
trials_dict = {}
for i, collection in enumerate(raw_collections):
extractor = get_bpod_extractor(sess_path, task_collection=collection)
trials_data, _ = extractor.extract(task_collection=collection, save=False)
trials_dict[i] = alfio.AlfBunch.from_df(trials_data['table'])
if len(trials_dict) > 1:
trials = training.concatenate_trials(trials_dict)
else:
trials = trials_dict[0]
except Exception as e:
if mode == 'raise':
raise Exception(f'Exhausted all possibilities for loading trials for {sess_path}') from e
else:
logger.warning(f'Exhausted all possibilities for loading trials for {sess_path}')
return
return trials
[docs]
def load_combined_trials(sess_paths, one, force=True):
"""
Load and concatenate trials for multiple sessions. Used when we want to concatenate trials for two sessions on the same day
:param sess_paths: list of paths to sessions
:param one: ONE instance
:return:
"""
trials_dict = {}
for sess_path in sess_paths:
trials = load_trials(Path(sess_path), one, force=force)
if trials is not None:
trials_dict[Path(sess_path).stem] = load_trials(Path(sess_path), one, force=force)
return training.concatenate_trials(trials_dict)
[docs]
def find_earliest_recompute_date(df):
"""
Find the earliest date that we need to compute the training status from. Training status depends on previous sessions
so if a session was missing and now has been added we need to recompute everything from that date onwards
:param df:
:return:
"""
missing_df = df[df['training_status'] == 'not_computed']
if len(missing_df) == 0:
return []
missing_df = missing_df.sort_values('date')
first_index = missing_df.index[0]
return df[first_index:].date.values
[docs]
def compute_training_status(df, compute_date, one, force=True):
"""
Compute the training status for compute date based on training from that session and two previous days.
When true and if the session trials can't be found, will attempt to re-extract from disk.
:return:
Parameters
----------
df : pandas.DataFrame
A training data frame, e.g. one generated from :func:`get_training_info_for_session`.
compute_date : str, datetime.datetime, pandas.Timestamp
The date to compute training on.
one : one.api.One
An instance of ONE for loading trials data.
force : bool
When true and if the session trials can't be found, will attempt to re-extract from disk.
Returns
-------
pandas.DataFrame
The input data frame with a 'training_status' column populated for `compute_date`.
"""
# compute_date = str(alfiles.session_path_parts(session_path, as_dict=True)['date'])
df_temp = df[df['date'] <= compute_date]
df_temp = df_temp.drop_duplicates(subset=['session_path', 'task_protocol'])
df_temp.sort_values('date')
dates = df_temp.date.values
n_sess_for_date = len(np.where(dates == compute_date)[0])
n_dates = np.min([2 + n_sess_for_date, len(dates)]).astype(int)
compute_dates = dates[(-1 * n_dates):]
if n_sess_for_date > 1:
compute_dates = compute_dates[:(-1 * (n_sess_for_date - 1))]
assert compute_dates[-1] == compute_date
df_temp_group = df_temp.groupby('date')
trials = {}
n_delay = 0
ephys_sessions = []
protocol = []
status = []
for date in compute_dates:
df_date = df_temp_group.get_group(date)
# If habituation skip
if df_date.iloc[-1]['task_protocol'] == 'habituation':
continue
# Here we should split by protocol in an ideal world but that world isn't today. This is only really relevant for
# chained protocols
trials[df_date.iloc[-1]['date']] = load_combined_trials(df_date.session_path.values, one, force=force)
protocol.append(df_date.iloc[-1]['task_protocol'])
status.append(df_date.iloc[-1]['training_status'])
if df_date.iloc[-1]['combined_n_delay'] >= 900: # delay of 15 mins
n_delay += 1
if df_date.iloc[-1]['location'] == 'ephys_rig':
ephys_sessions.append(df_date.iloc[-1]['date'])
n_status = np.max([-2, -1 * len(status)])
training_status, _ = training.get_training_status(trials, protocol, ephys_sessions, n_delay)
training_status = pass_through_training_hierachy(training_status, status[n_status])
df.loc[df['date'] == compute_date, 'training_status'] = training_status
return df
[docs]
def pass_through_training_hierachy(status_new, status_old):
"""
Makes sure that the new training status is not less than the one from the previous day. e.g Subject cannot regress in
performance
:param status_new: latest training status
:param status_old: previous training status
:return:
"""
if TRAINING_STATUS[status_old][0] > TRAINING_STATUS[status_new][0]:
return status_old
else:
return status_new
[docs]
def compute_session_duration_delay_location(sess_path, collections=None, **kwargs):
"""
Get meta information about task. Extracts session duration, delay before session start and location of session
Parameters
----------
sess_path : pathlib.Path, str
The session path with the pattern subject/yyyy-mm-dd/nnn.
collections : list
The location within the session path directory of task settings and data.
Returns
-------
int
The session duration in minutes, rounded to the nearest minute.
int
The delay between session start time and the first trial in seconds.
str {'ephys_rig', 'training_rig'}
The location of the session.
"""
if collections is None:
collections, _ = get_data_collection(sess_path)
session_duration = 0
session_delay = 0
session_location = 'training_rig'
for collection in collections:
md, sess_data = load_bpod(sess_path, task_collection=collection)
if md is None:
continue
try:
start_time, end_time = _get_session_times(sess_path, md, sess_data)
session_duration = session_duration + int((end_time - start_time).total_seconds() / 60)
session_delay = session_delay + md.get('SESSION_START_DELAY_SEC', 0)
except Exception:
session_duration = session_duration + 0
session_delay = session_delay + 0
if 'ephys' in md.get('PYBPOD_BOARD', None):
session_location = 'ephys_rig'
else:
session_location = 'training_rig'
return session_duration, session_delay, session_location
[docs]
def get_data_collection(session_path):
"""Return the location of the raw behavioral data and extracted trials data for a given session.
For multiple locations in one session (e.g. chained protocols), returns all collections.
Passive protocols are excluded.
Parameters
----------
session_path : pathlib.Path
A session path in the form subject/date/number.
Returns
-------
list of str
A list of sub-directory names that contain raw behaviour data.
list of str
A list of sub-directory names that contain ALF trials data.
Examples
--------
An iblrig v7 session
>>> get_data_collection(Path(r'C:/data/subject/2023-01-01/001'))
['raw_behavior_data'], ['alf']
An iblrig v8 session where two protocols were run
>>> get_data_collection(Path(r'C:/data/subject/2023-01-01/001'))
['raw_task_data_00', 'raw_task_data_01], ['alf/task_00', 'alf/task_01']
"""
experiment_description = read_params(session_path)
collections = []
if experiment_description is not None:
task_protocols = experiment_description.get('tasks', [])
for i, (protocol, task_info) in enumerate(chain(*map(dict.items, task_protocols))):
if 'passiveChoiceWorld' in protocol:
continue
collection = task_info.get('collection', f'raw_task_data_{i:02}')
if collection == 'raw_passive_data':
continue
collections.append(collection)
else:
settings = Path(session_path).rglob('_iblrig_taskSettings.raw*.json')
for setting in settings:
if setting.parent.name != 'raw_passive_data':
collections.append(setting.parent.name)
if len(collections) == 1 and collections[0] == 'raw_behavior_data':
alf_collections = ['alf']
elif all(['raw_task_data' in c for c in collections]):
alf_collections = [f'alf/task_{c[-2:]}' for c in collections]
else:
alf_collections = None
return collections, alf_collections
[docs]
def get_sess_dict(session_path, one, protocol, alf_collections=None, raw_collections=None, force=True):
sess_dict = {}
sess_dict['date'] = str(alfiles.session_path_parts(session_path, as_dict=True)['date'])
sess_dict['session_path'] = str(session_path)
sess_dict['task_protocol'] = protocol
if sess_dict['task_protocol'] == 'habituation':
nan_array = np.array([np.nan])
sess_dict['performance'], sess_dict['contrasts'], _ = (nan_array, nan_array, np.nan)
sess_dict['performance_easy'] = np.nan
sess_dict['reaction_time'] = np.nan
sess_dict['n_trials'] = np.nan
sess_dict['sess_duration'] = np.nan
sess_dict['n_delay'] = np.nan
sess_dict['location'] = np.nan
sess_dict['training_status'] = 'habituation'
sess_dict['bias_50'], sess_dict['thres_50'], sess_dict['lapselow_50'], sess_dict['lapsehigh_50'] = \
(np.nan, np.nan, np.nan, np.nan)
sess_dict['bias_20'], sess_dict['thres_20'], sess_dict['lapselow_20'], sess_dict['lapsehigh_20'] = \
(np.nan, np.nan, np.nan, np.nan)
sess_dict['bias_80'], sess_dict['thres_80'], sess_dict['lapselow_80'], sess_dict['lapsehigh_80'] = \
(np.nan, np.nan, np.nan, np.nan)
else:
# if we can't compute trials then we need to pass
trials = load_trials(session_path, one, collections=alf_collections, force=force, mode='warn')
if trials is None:
return
sess_dict['performance'], sess_dict['contrasts'], _ = training.compute_performance(trials, prob_right=True)
if sess_dict['task_protocol'] == 'training':
sess_dict['bias_50'], sess_dict['thres_50'], sess_dict['lapselow_50'], sess_dict['lapsehigh_50'] = \
training.compute_psychometric(trials)
sess_dict['bias_20'], sess_dict['thres_20'], sess_dict['lapselow_20'], sess_dict['lapsehigh_20'] = \
(np.nan, np.nan, np.nan, np.nan)
sess_dict['bias_80'], sess_dict['thres_80'], sess_dict['lapselow_80'], sess_dict['lapsehigh_80'] = \
(np.nan, np.nan, np.nan, np.nan)
else:
sess_dict['bias_50'], sess_dict['thres_50'], sess_dict['lapselow_50'], sess_dict['lapsehigh_50'] = \
training.compute_psychometric(trials, block=0.5)
sess_dict['bias_20'], sess_dict['thres_20'], sess_dict['lapselow_20'], sess_dict['lapsehigh_20'] = \
training.compute_psychometric(trials, block=0.2)
sess_dict['bias_80'], sess_dict['thres_80'], sess_dict['lapselow_80'], sess_dict['lapsehigh_80'] = \
training.compute_psychometric(trials, block=0.8)
sess_dict['performance_easy'] = training.compute_performance_easy(trials)
sess_dict['reaction_time'] = training.compute_median_reaction_time(trials)
sess_dict['n_trials'] = training.compute_n_trials(trials)
sess_dict['sess_duration'], sess_dict['n_delay'], sess_dict['location'] = \
compute_session_duration_delay_location(session_path, collections=raw_collections)
sess_dict['training_status'] = 'not_computed'
return sess_dict
[docs]
def get_training_info_for_session(session_paths, one, force=True):
"""
Extract the training information needed for plots for each session.
Parameters
----------
session_paths : list of pathlib.Path
List of session paths on same date.
one : one.api.One
An ONE instance.
force : bool
When true and if the session trials can't be found, will attempt to re-extract from disk.
Returns
-------
list of dict
A list of dictionaries the length of `session_paths` containing individual and aggregate
performance information.
"""
# return list of dicts to add
sess_dicts = []
for session_path in session_paths:
collections, alf_collections = get_data_collection(session_path)
session_path = Path(session_path)
protocols = []
for c in collections:
prot = get_bpod_extractor_class(session_path, task_collection=c)
prot = prot[:-6].lower()
protocols.append(prot)
un_protocols = np.unique(protocols)
# Example, training, training, biased - training would be combined, biased not
sess_dict = None
if len(un_protocols) != 1:
print(f'Different protocols in same session {session_path} : {protocols}')
for prot in un_protocols:
if prot is False:
continue
try:
alf = alf_collections[np.where(protocols == prot)[0]]
raw = collections[np.where(protocols == prot)[0]]
except TypeError:
alf = None
raw = None
sess_dict = get_sess_dict(session_path, one, prot, alf_collections=alf, raw_collections=raw, force=force)
else:
prot = un_protocols[0]
sess_dict = get_sess_dict(
session_path, one, prot, alf_collections=alf_collections, raw_collections=collections, force=force)
if sess_dict is not None:
sess_dicts.append(sess_dict)
protocols = [s['task_protocol'] for s in sess_dicts]
if len(protocols) > 0 and len(set(protocols)) != 1:
print(f'Different protocols on same date {sess_dicts[0]["date"]} : {protocols}')
# Only if all protocols are the same and are not habituation
if len(sess_dicts) > 1 and len(set(protocols)) == 1 and protocols[0] != 'habituation': # Only if all protocols are the same
print(f'{len(sess_dicts)} sessions being combined for date {sess_dicts[0]["date"]}')
combined_trials = load_combined_trials(session_paths, one, force=force)
performance, contrasts, _ = training.compute_performance(combined_trials, prob_right=True)
psychs = {}
psychs['50'] = training.compute_psychometric(combined_trials, block=0.5)
psychs['20'] = training.compute_psychometric(combined_trials, block=0.2)
psychs['80'] = training.compute_psychometric(combined_trials, block=0.8)
performance_easy = training.compute_performance_easy(combined_trials)
reaction_time = training.compute_median_reaction_time(combined_trials)
n_trials = training.compute_n_trials(combined_trials)
sess_duration = np.nansum([s['sess_duration'] for s in sess_dicts])
n_delay = np.nanmax([s['n_delay'] for s in sess_dicts])
for sess_dict in sess_dicts:
sess_dict['combined_performance'] = performance
sess_dict['combined_contrasts'] = contrasts
sess_dict['combined_performance_easy'] = performance_easy
sess_dict['combined_reaction_time'] = reaction_time
sess_dict['combined_n_trials'] = n_trials
sess_dict['combined_sess_duration'] = sess_duration
sess_dict['combined_n_delay'] = n_delay
for bias in [50, 20, 80]:
sess_dict[f'combined_bias_{bias}'] = psychs[f'{bias}'][0]
sess_dict[f'combined_thres_{bias}'] = psychs[f'{bias}'][1]
sess_dict[f'combined_lapselow_{bias}'] = psychs[f'{bias}'][2]
sess_dict[f'combined_lapsehigh_{bias}'] = psychs[f'{bias}'][3]
# Case where two sessions on same day with different number of contrasts! Oh boy
if sess_dict['combined_performance'].size != sess_dict['performance'].size:
sess_dict['performance'] = \
np.r_[sess_dict['performance'],
np.full(sess_dict['combined_performance'].size - sess_dict['performance'].size, np.nan)]
sess_dict['contrasts'] = \
np.r_[sess_dict['contrasts'],
np.full(sess_dict['combined_contrasts'].size - sess_dict['contrasts'].size, np.nan)]
else:
for sess_dict in sess_dicts:
sess_dict['combined_performance'] = sess_dict['performance']
sess_dict['combined_contrasts'] = sess_dict['contrasts']
sess_dict['combined_performance_easy'] = sess_dict['performance_easy']
sess_dict['combined_reaction_time'] = sess_dict['reaction_time']
sess_dict['combined_n_trials'] = sess_dict['n_trials']
sess_dict['combined_sess_duration'] = sess_dict['sess_duration']
sess_dict['combined_n_delay'] = sess_dict['n_delay']
for bias in [50, 20, 80]:
sess_dict[f'combined_bias_{bias}'] = sess_dict[f'bias_{bias}']
sess_dict[f'combined_thres_{bias}'] = sess_dict[f'thres_{bias}']
sess_dict[f'combined_lapsehigh_{bias}'] = sess_dict[f'lapsehigh_{bias}']
sess_dict[f'combined_lapselow_{bias}'] = sess_dict[f'lapselow_{bias}']
return sess_dicts
[docs]
def check_up_to_date(subj_path, df):
"""
Check which sessions on local file system are missing from the computed training table.
Parameters
----------
subj_path : pathlib.Path
The path to the subject's dated session folders.
df : pandas.DataFrame
The computed training table.
Returns
-------
pandas.DataFrame
A table of dates and session paths that are missing from the computed training table.
"""
df_session = pd.DataFrame(columns=['date', 'session_path'])
for session in alfio.iter_sessions(subj_path, pattern='????-??-??/*'):
s_df = pd.DataFrame({'date': session.parts[-2], 'session_path': str(session)}, index=[0])
df_session = pd.concat([df_session, s_df], ignore_index=True)
if df is None or 'combined_thres_50' not in df.columns:
return df_session
else:
# recorded_session_paths = df['session_path'].values
isin, _ = ismember(df_session.date.unique(), df.date.unique())
missing_dates = df_session.date.unique()[~isin]
return df_session[df_session['date'].isin(missing_dates)].sort_values('date')
[docs]
def plot_trial_count_and_session_duration(df, subject):
df = df.drop_duplicates('date').reset_index(drop=True)
y1 = {'column': 'combined_n_trials',
'title': 'Trial counts',
'lim': None,
'color': 'k',
'join': True}
y2 = {'column': 'combined_sess_duration',
'title': 'Session duration (mins)',
'lim': None,
'color': 'r',
'log': False,
'join': True}
ax = plot_over_days(df, subject, y1, y2)
return ax
[docs]
def plot_fit_params(df, subject):
fig, axs = plt.subplots(2, 2, figsize=(12, 6))
axs = axs.ravel()
df = df.drop_duplicates('date').reset_index(drop=True)
cmap = sns.diverging_palette(20, 220, n=3, center="dark")
y50 = {'column': 'combined_bias_50',
'title': 'Bias',
'lim': [-100, 100],
'color': cmap[1],
'join': False}
y80 = {'column': 'combined_bias_80',
'title': 'Bias',
'lim': [-100, 100],
'color': cmap[2],
'join': False}
y20 = {'column': 'combined_bias_20',
'title': 'Bias',
'lim': [-100, 100],
'color': cmap[0],
'join': False}
plot_over_days(df, subject, y50, ax=axs[0], legend=False, title=False)
plot_over_days(df, subject, y80, ax=axs[0], legend=False, title=False)
plot_over_days(df, subject, y20, ax=axs[0], legend=False, title=False)
axs[0].axhline(16, linewidth=2, linestyle='--', color='k')
axs[0].axhline(-16, linewidth=2, linestyle='--', color='k')
y50['column'] = 'combined_thres_50'
y50['title'] = 'Threshold'
y50['lim'] = [0, 100]
y80['column'] = 'combined_thres_20'
y80['title'] = 'Threshold'
y20['lim'] = [0, 100]
y20['column'] = 'combined_thres_80'
y20['title'] = 'Threshold'
y80['lim'] = [0, 100]
plot_over_days(df, subject, y50, ax=axs[1], legend=False, title=False)
plot_over_days(df, subject, y80, ax=axs[1], legend=False, title=False)
plot_over_days(df, subject, y20, ax=axs[1], legend=False, title=False)
axs[1].axhline(19, linewidth=2, linestyle='--', color='k')
y50['column'] = 'combined_lapselow_50'
y50['title'] = 'Lapse Low'
y50['lim'] = [0, 1]
y80['column'] = 'combined_lapselow_20'
y80['title'] = 'Lapse Low'
y80['lim'] = [0, 1]
y20['column'] = 'combined_lapselow_80'
y20['title'] = 'Lapse Low'
y20['lim'] = [0, 1]
plot_over_days(df, subject, y50, ax=axs[2], legend=False, title=False)
plot_over_days(df, subject, y80, ax=axs[2], legend=False, title=False)
plot_over_days(df, subject, y20, ax=axs[2], legend=False, title=False)
axs[2].axhline(0.2, linewidth=2, linestyle='--', color='k')
y50['column'] = 'combined_lapsehigh_50'
y50['title'] = 'Lapse High'
y50['lim'] = [0, 1]
y80['column'] = 'combined_lapsehigh_20'
y80['title'] = 'Lapse High'
y80['lim'] = [0, 1]
y20['column'] = 'combined_lapsehigh_80'
y20['title'] = 'Lapse High'
y20['lim'] = [0, 1]
plot_over_days(df, subject, y50, ax=axs[3], legend=False, title=False, training_lines=True)
plot_over_days(df, subject, y80, ax=axs[3], legend=False, title=False, training_lines=False)
plot_over_days(df, subject, y20, ax=axs[3], legend=False, title=False, training_lines=False)
axs[3].axhline(0.2, linewidth=2, linestyle='--', color='k')
fig.suptitle(f'{subject} {df.iloc[-1]["date"]}: {df.iloc[-1]["training_status"]}')
lines, labels = axs[3].get_legend_handles_labels()
fig.legend(lines, labels, loc='upper center', bbox_to_anchor=(0.5, 0.1), fancybox=True, shadow=True, ncol=5)
legend_elements = [Line2D([0], [0], marker='o', color='w', label='p=0.5', markerfacecolor=cmap[1], markersize=8),
Line2D([0], [0], marker='o', color='w', label='p=0.2', markerfacecolor=cmap[0], markersize=8),
Line2D([0], [0], marker='o', color='w', label='p=0.8', markerfacecolor=cmap[2], markersize=8)]
legend2 = plt.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1.1, -0.2), fancybox=True, shadow=True)
fig.add_artist(legend2)
return axs
[docs]
def plot_psychometric_curve(df, subject, one):
df = df.drop_duplicates('date').reset_index(drop=True)
sess_path = Path(df.iloc[-1]["session_path"])
trials = load_trials(sess_path, one)
fig, ax1 = plt.subplots(figsize=(8, 6))
training.plot_psychometric(trials, ax=ax1, title=f'{subject} {df.iloc[-1]["date"]}: {df.iloc[-1]["training_status"]}')
return ax1
[docs]
def plot_over_days(df, subject, y1, y2=None, ax=None, legend=True, title=True, training_lines=True):
if ax is None:
fig, ax1 = plt.subplots(figsize=(12, 6))
else:
ax1 = ax
dates = [datetime.strptime(dat, '%Y-%m-%d') for dat in df['date']]
if y1['join']:
ax1.plot(dates, df[y1['column']], color=y1['color'])
ax1.scatter(dates, df[y1['column']], color=y1['color'])
ax1.set_ylabel(y1['title'])
ax1.set_ylim(y1['lim'])
if y2 is not None:
ax2 = ax1.twinx()
if y2['join']:
ax2.plot(dates, df[y2['column']], color=y2['color'])
ax2.scatter(dates, df[y2['column']], color=y2['color'])
ax2.set_ylabel(y2['title'])
ax2.yaxis.label.set_color(y2['color'])
ax2.tick_params(axis='y', colors=y2['color'])
ax2.set_ylim(y2['lim'])
if y2['log']:
ax2.set_yscale('log')
ax2.spines['right'].set_visible(False)
ax2.spines['top'].set_visible(False)
ax2.spines['left'].set_visible(False)
month_format = mdates.DateFormatter('%b %Y')
month_locator = mdates.MonthLocator()
ax1.xaxis.set_major_locator(month_locator)
ax1.xaxis.set_major_formatter(month_format)
week_locator = mdates.WeekdayLocator(byweekday=mdates.MO, interval=1)
ax1.xaxis.set_minor_locator(week_locator)
ax1.grid(True, which='minor', axis='x', linestyle='--')
ax1.spines['left'].set_visible(False)
ax1.spines['right'].set_visible(False)
ax1.spines['top'].set_visible(False)
if training_lines:
ax1 = add_training_lines(df, ax1)
if title:
ax1.set_title(f'{subject} {df.iloc[-1]["date"]}: {df.iloc[-1]["training_status"]}')
# Put a legend below current axis
box = ax1.get_position()
ax1.set_position([box.x0, box.y0 + box.height * 0.1,
box.width, box.height * 0.9])
if legend:
ax1.legend(loc='upper center', bbox_to_anchor=(0.5, -0.1),
fancybox=True, shadow=True, ncol=5)
return ax1
[docs]
def add_training_lines(df, ax):
status = df.drop_duplicates(subset='training_status', keep='first')
for _, st in status.iterrows():
if st['training_status'] in ['untrainable', 'unbiasable']:
continue
if TRAINING_STATUS[st['training_status']][0] <= 0:
continue
ax.axvline(datetime.strptime(st['date'], '%Y-%m-%d'), linewidth=2,
color=np.array(TRAINING_STATUS[st['training_status']][1]) / 255, label=st['training_status'])
return ax
[docs]
def make_plots(session_path, one, df=None, save=False, upload=False, task_collection='raw_behavior_data'):
subject = one.path2ref(session_path)['subject']
subj_path = session_path.parent.parent
df = load_existing_dataframe(subj_path) if df is None else df
df = df[df['task_protocol'] != 'habituation']
if len(df) == 0:
return
ax1 = plot_trial_count_and_session_duration(df, subject)
ax2 = plot_performance_easy_median_reaction_time(df, subject)
ax3 = plot_heatmap_performance_over_days(df, subject)
ax4 = plot_fit_params(df, subject)
ax5 = plot_psychometric_curve(df, subject, one)
outputs = []
if save:
save_path = Path(subj_path)
save_name = save_path.joinpath('subj_trial_count_session_duration.png')
outputs.append(save_name)
ax1.get_figure().savefig(save_name, bbox_inches='tight')
save_name = save_path.joinpath('subj_performance_easy_reaction_time.png')
outputs.append(save_name)
ax2.get_figure().savefig(save_name, bbox_inches='tight')
save_name = save_path.joinpath('subj_performance_heatmap.png')
outputs.append(save_name)
ax3.get_figure().savefig(save_name, bbox_inches='tight')
save_name = save_path.joinpath('subj_psychometric_fit_params.png')
outputs.append(save_name)
ax4[0].get_figure().savefig(save_name, bbox_inches='tight')
save_name = save_path.joinpath('subj_psychometric_curve.png')
outputs.append(save_name)
ax5.get_figure().savefig(save_name, bbox_inches='tight')
if upload:
subj = one.alyx.rest('subjects', 'list', nickname=subject)[0]
snp = ReportSnapshot(session_path, subj['id'], content_type='subject', one=one)
snp.outputs = outputs
snp.register_images(widths=['orig'])