Source code for ibl_alignment_gui.plugins.channel_prediction

from collections.abc import Callable
from pathlib import Path
from typing import TYPE_CHECKING

import numpy as np
import pandas as pd
from qtpy import QtWidgets

from ibl_alignment_gui.utils.utils import shank_loop
from iblutil.numerical import ismember
from iblutil.util import Bunch

if TYPE_CHECKING:
    from ibl_alignment_gui.app.app_controller import AlignmentGUIController
    from ibl_alignment_gui.app.shank_controller import ShankController

PLUGIN_NAME = 'Channel Prediction'


[docs] def setup(controller: 'AlignmentGUIController') -> None: """ Set up the Channel Prediction plugin. Adds menu options to select different prediction models. Parameters ---------- controller: AlignmentGUIController The main application controller. """ controller.plugins[PLUGIN_NAME] = Bunch() channel_prediction = ChannelPrediction(controller) controller.plugins[PLUGIN_NAME]['loader'] = channel_prediction # Add menu bar for selecting what to show controller.plugins[PLUGIN_NAME]['activated'] = False # Add a submenu to the main menu plugin_menu = QtWidgets.QMenu(PLUGIN_NAME, controller.view) controller.plugin_options.addMenu(plugin_menu) action_group = QtWidgets.QActionGroup(plugin_menu) action_group.setExclusive(True) # Add the different prediction model options predictions_models = { 'Original': None, 'Cosmos': compute_cosmos_predictions, 'Cumulative': compute_cumulative_distribution, } for model, model_func in predictions_models.items(): action = QtWidgets.QAction(model, controller.view) action.setCheckable(True) action.setChecked(model == 'Original') action.triggered.connect( lambda _, m=model, func=model_func: channel_prediction.plot_regions(_, m, func) ) action_group.addAction(action) plugin_menu.addAction(action)
[docs] class ChannelPrediction: """ Class to handle channel prediction plotting in the alignment GUI. Parameters ---------- controller: AlignmentGUIController The main application controller. """ def __init__(self, controller: 'AlignmentGUIController') -> None: self.controller = controller
[docs] def plot_regions(self, _, model: str, func: Callable) -> None: """ Plot the brain regions based on the selected model. Parameters ---------- model: str The name of the model to use for predictions. func: Callable The function to compute the predictions. """ # Plot the regions based on the action if model == 'Original': plot_original_regions(self.controller) else: plot_predicted_regions(self.controller, model, func)
[docs] @shank_loop def plot_original_regions(_, items: 'ShankController', **kwargs) -> None: """Plot the original histology regions on the reference histology plot.""" items.view.plot_histology(items.view.fig_hist_ref, items.model.hist_data_ref, ax='right')
[docs] @shank_loop def plot_predicted_regions( controller: 'AlignmentGUIController', items: 'ShankController', model: str, func: Callable, **kwargs, ) -> None: """ Plot the model predictions on the reference histology plot. Parameters ---------- model: str The name of the model. func: Callable The function to compute the predictions. """ if not getattr(items.model, 'predictions', None): items.model.predictions = Bunch() results = items.model.predictions.get(model, None) if results is None: items.model.predictions[model] = func(controller, items) if 'probability' in items.model.predictions[model]: items.view.plot_histology_cumulative( items.view.fig_hist_ref, items.model.predictions[model] ) else: items.view.plot_histology( items.view.fig_hist_ref, items.model.predictions[model], ax='right' )
[docs] def compute_cosmos_predictions( controller: 'AlignmentGUIController', items: 'ShankController' ) -> Bunch[str, np.ndarray]: """ Example prediction model that returns cosmos brain regions. Returns ------- Bunch A bunch containing the predicted brain regions. """ # xyz coordinates sampled at 10 um along histology track from bottom or brain to top xyz_samples = items.model.align_handle.xyz_samples # depths of these coordinates along the track depth_samples = items.model.align_handle.ephysalign.sampling_trk region_ids = controller.model.brain_atlas.get_labels(xyz_samples, mapping='Cosmos') regions = controller.model.brain_atlas.regions.get(region_ids) return get_region_boundaries(regions, depth_samples)
[docs] def compute_random_predictions( controller: 'AlignmentGUIController', items: 'ShankController' ) -> Bunch[str, np.ndarray]: """ Example prediction model that uses the spikes data to assign random brain regions. Returns ------- Bunch A bunch containing the predicted brain regions. """ # xyz coordinates sampled at 10 um along histology track from bottom or brain to top xyz_samples = items.model.align_handle.xyz_samples # depths of these coordinates along the track depth_samples = items.model.align_handle.ephysalign.sampling_trk # Spikes and other data can be accessed in this way if needed spikes = items.model.raw_data['spikes'] def random_chunked_array(n, n_vals=20, seed=None): rng = np.random.default_rng(seed) cuts = np.sort(rng.choice(np.arange(1, n), size=n_vals - 1, replace=False)) chunks = np.diff(np.r_[0, cuts, n]) vals = rng.choice(np.arange(1001), size=n_vals, replace=False) chosen = rng.choice(vals, size=len(chunks), replace=True) return np.repeat(chosen, chunks) random = random_chunked_array(len(depth_samples), n_vals=20, seed=42) region_ids = controller.model.brain_atlas.regions.id[random] regions = controller.model.brain_atlas.regions.get(region_ids) return get_region_boundaries(regions, depth_samples)
[docs] def compute_cumulative_distribution( controller: 'AlignmentGUIController', items: 'ShankController' ) -> Bunch[str, np.ndarray]: """ Example prediction model that plots cumulative prediction of brain regions. Returns ------- Bunch A bunch containing the predicted brain regions. """ if Path('/Users/admin/Downloads/ea_features.pqt').exists(): df = pd.read_parquet('/Users/admin/Downloads/ea_features.pqt') int_cols = [c for c in df.columns if c.isdigit()] probas = df[int_cols].to_numpy() cprobas = probas.cumsum(axis=1) depth_samples = df['axial_um'].values region_ids = np.array([int(c) for c in int_cols]) _, region_idxs = ismember(region_ids, controller.model.brain_atlas.regions.id) colours = [controller.model.brain_atlas.regions.rgb[idx] for idx in region_idxs] else: # xyz coordinates sampled at 10 um along histology track from bottom or brain to top xyz_samples = items.model.align_handle.xyz_samples # depths of these coordinates along the track depth_samples = items.model.align_handle.ephysalign.sampling_trk region_ids = controller.model.brain_atlas.get_labels(xyz_samples, mapping='Beryl') region_ids = np.unique(region_ids) _, region_idxs = ismember(region_ids, controller.model.brain_atlas.regions.id) colours = [controller.model.brain_atlas.regions.rgb[idx] for idx in region_idxs] ndepths = depth_samples.size # number of depths nregions = region_ids.size # number of regions # Generate random probabilities probas = np.random.rand(ndepths, nregions) # Normalize each row to sum to 1 probas /= probas.sum(axis=1, keepdims=True) # Cumulative sum across regions (for stacking) cprobas = probas.cumsum(axis=1) data = Bunch(depths=depth_samples, regions=region_ids, colours=colours, probability=cprobas) return data
[docs] def get_region_boundaries(regions: dict, depths: np.ndarray) -> Bunch[str, np.ndarray]: """ Get the boundaries of brain regions along the histology track. Parameters ---------- regions: dict The brain regions along the histology track. depths: The depths along the histology track. Returns ------- Bunch A bunch containing the region boundaries, labels, and colours. """ boundaries = np.where(np.diff(regions.id))[0] n_regions = len(boundaries) + 1 region = np.empty((n_regions, 2)) region_label = np.empty((n_regions, 2), dtype=object) region_colour = np.empty((n_regions, 3), dtype=int) for i in range(n_regions): # Compute start and end indices for this region start = 0 if i == 0 else boundaries[i - 1] + 1 end = boundaries[i] if i < len(boundaries) else regions.id.size - 1 region[i, :] = depths[[start, end]] * 1e6 region_label[i, :] = (np.mean(depths[[start, end]]) * 1e6, regions.acronym[end]) region_colour[i, :] = regions.rgb[end] data = Bunch(region=region, axis_label=region_label, colour=region_colour) return data