"""Techniques to project the brain volume onto 2D images for visualisation purposes."""
from functools import lru_cache
import logging
import json

import nrrd
import numpy as np
from scipy.interpolate import interp1d
import matplotlib.pyplot as plt

from iblutil.util import Bunch
from import md5
import as aws

from iblatlas.atlas import AllenAtlas

_logger = logging.getLogger(__name__)

[docs] class FlatMap(AllenAtlas): """ """ def __init__(self, flatmap='dorsal_cortex', res_um=25): """ Available flatmaps are currently 'dorsal_cortex', 'circles' and 'pyramid' :param flatmap: :param res_um: """ super().__init__(res_um=res_um) = flatmap if flatmap == 'dorsal_cortex': self._get_flatmap_from_file() elif flatmap == 'circles': if res_um != 25: raise NotImplementedError('Pyramid circles not implemented for resolution other than 25um') self.flatmap, self.ml_scale, self.ap_scale = circles(N=5, atlas=self, display='flat') elif flatmap == 'pyramid': if res_um != 25: raise NotImplementedError('Pyramid circles not implemented for resolution other than 25um') self.flatmap, self.ml_scale, self.ap_scale = circles(N=5, atlas=self, display='pyramid') def _get_flatmap_from_file(self): # gets the file in the ONE cache for the flatmap name in the property, downloads it if needed file_flatmap = self._get_cache_dir().joinpath('flatmaps', f'{}_{self.res_um}.nrrd') if not file_flatmap.exists(): file_flatmap.parent.mkdir(exist_ok=True, parents=True) aws.s3_download_file(f'atlas/{}', file_flatmap) self.flatmap, _ =
[docs] def plot_flatmap(self, depth=0, volume='annotation', mapping='Allen', region_values=None, ax=None, **kwargs): """ Displays the 2D image corresponding to the flatmap. If there are several depths, by default it will display the first one. Parameters ---------- depth : int Index of the depth to display in the flatmap volume (the last dimension). volume : {'image', 'annotation', 'boundary', 'value'} - 'image' - Allen image volume. - 'annotation' - Allen annotation volume. - 'boundary' - outline of boundaries between all regions. - 'volume' - custom volume, must pass in volume of shape BrainAtlas.image.shape as regions_value argument. mapping : str, default='Allen' The brain region mapping to use. region_values : numpy.array An array the shape of the brain atlas image containing custom region values. Used when `volume` value is 'volume'. ax : matplotlib.pyplot.Axes, optional A set of axes to plot to. **kwargs See matplotlib.pyplot.imshow. Returns ------- matplotlib.pyplot.Axes The plotted image axes. """ if self.flatmap.ndim == 3: inds = np.int32(self.flatmap[:, :, depth]) else: inds = np.int32(self.flatmap[:, :]) regions = self._get_mapping(mapping=mapping)[self.label.flat[inds]] if volume == 'annotation': im = self._label2rgb(regions) elif volume == 'value': im = region_values[regions] elif volume == 'boundary': im = self.compute_boundaries(regions) elif volume == 'image': im = self.image.flat[inds] else: raise ValueError(f'Volume type "{volume}" not supported') if not ax: ax = plt.gca() return self._plot_slice(im, self.extent_flmap(), ax=ax, volume=volume, **kwargs)
[docs] def extent_flmap(self): """ Returns the boundary coordinates of the flat map. Returns ------- numpy.array The bounding coordinates of the flat map image, specified as (left, right, bottom, top). """ extent = np.r_[0, self.flatmap.shape[1], 0, self.flatmap.shape[0]] return extent
[docs] @lru_cache(maxsize=1, typed=False) def circles(N=5, atlas=None, display='flat'): """ :param N: number of circles :param atlas: brain atlas at 25 m :param display: "flat" or "pyramid" :return: 2D map of indices, ap_coordinate, ml_coordinate """ atlas = atlas if atlas else AllenAtlas() sz = np.array([]) level = np.array([]) for k in np.arange(N): nlast = 2000 # 25 um for 5mm diameter n = int((k + 1) * nlast / N) r = .4 * (k + 1) / N theta = (np.linspace(0, 2 * np.pi, n) + np.pi / 2) sz = np.r_[sz, r * np.exp(1j * theta)] level = np.r_[level, theta * 0 + k] atlas.compute_surface() iy, ix = np.where(~np.isnan( centroid = np.array([np.mean(iy), np.mean(ix)]) xlim = np.array([np.min(ix), np.max(ix)]) ylim = np.array([np.min(iy), np.max(iy)]) s = Bunch( x=np.real(sz) * np.diff(xlim) + centroid[1], y=np.imag(sz) * np.diff(ylim) + centroid[0], level=level, distance=level * 0, ) # compute the overall linear distance for each circle d0 = 0 for lev in np.unique(s['level']): ind = s['level'] == lev diff = np.abs(np.diff(s['x'][ind] + 1j * s['y'][ind])) s['distance'][ind] = np.cumsum(np.r_[0, diff]) + d0 d0 = s['distance'][ind][-1] fcn = interp1d(s['distance'], s['x'] + 1j * s['y'], fill_value='extrap') d = np.arange(0, np.ceil(s['distance'][-1])) s_ = Bunch({ 'x': np.real(fcn(d)), 'y': np.imag(fcn(d)), 'level': interp1d(s['distance'], level, kind='nearest')(d), 'distance': d }) if display == 'flat': ih = np.arange( iw = np.arange(s_['distance'].size) image_map = np.zeros((ih.size, iw.size), dtype=np.int32) iw, ih = np.meshgrid(iw, ih) # i2d = np.ravel_multi_index((ih[:], iw[:]), image_map.shape) iml, _ = np.meshgrid(np.round(s_.x).astype(np.int32), np.arange( iap, idv = np.meshgrid(np.round(s_.y).astype(np.int32), np.arange( i3d = atlas._lookup_inds(np.c_[iml.flat, iap.flat, idv.flat]) i3d = np.reshape(i3d, [, s_['x'].size]) image_map[ih, iw] = i3d elif display == 'pyramid': for i in np.flipud(np.arange(N)): ind = s_['level'] == i dtot = s_['distance'][ind] dtot = dtot - np.mean(dtot) if i == N - 1: ipx = np.arange(np.floor(dtot[0]), np.ceil(dtot[-1]) + 1) nh = * N X0 = int(ipx[-1]) image_map = np.zeros((nh, ipx.size), dtype=np.int32) iw = np.arange(np.sum(ind)) iw = np.int32(iw - np.mean(iw) + X0) ih = * i + np.arange( iw, ih = np.meshgrid(iw, ih) iml, _ = np.meshgrid(np.round(s_.x[ind]).astype(np.int32), np.arange( iap, idv = np.meshgrid(np.round(s_.y[ind]).astype(np.int32), np.arange( i3d = atlas._lookup_inds(np.c_[iml.flat, iap.flat, idv.flat]) i3d = np.reshape(i3d, [, s_['x'][ind].size]) image_map[ih, iw] = i3d x, y = (atlas.bc.i2x(s.x), atlas.bc.i2y(s.y)) return image_map, x, y
# if display == 'flat': # fig, ax = plt.subplots(2, 1, figsize=(16, 5)) # elif display == 'pyramid': # fig, ax = plt.subplots(1, 2, figsize=(14, 12)) # ax[0].imshow(ba._label2rgb(ba.label.flat[image_map]), origin='upper') # ax[1].imshow( # ax[1].plot(centroid[1], centroid[0], '*') # ax[1].plot(s.x, s.y)
[docs] def swanson(filename="swanson2allen.npz"): """ A rasterized rendition of the Swanson projection of the mouse brain, which is a 2D representation of the mouse brain. Each pixel in the image corresponds to a region index in the Allen CCFv2 annotation volume. [1] J. D. Hahn et al., β€œAn open access mouse brain flatmap and upgraded rat and human brain flatmaps based on current reference atlases,” J Comp Neurol, vol. 529, no. 3, pp. 576–594, Feb. 2021, doi: 10.1002/cne.24966. Parameters ---------- filename Returns ------- """ # filename could be "swanson2allen_original.npz", or "swanson2allen.npz" for remapped indices to match # existing labels in the brain atlas OLD_MD5 = [ 'bb0554ecc704dd4b540151ab57f73822', # version 2022-05-02 (remapped) '7722c1307cf9a6f291ad7632e5dcc88b', # version 2022-05-09 (removed wolf pixels and 2 artefact regions) ] npz_file = AllenAtlas._get_cache_dir().joinpath('flatmaps', filename) if not npz_file.exists() or md5(npz_file) in OLD_MD5: npz_file.parent.mkdir(exist_ok=True, parents=True)'downloading swanson image from {aws.S3_BUCKET_IBL} s3 bucket...') aws.s3_download_file(f'atlas/{}', npz_file) s2a = np.load(npz_file)['swanson2allen'] # inds contains regions ids return s2a
[docs] def swanson_json(filename="swansonpaths.json", remap=True): """ Vectorized version of the swanson bitmap file. The vectorized version was generated from swanson() using matlab contour to find the paths for each region. The paths for each region were then simplified using the Ramer Douglas Peucker algorithm Parameters ---------- filename remap Returns ------- """ OLD_MD5 = ['97ccca2b675b28ba9b15ca8af5ba4111', # errored map with FOTU and CUL4, 5 mixed up '56daa7022b5e03080d8623814cda6f38', # old md5 of swanson json without CENT and PTLp # and CUL4 split (on s3 called swansonpaths_56daa.json) 'f848783954883c606ca390ceda9e37d2'] json_file = AllenAtlas._get_cache_dir().joinpath('flatmaps', filename) if not json_file.exists() or md5(json_file) in OLD_MD5: json_file.parent.mkdir(exist_ok=True, parents=True)'downloading swanson paths from {aws.S3_BUCKET_IBL} s3 bucket...') aws.s3_download_file(f'atlas/{}', json_file, overwrite=True) with open(json_file) as f: sw_json = json.load(f) # The swanson contains regions that are children of regions contained within the Allen # annotation volume. Here we remap these regions to the parent that is contained with the # annotation volume if remap: id_map = {391: [392, 393, 394, 395, 396], 474: [483, 487], 536: [537, 541], 601: [602, 603, 604, 608], 622: [624, 625, 626, 627, 628, 629, 630, 631, 632, 634, 635, 636, 637, 638], 686: [687, 688, 689], 708: [709, 710], 721: [723, 724, 726, 727, 729, 730, 731], 740: [741, 742, 743], 758: [759, 760, 761, 762], 771: [772, 773], 777: [778, 779, 780], 788: [789, 790, 791, 792], 835: [836, 837, 838], 891: [894, 895, 896, 897, 898, 900, 901, 902], 926: [927, 928], 949: [950, 951, 952, 953, 954], 957: [958, 959, 960, 961, 962], 999: [1000, 1001], 578: [579, 580]} rev_map = {} for k, vals in id_map.items(): for v in vals: rev_map[v] = k for sw in sw_json: sw['thisID'] = rev_map.get(sw['thisID'], sw['thisID']) return sw_json
@lru_cache(maxsize=None) def _swanson_labels_positions(thres=20000): """ Computes label positions to overlay on the Swanson flatmap. Parameters ---------- thres : int, default=20000 The number of pixels above which a region is labeled. Returns ------- dict of str A map of brain acronym to a tuple of x y coordinates. """ s2a = swanson() iw, ih = np.meshgrid(np.arange(s2a.shape[1]), np.arange(s2a.shape[0])) # compute the center of mass of all regions (fast enough to do on the fly) bc = np.maximum(1, np.bincount(s2a.flatten())) cmw = np.bincount(s2a.flatten(), weights=iw.flatten()) / bc cmh = np.bincount(s2a.flatten(), weights=ih.flatten()) / bc bc[0] = 1 NWH, NWW = (200, 600) h, w = s2a.shape labels = {} for ilabel in np.where(bc > thres)[0]: x, y = (cmw[ilabel], cmh[ilabel]) # the polygon is convex and the label is outside. Dammit !!! if s2a[int(y), int(x)] != ilabel: # find the nearest point to the center of mass ih, iw = np.where(s2a == ilabel) iimin = np.argmin(np.abs((x - iw) + 1j * (y - ih))) # get the center of mass of a window around this point sh = np.arange(np.maximum(0, ih[iimin] - NWH), np.minimum(ih[iimin] + NWH, h)) sw = np.arange(np.maximum(0, iw[iimin] - NWW), np.minimum(iw[iimin] + NWW, w)) roi = s2a[sh][:, sw] == ilabel roi = roi / np.sum(roi) # ax.plot(x, y, 'k+') # ax.plot(iw[iimin], ih[iimin], '*k') x = sw[np.searchsorted(np.cumsum(np.sum(roi, axis=0)), .5) - 1] y = sh[np.searchsorted(np.cumsum(np.sum(roi, axis=1)), .5) - 1] # ax.plot(x, y, 'r+') labels[ilabel] = (x, y) return labels