from typing import TYPE_CHECKING, Any
import numpy as np
from qtpy import QtWidgets
from ibl_alignment_gui.utils.qt.custom_widgets import CheckBoxGroup, PopupWindow, SliderWidget
from ibl_alignment_gui.utils.utils import shank_loop
from iblutil.util import Bunch
if TYPE_CHECKING:
from ibl_alignment_gui.app.app_controller import AlignmentGUIController
from ibl_alignment_gui.app.shank_controller import ShankController
PLUGIN_NAME = 'Range Controller'
[docs]
def setup(controller: 'AlignmentGUIController') -> None:
"""
Set up the Range Controller plugin.
Adds a submenu to the main menu to open the plugin and attaches necessary callbacks
to the controller.
Parameters
----------
controller: AlignmentGUIController
The main application controller.
"""
controller.plugins[PLUGIN_NAME] = Bunch()
range_controller = RangeController(PLUGIN_NAME, controller)
controller.plugins[PLUGIN_NAME]['loader'] = range_controller
controller.plugins[PLUGIN_NAME]['activated'] = False
# Attach callbacks to methods in the controller
controller.plugins[PLUGIN_NAME]['plot_image_panels'] = range_controller.on_plot_changed
controller.plugins[PLUGIN_NAME]['plot_probe_panels'] = range_controller.on_plot_changed
controller.plugins[PLUGIN_NAME]['plot_scatter_panels'] = range_controller.on_plot_changed
controller.plugins[PLUGIN_NAME]['plot_line_panels'] = range_controller.on_plot_changed
controller.plugins[PLUGIN_NAME]['on_view_changed'] = range_controller.disable_sliders
# Add a submenu to the main menu
action = QtWidgets.QAction(PLUGIN_NAME, controller.view)
action.triggered.connect(lambda: callback(controller))
controller.plugin_options.addAction(action)
[docs]
def callback(controller: 'AlignmentGUIController') -> None:
"""
Open the Range Controller plugin window.
Parameters
----------
controller: AlignmentGUIController
The main application controller.
"""
controller.plugins[PLUGIN_NAME]['activated'] = True
controller.plugins[PLUGIN_NAME]['loader'].setup()
[docs]
class RangeControllerView(PopupWindow):
"""
The GUI view for the Range Controller plugin.
Parameters
----------
title : str
The title of the popup window.
controller : AlignmentGUIController
The main application controller.
"""
def __init__(self, title: str, controller: 'AlignmentGUIController'):
self.steps: int = 100
self.controller: AlignmentGUIController = controller
super().__init__(title, controller.view, size=(500, 600), graphics=False)
[docs]
def setup(self):
"""Add widgets to the popup window."""
self.sliders = Bunch()
self.labels = Bunch()
for plot in ['image', 'probe', 'line']:
self.sliders[plot] = SliderWidget(steps=self.steps, slider_type=plot)
self.labels[plot] = QtWidgets.QLabel()
self.shank_options = CheckBoxGroup()
self.shank_options.add_options(['All'] + self.controller.all_shanks)
self.shank_options.set_checked([self.controller.model.selected_shank])
self.shank_options.setup_callback(self.on_shank_button_clicked)
self.config_options = CheckBoxGroup()
self.config_options.add_options(['both'] + self.controller.model.configs)
self.config_options.set_checked([self.controller.model.selected_config])
self.config_options.setup_callback(self.on_config_button_clicked)
self.layout.addWidget(self.labels['image'])
self.layout.addWidget(self.sliders['image'])
self.layout.addWidget(self.labels['line'])
self.layout.addWidget(self.sliders['line'])
self.layout.addWidget(self.labels['probe'])
self.layout.addWidget(self.sliders['probe'])
self.layout.addWidget(QtWidgets.QLabel('Shanks:'))
self.layout.addWidget(self.shank_options)
self.layout.addWidget(QtWidgets.QLabel('Configurations:'))
self.layout.addWidget(self.config_options)
[docs]
def get_selected_shanks(self) -> list[str]:
"""
Get the list of selected shanks from the GUI.
Returns
-------
list
A list of selected shank names.
"""
return [shank for shank in self.shank_options.get_checked() if 'All' not in shank]
[docs]
def get_selected_configs(self) -> list[str]:
"""
Get the list of selected configurations from the GUI.
Returns
-------
list
A list of selected configuration names.
"""
return [config for config in self.config_options.get_checked() if 'both' not in config]
[docs]
class RangeController:
"""
The Range Controller plugin for the alignment GUI.
Parameters
----------
title : str
The title of the plugin window.
controller : AlignmentGUIController
The main application controller.
Attributes
----------
title : str
The title of the plugin window.
controller : AlignmentGUIController
The main application controller.
view : RangeControllerView
The GUI view for the Range Controller plugin.
disable : bool
A flag to disable callbacks during updates.
plot_keys : Bunch
A collection of current plot keys for image, probe, and line plots.
"""
def __init__(self, title: str, controller: 'AlignmentGUIController'):
self.title = title
self.controller = controller
self.view = None
self.disable = False
self.plot_keys = Bunch()
self.plot_keys['image'] = None
self.plot_keys['probe'] = None
self.plot_keys['line'] = None
[docs]
def setup(self) -> None:
"""Set up the plugin GUI and connect callbacks."""
self.view = RangeControllerView._get_or_create(self.title, self.controller)
self.view.closed.connect(self.on_close)
for slider_widget in self.view.sliders.values():
slider_widget.released.connect(self.on_slider_moved)
slider_widget.reset.connect(self.on_reset_button_pressed)
self.set_init_levels()
if self.controller.model.selected_config == 'both':
self.view.on_config_button_clicked(True, 'both')
[docs]
def on_close(self) -> None:
"""
Triggered when the plugin window is closed.
Deactivate the plugin and callbacks.
"""
self.controller.plugins[PLUGIN_NAME]['activated'] = False
[docs]
def disable_sliders(self) -> None:
"""Disable sliders when the view is changed to feature view."""
if self.controller.show_feature:
for slider_widget in self.view.sliders.values():
slider_widget.slider.setEnabled(False)
slider_widget.reset_button.setEnabled(False)
else:
for slider_widget in self.view.sliders.values():
slider_widget.slider.setEnabled(True)
slider_widget.reset_button.setEnabled(True)
self.set_init_levels()
[docs]
def set_init_levels(self) -> None:
"""Set the initial levels for image, probe, and line plots based on controller settings."""
self.plot_keys['image'] = self.controller.img_init
self.view.labels['image'].setText(f'Image: {self.plot_keys["image"]}')
self.on_plot_changed(
self.plot_keys['image'], self.get_image_plot_type(self.plot_keys['image'])
)
self.plot_keys['probe'] = self.controller.probe_init
self.view.labels['probe'].setText(f'Probe: {self.plot_keys["probe"]}')
self.on_plot_changed(self.plot_keys['probe'], 'probe')
self.plot_keys['line'] = self.controller.line_init
self.view.labels['line'].setText(f'Line: {self.plot_keys["line"]}')
self.on_plot_changed(self.plot_keys['line'], 'line')
[docs]
def get_image_plot_type(self, plot_key: str) -> str | None:
"""
Determine the plot type for a given plot key.
Parameters
----------
plot_key: str
Returns
-------
str | None
The plot type ('image' or 'scatter') or None if not found.
"""
if plot_key in self.controller.model.image_keys:
return 'image'
elif plot_key in self.controller.model.scatter_keys:
return 'scatter'
else:
return None
[docs]
def on_plot_changed(self, plot_key: str, plot_type: str) -> None:
"""
Triggered when a plot is changed in the main controller.
If no view is set or callbacks are disabled, it returns immediately.
Parameters
----------
plot_key: str
The key of the plot that was changed.
plot_type: str
The type of the plot ('image', 'probe', 'line', or 'scatter').
"""
if self.view is None or self.disable:
return
key = plot_type if plot_type != 'scatter' else 'image'
self.plot_keys[key] = plot_key
# Find the extremes across all shanks and configs
data = get_levels(self.controller, plot_key, plot_type)
max_levels = np.array([dat['data'].levels for dat in data if dat['data'] is not None])
self.view.sliders[key].set_slider_intervals([np.nanmin(max_levels), np.nanmax(max_levels)])
# Find the level for the currently selected shank and configs
data = get_levels(
self.controller,
plot_key,
plot_type,
shanks=self.view.get_selected_shanks(),
configs=self.view.get_selected_configs(),
)
levels = np.array([dat['data'].levels for dat in data if dat['data'] is not None])
if len(levels) == 0:
levels = np.nanquantile(max_levels, [0.1, 0.9])
levels = np.nanquantile(levels, [0.1, 0.9])
# Update the slider and label
self.view.sliders[key].set_slider_values(levels)
self.view.labels[key].setText(f'{key.capitalize()}: {self.plot_keys[key]}')
[docs]
def plot_panels(self, plot_key: str, plot_type: str) -> None:
"""
Re-plot the panels in the main view for the given plot key and type.
Parameters
----------
plot_key: str
The key of the plot to update.
plot_type: str
The type of the plot ('image', 'probe', 'line', 'scatter' or 'feature')
"""
if plot_type == 'probe':
self.controller.plot_probe_panels(plot_key)
elif plot_type == 'line':
self.controller.plot_line_panels(plot_key)
elif plot_type == 'scatter':
self.controller.plot_scatter_panels(plot_key)
elif plot_type == 'image':
self.controller.plot_image_panels(plot_key)
[docs]
def on_slider_moved(self, slider_widget: SliderWidget, plot_type: str) -> None:
"""
Triggered when a slider is moved.
Updates the slider values and updates the plots.
Parameters
----------
slider_widget: SliderWidget
The slider widget that was moved.
plot_type: str
The type of the plot
"""
# If not feature plot shown and feature slider moved, return
if self.controller.show_feature:
return
# Ensure we don't go into on_plot_changed
self.disable = True
levels = slider_widget.get_slider_values()
plot_key = self.plot_keys[plot_type]
if plot_type == 'image':
plot_type = self.get_image_plot_type(plot_key)
set_levels(
self.controller,
plot_key,
plot_type,
levels,
shanks=self.view.get_selected_shanks(),
configs=self.view.get_selected_configs(),
)
slider_widget.set_slider_values(levels)
self.plot_panels(plot_key, plot_type)
self.disable = False
[docs]
@shank_loop
def set_levels(
controller: 'AlignmentGUIController',
items: 'ShankController',
plot_key: str,
plot_type: str,
levels: tuple[float, float],
**kwargs,
) -> None:
"""
Set the levels for a given plot key and type.
Parameters
----------
plot_key: str
The key of the plot to update.
plot_type: str
The type of the plot ('image', 'probe', 'line', or 'scatter').
levels: list or tuple
The new levels to set.
"""
data = get_plot_group(items, plot_type).get(plot_key, None)
if data is None:
return
data.levels = np.copy(levels)
[docs]
@shank_loop
def reset_default_levels(
controller: 'AlignmentGUIController',
items: 'ShankController',
plot_key: str,
plot_type: str,
**kwargs,
) -> None:
"""
Reset the levels to default for a given plot key and type.
Parameters
----------
plot_key: str
The key of the plot to reset.
plot_type: str
The type of the plot ('image', 'probe', 'line', or 'scatter').
"""
data = get_plot_group(items, plot_type).get(plot_key, None)
if data is None:
return
data.levels = np.copy(data.default_levels)
[docs]
@shank_loop
def get_levels(
controller: 'AlignmentGUIController',
items: 'ShankController',
plot_key: str,
plot_type: str,
**kwargs,
) -> dict[str, Any]:
"""
Get the levels for a given plot key and type.
Parameters
----------
plot_key: str
The key of the plot.
plot_type: str
The type of the plot ('image', 'probe', 'line', or 'scatter').
Returns
-------
dict
A dictionary containing the data, shank, and config.
"""
data = get_plot_group(items, plot_type).get(plot_key, None)
if data is None:
return {'data': None, 'shank': kwargs['shank'], 'config': kwargs['config']}
return {'data': data, 'shank': kwargs['shank'], 'config': kwargs['config']}
[docs]
def get_plot_group(items: 'ShankController', plot_type: str) -> Bunch:
"""
Get the plot group for a given plot type.
Parameters
----------
plot_type: The type of the plot ('image', 'probe', 'line', or 'scatter').
Returns
-------
Bunch
The plot group corresponding to the plot type.
"""
plot_groups = {
'probe': items.model.probe_plots,
'line': items.model.line_plots,
'scatter': items.model.scatter_plots,
'image': items.model.image_plots,
}
return plot_groups[plot_type]