Skip to content

Commit 835d598

Browse files
committed
Reintroduce regional mask caching for feature queries
1 parent a91a7b4 commit 835d598

File tree

1 file changed

+29
-14
lines changed

1 file changed

+29
-14
lines changed

siibra/core/region.py

+29-14
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ class Region(anytree.NodeMixin, concept.AtlasConcept, structure.BrainStructure):
5353
_regex_re = re.compile(r'^\/(?P<expression>.+)\/(?P<flags>[a-zA-Z]*)$')
5454
_accepted_flags = "aiLmsux"
5555

56-
_GETMAP_CACHE = {}
57-
_GETMAP_CACHE_MAX_ENTRIES = 1
56+
_GETMASK_CACHE = {}
57+
_GETMASK_CACHE_MAX_ENTRIES = 1
5858

5959
def __init__(
6060
self,
@@ -441,12 +441,21 @@ def get_regional_mask(
441441
maptype = MapType[maptype.upper()]
442442

443443
threshold_info = "" if maptype == MapType.LABELLED else f"(threshold: {threshold}) "
444+
# check cache
445+
getmap_hash = hash(f"{self.id} - {space} - {maptype}{threshold_info}")
446+
if getmap_hash in self._GETMASK_CACHE:
447+
return self._GETMASK_CACHE[getmap_hash]
448+
444449
name = f"Mask {threshold_info}of '{self.name} ({self.parcellation})' in "
445450
try:
446451
regional_map = self.get_regional_map(space=space, maptype=maptype)
447452
if maptype == MapType.LABELLED:
448453
assert threshold == 0.0, f"threshold can only be set for {MapType.STATISTICAL} maps."
449-
result = regional_map
454+
result = volume.FilteredVolume(
455+
parent_volume=regional_map,
456+
label=regional_map.label,
457+
fragment=regional_map.fragment
458+
)
450459
result._boundingbox = None
451460
if maptype == MapType.STATISTICAL:
452461
result = volume.FilteredVolume(
@@ -459,18 +468,25 @@ def get_regional_mask(
459468
except NoMapAvailableError:
460469
# This region is not mapped directly in any map in the registry.
461470
# Try building a map from the child regions
462-
if (len(self.children) > 0) and all(c.mapped_in_space(space, recurse=True) for c in self.children):
463-
logger.info(f"{self.name} is not mapped in {space}. Merging the masks of its {len(self.children)} child regions.")
464-
child_volumes = [
465-
child.get_regional_mask(space=space, maptype=maptype, threshold=threshold)
466-
for child in self.children
471+
if (len(self.children) > 0) and self.mapped_in_space(space, recurse=True):
472+
mapped_descendants: List[Region] = [
473+
d for d in self.descendants if d.mapped_in_space(space, recurse=False)
474+
]
475+
logger.info(f"{self.name} is not mapped in {space}. Merging the masks of its {len(mapped_descendants)} map descendants.")
476+
descendant_volumes = [
477+
descendant.get_regional_mask(space=space, maptype=maptype, threshold=threshold)
478+
for descendant in mapped_descendants
467479
]
468480
result = volume.FilteredVolume(
469-
volume.merge(child_volumes),
481+
volume.merge(descendant_volumes),
470482
label=1
471483
)
472484
name += f"'{result.space}' (built by merging the mask {threshold_info} of its decendants)"
473485
result._name = name
486+
487+
while len(self._GETMASK_CACHE) > self._GETMASK_CACHE_MAX_ENTRIES:
488+
self._GETMASK_CACHE.pop(next(iter(self._GETMASK_CACHE)))
489+
self._GETMASK_CACHE[getmap_hash] = result
474490
return result
475491

476492
def get_regional_map(
@@ -515,11 +531,10 @@ def get_regional_map(
515531
and self.name in m.regions
516532
):
517533
return m.get_volume(region=self)
518-
else:
519-
raise NoMapAvailableError(
520-
f"{self.name} is not mapped in {space} as a {str(maptype)} map."
521-
" Please try getting the children or getting the mask."
522-
)
534+
raise NoMapAvailableError(
535+
f"{self.name} is not mapped in {space} as a {str(maptype)} map."
536+
" Please try getting the children or getting the mask."
537+
)
523538

524539
def mapped_in_space(self, space, recurse: bool = False) -> bool:
525540
"""

0 commit comments

Comments
 (0)