Source code for ibllib.qc.qcplots

"""Plots for trial QC

Example:
    one = ONE()
    # Load data
    eid = 'c8ef527b-6f7f-4f08-8b99-5aeb9d2b3740
    # Run QC
    qc = TaskQC(eid, one=one)
    plot_results(qc)
    plt.show()

"""
from collections import Counter, Sized
from pathlib import Path
from datetime import datetime

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from ibllib.qc.task_metrics import TaskQC


[docs] def plot_results(qc_obj, save_path=None): if not isinstance(qc_obj, TaskQC): raise ValueError('Input must be TaskQC object') if not qc_obj.passed: qc_obj.compute() outcome, results, outcomes = qc_obj.compute_session_status() # Sort checks by outcome and print map = {k: [] for k in set(outcomes.values())} for k, v in outcomes.items(): map[v].append(k[6:]) for k, v in map.items(): print(f'The following checks were labelled {k}:') print('\n'.join(v), '\n') # Collect some session details n_trials = qc_obj.extractor.data['intervals'].shape[0] det = qc_obj.one.get_details(qc_obj.eid) ref = f"{datetime.fromisoformat(det['start_time']).date()}_{det['number']:d}_{det['subject']}" title = ref + (' (Bpod data only)' if qc_obj.extractor.bpod_only else '') # Sort into each category counts = Counter(outcomes.values()) plt.bar(range(len(counts)), counts.values(), align='center', tick_label=list(counts.keys())) plt.gcf().suptitle(title) plt.ylabel('# QC checks') plt.xlabel('outcome') a4_dims = (11.7, 8.27) fig, (ax0, ax1) = plt.subplots(2, 1, figsize=a4_dims, constrained_layout=True) fig.suptitle(title) # Plot failed trial level metrics def get_trial_level_failed(d): new_dict = {k[6:]: v for k, v in d.items() if outcomes[k] == 'FAIL' and isinstance(v, Sized) and len(v) == n_trials} return pd.DataFrame.from_dict(new_dict) sns.boxplot(data=get_trial_level_failed(qc_obj.metrics), orient='h', ax=ax0) ax0.set_yticklabels(ax0.get_yticklabels(), rotation=30, fontsize=8) ax0.set(xscale='symlog', title='Metrics (failed)', xlabel='metric values (units vary)') # Plot failed trial level metrics sns.barplot(data=get_trial_level_failed(qc_obj.passed), orient='h', ax=ax1) ax1.set_yticklabels(ax1.get_yticklabels(), rotation=30, fontsize=8) ax1.set(title='Counts', xlabel='proportion of trials that passed') if save_path is not None: save_path = Path(save_path) if save_path.is_dir() and not save_path.exists(): print(f"Folder {save_path} does not exist, not saving...") elif save_path.is_dir(): fig.savefig(save_path.joinpath(f"{ref}_QC.png")) else: fig.savefig(save_path)