"""Techniques to project the brain volume onto 2D images for visualisation purposes."""
from functools import lru_cache
import logging
import json
import nrrd
import numpy as np
from scipy.interpolate import interp1d
import matplotlib.pyplot as plt
from iblutil.util import Bunch
from iblutil.io.hashfile import md5
import one.remote.aws as aws
from iblatlas.atlas import AllenAtlas
_logger = logging.getLogger(__name__)
[docs]
class FlatMap(AllenAtlas):
"""
"""
def __init__(self, flatmap='dorsal_cortex', res_um=25):
"""
Available flatmaps are currently 'dorsal_cortex', 'circles' and 'pyramid'
:param flatmap:
:param res_um:
"""
super().__init__(res_um=res_um)
self.name = flatmap
if flatmap == 'dorsal_cortex':
self._get_flatmap_from_file()
elif flatmap == 'circles':
if res_um != 25:
raise NotImplementedError('Pyramid circles not implemented for resolution other than 25um')
self.flatmap, self.ml_scale, self.ap_scale = circles(N=5, atlas=self, display='flat')
elif flatmap == 'pyramid':
if res_um != 25:
raise NotImplementedError('Pyramid circles not implemented for resolution other than 25um')
self.flatmap, self.ml_scale, self.ap_scale = circles(N=5, atlas=self, display='pyramid')
def _get_flatmap_from_file(self):
# gets the file in the ONE cache for the flatmap name in the property, downloads it if needed
file_flatmap = self._get_cache_dir().joinpath('flatmaps', f'{self.name}_{self.res_um}.nrrd')
if not file_flatmap.exists():
file_flatmap.parent.mkdir(exist_ok=True, parents=True)
aws.s3_download_file(f'atlas/{file_flatmap.name}', file_flatmap)
self.flatmap, _ = nrrd.read(file_flatmap)
[docs]
def plot_flatmap(self, depth=0, volume='annotation', mapping='Allen', region_values=None, ax=None, **kwargs):
"""
Displays the 2D image corresponding to the flatmap.
If there are several depths, by default it will display the first one.
Parameters
----------
depth : int
Index of the depth to display in the flatmap volume (the last dimension).
volume : {'image', 'annotation', 'boundary', 'value'}
- 'image' - Allen image volume.
- 'annotation' - Allen annotation volume.
- 'boundary' - outline of boundaries between all regions.
- 'volume' - custom volume, must pass in volume of shape BrainAtlas.image.shape as
regions_value argument.
mapping : str, default='Allen'
The brain region mapping to use.
region_values : numpy.array
An array the shape of the brain atlas image containing custom region values. Used when
`volume` value is 'volume'.
ax : matplotlib.pyplot.Axes, optional
A set of axes to plot to.
**kwargs
See matplotlib.pyplot.imshow.
Returns
-------
matplotlib.pyplot.Axes
The plotted image axes.
"""
if self.flatmap.ndim == 3:
inds = np.int32(self.flatmap[:, :, depth])
else:
inds = np.int32(self.flatmap[:, :])
regions = self._get_mapping(mapping=mapping)[self.label.flat[inds]]
if volume == 'annotation':
im = self._label2rgb(regions)
elif volume == 'value':
im = region_values[regions]
elif volume == 'boundary':
im = self.compute_boundaries(regions)
elif volume == 'image':
im = self.image.flat[inds]
else:
raise ValueError(f'Volume type "{volume}" not supported')
if not ax:
ax = plt.gca()
return self._plot_slice(im, self.extent_flmap(), ax=ax, volume=volume, **kwargs)
[docs]
def extent_flmap(self):
"""
Returns the boundary coordinates of the flat map.
Returns
-------
numpy.array
The bounding coordinates of the flat map image, specified as (left, right, bottom, top).
"""
extent = np.r_[0, self.flatmap.shape[1], 0, self.flatmap.shape[0]]
return extent
[docs]
@lru_cache(maxsize=1, typed=False)
def circles(N=5, atlas=None, display='flat'):
"""
:param N: number of circles
:param atlas: brain atlas at 25 m
:param display: "flat" or "pyramid"
:return: 2D map of indices, ap_coordinate, ml_coordinate
"""
atlas = atlas if atlas else AllenAtlas()
sz = np.array([])
level = np.array([])
for k in np.arange(N):
nlast = 2000 # 25 um for 5mm diameter
n = int((k + 1) * nlast / N)
r = .4 * (k + 1) / N
theta = (np.linspace(0, 2 * np.pi, n) + np.pi / 2)
sz = np.r_[sz, r * np.exp(1j * theta)]
level = np.r_[level, theta * 0 + k]
atlas.compute_surface()
iy, ix = np.where(~np.isnan(atlas.top))
centroid = np.array([np.mean(iy), np.mean(ix)])
xlim = np.array([np.min(ix), np.max(ix)])
ylim = np.array([np.min(iy), np.max(iy)])
s = Bunch(
x=np.real(sz) * np.diff(xlim) + centroid[1],
y=np.imag(sz) * np.diff(ylim) + centroid[0],
level=level,
distance=level * 0,
)
# compute the overall linear distance for each circle
d0 = 0
for lev in np.unique(s['level']):
ind = s['level'] == lev
diff = np.abs(np.diff(s['x'][ind] + 1j * s['y'][ind]))
s['distance'][ind] = np.cumsum(np.r_[0, diff]) + d0
d0 = s['distance'][ind][-1]
fcn = interp1d(s['distance'], s['x'] + 1j * s['y'], fill_value='extrap')
d = np.arange(0, np.ceil(s['distance'][-1]))
s_ = Bunch({
'x': np.real(fcn(d)),
'y': np.imag(fcn(d)),
'level': interp1d(s['distance'], level, kind='nearest')(d),
'distance': d
})
if display == 'flat':
ih = np.arange(atlas.bc.nz)
iw = np.arange(s_['distance'].size)
image_map = np.zeros((ih.size, iw.size), dtype=np.int32)
iw, ih = np.meshgrid(iw, ih)
# i2d = np.ravel_multi_index((ih[:], iw[:]), image_map.shape)
iml, _ = np.meshgrid(np.round(s_.x).astype(np.int32), np.arange(atlas.bc.nz))
iap, idv = np.meshgrid(np.round(s_.y).astype(np.int32), np.arange(atlas.bc.nz))
i3d = atlas._lookup_inds(np.c_[iml.flat, iap.flat, idv.flat])
i3d = np.reshape(i3d, [atlas.bc.nz, s_['x'].size])
image_map[ih, iw] = i3d
elif display == 'pyramid':
for i in np.flipud(np.arange(N)):
ind = s_['level'] == i
dtot = s_['distance'][ind]
dtot = dtot - np.mean(dtot)
if i == N - 1:
ipx = np.arange(np.floor(dtot[0]), np.ceil(dtot[-1]) + 1)
nh = atlas.bc.nz * N
X0 = int(ipx[-1])
image_map = np.zeros((nh, ipx.size), dtype=np.int32)
iw = np.arange(np.sum(ind))
iw = np.int32(iw - np.mean(iw) + X0)
ih = atlas.bc.nz * i + np.arange(atlas.bc.nz)
iw, ih = np.meshgrid(iw, ih)
iml, _ = np.meshgrid(np.round(s_.x[ind]).astype(np.int32), np.arange(atlas.bc.nz))
iap, idv = np.meshgrid(np.round(s_.y[ind]).astype(np.int32), np.arange(atlas.bc.nz))
i3d = atlas._lookup_inds(np.c_[iml.flat, iap.flat, idv.flat])
i3d = np.reshape(i3d, [atlas.bc.nz, s_['x'][ind].size])
image_map[ih, iw] = i3d
x, y = (atlas.bc.i2x(s.x), atlas.bc.i2y(s.y))
return image_map, x, y
# if display == 'flat':
# fig, ax = plt.subplots(2, 1, figsize=(16, 5))
# elif display == 'pyramid':
# fig, ax = plt.subplots(1, 2, figsize=(14, 12))
# ax[0].imshow(ba._label2rgb(ba.label.flat[image_map]), origin='upper')
# ax[1].imshow(ba.top)
# ax[1].plot(centroid[1], centroid[0], '*')
# ax[1].plot(s.x, s.y)
[docs]
def swanson(filename="swanson2allen.npz"):
"""
A rasterized rendition of the Swanson projection of the mouse brain, which is a 2D
representation of the mouse brain.
Each pixel in the image corresponds to a region index in the Allen CCFv2 annotation volume.
[1] J. D. Hahn et al., βAn open access mouse brain flatmap and upgraded rat and human
brain flatmaps based on current reference atlases,β
J Comp Neurol, vol. 529, no. 3, pp. 576β594, Feb. 2021, doi: 10.1002/cne.24966.
Parameters
----------
filename
Returns
-------
"""
# filename could be "swanson2allen_original.npz", or "swanson2allen.npz" for remapped indices to match
# existing labels in the brain atlas
OLD_MD5 = [
'bb0554ecc704dd4b540151ab57f73822', # version 2022-05-02 (remapped)
'7722c1307cf9a6f291ad7632e5dcc88b', # version 2022-05-09 (removed wolf pixels and 2 artefact regions)
]
npz_file = AllenAtlas._get_cache_dir().joinpath('flatmaps', filename)
if not npz_file.exists() or md5(npz_file) in OLD_MD5:
npz_file.parent.mkdir(exist_ok=True, parents=True)
_logger.info(f'downloading swanson image from {aws.S3_BUCKET_IBL} s3 bucket...')
aws.s3_download_file(f'atlas/{npz_file.name}', npz_file)
s2a = np.load(npz_file)['swanson2allen'] # inds contains regions ids
return s2a
[docs]
def swanson_json(filename="swansonpaths.json", remap=True):
"""
Vectorized version of the swanson bitmap file. The vectorized version was generated from swanson() using matlab
contour to find the paths for each region. The paths for each region were then simplified using the
Ramer Douglas Peucker algorithm https://rdp.readthedocs.io/en/latest/
Parameters
----------
filename
remap
Returns
-------
"""
OLD_MD5 = ['97ccca2b675b28ba9b15ca8af5ba4111', # errored map with FOTU and CUL4, 5 mixed up
'56daa7022b5e03080d8623814cda6f38', # old md5 of swanson json without CENT and PTLp
# and CUL4 split (on s3 called swansonpaths_56daa.json)
'f848783954883c606ca390ceda9e37d2']
json_file = AllenAtlas._get_cache_dir().joinpath('flatmaps', filename)
if not json_file.exists() or md5(json_file) in OLD_MD5:
json_file.parent.mkdir(exist_ok=True, parents=True)
_logger.info(f'downloading swanson paths from {aws.S3_BUCKET_IBL} s3 bucket...')
aws.s3_download_file(f'atlas/{json_file.name}', json_file, overwrite=True)
with open(json_file) as f:
sw_json = json.load(f)
# The swanson contains regions that are children of regions contained within the Allen
# annotation volume. Here we remap these regions to the parent that is contained with the
# annotation volume
if remap:
id_map = {391: [392, 393, 394, 395, 396],
474: [483, 487],
536: [537, 541],
601: [602, 603, 604, 608],
622: [624, 625, 626, 627, 628, 629, 630, 631, 632, 634, 635, 636, 637, 638],
686: [687, 688, 689],
708: [709, 710],
721: [723, 724, 726, 727, 729, 730, 731],
740: [741, 742, 743],
758: [759, 760, 761, 762],
771: [772, 773],
777: [778, 779, 780],
788: [789, 790, 791, 792],
835: [836, 837, 838],
891: [894, 895, 896, 897, 898, 900, 901, 902],
926: [927, 928],
949: [950, 951, 952, 953, 954],
957: [958, 959, 960, 961, 962],
999: [1000, 1001],
578: [579, 580]}
rev_map = {}
for k, vals in id_map.items():
for v in vals:
rev_map[v] = k
for sw in sw_json:
sw['thisID'] = rev_map.get(sw['thisID'], sw['thisID'])
return sw_json
@lru_cache(maxsize=None)
def _swanson_labels_positions(thres=20000):
"""
Computes label positions to overlay on the Swanson flatmap.
Parameters
----------
thres : int, default=20000
The number of pixels above which a region is labeled.
Returns
-------
dict of str
A map of brain acronym to a tuple of x y coordinates.
"""
s2a = swanson()
iw, ih = np.meshgrid(np.arange(s2a.shape[1]), np.arange(s2a.shape[0]))
# compute the center of mass of all regions (fast enough to do on the fly)
bc = np.maximum(1, np.bincount(s2a.flatten()))
cmw = np.bincount(s2a.flatten(), weights=iw.flatten()) / bc
cmh = np.bincount(s2a.flatten(), weights=ih.flatten()) / bc
bc[0] = 1
NWH, NWW = (200, 600)
h, w = s2a.shape
labels = {}
for ilabel in np.where(bc > thres)[0]:
x, y = (cmw[ilabel], cmh[ilabel])
# the polygon is convex and the label is outside. Dammit !!!
if s2a[int(y), int(x)] != ilabel:
# find the nearest point to the center of mass
ih, iw = np.where(s2a == ilabel)
iimin = np.argmin(np.abs((x - iw) + 1j * (y - ih)))
# get the center of mass of a window around this point
sh = np.arange(np.maximum(0, ih[iimin] - NWH), np.minimum(ih[iimin] + NWH, h))
sw = np.arange(np.maximum(0, iw[iimin] - NWW), np.minimum(iw[iimin] + NWW, w))
roi = s2a[sh][:, sw] == ilabel
roi = roi / np.sum(roi)
# ax.plot(x, y, 'k+')
# ax.plot(iw[iimin], ih[iimin], '*k')
x = sw[np.searchsorted(np.cumsum(np.sum(roi, axis=0)), .5) - 1]
y = sh[np.searchsorted(np.cumsum(np.sum(roi, axis=1)), .5) - 1]
# ax.plot(x, y, 'r+')
labels[ilabel] = (x, y)
return labels