import numpy as np
import pyqtgraph as pg
from ibl_alignment_gui.app.shank_view import ShankView
from ibl_alignment_gui.handlers.shank_handler import ShankHandler
from ibl_alignment_gui.utils.qt.custom_widgets import ColorBar
from iblutil.util import Bunch
[docs]
class ShankController:
"""
Controller for handling data and plots on a shank of a probe for a given configuration.
Parameters
----------
model: ShankHandler
The model containing all the data for the shank and config combination
name: str
The name of shank
index: int
The index of the shank
config:
The config of the shank
Attributes
----------
name: str
The name of shank
index: int
The index of the shank
config:
The config of the shank
model: ShankHandler
The model containing all the data for the shank and config combination
view: ShankView
The view containing all the plots for the shank and config combination
cluster: bool
Whether the chosen scatter plot contains cluster data
cluster_data: np.ndarray | None
The cluster data for the chosen scatter plot if it exists
"""
def __init__(self, model: ShankHandler, name: str, index: int, config: str):
self.name: str = name
self.index: int = index
self.config: str = config
self.model: ShankHandler = model
self.view: ShankView = ShankView(self.name, self.index, self.config)
self.cluster: bool = False
self.cluster_data: np.ndarray | None = None
[docs]
def init_reference_line_arrays(self):
"""See :meth:`ShankView.init_reference_line_arrays` for details."""
self.view.init_reference_line_arrays()
[docs]
def init_plot_items(self):
"""See :meth:`ShankView.init_plot_items` for details."""
self.view.init_plot_items()
# --------------------------------------------------------------------------------------------
# Plot functions
# --------------------------------------------------------------------------------------------
[docs]
def plot_histology(self) -> pg.LinearRegionItem:
"""
Plot the histology plot.
Returns
-------
pg.LinearRegionItem
A pg.LinearRegionItem corresponding to the second last brain region added to the
histology plot
"""
self.view.plot_histology(self.view.fig_hist, self.model.hist_data)
return self.view.hist_regions[-2]
[docs]
def plot_histology_ref(self) -> None:
"""Plot the histology reference plot."""
self.view.plot_histology(self.view.fig_hist_ref, self.model.hist_data_ref, ax='right')
[docs]
def plot_scale_factor(self) -> ColorBar:
"""
Plot the scale factor along the probe.
Returns
-------
cbar: ColorBar
The ColorBar object added to the plot
"""
data = self.model.scale_data
data['scale_factor'] = data['scale'] - 0.5
cbar = self.view.plot_scale_factor(data)
return cbar
[docs]
def plot_slice(self, plot_key: str) -> tuple[pg.ViewBox, pg.ImageItem, ColorBar]:
"""
Plot a slice plot showing a coronal slice through the brain.
Parameters
----------
plot_key: str
The key of the plot to display
Returns
-------
fig: pg.ViewBox
The viewbox that contains the image
img: pg.ImageItem
The image item that was added to the plot
cbar: ColorBar
The colobar object that was added to the plot
"""
data = self.model.slice_plots.get(plot_key, None)
data_traj = Bunch()
data_traj['x'] = self.model.xyz_track[:, 0]
data_traj['y'] = self.model.xyz_track[:, 2]
img, cbar = self.view.plot_slice(data, data_traj)
fig = self.view.fig_slice
return fig, img, cbar
[docs]
def plot_fit(self) -> None:
"""Plot fit lines on the fit figure."""
data = Bunch()
data['x'] = self.model.feature * 1e6
data['y'] = self.model.track * 1e6
data['depth_lin'] = self.model.feature2track_lin(self.view.depth, data['x'], data['y'])
data['depth'] = self.view.depth
self.view.plot_fit(data)
[docs]
def plot_channels(self, fig_slice: pg.ViewBox, colour: str | None = None) -> None:
"""
Plot channels on a slice plot.
Parameters
----------
fig_slice: pg.ViewBox
The slice fig to add the channel items to
colour: str
The colour of the scatter points used to plot the channels
Notes
-----
- fig_slice is passed in as a parameter as for the dual config display the channels
plotted on a different slice figure than the one stored in the view.
"""
data = Bunch()
data['xyz_channels'] = self.model.xyz_channels
data['track_lines'] = self.model.track_lines
self.view.plot_channels(fig_slice, data, colour=colour)
[docs]
def plot_scatter(self, plot_key: str, levels: list | None = None) -> ColorBar | None:
"""
Plot a scatter plot.
Parameters
----------
plot_key: str
The key of the plot to display
levels:
The levels used to scale the colorbar on the plot
Returns
-------
cbar: ColorBar
The colobar abject that was added to the plot
"""
data = self.model.scatter_plots.get(plot_key, None)
cbar = self.view.plot_scatter(data, levels=levels)
if data and data.cluster:
self.cluster_data = data.x
self.cluster = True
else:
self.cluster = False
return cbar
[docs]
def plot_line(self, plot_key: str) -> None:
"""
Plot a line plot.
Parameters
----------
plot_key: str
The key of the plot to display
"""
data = self.model.line_plots.get(plot_key, None)
self.view.plot_line(data)
[docs]
def plot_probe(self, plot_key: str, levels: list | None = None) -> ColorBar | None:
"""
Plot a probe plot.
Parameters
----------
plot_key: str
The key of the plot to display
levels:
The levels used to scale the colorbar on the plot
Returns
-------
cbar: ColorBar
The colobar abject that was added to the plot
"""
data = self.model.probe_plots.get(plot_key, None)
cbar = self.view.plot_probe(data, levels=levels)
return cbar
[docs]
def plot_image(self, plot_key: str, levels: list | None = None) -> ColorBar | None:
"""
Plot an image plot.
Parameters
----------
plot_key: str
The key of the plot to display
levels:
The levels used to scale the colorbar on the plot
Returns
-------
cbar: ColorBar
The colobar abject that was added to the plot
"""
data = self.model.image_plots.get(plot_key, None)
cbar = self.view.plot_image(data, levels=levels)
return cbar
[docs]
def plot_feature(self, plot_key: str) -> None:
"""
Plot a feature plot.
Parameters
----------
plot_key: str
The key of the plot to display
"""
data = self.model.feature_plots.get(plot_key, None)
self.view.plot_feature(data)
# --------------------------------------------------------------------------------------------
# Update displays
# --------------------------------------------------------------------------------------------
[docs]
def toggle_labels(self, *args) -> None:
"""See :meth:`ShankView.toggle_labels` for details."""
self.view.toggle_labels(*args)
[docs]
def toggle_channels(self, *args) -> None:
"""See :meth:`ShankView.toggle_channels` for details."""
self.view.toggle_channels(*args)
[docs]
def set_xaxis_range(self, fig: str) -> None:
"""
Set the x-axis range of the specified figure.
Parameters
----------
fig: str
The attribute name of the figure to set the x-axis for
"""
self.view.set_xaxis_range(getattr(self.view, fig))
[docs]
def set_yaxis_range(self, fig: str) -> None:
"""
Set the y-axis range of the specified figure.
Parameters
----------
fig: str
The attribute name of the figure to set the y-axis for
"""
self.view.set_yaxis_range(getattr(self.view, fig))
[docs]
def set_probe_lims(self) -> None:
"""Set the limits for the probe tip and probe top based on values stored in model."""
self.view.set_probe_lims(self.model.chn_min, self.model.chn_max)
[docs]
def set_yaxis_lims(self, *args) -> None:
"""See :meth:`ShankView.set_yaxis_lims` for details."""
self.view.set_yaxis_lims(*args)
[docs]
def get_yaxis_lims(self) -> list[float, float]:
"""
Get the y-axis limits from the model.
Returns
-------
list[float, float]
The ymin and ymax values from the model
"""
return [self.model.y_min, self.model.y_max]
[docs]
def set_scale_title(self, hover_item: pg.LinearRegionItem):
"""
Update the title of the scale plot color bar based on the hovered region.
Parameters
----------
hover_item: pg.LinearRegionItem
The region item currently hovered over.
"""
idx = self.view.match_linear_region(hover_item)
self.view.set_fig_scale_title(self.model.scale_data['scale'][idx])
[docs]
def set_feature_title(self, feature_title: str | None) -> None:
"""See :meth:`ShankView.set_feature_title` for details."""
self.view.set_feature_title(feature_title)
[docs]
def reset_slice_axis(self) -> None:
"""See :meth:`ShankView.reset_slice_axis` for details."""
self.view.reset_slice_axis()
[docs]
def filter_units(self, filter_type: str) -> None:
"""See :meth:`ShankHandler.filter_units` for details."""
self.model.filter_units(filter_type)
[docs]
def reset_levels(self) -> None:
"""See :meth:`ShankHandler.reset_levels` for details."""
self.model.reset_levels()
# --------------------------------------------------------------------------------------------
# Fitting functions
# --------------------------------------------------------------------------------------------
[docs]
def offset_hist_data(self, *args) -> None:
"""See :meth:`ShankHandler.offset_hist_data` for details."""
self.model.offset_hist_data(*args)
[docs]
def scale_hist_data(self, extend_feature: float, lin_fit: bool) -> None:
"""
Scale brain regions along the probe track based on reference lines.
Parameters
----------
extend_feature: float
Amount to extend for linear fit
lin_fit: bool
Whether to use a linear fit or not
"""
line_feature, line_track = self.view.get_feature_and_track_coords()
self.model.scale_hist_data(
line_track, line_feature, extend_feature=extend_feature, lin_fit=lin_fit
)
[docs]
def get_scaled_histology(self) -> None:
"""See :meth:`ShankHandler.get_scaled_histology` for details."""
self.model.get_scaled_histology()
# --------------------------------------------------------------------------------------------
# Reference lines
# --------------------------------------------------------------------------------------------
@property
def points(self) -> list[pg.ScatterPlotItem]:
"""Return the points stored in the view."""
return self.view.points
[docs]
def match_feature_line(self, *args) -> tuple[int | None, list | np.ndarray | None]:
"""See :meth:`ShankView.match_feature_line` for details."""
return self.view.match_feature_line(*args)
[docs]
def match_track_line(self, *args) -> int | None:
"""See :meth:`ShankView.match_track_line` for details."""
return self.view.match_track_line(*args)
[docs]
def create_reference_line_and_point(self, *args, **kwargs) -> tuple:
"""See :meth:`ShankView.create_reference_line_and_point` for details."""
return self.view.create_reference_line_and_point(*args, **kwargs)
[docs]
def remove_reference_line(self, *args) -> None:
"""See :meth:`ShankView.remove_reference_line` for details."""
self.view.remove_reference_line(*args)
[docs]
def delete_reference_line_and_point(self, *args) -> None:
"""See :meth:`ShankView.delete_reference_line_and_point` for details."""
self.view.delete_reference_line_and_point(*args)
[docs]
def update_feature_reference_line_and_point(self, *args) -> None:
"""See :meth:`ShankView.update_feature_reference_line_and_point` for details."""
self.view.update_feature_reference_line_and_point(*args)
[docs]
def update_track_reference_line_and_point(self, *args) -> None:
"""See :meth:`ShankView.update_track_reference_line_and_point` for details."""
self.view.update_track_reference_line_and_point(*args)
[docs]
def align_reference_lines_and_points(self) -> None:
"""See :meth:`ShankView.align_reference_lines` for details."""
self.view.align_reference_lines_and_points()
[docs]
def remove_reference_lines_from_display(self) -> None:
"""See :meth:`ShankView.remove_reference_lines_to_display` for details."""
self.view.remove_reference_lines_from_display()
[docs]
def add_reference_lines_to_display(self) -> None:
"""See :meth:`ShankView.add_reference_lines_to_display` for details."""
self.view.add_reference_lines_to_display()