import logging
import argparse
from itertools import cycle
import random
from collections.abc import Sized
from pathlib import Path
import pandas as pd
import numpy as np
from matplotlib.colors import TABLEAU_COLORS
from one.api import ONE
from one.alf.spec import is_session_path
import ibllib.plots as plots
from ibllib.misc import qt
from ibllib.qc.task_metrics import TaskQC
from ibllib.qc.task_qc_viewer import ViewEphysQC
from ibllib.pipes.dynamic_pipeline import get_trials_tasks
from ibllib.pipes.base_tasks import BehaviourTask
from ibllib.pipes.behavior_tasks import HabituationTrialsBpod, ChoiceWorldTrialsBpod
EVENT_MAP = {'goCue_times': ['#2ca02c', 'solid'], # green
'goCueTrigger_times': ['#2ca02c', 'dotted'], # green
'errorCue_times': ['#d62728', 'solid'], # red
'errorCueTrigger_times': ['#d62728', 'dotted'], # red
'valveOpen_times': ['#17becf', 'solid'], # cyan
'stimFreeze_times': ['#0000ff', 'solid'], # blue
'stimFreezeTrigger_times': ['#0000ff', 'dotted'], # blue
'stimOff_times': ['#9400d3', 'solid'], # dark violet
'stimOffTrigger_times': ['#9400d3', 'dotted'], # dark violet
'stimOn_times': ['#e377c2', 'solid'], # pink
'stimOnTrigger_times': ['#e377c2', 'dotted'], # pink
'response_times': ['#8c564b', 'solid'], # brown
}
cm = [EVENT_MAP[k][0] for k in EVENT_MAP]
ls = [EVENT_MAP[k][1] for k in EVENT_MAP]
CRITICAL_CHECKS = (
'check_audio_pre_trial',
'check_correct_trial_event_sequence',
'check_error_trial_event_sequence',
'check_n_trial_events',
'check_response_feedback_delays',
'check_response_stimFreeze_delays',
'check_reward_volume_set',
'check_reward_volumes',
'check_stimOn_goCue_delays',
'check_stimulus_move_before_goCue',
'check_wheel_move_before_feedback',
'check_wheel_freeze_during_quiescence'
)
_logger = logging.getLogger(__name__)
[docs]
class QcFrame:
qc = None
"""ibllib.qc.task_metrics.TaskQC: A TaskQC object containing extracted data"""
frame = None
"""pandas.DataFrame: A table of failing trial-level QC metrics."""
def __init__(self, qc):
"""
An interactive display of task QC data.
Parameters
----------
qc : ibllib.qc.task_metrics.TaskQC
A TaskQC object containing extracted data for plotting.
"""
assert qc.extractor and qc.metrics, 'Please run QC before passing to QcFrame'
self.qc = qc
# Print failed
outcome, results, outcomes = self.qc.compute_session_status()
map = {k: [] for k in set(outcomes.values())}
for k, v in outcomes.items():
map[v].append(k[6:])
for k, v in map.items():
if k == 'PASS':
continue
print(f'The following checks were labelled {k}:')
print('\n'.join(v), '\n')
print('The following *critical* checks did not pass:')
critical_checks = [f'_{x.replace("check", "task")}' for x in CRITICAL_CHECKS]
for k, v in outcomes.items():
if v != 'PASS' and k in critical_checks:
print(k[6:])
# Make DataFrame from the trail level metrics
def get_trial_level_failed(d):
new_dict = {k[6:]: v for k, v in d.items() if
isinstance(v, Sized) and len(v) == self.n_trials}
return pd.DataFrame.from_dict(new_dict)
self.frame = get_trial_level_failed(self.qc.metrics)
self.frame['intervals_0'] = self.qc.extractor.data['intervals'][:, 0]
self.frame['intervals_1'] = self.qc.extractor.data['intervals'][:, 1]
self.frame.insert(loc=0, column='trial_no', value=self.frame.index)
@property
def n_trials(self):
return self.qc.extractor.data['intervals'].shape[0]
[docs]
def get_wheel_data(self):
return {'re_pos': self.qc.extractor.data.get('wheel_position', np.array([])),
're_ts': self.qc.extractor.data.get('wheel_timestamps', np.array([]))}
[docs]
def create_plots(self, axes,
wheel_axes=None, trial_events=None, color_map=None, linestyle=None):
"""
Plots the data for bnc1 (sound) and bnc2 (frame2ttl).
:param axes: An axes handle on which to plot the TTL events
:param wheel_axes: An axes handle on which to plot the wheel trace
:param trial_events: A list of Bpod trial events to plot, e.g. ['stimFreeze_times'],
if None, valve, sound and stimulus events are plotted
:param color_map: A color map to use for the events, default is the tableau color map
linestyle: A line style map to use for the events, default is random.
:return: None
"""
color_map = color_map or TABLEAU_COLORS.keys()
if trial_events is None:
# Default trial events to plot as vertical lines
trial_events = [
'goCue_times',
'goCueTrigger_times',
'feedback_times',
('stimCenter_times'
if 'stimCenter_times' in self.qc.extractor.data
else 'stimFreeze_times'), # handle habituationChoiceWorld exception
'stimOff_times',
'stimOn_times'
]
plot_args = {
'ymin': 0,
'ymax': 4,
'linewidth': 2,
'ax': axes
}
bnc1 = self.qc.extractor.frame_ttls
bnc2 = self.qc.extractor.audio_ttls
trial_data = self.qc.extractor.data
if bnc1['times'].size:
plots.squares(bnc1['times'], bnc1['polarities'] * 0.4 + 1, ax=axes, color='k')
if bnc2['times'].size:
plots.squares(bnc2['times'], bnc2['polarities'] * 0.4 + 2, ax=axes, color='k')
linestyle = linestyle or random.choices(('-', '--', '-.', ':'), k=len(trial_events))
if self.qc.extractor.bpod_ttls is not None:
bpttls = self.qc.extractor.bpod_ttls
plots.squares(bpttls['times'], bpttls['polarities'] * 0.4 + 3, ax=axes, color='k')
plot_args['ymax'] = 4
ylabels = ['', 'frame2ttl', 'sound', 'bpod', '']
else:
plot_args['ymax'] = 3
ylabels = ['', 'frame2ttl', 'sound', '']
for event, c, l in zip(trial_events, cycle(color_map), linestyle):
if event in trial_data:
plots.vertical_lines(trial_data[event], label=event, color=c, linestyle=l, **plot_args)
axes.legend(loc='upper left', fontsize='xx-small', bbox_to_anchor=(1, 0.5))
axes.set_yticks(list(range(plot_args['ymax'] + 1)))
axes.set_yticklabels(ylabels)
axes.set_ylim([0, plot_args['ymax']])
if wheel_axes:
wheel_data = self.get_wheel_data()
wheel_plot_args = {
'ax': wheel_axes,
'ymin': wheel_data['re_pos'].min() if wheel_data['re_pos'].size else 0,
'ymax': wheel_data['re_pos'].max() if wheel_data['re_pos'].size else 1}
plot_args = {**plot_args, **wheel_plot_args}
wheel_axes.plot(wheel_data['re_ts'], wheel_data['re_pos'], 'k-x')
for event, c, ln in zip(trial_events, cycle(color_map), linestyle):
if event in trial_data:
plots.vertical_lines(trial_data[event],
label=event, color=c, linestyle=ln, **plot_args)
[docs]
def get_bpod_trials_task(task):
"""
Return the correct trials task for extracting only the Bpod trials.
Parameters
----------
task : ibllib.pipes.tasks.Task
A pipeline task from which to derive the Bpod trials task.
Returns
-------
ibllib.pipes.tasks.Task
A Bpod choice world trials task instance.
"""
if task.__class__ in (ChoiceWorldTrialsBpod, HabituationTrialsBpod):
pass # do nothing; already Bpod only
else:
assert isinstance(task, BehaviourTask)
# A dynamic pipeline task
trials_class = HabituationTrialsBpod if 'habituation' in task.protocol else ChoiceWorldTrialsBpod
task = trials_class(task.session_path,
collection=task.collection, protocol_number=task.protocol_number,
protocol=task.protocol, one=task.one)
return task
[docs]
def show_session_task_qc(qc_or_session=None, bpod_only=False, local=False, one=None, protocol_number=None):
"""
Displays the task QC for a given session.
NB: For this to work, all behaviour trials task classes must implement a `run_qc` method.
Parameters
----------
qc_or_session : str, pathlib.Path, ibllib.qc.task_metrics.TaskQC, QcFrame
An experiment ID, session path, or TaskQC object.
bpod_only : bool
If true, display Bpod extracted events instead of data from the DAQ.
local : bool
If true, asserts all data local (i.e. do not attempt to download missing datasets).
one : one.api.One
An instance of ONE.
protocol_number : int
If not None, displays the QC for the protocol number provided. Argument is ignored if
`qc_or_session` is a TaskQC object or QcFrame instance.
Returns
-------
QcFrame
The QcFrame object.
"""
if isinstance(qc_or_session, QcFrame):
qc = qc_or_session
elif isinstance(qc_or_session, TaskQC):
qc = QcFrame(qc_or_session)
else: # assumed to be eid or session path
one = one or ONE(mode='local' if local else 'auto')
if not is_session_path(Path(qc_or_session)):
eid = one.to_eid(qc_or_session)
session_path = one.eid2path(eid)
else:
session_path = Path(qc_or_session)
tasks = get_trials_tasks(session_path, one=None if local else one, bpod_only=bpod_only)
# Get the correct task and ensure not passive
if protocol_number is None:
if not (task := next((t for t in tasks if 'passive' not in t.name.lower()), None)):
raise ValueError('No non-passive behaviour tasks found for session ' + '/'.join(session_path.parts[-3:]))
elif not isinstance(protocol_number, int) or protocol_number < 0:
raise TypeError('Protocol number must be a positive integer')
elif protocol_number > len(tasks) - 1:
raise ValueError('Invalid protocol number')
else:
task = tasks[protocol_number]
if 'passive' in task.name.lower():
raise ValueError('QC display not supported for passive protocols')
_logger.debug('Using %s task', task.name)
# Ensure required data are present
task.location = 'server' if local else 'remote' # affects whether missing data are downloaded
task.setUp()
if local: # currently setUp does not raise on missing data
task.assert_expected_inputs(raise_error=True)
# Compute the QC and build the frame
task_qc = task.run_qc(update=False)
qc = QcFrame(task_qc)
# Handle trial event names in habituationChoiceWorld
events = EVENT_MAP.keys()
if 'stimCenter_times' in qc.qc.extractor.data:
events = map(lambda x: x.replace('stimFreeze', 'stimCenter'), events)
# Run QC and plot
w = ViewEphysQC.viewqc(wheel=qc.get_wheel_data())
qc.create_plots(w.wplot.canvas.ax,
wheel_axes=w.wplot.canvas.ax2,
trial_events=list(events),
color_map=cm,
linestyle=ls)
# Update table and callbacks
w.update_df(qc.frame)
qt.run_app()
return qc
[docs]
def qc_gui_cli():
"""Run TaskQC viewer with wheel data.
For information on the QC checks see the QC Flags & failures document:
https://docs.google.com/document/d/1X-ypFEIxqwX6lU9pig4V_zrcR5lITpd8UJQWzW9I9zI/edit#
Examples
--------
>>> ipython task_qc.py c9fec76e-7a20-4da4-93ad-04510a89473b
>>> ipython task_qc.py ./KS022/2019-12-10/001 --local
"""
# Parse parameters
parser = argparse.ArgumentParser(description='Quick viewer to see the behaviour data from'
'choice world sessions.')
parser.add_argument('session', help='session uuid or path')
parser.add_argument('--bpod', action='store_true', help='run QC on Bpod data only (no FPGA)')
parser.add_argument('--local', action='store_true', help='run from disk location (lab server')
args = parser.parse_args() # returns data from the options specified (echo)
show_session_task_qc(qc_or_session=args.session, bpod_only=args.bpod, local=args.local)
if __name__ == '__main__':
qc_gui_cli()