import re
from abc import ABC, abstractmethod
from collections import defaultdict
from datetime import timedelta
from pathlib import Path
from typing import Any
import numpy as np
import pandas as pd
from ibl_alignment_gui.handlers.shank_handler import ShankHandler
from ibl_alignment_gui.loaders.alignment_loader import (
AlignmentLoaderLocal,
AlignmentLoaderOne,
)
from ibl_alignment_gui.loaders.alignment_uploader import (
AlignmentUploaderLocal,
AlignmentUploaderOne,
)
from ibl_alignment_gui.loaders.data_loader import (
DataLoaderLocal,
DataLoaderOne,
FeatureLoaderOne,
SpikeGLXLoaderLocal,
SpikeGLXLoaderOne,
)
from ibl_alignment_gui.loaders.geometry_loader import (
GeometryLoaderLocal,
GeometryLoaderOne,
)
from ibl_alignment_gui.loaders.histology_loader import (
NrrdSliceLoader,
download_histology_data,
)
from ibl_alignment_gui.loaders.plot_loader import PlotLoader
from ibl_alignment_gui.utils.parse_yaml import DatasetPaths, load_alignment_yaml
from iblatlas.atlas import AllenAtlas
from iblutil.util import Bunch
from one import params
from one.api import ONE
[docs]
class ProbeHandler(ABC):
"""
Abstract base class for handling alignment and data loading for a probe.
This class provides access to loader methods that handle different aspects of the
alignment process. Where applicable the probe is split into shanks and each shank is handled
separately.
It can also handle multiple configurations for each shank, for example if two different channel
maps are used to record data.
Parameters
----------
brain_atlas: AllenAtlas
An AllenAtlas instance.
"""
def __init__(self, brain_atlas: AllenAtlas):
self.brain_atlas: AllenAtlas = brain_atlas or AllenAtlas()
self.shanks: dict[str, Bunch] = defaultdict(Bunch)
# Configuration state
self.default_config: str = 'default'
self.non_default_config: str | None = None
self.configs: list[str] = ['default']
self.possible_configs: list[str] = ['default']
self.selected_config: str = 'default'
# Active shank indices
self.selected_shank: str | None = None
self.selected_idx: int | None = None
# -------------------------------------------------------------------------
# Shank access methods
# -------------------------------------------------------------------------
[docs]
def get_current_shank(self, shank, config) -> ShankHandler:
"""Return the currently active shank."""
return self.shanks[shank][config]
[docs]
def get_selected_shank(self) -> Bunch:
"""Return the currently selected shank."""
return self.shanks[self.selected_shank]
[docs]
def get_config(self, idx: int) -> None:
"""
Select a configuration by index.
Parameters
----------
idx : int
Index in the list of possible configurations.
"""
self.selected_config = self.possible_configs[idx]
# -------------------------------------------------------------------------
# Alignment methods - methods in loaders['align']
# -------------------------------------------------------------------------
[docs]
def load_previous_alignments(self) -> dict:
"""
Load previous alignments for the selected shank.
Always returns the alignments from the default configuration.
Returns
-------
dict
Previous alignments for the selected shank.
"""
# Load the previous alignment from the default configuration
self.get_selected_shank()[self.default_config].loaders['align'].load_previous_alignments()
# Set the previous alignments from the non default configuration to the default
# one (if it exists)
if self.non_default_config is not None:
self.get_selected_shank()[self.non_default_config].loaders['align'].alignments = (
self.get_selected_shank()[self.default_config].loaders['align'].alignments
)
self.get_selected_shank()[self.non_default_config].loaders[
'align'
].get_previous_alignments()
return (
self.get_selected_shank()[self.default_config]
.loaders['align']
.get_previous_alignments()
)
[docs]
def get_previous_alignments(self) -> dict:
"""
Get previous alignments for the selected shank.
Always returns the alignments from the default configuration.
Returns
-------
dict
Previous alignments for the selected shank
"""
return (
self.get_selected_shank()[self.default_config]
.loaders['align']
.get_previous_alignments()
)
[docs]
def get_starting_alignment(self, idx: int) -> None:
"""
Set the index of the starting alignment for the selected shank and each configuration.
Parameters
----------
idx: int
The index of the previous alignment to load.
"""
for config in self.configs:
self.get_selected_shank()[config].loaders['align'].get_starting_alignment(idx)
[docs]
def set_init_alignment(self) -> None:
"""Initialise the alignment for the selected shank and each configuration."""
for config in self.configs:
self.get_selected_shank()[config].set_init_alignment()
# -------------------------------------------------------------------------
# Alignment handler methods - methods in align_handle
# -------------------------------------------------------------------------
[docs]
def next_idx(self) -> int:
"""
Return the index of the next available alignment for the selected shank.
Returns
-------
int
The index of the next available alignment stored in th circular buffer.
"""
if self.non_default_config is not None:
self.get_selected_shank()[self.non_default_config].align_handle.next_idx()
return self.get_selected_shank()[self.default_config].align_handle.next_idx()
[docs]
def prev_idx(self) -> int:
"""
Return the index of the previously available alignment for the selected shank.
Returns
-------
int
The index of the previous available alignment stored in th circular buffer.
"""
if self.non_default_config is not None:
self.get_selected_shank()[self.non_default_config].align_handle.prev_idx()
return self.get_selected_shank()[self.default_config].align_handle.prev_idx()
@property
def current_idx(self) -> int:
"""
Return the current index of the alignment stored in the buffer for the selected shank.
Returns
-------
int
The index of the current alignment
"""
return self.get_selected_shank()[self.default_config].align_handle.current_idx
@property
def total_idx(self) -> int:
"""
Return the total index of the alignments stored in the buffer for the selected shank.
Returns
-------
int
The total number of alignments stored in the circular buffer
"""
return self.get_selected_shank()[self.default_config].align_handle.total_idx
[docs]
def get_plot(self, shank: str, plot: str, key: str, config: str | None = None) -> Any:
"""
Access a specific plot for a specific shank and configuration.
Parameters
----------
shank: str
The shank label to access.
plot: str
The plot type to access. One of 'image', 'scatter', 'line', 'probe' or 'slice'
key: str
The plot key to access.
config: str
The configuration to access. If None, uses the default configuration.
Returns
-------
Any
The requested plot, or None if not found.
"""
config = config or self.default_config
return getattr(self.shanks[shank][config].loaders['plots'], plot).get(key, None)
[docs]
def get_plot_keys(self, plot: str) -> list[str]:
"""
Find a list of available keys across all shanks and configurations for a given plot type.
Parameters
----------
plot
The plot type to get the keys for. One of 'image', 'scatter', 'line', 'probe'
or 'slice'
Returns
-------
list
A list of unique plot keys.
"""
keys = []
for shank in self.shanks:
for config in self.configs:
keys += getattr(self.shanks[shank][config].loaders['plots'], plot).keys()
return sorted(set(keys))
@property
def image_keys(self) -> list[str]:
"""
Find the list of available image plot keys across all shanks and configurations.
Returns
-------
list:
A list of unique image plot keys.
"""
return self.get_plot_keys('image_plots')
@property
def scatter_keys(self) -> list[str]:
"""
Find the list of available scatter plot keys across all shanks and configurations.
Returns
-------
list:
A list of unique scatter plot keys.
"""
return self.get_plot_keys('scatter_plots')
@property
def line_keys(self) -> list[str]:
"""
Find the list of available line plot keys across all shanks and configurations.
Returns
-------
list:
A list of unique line plot keys.
"""
return self.get_plot_keys('line_plots')
@property
def probe_keys(self) -> list[str]:
"""
Find the list of available probe plot keys across all shanks and configurations.
Returns
-------
tuple:
A tuple of unique probe plot keys.
"""
return self.get_plot_keys('probe_plots')
@property
def feature_keys(self) -> list[str]:
"""
Find the list of available feature plot keys across all shanks and configurations.
Returns
-------
tuple:
A tuple of unique probe plot keys.
"""
return self.get_plot_keys('feature_plots')
@property
def slice_keys(self) -> list[str]:
"""
Find the list of available slice plot keys across all shanks and configurations.
Returns
-------
list:
A list of unique slice plot keys.
"""
return self.get_plot_keys('slice_plots')
# -------------------------------------------------------------------------
# Data loading & upload
# -------------------------------------------------------------------------
[docs]
def load_data(self) -> None:
"""Download and load data for all configs and shanks."""
slice_loader = self.download_histology()
for probe in self.shanks:
for config in self.configs:
self.shanks[probe][config].loaders['hist'] = slice_loader
self.shanks[probe][config].load_data()
[docs]
def load_plots(self):
"""Load plots for all configs and shanks."""
for probe in self.shanks:
for config in self.configs:
self.shanks[probe][config].load_plots()
[docs]
def upload_data(self) -> str:
"""
Upload data for the selected shank for each configuration.
Always returns the upload result from the default configuration.
Returns
-------
str
Upload result from the default config.
"""
info = Bunch()
for config in self.configs:
info[config] = self.get_selected_shank()[config].upload_data()
return info[self.default_config]
# -------------------------------------------------------------------------
# Utility
# -------------------------------------------------------------------------
[docs]
@staticmethod
def normalize_shank_label(shank_label: str) -> str:
"""
Normalize a shank label to the form 'probe0X'.
Parameters
----------
shank_label : str
Input shank label.
Returns
-------
str
Normalized label.
"""
match = re.match(r'(probe\d+)', shank_label)
return match.group(1) if match else shank_label
# -------------------------------------------------------------------------
# Abstract methods
# -------------------------------------------------------------------------
[docs]
@abstractmethod
def set_info(self, *args):
"""Set probe information."""
[docs]
@abstractmethod
def download_histology(self):
"""Load histology data."""
[docs]
@abstractmethod
def get_shanks(self, *args):
"""Return shank information."""
[docs]
@abstractmethod
def initialise_shanks(self):
"""Initialize shank data."""
[docs]
class ProbeHandlerONE(ProbeHandler):
"""
ONE implementation of ProbeHandler.
For this ProbeHandler all ephys and alignment data is downloaded and accessed via
ONE and Alyx.
The data for all shanks on a probe will be loaded at once.
Parameters
----------
one : ONE
An ONE instance used to upload results to Alyx
brain_atlas : AllenAtlas
An AllenAtlas object.
spike_collection : str, optional
Spike sorting algorithm to load (e.g. 'pykilosort', 'iblsorter').
"""
def __init__(
self,
one: ONE = None,
brain_atlas: AllenAtlas | None = None,
spike_collection: str | None = None,
):
self.one = one or ONE()
self.spike_collection = spike_collection
super().__init__(brain_atlas)
[docs]
def get_subjects(self) -> np.ndarray:
"""
Find all subjects that have probe insertions with spikesorting data.
Returns
-------
np.ndarray
An array of subject names
"""
self.sess_ins = self.one.alyx.rest(
'insertions', 'list', dataset_type='spikes.times', expires=timedelta(days=1)
)
self.subj_ins = [sess['session_info']['subject'] for sess in self.sess_ins]
self.subjects = np.unique(self.subj_ins)
return self.subjects
[docs]
def get_sessions(self, idx: int) -> np.ndarray:
"""
Find all probes for a given subject.
Note if multi-shank data it will return probe00 rather than probe00a, the individual shank
is chosen using the shank dropdown.
Parameters
----------
idx : idx
The index of the chosen subject
Returns
-------
np.ndarray
All probes with spikesorting data for the chosen subject
"""
self.chosen_sess = self.subjects[idx]
sess_idx = [i for i, e in enumerate(self.subj_ins) if e == self.chosen_sess]
self.sess = [self.sess_ins[idx] for idx in sess_idx]
self.sessions = [self.get_session_probe_name(sess) for sess in self.sess]
self.sessions = np.unique(self.sessions)
return self.sessions
[docs]
def get_shanks(self, idx: int) -> list:
"""
Find all shanks for a given probe and initialise the loaders.
Parameters
----------
idx : idx
The index of the chosen probe
Returns
-------
np.ndarray
All shanks for the chosen probe
"""
self.chosen_probe = self.sessions[idx]
sess_idx = [
i
for i, e in enumerate(self.sess)
if self.get_session_probe_name(e) == self.chosen_probe
]
self.shank_labels = [self.sess[idx] for idx in sess_idx]
shanks = [s['name'] for s in self.shank_labels]
idx = np.argsort(shanks)
self.shank_labels = np.array(self.shank_labels)[idx]
shanks = np.array(shanks)[idx]
self.initialise_shanks()
return list(shanks)
[docs]
def get_session_probe_name(self, ins: dict) -> str:
"""
Make a string containing the combination of session information and probe name.
Removes the shank identifiers from the probe names.
Parameters
----------
ins: dict
A dict containing insertion data
Returns
-------
str:
A string with the session info and probe name
"""
return (
ins['session_info']['start_time'][:10]
+ ' '
+ f'{ins["session_info"]["number"]:03}'
+ ' '
+ self.normalize_shank_label(ins['name'])
)
[docs]
def set_info(self, idx):
"""
Set the information about the selected shank.
Parameters
----------
idx: int
The index of the selected shank
"""
self.selected_shank = self.shank_labels[idx]['name']
self.selected_idx = idx
self.subj = self.shank_labels[idx]['session_info']['subject']
self.lab = self.shank_labels[idx]['session_info']['lab']
self.pid = self.shank_labels[idx]['id']
[docs]
def download_histology(self) -> NrrdSliceLoader:
"""Download and load in the histology slice data."""
_, hist_path = download_histology_data(self.subj, self.lab)
return NrrdSliceLoader(hist_path, self.brain_atlas)
[docs]
def initialise_shanks(self):
"""Initialise each shank with the loaders."""
self.shanks = defaultdict(Bunch)
for ins in self.shank_labels:
loaders = Bunch()
loaders['data'] = DataLoaderOne(ins, self.one, spike_collection=self.spike_collection)
loaders['geom'] = GeometryLoaderOne(
ins, self.one, probe_collection=loaders['data'].probe_collection
)
loaders['align'] = AlignmentLoaderOne(ins, self.one)
loaders['upload'] = AlignmentUploaderOne(ins, self.one, self.brain_atlas)
loaders['ephys'] = SpikeGLXLoaderOne(ins, self.one)
loaders['features'] = FeatureLoaderOne(ins, self.one)
loaders['plots'] = PlotLoader()
self.shanks[ins['name']][self.default_config] = ShankHandler(loaders, 0)
[docs]
def load_data(self) -> None:
"""Load data for all configs and shanks."""
print(f'******** Loading session {self.chosen_sess} {self.chosen_probe} ********')
super().load_data()
[docs]
class ProbeHandlerCSV(ProbeHandler):
"""
ProbeHandler where data from two channel maps has been recorded on the shanks.
The data for the dense configuration is available via ONE whereas the data for the quarter
configuration is only available on the local file system. Reads in a csv file that contains
information about where to read the relevant data from.
"""
def __init__(
self, csv_file: str | Path, one: ONE = None, brain_atlas: AllenAtlas | None = None
):
super().__init__(brain_atlas)
csv_file = Path(csv_file)
assert csv_file.exists()
self.root_path = csv_file.parent
self.df = pd.read_csv(csv_file, keep_default_na=False)
self.df['session_strip'] = self.df['session'].str.rsplit('/', n=1).str[0]
self.one = one or ONE()
self.possible_configs = ['quarter', 'dense', 'both']
self.configs = ['quarter', 'dense']
self.default_config = 'dense'
self.non_default_config = 'quarter'
self.selected_config = 'quarter'
[docs]
def get_subjects(self) -> np.ndarray:
"""
Find all sessions with spike sorting data.
Returns
-------
np.ndarray
All sessions with spikesorting data.
"""
# Returns sessions
self.subjects = self.df['session_strip'].unique()
return self.subjects
[docs]
def get_sessions(self, idx) -> np.ndarray:
"""
Find all probes for a given session.
Note if multi-shank data it will return probe00 rather than probe00a, the individual shank
is chosen using the shank dropdown.
Parameters
----------
idx : idx
The index of the chosen subject
Returns
-------
np.ndarray
All probes with spikesorting data for the chosen session
"""
self.session_df = self.df.loc[self.df['session_strip'] == self.subjects[idx]]
self.sessions = np.unique(
[self.normalize_shank_label(pr) for pr in self.session_df['probe'].values]
)
return self.sessions
[docs]
def get_shanks(self, idx: int) -> np.ndarray:
"""
Find all shanks for a given probe and initialise the loaders.
Parameters
----------
idx : idx
The index of the chosen probe
Returns
-------
np.ndarray
All shanks for the chosen probe
"""
shank = self.sessions[idx]
self.shank_df = self.session_df.loc[
self.session_df['probe'].str.contains(shank)
].sort_values('probe')
self.initialise_shanks()
self.shank_labels = self.shank_df['probe'].unique()
return self.shank_labels
[docs]
def set_info(self, idx: int) -> None:
"""
Set the information about the selected shank.
Parameters
----------
idx: int
The index of the selected shank
"""
self.selected_shank = self.shank_labels[idx]
self.selected_idx = idx
[docs]
def download_histology(self) -> NrrdSliceLoader:
"""Download and load in the histology slice data."""
_, hist_path = download_histology_data(self.subj, self.lab)
return NrrdSliceLoader(hist_path, self.brain_atlas)
[docs]
def initialise_shanks(self) -> None:
"""Initialise each shank and config with the selected loaders."""
self.shanks = defaultdict(Bunch)
user = params.get().ALYX_LOGIN
for _, shank in self.shank_df.iterrows():
loaders = Bunch()
local_path = self.root_path.joinpath(shank.local_path)
data_paths = DatasetPaths(
spike_sorting=local_path.joinpath(shank.spike_collection or ''),
processed_ephys=local_path.joinpath(shank.ephys_collection or ''),
raw_ephys=local_path.joinpath(shank.meta_collection or ''),
task=local_path.joinpath(shank.task_collection or ''),
raw_task=local_path.joinpath(shank.raw_task_collection or ''),
)
ins = self.get_insertion(shank)
xyz_picks = ins['json'].get('xyz_picks', None)
xyz_picks = np.array(xyz_picks) / 1e6 if xyz_picks is not None else None
if shank.is_quarter: # Quarter is offline
loaders['data'] = DataLoaderLocal(data_paths)
loaders['geom'] = GeometryLoaderLocal(data_paths)
loaders['align'] = AlignmentLoaderLocal(
data_paths.spike_sorting, 0, 1, user=user, xyz_picks=xyz_picks
)
loaders['upload'] = AlignmentUploaderLocal(
data_paths.spike_sorting, 0, loaders['geom'], self.brain_atlas, user=user
)
loaders['ephys'] = SpikeGLXLoaderLocal(data_paths.raw_ephys)
loaders['plots'] = PlotLoader()
loaders['features'] = FeatureLoaderOne(ins, self.one)
self.shanks[shank.probe]['quarter'] = ShankHandler(loaders, 0)
else: # Dense is online
# If we don't have the data locally we download it
if data_paths.spike_sorting == local_path:
loaders['data'] = DataLoaderOne(ins, self.one)
loaders['geom'] = GeometryLoaderOne(
ins, self.one, probe_collection=loaders['data'].probe_collection
)
# Otherwise we load from local
else:
loaders['data'] = DataLoaderLocal(data_paths)
loaders['geom'] = GeometryLoaderLocal(data_paths)
loaders['align'] = AlignmentLoaderOne(ins, self.one, user=user)
loaders['upload'] = AlignmentUploaderOne(ins, self.one, self.brain_atlas)
loaders['ephys'] = SpikeGLXLoaderOne(ins, self.one)
loaders['features'] = FeatureLoaderOne(ins, self.one)
loaders['plots'] = PlotLoader()
self.shanks[shank.probe]['dense'] = ShankHandler(loaders, 0)
self._sync_alignments()
self.subj = shank['subject']
self.lab = shank['lab']
def _sync_alignments(self) -> None:
"""Synchronize alignments between dense and quarter loaders."""
for _, shank_group in self.shanks.items():
dense_align = shank_group['dense'].loaders['align']
quarter_align = shank_group['quarter'].loaders['align']
if dense_align.alignment_keys != ['original']:
# Alyx alignment exists: overwrite local
quarter_align.alignments = dense_align.alignments
quarter_align.get_previous_alignments()
quarter_align.get_starting_alignment(0)
elif quarter_align.alignment_keys != ['original']:
# Local alignment exists: add to online
dense_align.add_extra_alignments(quarter_align.alignments)
dense_align.get_previous_alignments()
dense_align.get_starting_alignment(0)
# Ensure consistency by syncing quarter with updated dense
quarter_align.alignments = dense_align.alignments
quarter_align.get_previous_alignments()
quarter_align.get_starting_alignment(0)
[docs]
def get_insertion(self, shank: pd.Series) -> dict:
"""Get the alyx probe insertion for the shank."""
ins = self.one.alyx.rest('insertions', 'list', id=shank.pid, expires=timedelta(days=1))
return ins[0]
[docs]
class ProbeHandlerLocal(ProbeHandler):
"""
Local file system implementation of ProbeHandler.
For this ProbeHandler, all ephys and alignment data must be stored in a single folder on disk.
"""
def __init__(self, brain_atlas: AllenAtlas | None = None):
super().__init__(brain_atlas)
[docs]
def get_shanks(self, folder_path: Path) -> list[str]:
"""
Find the number of shanks on the probes.
Loads the channels or ap meta data from the folder path and initialises the loaders
for each shank.
Parameters
----------
folder_path : Path
A path to the folder on the local disk that contains the data
"""
self.data_paths = DatasetPaths(
spike_sorting=folder_path,
processed_ephys=folder_path,
raw_ephys=folder_path,
histology=folder_path,
picks=folder_path,
output=folder_path,
)
# Load in the geometry and find the number of shnaks
self.geom = GeometryLoaderLocal(self.data_paths)
self.geom.get_geometry()
self.n_shanks = self.geom.channels.n_shanks
if self.n_shanks == 1:
self.shank_labels = ['1/1']
else:
self.shank_labels = [
f'{iShank + 1}/{self.n_shanks}' for iShank in range(self.n_shanks)
]
self.initialise_shanks()
return self.shank_labels
[docs]
def set_info(self, idx: int) -> None:
"""
Set the information about the selected shank.
Parameters
----------
idx: int
The index of the selected shank
"""
self.selected_shank = f'shank_{self.shank_labels[idx]}'
self.selected_idx = idx
[docs]
def download_histology(self) -> NrrdSliceLoader:
"""Load in the histology slice data."""
return NrrdSliceLoader(self.data_paths.histology, self.brain_atlas)
[docs]
def initialise_shanks(self) -> None:
"""Initialise each shank with the loaders."""
self.shanks = defaultdict(Bunch)
for ish, ishank in enumerate(self.shank_labels):
loaders = Bunch()
loaders['geom'] = self.geom
loaders['data'] = DataLoaderLocal(self.data_paths)
loaders['align'] = AlignmentLoaderLocal(self.data_paths.picks, ish, self.n_shanks)
loaders['upload'] = AlignmentUploaderLocal(
self.data_paths.output, ish, self.n_shanks, self.brain_atlas
)
loaders['ephys'] = SpikeGLXLoaderLocal(self.data_paths.raw_ephys)
loaders['plots'] = PlotLoader()
self.shanks[f'shank_{ishank}'][self.default_config] = ShankHandler(loaders, ish)
[docs]
class ProbeHandlerLocalYaml(ProbeHandler):
"""
Local file system implementation of ProbeHandler that uses a yaml file.
The yaml file contains information about where to read the relevant data from.
"""
def __init__(self, yaml_file: str | Path, brain_atlas: AllenAtlas | None = None):
super().__init__(brain_atlas)
configs, probes, self.data_paths = load_alignment_yaml(yaml_file)
if len(configs) > 1:
self.configs = configs
self.default_config = self.configs[0]
self.non_default_config = self.configs[1]
self.possible_configs = self.configs + ['both']
self.selected_config = self.configs[0]
self.probes = probes
[docs]
def get_shanks(self, _) -> list[str]:
"""
Initialise the shanks based on the yaml file.
If only one probe label is given we load in the geometry to see if it is a
multi-shank recording. Otherwise, we assume the yaml has specified all shanks
and these are treated individually.
"""
# If we have only one probe label we load in the geometry to see if it is a
# multi-shank recording
if len(self.probes) == 1:
# Load in the geometry and find the number of shanks
data_path = self.data_paths[self.default_config][self.probes[0]]
self.geom = GeometryLoaderLocal(data_path)
self.geom.get_geometry()
self.n_shanks = self.geom.channels.n_shanks
if self.n_shanks == 1:
self.shank_labels = self.probes
else:
self.shank_labels = [f'shank_{iShank + 1}' for iShank in range(self.n_shanks)]
# Otherwise we assume the yaml has specified all shanks and these are treated individually
else:
self.shank_labels = self.probes
self.n_shanks = 1
self.initialise_shanks()
return self.shank_labels
[docs]
def set_info(self, idx: int) -> None:
"""
Set the information about the selected shank.
Parameters
----------
idx: int
The index of the selected shank
"""
self.selected_shank = self.shank_labels[idx]
self.selected_idx = idx
[docs]
def download_histology(self) -> NrrdSliceLoader:
"""Load in the histology slice data."""
histology_path = self.data_paths[self.selected_config][self.shank_labels[0]].histology
return NrrdSliceLoader(histology_path, self.brain_atlas)
[docs]
def initialise_shanks(self) -> None:
"""Initialise each shank and config with the selected loaders."""
self.shanks = defaultdict(Bunch)
for ish, shank in enumerate(self.shank_labels):
ishank = ish if self.n_shanks > 1 else 0
for config in self.configs:
data_paths = self.data_paths[config][shank]
loaders = Bunch()
loaders['data'] = DataLoaderLocal(data_paths)
loaders['geom'] = GeometryLoaderLocal(data_paths)
loaders['align'] = AlignmentLoaderLocal(data_paths.picks, ishank, self.n_shanks)
loaders['upload'] = AlignmentUploaderLocal(
data_paths.output, ishank, loaders['geom'], self.brain_atlas
)
loaders['ephys'] = SpikeGLXLoaderLocal(data_paths.raw_ephys)
loaders['plots'] = PlotLoader()
self.shanks[shank][config] = ShankHandler(loaders, ishank)