"""An interactive PyQT QC data frame."""
import logging
from PyQt5 import QtCore, QtWidgets
from matplotlib.figure import Figure
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg, NavigationToolbar2QT
import pandas as pd
import numpy as np
from ibllib.misc import qt
_logger = logging.getLogger(__name__)
[docs]
class DataFrameModel(QtCore.QAbstractTableModel):
DtypeRole = QtCore.Qt.UserRole + 1000
ValueRole = QtCore.Qt.UserRole + 1001
def __init__(self, df=pd.DataFrame(), parent=None):
super(DataFrameModel, self).__init__(parent)
self._dataframe = df
[docs]
def setDataFrame(self, dataframe):
self.beginResetModel()
self._dataframe = dataframe.copy()
self.endResetModel()
def dataFrame(self):
return self._dataframe
dataFrame = QtCore.pyqtProperty(pd.DataFrame, fget=dataFrame, fset=setDataFrame)
[docs]
def rowCount(self, parent=QtCore.QModelIndex()):
if parent.isValid():
return 0
return len(self._dataframe.index)
[docs]
def columnCount(self, parent=QtCore.QModelIndex()):
if parent.isValid():
return 0
return self._dataframe.columns.size
[docs]
def data(self, index, role=QtCore.Qt.DisplayRole):
if (not index.isValid() or not (0 <= index.row() < self.rowCount() and
0 <= index.column() < self.columnCount())):
return QtCore.QVariant()
row = self._dataframe.index[index.row()]
col = self._dataframe.columns[index.column()]
dt = self._dataframe[col].dtype
val = self._dataframe.iloc[row][col]
if role == QtCore.Qt.DisplayRole:
return str(val)
elif role == DataFrameModel.ValueRole:
return val
if role == DataFrameModel.DtypeRole:
return dt
return QtCore.QVariant()
[docs]
def roleNames(self):
roles = {
QtCore.Qt.DisplayRole: b'display',
DataFrameModel.DtypeRole: b'dtype',
DataFrameModel.ValueRole: b'value'
}
return roles
[docs]
def sort(self, col, order):
"""
Sort table by given column number.
:param col: the column number selected (between 0 and self._dataframe.columns.size)
:param order: the order to be sorted, 0 is descending; 1, ascending
:return:
"""
self.layoutAboutToBeChanged.emit()
col_name = self._dataframe.columns.values[col]
# print('sorting by ' + col_name)
self._dataframe.sort_values(by=col_name, ascending=not order, inplace=True)
self._dataframe.reset_index(inplace=True, drop=True)
self.layoutChanged.emit()
[docs]
class PlotCanvas(FigureCanvasQTAgg):
def __init__(self, parent=None, width=5, height=4, dpi=100, wheel=None):
fig = Figure(figsize=(width, height), dpi=dpi)
FigureCanvasQTAgg.__init__(self, fig)
self.setParent(parent)
FigureCanvasQTAgg.setSizePolicy(
self,
QtWidgets.QSizePolicy.Expanding,
QtWidgets.QSizePolicy.Expanding)
FigureCanvasQTAgg.updateGeometry(self)
if wheel:
self.ax, self.ax2 = fig.subplots(
2, 1, gridspec_kw={'height_ratios': [2, 1]}, sharex=True)
else:
self.ax = fig.add_subplot(111)
self.draw()
[docs]
class PlotWindow(QtWidgets.QWidget):
def __init__(self, parent=None, wheel=None):
QtWidgets.QWidget.__init__(self, parent=None)
self.canvas = PlotCanvas(wheel=wheel)
self.vbl = QtWidgets.QVBoxLayout() # Set box for plotting
self.vbl.addWidget(self.canvas)
self.setLayout(self.vbl)
self.vbl.addWidget(NavigationToolbar2QT(self.canvas, self))
[docs]
class GraphWindow(QtWidgets.QWidget):
def __init__(self, parent=None, wheel=None):
QtWidgets.QWidget.__init__(self, parent=parent)
vLayout = QtWidgets.QVBoxLayout(self)
hLayout = QtWidgets.QHBoxLayout()
self.pathLE = QtWidgets.QLineEdit(self)
hLayout.addWidget(self.pathLE)
self.loadBtn = QtWidgets.QPushButton("Select File", self)
hLayout.addWidget(self.loadBtn)
vLayout.addLayout(hLayout)
self.pandasTv = QtWidgets.QTableView(self)
vLayout.addWidget(self.pandasTv)
self.loadBtn.clicked.connect(self.load_file)
self.pandasTv.setSortingEnabled(True)
self.pandasTv.doubleClicked.connect(self.tv_double_clicked)
self.wplot = PlotWindow(wheel=wheel)
self.wplot.show()
self.wheel = wheel
[docs]
def load_file(self):
fileName, _ = QtWidgets.QFileDialog.getOpenFileName(
self, "Open File", "", "CSV Files (*.csv)")
self.pathLE.setText(fileName)
df = pd.read_csv(fileName)
self.update_df(df)
[docs]
def update_df(self, df):
model = DataFrameModel(df)
self.pandasTv.setModel(model)
self.wplot.canvas.draw()
[docs]
def tv_double_clicked(self):
df = self.pandasTv.model()._dataframe
ind = self.pandasTv.currentIndex()
start = df.loc[ind.row()]['intervals_0']
finish = df.loc[ind.row()]['intervals_1']
dt = finish - start
if self.wheel:
idx = np.searchsorted(
self.wheel['re_ts'], np.array([start - dt / 10, finish + dt / 10]))
period = self.wheel['re_pos'][idx[0]:idx[1]]
if period.size == 0:
_logger.warning('No wheel data during trial #%i', ind.row())
else:
min_val, max_val = np.min(period), np.max(period)
self.wplot.canvas.ax2.set_ylim(min_val - 1, max_val + 1)
self.wplot.canvas.ax2.set_xlim(start - dt / 10, finish + dt / 10)
self.wplot.canvas.ax.set_xlim(start - dt / 10, finish + dt / 10)
self.wplot.canvas.draw()
[docs]
def viewqc(qc=None, title=None, wheel=None):
qt.create_app()
qcw = GraphWindow(wheel=wheel)
qcw.setWindowTitle(title)
if qc is not None:
qcw.update_df(qc)
qcw.show()
return qcw