"""
Module that produces figures, usually for the extraction pipeline
"""
import logging
import time
from pathlib import Path
import traceback
from string import ascii_uppercase
import numpy as np
import pandas as pd
import scipy.signal
import matplotlib.pyplot as plt
from ibldsp import voltage
from ibllib.plots.snapshot import ReportSnapshotProbe, ReportSnapshot
from one.api import ONE
import one.alf.io as alfio
from one.alf.exceptions import ALFObjectNotFound
from ibllib.io.video import get_video_frame, url_from_eid
from ibllib.oneibl.data_handlers import ExpectedDataset
import spikeglx
import neuropixel
from brainbox.plot import driftmap
from brainbox.io.spikeglx import Streamer
from brainbox.behavior.dlc import SAMPLING, plot_trace_on_frame, plot_wheel_position, plot_lick_hist, \
plot_lick_raster, plot_motion_energy_hist, plot_speed_hist, plot_pupil_diameter_hist
from brainbox.ephys_plots import image_lfp_spectrum_plot, image_rms_plot, plot_brain_regions
from brainbox.io.one import load_spike_sorting_fast
from brainbox.behavior import training
from iblutil.numerical import ismember
from ibllib.plots.misc import Density
logger = logging.getLogger(__name__)
[docs]
def set_axis_label_size(ax, labels=14, ticklabels=12, title=14, cmap=False):
"""
Function to normalise size of all axis labels
:param ax:
:param labels:
:param ticklabels:
:param title:
:param cmap:
:return:
"""
ax.xaxis.get_label().set_fontsize(labels)
ax.yaxis.get_label().set_fontsize(labels)
ax.tick_params(labelsize=ticklabels)
ax.title.set_fontsize(title)
if cmap:
cbar = ax.images[-1].colorbar
cbar.ax.tick_params(labelsize=ticklabels)
cbar.ax.yaxis.get_label().set_fontsize(labels)
[docs]
def remove_axis_outline(ax):
"""
Function to remove outline of empty axis
:param ax:
:return:
"""
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.spines['left'].set_visible(False)
[docs]
class BehaviourPlots(ReportSnapshot):
"""Behavioural plots."""
@property
def signature(self):
signature = {
'input_files': [
('*trials.table.pqt', self.trials_collection, True),
],
'output_files': [
('psychometric_curve.png', 'snapshot/behaviour', True),
('chronometric_curve.png', 'snapshot/behaviour', True),
('reaction_time_with_trials.png', 'snapshot/behaviour', True)
]
}
return signature
def __init__(self, eid, session_path=None, one=None, **kwargs):
"""
Generate and upload behaviour plots.
Parameters
----------
eid : str, uuid.UUID
An experiment UUID.
session_path : pathlib.Path
A session path.
one : one.api.One
An instance of ONE for registration to Alyx.
trials_collection : str
The location of the trials data (default: 'alf').
kwargs
Arguments for ReportSnapshot constructor.
"""
self.one = one
self.eid = eid
self.session_path = session_path or self.one.eid2path(self.eid)
self.trials_collection = kwargs.pop('task_collection', 'alf')
super(BehaviourPlots, self).__init__(self.session_path, self.eid, one=self.one,
**kwargs)
# Output directory should mirror trials collection, sans 'alf' part
self.output_directory = self.session_path.joinpath(
'snapshot', 'behaviour', self.trials_collection.removeprefix('alf').strip('/'))
self.output_directory.mkdir(exist_ok=True, parents=True)
def _run(self):
output_files = []
trials = alfio.load_object(self.session_path.joinpath(self.trials_collection), 'trials')
if self.one:
title = self.one.path2ref(self.session_path, as_dict=False)
else:
title = '_'.join(list(self.session_path.parts[-3:]))
fig, ax = training.plot_psychometric(trials, title=title, figsize=(8, 6))
set_axis_label_size(ax)
save_path = Path(self.output_directory).joinpath("psychometric_curve.png")
output_files.append(save_path)
fig.savefig(save_path)
plt.close(fig)
fig, ax = training.plot_reaction_time(trials, title=title, figsize=(8, 6))
set_axis_label_size(ax)
save_path = Path(self.output_directory).joinpath("chronometric_curve.png")
output_files.append(save_path)
fig.savefig(save_path)
plt.close(fig)
fig, ax = training.plot_reaction_time_over_trials(trials, title=title, figsize=(8, 6))
set_axis_label_size(ax)
save_path = Path(self.output_directory).joinpath("reaction_time_with_trials.png")
output_files.append(save_path)
fig.savefig(save_path)
plt.close(fig)
return output_files
# TODO put into histology and alignment pipeline
[docs]
class HistologySlices(ReportSnapshotProbe):
"""Plots coronal and sagittal slice showing electrode locations."""
def _run(self):
assert self.pid
assert self.brain_atlas
output_files = []
self.histology_status = self.get_histology_status()
electrodes = self.get_channels('electrodeSites', f'alf/{self.pname}')
if self.hist_lookup[self.histology_status] > 0:
fig = plt.figure(figsize=(12, 9))
gs = fig.add_gridspec(2, 2, width_ratios=[.95, .05])
ax1 = fig.add_subplot(gs[0, 0])
self.brain_atlas.plot_tilted_slice(electrodes['mlapdv'], 1, ax=ax1)
ax1.scatter(electrodes['mlapdv'][:, 0] * 1e6, electrodes['mlapdv'][:, 2] * 1e6, s=8, c='r')
ax1.set_title(f"{self.pid_label}")
ax2 = fig.add_subplot(gs[1, 0])
self.brain_atlas.plot_tilted_slice(electrodes['mlapdv'], 0, ax=ax2)
ax2.scatter(electrodes['mlapdv'][:, 1] * 1e6, electrodes['mlapdv'][:, 2] * 1e6, s=8, c='r')
ax3 = fig.add_subplot(gs[:, 1])
plot_brain_regions(electrodes['atlas_id'], brain_regions=self.brain_regions, display=True, ax=ax3,
title=self.histology_status)
save_path = Path(self.output_directory).joinpath("histology_slices.png")
output_files.append(save_path)
fig.savefig(save_path)
plt.close(fig)
return output_files
[docs]
def get_probe_signature(self):
input_signature = [('electrodeSites.localCoordinates.npy', f'alf/{self.pname}', False),
('electrodeSites.brainLocationIds_ccf_2017.npy', f'alf/{self.pname}', False),
('electrodeSites.mlapdv.npy', f'alf/{self.pname}', False)]
output_signature = [('histology_slices.png', f'snapshot/{self.pname}', True)]
self.signature = {'input_files': input_signature, 'output_files': output_signature}
[docs]
class LfpPlots(ReportSnapshotProbe):
"""
Plots LFP spectrum and LFP RMS plots
"""
def _run(self):
assert self.pid
output_files = []
if self.location != 'server':
self.histology_status = self.get_histology_status()
electrodes = self.get_channels('electrodeSites', f'alf/{self.pname}')
# lfp spectrum
fig, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9))
lfp = alfio.load_object(self.session_path.joinpath(f'raw_ephys_data/{self.pname}'), 'ephysSpectralDensityLF',
namespace='iblqc')
_, _, _ = image_lfp_spectrum_plot(lfp.power, lfp.freqs, clim=[-65, -95], fig_kwargs={'figsize': (8, 6)}, ax=axs[0],
display=True, title=f"{self.pid_label}")
set_axis_label_size(axs[0], cmap=True)
if self.histology_status:
plot_brain_regions(electrodes['atlas_id'], brain_regions=self.brain_regions, display=True, ax=axs[1],
title=self.histology_status)
set_axis_label_size(axs[1])
else:
remove_axis_outline(axs[1])
save_path = Path(self.output_directory).joinpath("lfp_spectrum.png")
output_files.append(save_path)
fig.savefig(save_path)
plt.close(fig)
# lfp rms
# TODO need to figure out the clim range
fig, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9))
lfp = alfio.load_object(self.session_path.joinpath(f'raw_ephys_data/{self.pname}'), 'ephysTimeRmsLF', namespace='iblqc')
_, _, _ = image_rms_plot(lfp.rms, lfp.timestamps, median_subtract=False, band='LFP', clim=[-35, -45], ax=axs[0],
cmap='inferno', fig_kwargs={'figsize': (8, 6)}, display=True, title=f"{self.pid_label}")
set_axis_label_size(axs[0], cmap=True)
if self.histology_status:
plot_brain_regions(electrodes['atlas_id'], brain_regions=self.brain_regions, display=True, ax=axs[1],
title=self.histology_status)
set_axis_label_size(axs[1])
else:
remove_axis_outline(axs[1])
save_path = Path(self.output_directory).joinpath("lfp_rms.png")
output_files.append(save_path)
fig.savefig(save_path)
plt.close(fig)
return output_files
[docs]
def get_probe_signature(self):
input_signature = [('_iblqc_ephysTimeRmsLF.rms.npy', f'raw_ephys_data/{self.pname}', True),
('_iblqc_ephysTimeRmsLF.timestamps.npy', f'raw_ephys_data/{self.pname}', True),
('_iblqc_ephysSpectralDensityLF.freqs.npy', f'raw_ephys_data/{self.pname}', True),
('_iblqc_ephysSpectralDensityLF.power.npy', f'raw_ephys_data/{self.pname}', True),
('electrodeSites.localCoordinates.npy', f'alf/{self.pname}', False),
('electrodeSites.brainLocationIds_ccf_2017.npy', f'alf/{self.pname}', False),
('electrodeSites.mlapdv.npy', f'alf/{self.pname}', False)]
output_signature = [('lfp_spectrum.png', f'snapshot/{self.pname}', True),
('lfp_rms.png', f'snapshot/{self.pname}', True)]
self.signature = {'input_files': input_signature, 'output_files': output_signature}
[docs]
class ApPlots(ReportSnapshotProbe):
"""
Plots AP RMS plots
"""
def _run(self):
assert self.pid
output_files = []
if self.location != 'server':
self.histology_status = self.get_histology_status()
electrodes = self.get_channels('electrodeSites', f'alf/{self.pname}')
# TODO need to figure out the clim range
fig, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9))
ap = alfio.load_object(self.session_path.joinpath(f'raw_ephys_data/{self.pname}'), 'ephysTimeRmsAP', namespace='iblqc')
_, _, _ = image_rms_plot(ap.rms, ap.timestamps, median_subtract=False, band='AP', ax=axs[0],
fig_kwargs={'figsize': (8, 6)}, display=True, title=f"{self.pid_label}")
set_axis_label_size(axs[0], cmap=True)
if self.histology_status:
plot_brain_regions(electrodes['atlas_id'], brain_regions=self.brain_regions, display=True, ax=axs[1],
title=self.histology_status)
set_axis_label_size(axs[1])
else:
remove_axis_outline(axs[1])
save_path = Path(self.output_directory).joinpath("ap_rms.png")
output_files.append(save_path)
fig.savefig(save_path)
plt.close(fig)
return output_files
[docs]
def get_probe_signature(self):
input_signature = [('_iblqc_ephysTimeRmsAP.rms.npy', f'raw_ephys_data/{self.pname}', True),
('_iblqc_ephysTimeRmsAP.timestamps.npy', f'raw_ephys_data/{self.pname}', True),
('electrodeSites.localCoordinates.npy', f'alf/{self.pname}', False),
('electrodeSites.brainLocationIds_ccf_2017.npy', f'alf/{self.pname}', False),
('electrodeSites.mlapdv.npy', f'alf/{self.pname}', False)]
output_signature = [('ap_rms.png', f'snapshot/{self.pname}', True)]
self.signature = {'input_files': input_signature, 'output_files': output_signature}
[docs]
class SpikeSorting(ReportSnapshotProbe):
"""
Plots raw electrophysiology AP band
:param session_path: session path
:param probe_id: str, UUID of the probe insertion for which to create the plot
:param **kwargs: keyword arguments passed to tasks.Task
"""
def _run(self, collection=None):
"""runs for initiated PID, streams data, destripe and check bad channels"""
def plot_driftmap(self, spikes, clusters, channels, collection):
fig, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9))
driftmap(spikes.times, spikes.depths, t_bin=0.007, d_bin=10, vmax=0.5, ax=axs[0])
title_str = f"{self.pid_label}, {collection}, {self.pid} \n " \
f"{spikes.clusters.size:_} spikes, {clusters.depths.size:_} clusters"
ylim = (0, np.max(channels['axial_um']))
axs[0].set(ylim=ylim, title=title_str)
run_label = str(Path(collection).relative_to(f'alf/{self.pname}'))
run_label = "ks2matlab" if run_label == '.' else run_label
outfile = self.output_directory.joinpath(f"spike_sorting_raster_{run_label}.png")
set_axis_label_size(axs[0])
if self.histology_status:
plot_brain_regions(channels['atlas_id'], channel_depths=channels['axial_um'],
brain_regions=self.brain_regions, display=True, ax=axs[1], title=self.histology_status)
axs[1].set(ylim=ylim)
set_axis_label_size(axs[1])
else:
remove_axis_outline(axs[1])
fig.savefig(outfile)
plt.close(fig)
return outfile, fig, axs
output_files = []
if self.location == 'server':
assert collection
spikes = alfio.load_object(self.session_path.joinpath(collection), 'spikes')
clusters = alfio.load_object(self.session_path.joinpath(collection), 'clusters')
channels = alfio.load_object(self.session_path.joinpath(collection), 'channels')
channels['axial_um'] = channels['localCoordinates'][:, 1]
out, fig, axs = plot_driftmap(self, spikes, clusters, channels, collection)
output_files.append(out)
else:
self.histology_status = self.get_histology_status()
all_here, output_files = self.assert_expected(self.output_files, silent=True)
spike_sorting_runs = self.one.list_datasets(self.eid, filename='spikes.times.npy', collection=f'alf/{self.pname}*')
if all_here and len(output_files) == len(spike_sorting_runs):
return output_files
logger.info(self.output_directory)
for run in spike_sorting_runs:
collection = str(Path(run).parent.as_posix())
spikes, clusters, channels = load_spike_sorting_fast(
eid=self.eid, probe=self.pname, one=self.one, nested=False, collection=collection,
dataset_types=['spikes.depths'], brain_regions=self.brain_regions)
if 'atlas_id' not in channels.keys():
channels = self.get_channels('channels', collection)
out, fig, axs = plot_driftmap(self, spikes, clusters, channels, collection)
output_files.append(out)
return output_files
[docs]
def get_probe_signature(self):
input_signature = [('spikes.times.npy', f'alf/{self.pname}*', True),
('spikes.amps.npy', f'alf/{self.pname}*', True),
('spikes.depths.npy', f'alf/{self.pname}*', True),
('clusters.depths.npy', f'alf/{self.pname}*', True),
('channels.localCoordinates.npy', f'alf/{self.pname}*', False),
('channels.mlapdv.npy', f'alf/{self.pname}*', False),
('channels.brainLocationIds_ccf_2017.npy', f'alf/{self.pname}*', False)]
output_signature = [('spike_sorting_raster*.png', f'snapshot/{self.pname}', True)]
self.signature = {'input_files': input_signature, 'output_files': output_signature}
[docs]
def get_signatures(self, **kwargs):
files_spikes = Path(self.session_path).joinpath('alf').rglob('spikes.times.npy')
folder_probes = [f.parent for f in files_spikes]
full_input_files = []
for sig in self.signature['input_files']:
for folder in folder_probes:
full_input_files.append((sig[0], str(folder.relative_to(self.session_path)), sig[2]))
if len(full_input_files) != 0:
self.input_files = full_input_files
else:
self.input_files = self.signature['input_files']
self.output_files = self.signature['output_files']
self.input_files = [ExpectedDataset.input(*i) for i in self.input_files]
self.output_files = [ExpectedDataset.output(*i) for i in self.output_files]
[docs]
class BadChannelsAp(ReportSnapshotProbe):
"""
Plots raw electrophysiology AP band
task = BadChannelsAp(pid, one=one=one)
:param session_path: session path
:param probe_id: str, UUID of the probe insertion for which to create the plot
:param **kwargs: keyword arguments passed to tasks.Task
"""
[docs]
def get_probe_signature(self):
pname = self.pname
input_signature = [('*ap.meta', f'raw_ephys_data/{pname}', True),
('*ap.ch', f'raw_ephys_data/{pname}', False)]
output_signature = [('raw_ephys_bad_channels.png', f'snapshot/{pname}', True),
('raw_ephys_bad_channels_highpass.png', f'snapshot/{pname}', True),
('raw_ephys_bad_channels_highpass.png', f'snapshot/{pname}', True),
('raw_ephys_bad_channels_destripe.png', f'snapshot/{pname}', True),
('raw_ephys_bad_channels_difference.png', f'snapshot/{pname}', True),
]
self.signature = {'input_files': input_signature, 'output_files': output_signature}
def _run(self):
"""runs for initiated PID, streams data, destripe and check bad channels"""
assert self.pid
self.eqcs = []
T0 = 60 * 30
SNAPSHOT_LABEL = "raw_ephys_bad_channels"
output_files = list(self.output_directory.glob(f'{SNAPSHOT_LABEL}*'))
if len(output_files) == 4:
return output_files
self.output_directory.mkdir(exist_ok=True, parents=True)
if self.location != 'server':
self.histology_status = self.get_histology_status()
electrodes = self.get_channels('electrodeSites', f'alf/{self.pname}')
if 'atlas_id' in electrodes.keys():
electrodes['ibr'] = ismember(electrodes['atlas_id'], self.brain_regions.id)[1]
electrodes['acronym'] = self.brain_regions.acronym[electrodes['ibr']]
electrodes['name'] = self.brain_regions.name[electrodes['ibr']]
electrodes['title'] = self.histology_status
else:
electrodes = None
nsecs = 1
sr = Streamer(pid=self.pid, one=self.one, remove_cached=False, typ='ap')
s0 = T0 * sr.fs
tsel = slice(int(s0), int(s0) + int(nsecs * sr.fs))
# Important: remove sync channel from raw data, and transpose
raw = sr[tsel, :-sr.nsync].T
else:
electrodes = None
ap_file = next(self.session_path.joinpath('raw_ephys_data', self.pname).glob('*ap.*bin'), None)
if ap_file is not None:
sr = spikeglx.Reader(ap_file)
# If T0 is greater than recording length, take 500 sec before end
if sr.rl < T0:
T0 = int(sr.rl - 500)
raw = sr[int((sr.fs * T0)):int((sr.fs * (T0 + 1))), :-sr.nsync].T
else:
return []
if sr.meta.get('NP2.4_shank', None) is not None:
h = neuropixel.trace_header(sr.major_version, nshank=4)
h = neuropixel.split_trace_header(h, shank=int(sr.meta.get('NP2.4_shank')))
else:
h = neuropixel.trace_header(sr.major_version, nshank=np.unique(sr.geometry['shank']).size)
channel_labels, channel_features = voltage.detect_bad_channels(raw, sr.fs)
_, eqcs, output_files = ephys_bad_channels(
raw=raw, fs=sr.fs, channel_labels=channel_labels, channel_features=channel_features, h=h, channels=electrodes,
title=SNAPSHOT_LABEL, destripe=True, save_dir=self.output_directory, br=self.brain_regions, pid_info=self.pid_label)
self.eqcs = eqcs
return output_files
[docs]
def ephys_bad_channels(raw, fs, channel_labels, channel_features, h=None, channels=None, title="ephys_bad_channels",
save_dir=None, destripe=False, eqcs=None, br=None, pid_info=None, plot_backend='matplotlib'):
nc, ns = raw.shape
rl = ns / fs
def gain2level(gain):
return 10 ** (gain / 20) * 4 * np.array([-1, 1])
if fs >= 2600: # AP band
ylim_rms = [0, 100]
ylim_psd_hf = [0, 0.1]
eqc_xrange = [450, 500]
butter_kwargs = {'N': 3, 'Wn': 300 / fs * 2, 'btype': 'highpass'}
eqc_gain = - 90
eqc_levels = gain2level(eqc_gain)
else:
# we are working with the LFP
ylim_rms = [0, 1000]
ylim_psd_hf = [0, 1]
eqc_xrange = [450, 950]
butter_kwargs = {'N': 3, 'Wn': np.array([2, 125]) / fs * 2, 'btype': 'bandpass'}
eqc_gain = - 78
eqc_levels = gain2level(eqc_gain)
inoisy = np.where(channel_labels == 2)[0]
idead = np.where(channel_labels == 1)[0]
ioutside = np.where(channel_labels == 3)[0]
# display voltage traces
eqcs = [] if eqcs is None else eqcs
# butterworth, for display only
sos = scipy.signal.butter(**butter_kwargs, output='sos')
butt = scipy.signal.sosfiltfilt(sos, raw)
if plot_backend == 'matplotlib':
_, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9))
eqcs.append(Density(butt, fs=fs, taxis=1, ax=axs[0], title='highpass', vmin=eqc_levels[0], vmax=eqc_levels[1]))
if destripe:
dest = voltage.destripe(raw, fs=fs, h=h, channel_labels=channel_labels)
_, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9))
eqcs.append(Density(
dest, fs=fs, taxis=1, ax=axs[0], title='destripe', vmin=eqc_levels[0], vmax=eqc_levels[1]))
_, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9))
eqcs.append(Density((butt - dest), fs=fs, taxis=1, ax=axs[0], title='difference', vmin=eqc_levels[0],
vmax=eqc_levels[1]))
for eqc in eqcs:
y, x = np.meshgrid(ioutside, np.linspace(0, rl * 1e3, 500))
eqc.ax.scatter(x.flatten(), y.flatten(), c='goldenrod', s=4)
y, x = np.meshgrid(inoisy, np.linspace(0, rl * 1e3, 500))
eqc.ax.scatter(x.flatten(), y.flatten(), c='r', s=4)
y, x = np.meshgrid(idead, np.linspace(0, rl * 1e3, 500))
eqc.ax.scatter(x.flatten(), y.flatten(), c='b', s=4)
eqc.ax.set_xlim(*eqc_xrange)
eqc.ax.set_ylim(0, nc)
eqc.ax.set_ylabel('Channel index')
eqc.ax.set_title(f'{pid_info}_{eqc.title}')
set_axis_label_size(eqc.ax)
ax = eqc.figure.axes[1]
if channels is not None:
chn_title = channels.get('title', None)
plot_brain_regions(channels['atlas_id'], brain_regions=br, display=True, ax=ax,
title=chn_title)
set_axis_label_size(ax)
else:
remove_axis_outline(ax)
else:
from viewspikes.gui import viewephys # noqa
eqcs.append(viewephys(butt, fs=fs, channels=channels, title='highpass', br=br))
if destripe:
dest = voltage.destripe(raw, fs=fs, h=h, channel_labels=channel_labels)
eqcs.append(viewephys(dest, fs=fs, channels=channels, title='destripe', br=br))
eqcs.append(viewephys((butt - dest), fs=fs, channels=channels, title='difference', br=br))
for eqc in eqcs:
y, x = np.meshgrid(ioutside, np.linspace(0, rl * 1e3, 500))
eqc.ctrl.add_scatter(x.flatten(), y.flatten(), rgb=(164, 142, 35), label='outside')
y, x = np.meshgrid(inoisy, np.linspace(0, rl * 1e3, 500))
eqc.ctrl.add_scatter(x.flatten(), y.flatten(), rgb=(255, 0, 0), label='noisy')
y, x = np.meshgrid(idead, np.linspace(0, rl * 1e3, 500))
eqc.ctrl.add_scatter(x.flatten(), y.flatten(), rgb=(0, 0, 255), label='dead')
eqcs[0].ctrl.set_gain(eqc_gain)
eqcs[0].resize(1960, 1200)
eqcs[0].viewBox_seismic.setXRange(*eqc_xrange)
eqcs[0].viewBox_seismic.setYRange(0, nc)
eqcs[0].ctrl.propagate()
# display features
fig, axs = plt.subplots(2, 2, sharex=True, figsize=[16, 9], tight_layout=True)
fig.suptitle(title)
axs[0, 0].plot(channel_features['rms_raw'] * 1e6)
axs[0, 0].set(title='rms', xlabel='channel number', ylabel='rms (uV)', ylim=ylim_rms)
axs[1, 0].plot(channel_features['psd_hf'])
axs[1, 0].plot(inoisy, np.minimum(channel_features['psd_hf'][inoisy], 0.0999), 'xr')
axs[1, 0].set(title='PSD above 80% Nyquist', xlabel='channel number', ylabel='PSD (uV ** 2 / Hz)', ylim=ylim_psd_hf)
axs[1, 0].legend = ['psd', 'noisy']
axs[0, 1].plot(channel_features['xcor_hf'])
axs[0, 1].plot(channel_features['xcor_lf'])
axs[0, 1].plot(idead, channel_features['xcor_hf'][idead], 'xb')
axs[0, 1].plot(ioutside, channel_features['xcor_lf'][ioutside], 'xy')
axs[0, 1].set(title='Similarity', xlabel='channel number', ylabel='', ylim=[-1.5, 0.5])
axs[0, 1].legend(['detrend', 'trend', 'dead', 'outside'])
fscale, psd = scipy.signal.welch(raw * 1e6, fs=fs) # units; uV ** 2 / Hz
axs[1, 1].imshow(20 * np.log10(psd).T, extent=[0, nc - 1, fscale[0], fscale[-1]], origin='lower', aspect='auto',
vmin=-50, vmax=-20)
axs[1, 1].set(title='PSD', xlabel='channel number', ylabel="Frequency (Hz)")
axs[1, 1].plot(idead, idead * 0 + fs / 4, 'xb')
axs[1, 1].plot(inoisy, inoisy * 0 + fs / 4, 'xr')
axs[1, 1].plot(ioutside, ioutside * 0 + fs / 4, 'xy')
if save_dir is not None:
output_files = [Path(save_dir).joinpath(f"{title}.png")]
fig.savefig(output_files[0])
for eqc in eqcs:
if plot_backend == 'matplotlib':
output_files.append(Path(save_dir).joinpath(f"{title}_{eqc.title}.png"))
eqc.figure.savefig(str(output_files[-1]))
else:
output_files.append(Path(save_dir).joinpath(f"{title}_{eqc.windowTitle()}.png"))
eqc.grab().save(str(output_files[-1]))
return fig, eqcs, output_files
else:
return fig, eqcs
[docs]
def raw_destripe(raw, fs, t0, i_plt, n_plt,
fig=None, axs=None, savedir=None, detect_badch=True,
SAMPLE_SKIP=200, DISPLAY_TIME=0.05, N_CHAN=384,
MIN_X=-0.00011, MAX_X=0.00011):
'''
:param raw: raw ephys data, Ns x Nc, x-axis: time (s), y-axis: channel
:param fs: sampling freq (Hz) of the raw ephys data
:param t0: time (s) of ephys sample beginning from session start
:param i_plt: increment of plot to display image one (start from 0, has to be < n_plt)
:param n_plt: total number of subplot on figure
:param fig: figure handle
:param axs: axis handle
:param savedir: filename, including directory, to save figure to
:param detect_badch: boolean, to detect or not bad channels
:param SAMPLE_SKIP: number of samples to skip at origin of ephsy sample for display
:param DISPLAY_TIME: time (s) to display
:param N_CHAN: number of expected channels on the probe
:param MIN_X: max voltage for color range
:param MAX_X: min voltage for color range
:return: fig, axs
'''
# Import
from ibldsp import voltage
from ibllib.plots import Density
# Init fig
if fig is None or axs is None:
fig, axs = plt.subplots(nrows=1, ncols=n_plt, figsize=(14, 5), gridspec_kw={'width_ratios': 4 * n_plt})
if i_plt > len(axs) - 1: # Error
raise ValueError(f'The given increment of subplot ({i_plt+1}) '
f'is larger than the total number of subplots ({len(axs)})')
[nc, ns] = raw.shape
if nc == N_CHAN:
destripe = voltage.destripe(raw, fs=fs)
X = destripe[:, :int(DISPLAY_TIME * fs)].T
Xs = X[SAMPLE_SKIP:].T # Remove artifact at beginning
Tplot = Xs.shape[1] / fs
# PLOT RAW DATA
d = Density(-Xs, fs=fs, taxis=1, ax=axs[i_plt], vmin=MIN_X, vmax=MAX_X) # noqa
axs[i_plt].set_ylabel('')
axs[i_plt].set_xlim((0, Tplot * 1e3))
axs[i_plt].set_ylim((0, nc))
# Init title
title_plt = f't0 = {int(t0 / 60)} min'
if detect_badch:
# Detect and remove bad channels prior to spike detection
labels, xfeats = voltage.detect_bad_channels(raw, fs)
idx_badchan = np.where(labels != 0)[0]
# Plot bad channels on raw data
x, y = np.meshgrid(idx_badchan, np.linspace(0, Tplot * 1e3, 20))
axs[i_plt].plot(y.flatten(), x.flatten(), '.k', markersize=1)
# Append title
title_plt += f', n={len(idx_badchan)} bad ch'
# Set title
axs[i_plt].title.set_text(title_plt)
else:
axs[i_plt].title.set_text(f'CANNOT DESTRIPE, N CHAN = {nc}')
# Amend some axis style
if i_plt > 0:
axs[i_plt].set_yticklabels('')
# Fig layout
fig.tight_layout()
if savedir is not None:
fig.savefig(fname=savedir)
return fig, axs
[docs]
def dlc_qc_plot(session_path, one=None, device_collection='raw_video_data',
cameras=('left', 'right', 'body'), trials_collection='alf'):
"""
Creates DLC QC plot.
Data is searched first locally, then on Alyx. Panels that lack required data are skipped.
Required data to create all panels
'raw_video_data/_iblrig_bodyCamera.raw.mp4',
'raw_video_data/_iblrig_leftCamera.raw.mp4',
'raw_video_data/_iblrig_rightCamera.raw.mp4',
'alf/_ibl_bodyCamera.dlc.pqt',
'alf/_ibl_leftCamera.dlc.pqt',
'alf/_ibl_rightCamera.dlc.pqt',
'alf/_ibl_bodyCamera.times.npy',
'alf/_ibl_leftCamera.times.npy',
'alf/_ibl_rightCamera.times.npy',
'alf/_ibl_leftCamera.features.pqt',
'alf/_ibl_rightCamera.features.pqt',
'alf/rightROIMotionEnergy.position.npy',
'alf/leftROIMotionEnergy.position.npy',
'alf/bodyROIMotionEnergy.position.npy',
'alf/_ibl_trials.choice.npy',
'alf/_ibl_trials.feedbackType.npy',
'alf/_ibl_trials.feedback_times.npy',
'alf/_ibl_trials.stimOn_times.npy',
'alf/_ibl_wheel.position.npy',
'alf/_ibl_wheel.timestamps.npy',
'alf/licks.times.npy',
:params session_path: Path to session data on disk
:params one: ONE instance, if None is given, default ONE is instantiated
:returns: Matplotlib figure
"""
one = one or ONE()
# hack for running on cortexlab local server
if one.alyx.base_url == 'https://alyx.cortexlab.net':
one = ONE(base_url='https://alyx.internationalbrainlab.org')
data = {}
session_path = Path(session_path)
# Load data for each camera
for cam in cameras:
# Load a single frame for each video
# Check if video data is available locally,if yes, load a single frame
video_path = session_path.joinpath(device_collection, f'_iblrig_{cam}Camera.raw.mp4')
if video_path.exists():
data[f'{cam}_frame'] = get_video_frame(video_path, frame_number=5 * 60 * SAMPLING[cam])[:, :, 0]
# If not, try to stream a frame (try three times)
else:
try:
video_url = url_from_eid(one.path2eid(session_path), one=one)[cam]
for tries in range(3):
try:
data[f'{cam}_frame'] = get_video_frame(video_url, frame_number=5 * 60 * SAMPLING[cam])[:, :, 0]
break
except Exception:
if tries < 2:
tries += 1
logger.info(f"Streaming {cam} video failed, retrying x{tries}")
time.sleep(30)
else:
logger.warning(f"Could not load video frame for {cam} cam. Skipping trace on frame.")
data[f'{cam}_frame'] = None
except KeyError:
logger.warning(f"Could not load video frame for {cam} cam. Skipping trace on frame.")
data[f'{cam}_frame'] = None
# Other camera associated data
for feat in ['dlc', 'times', 'features', 'ROIMotionEnergy']:
# Check locally first, then try to load from alyx, if nothing works, set to None
if feat == 'features' and cam == 'body': # this doesn't exist for body cam
continue
local_file = list(session_path.joinpath('alf').glob(f'*{cam}Camera.{feat}*'))
if len(local_file) > 0:
data[f'{cam}_{feat}'] = alfio.load_file_content(local_file[0])
else:
alyx_ds = [ds for ds in one.list_datasets(one.path2eid(session_path)) if f'{cam}Camera.{feat}' in ds]
if len(alyx_ds) > 0:
data[f'{cam}_{feat}'] = one.load_dataset(one.path2eid(session_path), alyx_ds[0])
else:
logger.warning(f"Could not load _ibl_{cam}Camera.{feat} some plots have to be skipped.")
data[f'{cam}_{feat}'] = None
# Sometimes there is a file but the object is empty, set to None
if data[f'{cam}_{feat}'] is not None and len(data[f'{cam}_{feat}']) == 0:
logger.warning(f"Object loaded from _ibl_{cam}Camera.{feat} is empty, some plots have to be skipped.")
data[f'{cam}_{feat}'] = None
# If we have no frame and/or no DLC and/or no times for all cams, raise an error, something is really wrong
assert any(data[f'{cam}_frame'] is not None for cam in cameras), "No camera data could be loaded, aborting."
assert any(data[f'{cam}_dlc'] is not None for cam in cameras), "No DLC data could be loaded, aborting."
assert any(data[f'{cam}_times'] is not None for cam in cameras), "No camera times data could be loaded, aborting."
# Load session level data
for alf_object, collection in zip(['trials', 'wheel', 'licks'], [trials_collection, trials_collection, 'alf']):
try:
data[f'{alf_object}'] = alfio.load_object(session_path.joinpath(collection), alf_object) # load locally
continue
except ALFObjectNotFound:
pass
try:
# then try from alyx
data[f'{alf_object}'] = one.load_object(one.path2eid(session_path), alf_object, collection=collection)
except ALFObjectNotFound:
logger.warning(f"Could not load {alf_object} object, some plots have to be skipped.")
data[f'{alf_object}'] = None
# Simplify and clean up trials data
if data['trials']:
data['trials'] = pd.DataFrame(
{k: data['trials'][k] for k in ['stimOn_times', 'feedback_times', 'choice', 'feedbackType']})
# Discard nan events and too long trials
data['trials'] = data['trials'].dropna()
data['trials'] = data['trials'].drop(
data['trials'][(data['trials']['feedback_times'] - data['trials']['stimOn_times']) > 10].index)
# Make a list of panels, if inputs are missing, instead input a text to display
panels = []
# Panel A, B, C: Trace on frame
for cam in cameras:
if data[f'{cam}_frame'] is not None and data[f'{cam}_dlc'] is not None:
panels.append((plot_trace_on_frame,
{'frame': data[f'{cam}_frame'], 'dlc_df': data[f'{cam}_dlc'], 'cam': cam}))
else:
panels.append((None, f'Data missing\n{cam.capitalize()} cam trace on frame'))
# If trials data is not there, we cannot plot any of the trial average plots, skip all remaining panels
if data['trials'] is None:
panels.extend([(None, 'No trial data,\ncannot compute trial avgs')] * 7)
else:
# Panel D: Motion energy
camera_dict = {}
for cam in cameras: # Remove cameras where we don't have motion energy AND camera times
d = {'motion_energy': data.get(f'{cam}_ROIMotionEnergy'), 'times': data.get(f'{cam}_times')}
if not any(x is None for x in d.values()):
camera_dict[cam] = d
if len(camera_dict) > 0:
panels.append((plot_motion_energy_hist, {'camera_dict': camera_dict, 'trials_df': data['trials']}))
else:
panels.append((None, 'Data missing\nMotion energy'))
# Panel E: Wheel position
if data['wheel']:
panels.append((plot_wheel_position, {'wheel_position': data['wheel'].position,
'wheel_time': data['wheel'].timestamps,
'trials_df': data['trials']}))
else:
panels.append((None, 'Data missing\nWheel position'))
# Panel F, G: Paw speed and nose speed
# Try if all data is there for left cam first, otherwise right
for cam in ['left', 'right']:
fail = False
if (data[f'{cam}_dlc'] is not None and data[f'{cam}_times'] is not None
and len(data[f'{cam}_times']) >= len(data[f'{cam}_dlc'])):
break
fail = True
if not fail:
paw = 'r' if cam == 'left' else 'l'
panels.append((plot_speed_hist, {'dlc_df': data[f'{cam}_dlc'], 'cam_times': data[f'{cam}_times'],
'trials_df': data['trials'], 'feature': f'paw_{paw}', 'cam': cam}))
panels.append((plot_speed_hist, {'dlc_df': data[f'{cam}_dlc'], 'cam_times': data[f'{cam}_times'],
'trials_df': data['trials'], 'feature': 'nose_tip', 'legend': False,
'cam': cam}))
else:
panels.extend([(None, 'Data missing or corrupt\nSpeed histograms')] * 2)
# Panel H and I: Lick plots
if data['licks'] and data['licks'].times.shape[0] > 0:
panels.append((plot_lick_hist, {'lick_times': data['licks'].times, 'trials_df': data['trials']}))
panels.append((plot_lick_raster, {'lick_times': data['licks'].times, 'trials_df': data['trials']}))
else:
panels.extend([(None, 'Data missing\nLicks plots') for i in range(2)])
# Panel J: pupil plot
# Try if all data is there for left cam first, otherwise right
for cam in ['left', 'right']:
fail = False
if (data.get(f'{cam}_times') is not None and data.get(f'{cam}_features') is not None
and len(data[f'{cam}_times']) >= len(data[f'{cam}_features'])
and not np.all(np.isnan(data[f'{cam}_features'].pupilDiameter_smooth))):
break
fail = True
if not fail:
panels.append((plot_pupil_diameter_hist,
{'pupil_diameter': data[f'{cam}_features'].pupilDiameter_smooth,
'cam_times': data[f'{cam}_times'], 'trials_df': data['trials'], 'cam': cam}))
else:
panels.append((None, 'Data missing or corrupt\nPupil diameter'))
# Plotting
plt.rcParams.update({'font.size': 10})
fig = plt.figure(figsize=(17, 10))
for i, panel in enumerate(panels):
ax = plt.subplot(2, 5, i + 1)
ax.text(-0.1, 1.15, ascii_uppercase[i], transform=ax.transAxes, fontsize=16, fontweight='bold')
# Check if there was in issue with inputs, if yes, print the respective text
if panel[0] is None:
ax.text(.5, .5, panel[1], color='r', fontweight='bold', fontsize=12, horizontalalignment='center',
verticalalignment='center', transform=ax.transAxes)
plt.axis('off')
else:
try:
panel[0](**panel[1])
except Exception:
logger.error(f'Error in {panel[0].__name__}\n' + traceback.format_exc())
ax.text(.5, .5, f'Error while plotting\n{panel[0].__name__}', color='r', fontweight='bold',
fontsize=12, horizontalalignment='center', verticalalignment='center', transform=ax.transAxes)
plt.axis('off')
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
return fig