Source code for ibllib.qc.task_qc_viewer.ViewEphysQC

"""An interactive PyQT QC data frame."""
import logging

from PyQt5 import QtCore, QtWidgets
from matplotlib.figure import Figure
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg, NavigationToolbar2QT
import pandas as pd
import numpy as np

from ibllib.misc import qt

_logger = logging.getLogger(__name__)


[docs] class DataFrameModel(QtCore.QAbstractTableModel): DtypeRole = QtCore.Qt.UserRole + 1000 ValueRole = QtCore.Qt.UserRole + 1001 def __init__(self, df=pd.DataFrame(), parent=None): super(DataFrameModel, self).__init__(parent) self._dataframe = df
[docs] def setDataFrame(self, dataframe): self.beginResetModel() self._dataframe = dataframe.copy() self.endResetModel()
def dataFrame(self): return self._dataframe dataFrame = QtCore.pyqtProperty(pd.DataFrame, fget=dataFrame, fset=setDataFrame)
[docs] @QtCore.pyqtSlot(int, QtCore.Qt.Orientation, result=str) def headerData(self, section: int, orientation: QtCore.Qt.Orientation, role: int = QtCore.Qt.DisplayRole): if role == QtCore.Qt.DisplayRole: if orientation == QtCore.Qt.Horizontal: return self._dataframe.columns[section] else: return str(self._dataframe.index[section]) return QtCore.QVariant()
[docs] def rowCount(self, parent=QtCore.QModelIndex()): if parent.isValid(): return 0 return len(self._dataframe.index)
[docs] def columnCount(self, parent=QtCore.QModelIndex()): if parent.isValid(): return 0 return self._dataframe.columns.size
[docs] def data(self, index, role=QtCore.Qt.DisplayRole): if (not index.isValid() or not (0 <= index.row() < self.rowCount() and 0 <= index.column() < self.columnCount())): return QtCore.QVariant() row = self._dataframe.index[index.row()] col = self._dataframe.columns[index.column()] dt = self._dataframe[col].dtype val = self._dataframe.iloc[row][col] if role == QtCore.Qt.DisplayRole: return str(val) elif role == DataFrameModel.ValueRole: return val if role == DataFrameModel.DtypeRole: return dt return QtCore.QVariant()
[docs] def roleNames(self): roles = { QtCore.Qt.DisplayRole: b'display', DataFrameModel.DtypeRole: b'dtype', DataFrameModel.ValueRole: b'value' } return roles
[docs] def sort(self, col, order): """ Sort table by given column number. :param col: the column number selected (between 0 and self._dataframe.columns.size) :param order: the order to be sorted, 0 is descending; 1, ascending :return: """ self.layoutAboutToBeChanged.emit() col_name = self._dataframe.columns.values[col] # print('sorting by ' + col_name) self._dataframe.sort_values(by=col_name, ascending=not order, inplace=True) self._dataframe.reset_index(inplace=True, drop=True) self.layoutChanged.emit()
[docs] class PlotCanvas(FigureCanvasQTAgg): def __init__(self, parent=None, width=5, height=4, dpi=100, wheel=None): fig = Figure(figsize=(width, height), dpi=dpi) FigureCanvasQTAgg.__init__(self, fig) self.setParent(parent) FigureCanvasQTAgg.setSizePolicy( self, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding) FigureCanvasQTAgg.updateGeometry(self) if wheel: self.ax, self.ax2 = fig.subplots( 2, 1, gridspec_kw={'height_ratios': [2, 1]}, sharex=True) else: self.ax = fig.add_subplot(111) self.draw()
[docs] class PlotWindow(QtWidgets.QWidget): def __init__(self, parent=None, wheel=None): QtWidgets.QWidget.__init__(self, parent=None) self.canvas = PlotCanvas(wheel=wheel) self.vbl = QtWidgets.QVBoxLayout() # Set box for plotting self.vbl.addWidget(self.canvas) self.setLayout(self.vbl) self.vbl.addWidget(NavigationToolbar2QT(self.canvas, self))
[docs] class GraphWindow(QtWidgets.QWidget): def __init__(self, parent=None, wheel=None): QtWidgets.QWidget.__init__(self, parent=parent) vLayout = QtWidgets.QVBoxLayout(self) hLayout = QtWidgets.QHBoxLayout() self.pathLE = QtWidgets.QLineEdit(self) hLayout.addWidget(self.pathLE) self.loadBtn = QtWidgets.QPushButton("Select File", self) hLayout.addWidget(self.loadBtn) vLayout.addLayout(hLayout) self.pandasTv = QtWidgets.QTableView(self) vLayout.addWidget(self.pandasTv) self.loadBtn.clicked.connect(self.load_file) self.pandasTv.setSortingEnabled(True) self.pandasTv.doubleClicked.connect(self.tv_double_clicked) self.wplot = PlotWindow(wheel=wheel) self.wplot.show() self.wheel = wheel
[docs] def load_file(self): fileName, _ = QtWidgets.QFileDialog.getOpenFileName( self, "Open File", "", "CSV Files (*.csv)") self.pathLE.setText(fileName) df = pd.read_csv(fileName) self.update_df(df)
[docs] def update_df(self, df): model = DataFrameModel(df) self.pandasTv.setModel(model) self.wplot.canvas.draw()
[docs] def tv_double_clicked(self): df = self.pandasTv.model()._dataframe ind = self.pandasTv.currentIndex() start = df.loc[ind.row()]['intervals_0'] finish = df.loc[ind.row()]['intervals_1'] dt = finish - start if self.wheel: idx = np.searchsorted( self.wheel['re_ts'], np.array([start - dt / 10, finish + dt / 10])) period = self.wheel['re_pos'][idx[0]:idx[1]] if period.size == 0: _logger.warning('No wheel data during trial #%i', ind.row()) else: min_val, max_val = np.min(period), np.max(period) self.wplot.canvas.ax2.set_ylim(min_val - 1, max_val + 1) self.wplot.canvas.ax2.set_xlim(start - dt / 10, finish + dt / 10) self.wplot.canvas.ax.set_xlim(start - dt / 10, finish + dt / 10) self.wplot.canvas.draw()
[docs] def viewqc(qc=None, title=None, wheel=None): qt.create_app() qcw = GraphWindow(wheel=wheel) qcw.setWindowTitle(title) if qc is not None: qcw.update_df(qc) qcw.show() return qcw