from types import MethodType
from typing import TYPE_CHECKING, Any
import numpy as np
from ibl_alignment_gui.loaders.plot_loader import ScatterData, skip_missing
from ibl_alignment_gui.utils.utils import shank_loop
from iblutil.util import Bunch
if TYPE_CHECKING:
from ibl_alignment_gui.app.app_controller import AlignmentGUIController
from ibl_alignment_gui.app.shank_controller import ShankController
PLUGIN_NAME = 'Additional Plots'
[docs]
def setup(controller: 'AlignmentGUIController') -> None:
"""
Example to show how to add additional plots to the GUI.
Parameters
----------
controller: AlignmentGUIController
The main application controller.
"""
controller.plugins[PLUGIN_NAME] = Bunch()
controller.plugins[PLUGIN_NAME]['activated'] = True
# Attach callbacks to methods in the controller
controller.plugins[PLUGIN_NAME]['load_data'] = add_plots
[docs]
@shank_loop
def add_plots(_, items: 'ShankController', **kwargs) -> None:
"""
Add additional plots to the plot loader.
Parameters
----------
_
items: ShankController
A ShankController instance.
-------
"""
# Add the additional data that may be required for the plots
items.model.raw_data['clusters']['predicted_region'] = np.random.randint(
0, 500, size=items.model.raw_data['clusters']['peakToTrough'].shape
)
items.model.loaders['plots'].scatter_amp_depth_prediction = MethodType(
scatter_amp_depth_prediction, items.model.loaders['plots']
)
[docs]
@skip_missing(['spikes'])
def scatter_amp_depth_prediction(self) -> dict[str, Any]:
"""
Generate data for a scatter plot of cluster depth vs. cluster amplitude.
Clusters are coloured by their predicted region.
Returns
-------
Dict
A dict containing a ScatterData object with key 'Cluster Amp vs Depth vs Duration'.
"""
levels = np.array([0, 500])
scatter = ScatterData(
x=self.avg_amp[self.cluster_idx],
y=self.avg_depth[self.cluster_idx],
levels=levels,
default_levels=np.copy(levels),
colours=self.data['clusters']['predicted_region'][self.cluster_idx],
pen='k',
size=np.array(8),
symbol=np.array('o'),
xrange=np.array(
[
0.9 * np.nanmin(self.avg_amp[self.cluster_idx]),
1.1 * np.nanmax(self.avg_amp[self.cluster_idx]),
]
),
xaxis='Amplitude (uV)',
title='Prediction',
cmap='Purples',
cluster=True,
)
return {'Cluster Amp vs Depth vs Prediction': scatter}