Source code for ibllib.qc.task_qc_viewer.task_qc

import logging
import argparse
from itertools import cycle
import random
from collections.abc import Sized
from pathlib import Path

import pandas as pd
import numpy as np
from matplotlib.colors import TABLEAU_COLORS
from one.api import ONE
from one.alf.spec import is_session_path

import ibllib.plots as plots
from ibllib.misc import qt
from ibllib.qc.task_metrics import TaskQC
from ibllib.qc.task_qc_viewer import ViewEphysQC
from ibllib.pipes.dynamic_pipeline import get_trials_tasks
from ibllib.pipes.base_tasks import BehaviourTask
from ibllib.pipes.behavior_tasks import HabituationTrialsBpod, ChoiceWorldTrialsBpod
from ibllib.pipes.training_preprocessing import TrainingTrials

EVENT_MAP = {'goCue_times': ['#2ca02c', 'solid'],  # green
             'goCueTrigger_times': ['#2ca02c', 'dotted'],  # green
             'errorCue_times': ['#d62728', 'solid'],  # red
             'errorCueTrigger_times': ['#d62728', 'dotted'],  # red
             'valveOpen_times': ['#17becf', 'solid'],  # cyan
             'stimFreeze_times': ['#0000ff', 'solid'],  # blue
             'stimFreezeTrigger_times': ['#0000ff', 'dotted'],  # blue
             'stimOff_times': ['#9400d3', 'solid'],  # dark violet
             'stimOffTrigger_times': ['#9400d3', 'dotted'],  # dark violet
             'stimOn_times': ['#e377c2', 'solid'],  # pink
             'stimOnTrigger_times': ['#e377c2', 'dotted'],  # pink
             'response_times': ['#8c564b', 'solid'],  # brown
             }
cm = [EVENT_MAP[k][0] for k in EVENT_MAP]
ls = [EVENT_MAP[k][1] for k in EVENT_MAP]
CRITICAL_CHECKS = (
    'check_audio_pre_trial',
    'check_correct_trial_event_sequence',
    'check_error_trial_event_sequence',
    'check_n_trial_events',
    'check_response_feedback_delays',
    'check_response_stimFreeze_delays',
    'check_reward_volume_set',
    'check_reward_volumes',
    'check_stimOn_goCue_delays',
    'check_stimulus_move_before_goCue',
    'check_wheel_move_before_feedback',
    'check_wheel_freeze_during_quiescence'
)


_logger = logging.getLogger(__name__)


[docs] class QcFrame: qc = None """ibllib.qc.task_metrics.TaskQC: A TaskQC object containing extracted data""" frame = None """pandas.DataFrame: A table of failing trial-level QC metrics.""" def __init__(self, qc): """ An interactive display of task QC data. Parameters ---------- qc : ibllib.qc.task_metrics.TaskQC A TaskQC object containing extracted data for plotting. """ assert qc.extractor and qc.metrics, 'Please run QC before passing to QcFrame' self.qc = qc # Print failed outcome, results, outcomes = self.qc.compute_session_status() map = {k: [] for k in set(outcomes.values())} for k, v in outcomes.items(): map[v].append(k[6:]) for k, v in map.items(): if k == 'PASS': continue print(f'The following checks were labelled {k}:') print('\n'.join(v), '\n') print('The following *critical* checks did not pass:') critical_checks = [f'_{x.replace("check", "task")}' for x in CRITICAL_CHECKS] for k, v in outcomes.items(): if v != 'PASS' and k in critical_checks: print(k[6:]) # Make DataFrame from the trail level metrics def get_trial_level_failed(d): new_dict = {k[6:]: v for k, v in d.items() if isinstance(v, Sized) and len(v) == self.n_trials} return pd.DataFrame.from_dict(new_dict) self.frame = get_trial_level_failed(self.qc.metrics) self.frame['intervals_0'] = self.qc.extractor.data['intervals'][:, 0] self.frame['intervals_1'] = self.qc.extractor.data['intervals'][:, 1] self.frame.insert(loc=0, column='trial_no', value=self.frame.index) @property def n_trials(self): return self.qc.extractor.data['intervals'].shape[0]
[docs] def get_wheel_data(self): return {'re_pos': self.qc.extractor.data.get('wheel_position', np.array([])), 're_ts': self.qc.extractor.data.get('wheel_timestamps', np.array([]))}
[docs] def create_plots(self, axes, wheel_axes=None, trial_events=None, color_map=None, linestyle=None): """ Plots the data for bnc1 (sound) and bnc2 (frame2ttl). :param axes: An axes handle on which to plot the TTL events :param wheel_axes: An axes handle on which to plot the wheel trace :param trial_events: A list of Bpod trial events to plot, e.g. ['stimFreeze_times'], if None, valve, sound and stimulus events are plotted :param color_map: A color map to use for the events, default is the tableau color map linestyle: A line style map to use for the events, default is random. :return: None """ color_map = color_map or TABLEAU_COLORS.keys() if trial_events is None: # Default trial events to plot as vertical lines trial_events = [ 'goCue_times', 'goCueTrigger_times', 'feedback_times', ('stimCenter_times' if 'stimCenter_times' in self.qc.extractor.data else 'stimFreeze_times'), # handle habituationChoiceWorld exception 'stimOff_times', 'stimOn_times' ] plot_args = { 'ymin': 0, 'ymax': 4, 'linewidth': 2, 'ax': axes } bnc1 = self.qc.extractor.frame_ttls bnc2 = self.qc.extractor.audio_ttls trial_data = self.qc.extractor.data if bnc1['times'].size: plots.squares(bnc1['times'], bnc1['polarities'] * 0.4 + 1, ax=axes, color='k') if bnc2['times'].size: plots.squares(bnc2['times'], bnc2['polarities'] * 0.4 + 2, ax=axes, color='k') linestyle = linestyle or random.choices(('-', '--', '-.', ':'), k=len(trial_events)) if self.qc.extractor.bpod_ttls is not None: bpttls = self.qc.extractor.bpod_ttls plots.squares(bpttls['times'], bpttls['polarities'] * 0.4 + 3, ax=axes, color='k') plot_args['ymax'] = 4 ylabels = ['', 'frame2ttl', 'sound', 'bpod', ''] else: plot_args['ymax'] = 3 ylabels = ['', 'frame2ttl', 'sound', ''] for event, c, l in zip(trial_events, cycle(color_map), linestyle): if event in trial_data: plots.vertical_lines(trial_data[event], label=event, color=c, linestyle=l, **plot_args) axes.legend(loc='upper left', fontsize='xx-small', bbox_to_anchor=(1, 0.5)) axes.set_yticks(list(range(plot_args['ymax'] + 1))) axes.set_yticklabels(ylabels) axes.set_ylim([0, plot_args['ymax']]) if wheel_axes: wheel_data = self.get_wheel_data() wheel_plot_args = { 'ax': wheel_axes, 'ymin': wheel_data['re_pos'].min() if wheel_data['re_pos'].size else 0, 'ymax': wheel_data['re_pos'].max() if wheel_data['re_pos'].size else 1} plot_args = {**plot_args, **wheel_plot_args} wheel_axes.plot(wheel_data['re_ts'], wheel_data['re_pos'], 'k-x') for event, c, ln in zip(trial_events, cycle(color_map), linestyle): if event in trial_data: plots.vertical_lines(trial_data[event], label=event, color=c, linestyle=ln, **plot_args)
[docs] def get_bpod_trials_task(task): """ Return the correct trials task for extracting only the Bpod trials. Parameters ---------- task : ibllib.pipes.tasks.Task A pipeline task from which to derive the Bpod trials task. Returns ------- ibllib.pipes.tasks.Task A Bpod choice world trials task instance. """ if isinstance(task, TrainingTrials) or task.__class__ in (ChoiceWorldTrialsBpod, HabituationTrialsBpod): pass # do nothing; already Bpod only elif isinstance(task, BehaviourTask): # A dynamic pipeline task trials_class = HabituationTrialsBpod if 'habituation' in task.protocol else ChoiceWorldTrialsBpod task = trials_class(task.session_path, collection=task.collection, protocol_number=task.protocol_number, protocol=task.protocol, one=task.one) else: # A legacy pipeline task (should be EphysTrials as there are no other options) task = TrainingTrials(task.session_path, one=task.one) return task
[docs] def show_session_task_qc(qc_or_session=None, bpod_only=False, local=False, one=None, protocol_number=None): """ Displays the task QC for a given session. NB: For this to work, all behaviour trials task classes must implement a `run_qc` method. Parameters ---------- qc_or_session : str, pathlib.Path, ibllib.qc.task_metrics.TaskQC, QcFrame An experiment ID, session path, or TaskQC object. bpod_only : bool If true, display Bpod extracted events instead of data from the DAQ. local : bool If true, asserts all data local (i.e. do not attempt to download missing datasets). one : one.api.One An instance of ONE. protocol_number : int If not None, displays the QC for the protocol number provided. Argument is ignored if `qc_or_session` is a TaskQC object or QcFrame instance. Returns ------- QcFrame The QcFrame object. """ if isinstance(qc_or_session, QcFrame): qc = qc_or_session elif isinstance(qc_or_session, TaskQC): qc = QcFrame(qc_or_session) else: # assumed to be eid or session path one = one or ONE(mode='local' if local else 'auto') if not is_session_path(Path(qc_or_session)): eid = one.to_eid(qc_or_session) session_path = one.eid2path(eid) else: session_path = Path(qc_or_session) tasks = get_trials_tasks(session_path, one=None if local else one) # Get the correct task and ensure not passive if protocol_number is None: if not (task := next((t for t in tasks if 'passive' not in t.name.lower()), None)): raise ValueError('No non-passive behaviour tasks found for session ' + '/'.join(session_path.parts[-3:])) elif not isinstance(protocol_number, int) or protocol_number < 0: raise TypeError('Protocol number must be a positive integer') elif protocol_number > len(tasks) - 1: raise ValueError('Invalid protocol number') else: task = tasks[protocol_number] if 'passive' in task.name.lower(): raise ValueError('QC display not supported for passive protocols') # If Bpod only and not a dynamic pipeline Bpod behaviour task OR legacy TrainingTrials task if bpod_only and 'bpod' not in task.name.lower(): # Use the dynamic pipeline Bpod behaviour task instead (should work with legacy pipeline too) task = get_bpod_trials_task(task) _logger.debug('Using %s task', task.name) # Ensure required data are present task.location = 'server' if local else 'remote' # affects whether missing data are downloaded task.setUp() if local: # currently setUp does not raise on missing data task.assert_expected_inputs(raise_error=True) # Compute the QC and build the frame task_qc = task.run_qc(update=False) qc = QcFrame(task_qc) # Handle trial event names in habituationChoiceWorld events = EVENT_MAP.keys() if 'stimCenter_times' in qc.qc.extractor.data: events = map(lambda x: x.replace('stimFreeze', 'stimCenter'), events) # Run QC and plot w = ViewEphysQC.viewqc(wheel=qc.get_wheel_data()) qc.create_plots(w.wplot.canvas.ax, wheel_axes=w.wplot.canvas.ax2, trial_events=list(events), color_map=cm, linestyle=ls) # Update table and callbacks w.update_df(qc.frame) qt.run_app() return qc
[docs] def qc_gui_cli(): """Run TaskQC viewer with wheel data. For information on the QC checks see the QC Flags & failures document: https://docs.google.com/document/d/1X-ypFEIxqwX6lU9pig4V_zrcR5lITpd8UJQWzW9I9zI/edit# Examples -------- >>> ipython task_qc.py c9fec76e-7a20-4da4-93ad-04510a89473b >>> ipython task_qc.py ./KS022/2019-12-10/001 --local """ # Parse parameters parser = argparse.ArgumentParser(description='Quick viewer to see the behaviour data from' 'choice world sessions.') parser.add_argument('session', help='session uuid or path') parser.add_argument('--bpod', action='store_true', help='run QC on Bpod data only (no FPGA)') parser.add_argument('--local', action='store_true', help='run from disk location (lab server') args = parser.parse_args() # returns data from the options specified (echo) show_session_task_qc(qc_or_session=args.session, bpod_only=args.bpod, local=args.local)
if __name__ == '__main__': qc_gui_cli()