Source code for ibl_alignment_gui.plugins.cluster_features

from typing import TYPE_CHECKING, Any

import numpy as np
import pyqtgraph as pg
from qtpy import QtCore, QtGui, QtWidgets

from brainbox.population.decode import xcorr
from ibl_alignment_gui.utils.qt.custom_widgets import PopupWindow, set_axis

if TYPE_CHECKING:
    from ibl_alignment_gui.app.app_controller import AlignmentGUIController, AlignmentGUIView
    from ibl_alignment_gui.app.shank_controller import ShankController
    from ibl_alignment_gui.loaders.plot_loader import PlotLoader

PLUGIN_NAME = 'Cluster Features'

AUTOCORR_BIN_SIZE = 0.5 / 1000
AUTOCORR_WIN_SIZE = 20 / 1000
FS = 30000


[docs] def setup(controller: 'AlignmentGUIController') -> None: """ Set up the Cluster Features plugin. Adds a submenu to the main GUI for managing the cluster popups. Parameters ---------- controller: AlignmentGUIController The main application controller. """ controller.plugins[PLUGIN_NAME] = dict() controller.plugins[PLUGIN_NAME]['loader'] = ClusterPopupManager(controller) controller.plugins[PLUGIN_NAME]['callback'] = callback controller.plugins[PLUGIN_NAME]['activate'] = True # Add a submenu to the main menu plugin_menu = QtWidgets.QMenu(PLUGIN_NAME, controller.view) controller.plugin_options.addMenu(plugin_menu) # Add the actions to the submenu # Minimise cluster popups action = QtWidgets.QAction(f'Minimise/Show {PLUGIN_NAME}', controller.view) action.setShortcut('M') action.setShortcutContext(QtCore.Qt.ApplicationShortcut) action.triggered.connect(controller.plugins[PLUGIN_NAME]['loader'].minimise_popups) plugin_menu.addAction(action) # Close cluster popups action = QtWidgets.QAction(f'Close {PLUGIN_NAME}', controller.view) action.setShortcut('Alt+X') action.setShortcutContext(QtCore.Qt.ApplicationShortcut) action.triggered.connect(controller.plugins[PLUGIN_NAME]['loader'].close_popups) plugin_menu.addAction(action)
[docs] def callback( controller: 'AlignmentGUIController', items: 'ShankController', _, point: pg.ScatterPlotItem ) -> None: """ Triggered when a cluster in a scatter plot is clicked. Computes autocorrelation and template waveform for the selected cluster and opens a popup showing the plots. Parameters ---------- controller: AlignmentGUIController The main application controller. items: ShankController The shank controller containing cluster data. point: pg.ScatterPlotItem The clicked scatter item in the scatter plot. """ point_pos = point[0].pos() clust_idx = np.argwhere(items.cluster_data == point_pos.x())[0][0] plot_loader = items.model.loaders['plots'] t_autocorr, t_template = compute_timescales(plot_loader) data = {} data['t_autocorr'] = t_autocorr data['autocorr'], clust_no = get_autocorr(plot_loader, clust_idx) data['t_template'] = t_template data['template_wf'] = get_template_wf(plot_loader, clust_idx) controller.plugins[PLUGIN_NAME]['loader'].add_popup(items.name, clust_no, items.config, data)
[docs] class ClusterPopup(PopupWindow): """ A popup qt window per cluster. Shows plots of the cluster autocorrelogram and template waveform. Parameters ---------- title: str The title of the popup window. data: dict or None A dictionary containing data to be plotted in the popup. parent: QtWidgets.QMainWindow or None The parent window of the popup. """ def __init__(self, title: str, view: 'AlignmentGUIView', data: dict | None = None): self.data = data super().__init__(title, parent=view, size=(300, 300), graphics=True)
[docs] def setup(self) -> None: """Configure the plots inside the popup window.""" autocorr_plot = pg.PlotItem() autocorr_plot.setXRange( min=np.min(self.data['t_autocorr']), max=np.max(self.data['t_autocorr']) ) autocorr_plot.setYRange(min=0, max=1.05 * np.max(self.data['autocorr'])) set_axis(autocorr_plot, 'bottom', label='T (ms)') set_axis(autocorr_plot, 'left', label='Number of spikes') plot = pg.BarGraphItem( x=self.data['t_autocorr'], height=self.data['autocorr'], width=0.24, brush=QtGui.QColor(160, 160, 160), ) autocorr_plot.addItem(plot) template_plot = pg.PlotItem() plot = pg.PlotCurveItem() template_plot.setXRange( min=np.min(self.data['t_template']), max=np.max(self.data['t_template']) ) set_axis(template_plot, 'bottom', label='T (ms)') set_axis(template_plot, 'left', label='Amplitude (a.u.)') plot.setData( x=self.data['t_template'], y=self.data['template_wf'], pen=pg.mkPen(color='k', style=QtCore.Qt.SolidLine, width=2), ) template_plot.addItem(plot) self.popup_widget.addItem(autocorr_plot, 0, 0) self.popup_widget.addItem(template_plot, 1, 0)
[docs] class ClusterPopupManager: """ Manager for multiple cluster popups in the GUI. Attributes ---------- parent_view : QtWidgets.QMainWindow The main window of the application. cluster_popups : list A list of currently open cluster popups. popup_status : bool Status indicating whether popups are minimised or shown. """ def __init__(self, controller: 'AlignmentGUIController'): self.view = controller.view self.cluster_popups = [] self.popup_status = True
[docs] def add_popup(self, shank: str, clust_no: int, config: str, data: dict[str, Any]) -> None: """ Add a new cluster popup to the manager and set up its signals. Parameters ---------- shank: str The name of shank clust_no: int The cluster number. config: str The config of the shank data: dict A dict containing data to be plotted in the popup. """ name = f'{shank}_{config}' if config else shank clust_popup = ClusterPopup._get_or_create( f'{name}: cluster {clust_no}', self.view, data=data ) clust_popup.closed.connect(self.popup_closed) clust_popup.leave.connect(self.popup_left) clust_popup.enter.connect(self.popup_entered) self.cluster_popups.append(clust_popup)
[docs] def minimise_popups(self) -> None: """Toggle between minimizing and restoring all cluster popups.""" self.popup_status = not self.popup_status if self.popup_status: for pop in self.cluster_popups: pop.showNormal() else: for pop in self.cluster_popups: pop.showMinimized()
[docs] def close_popups(self) -> None: """Close all cluster popups and reset the list.""" for pop in self.cluster_popups: pop.blockSignals(True) pop.close() self.cluster_popups = []
[docs] def popup_closed(self, popup: ClusterPopup) -> None: """ Triggered when a popup is closed by the user. Parameters ---------- popup: ClusterPopup The cluster popup that was closed. """ if len(self.cluster_popups) > 0: popup_idx = [iP for iP, pop in enumerate(self.cluster_popups) if pop == popup][0] self.cluster_popups.pop(popup_idx)
[docs] def popup_left(self) -> None: """Triggered when the mouse leaves a popup.""" self.view.raise_() self.view.activateWindow()
[docs] def popup_entered(self, popup: ClusterPopup) -> None: """Triggered when the mouse enters a popup.""" popup.raise_() popup.activateWindow()
[docs] def reset(self) -> None: """Triggered when the main GUI is closed. Closes all popups and resets the manager.""" self.close_popups()
[docs] def compute_timescales(plot_loader: 'PlotLoader') -> tuple[np.ndarray, np.ndarray]: """ Compute time vectors for autocorrelogram and template waveform plots. Parameters ---------- plot_loader: PlotLoader The plot loader object containing spike and cluster data. Returns ------- t_autocorr: np.ndarray The time vector for autocorrelogram (ms). t_template: np.ndarray The time vector for template waveform (ms). """ t_autocorr = 1e3 * np.arange( (AUTOCORR_WIN_SIZE / 2) - AUTOCORR_WIN_SIZE, (AUTOCORR_WIN_SIZE / 2) + AUTOCORR_BIN_SIZE, AUTOCORR_BIN_SIZE, ) n_template = plot_loader.data['clusters']['waveforms'][0, :, 0].size t_template = 1e3 * (np.arange(n_template)) / FS return t_autocorr, t_template
[docs] def get_autocorr(plot_loader: 'PlotLoader', clust_idx: int) -> tuple[np.ndarray, int]: """ Compute the autocorrelogram for a specific cluster. Parameters ---------- plot_loader: PlotLoader The plot loader object containing spike and cluster data. clust_idx: int Index of the cluster Returns ------- autocorr: np.ndarray The autocorrelogram of the selected cluster clust_id: int The cluster id of the selected cluster """ idx = plot_loader.spike_clusters == plot_loader.cluster_idx[clust_idx] autocorr = xcorr( plot_loader.spike_times[idx], plot_loader.spike_clusters[idx], AUTOCORR_BIN_SIZE, AUTOCORR_WIN_SIZE, ) if plot_loader.data['clusters'].get('metrics', {}).get('cluster_id', None) is None: clust_id = plot_loader.cluster_idx[clust_idx] else: clust_id = plot_loader.data['clusters'].metrics.cluster_id[ plot_loader.cluster_idx[clust_idx] ] return autocorr[0, 0, :], clust_id
[docs] def get_template_wf(plot_loader: 'PlotLoader', clust_idx: int) -> np.ndarray: """ Retrieve the template waveform for a specific cluster. Parameters ---------- plot_loader: PlotLoader The plot loader object containing spike and cluster data. clust_idx: int Index of the cluster Returns ------- template_wf: np.ndarray The template waveform of the selected cluster """ template_wf = plot_loader.data['clusters']['waveforms'][ plot_loader.cluster_idx[clust_idx], :, 0 ] return template_wf * 1e6