Skip to content

Commit

Permalink
Add first batch of roundtrip tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ktsitsi committed Oct 10, 2024
1 parent e84ace6 commit 3745388
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 30 deletions.
Binary file added tests/data/nifti/example4d.nii
Binary file not shown.
Binary file added tests/data/nifti/functional.nii
Binary file not shown.
Binary file added tests/data/nifti/standard.nii
Binary file not shown.
60 changes: 60 additions & 0 deletions tests/integration/converters/test_nifti.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import nibabel as nib
import numpy as np
import pytest

import tiledb
from tests import get_path
from tiledb.bioimg.converters.nifti import NiftiConverter


def compare_nifti_images(file1, file2):
img1 = nib.load(file1)
img2 = nib.load(file2)

# Compare the headers (metadata)
if img1.header != img2.header:
return False

# Compare the affine matrices (spatial information)
if not np.array_equal(img1.affine, img2.affine):
return False

# Compare the image data (voxel data)
data1 = img1.get_fdata()
data2 = img2.get_fdata()
if not np.array_equal(data1, data2):
return False
return True


@pytest.mark.parametrize(
"filename", ["nifti/example4d.nii", "nifti/functional.nii", "nifti/standard.nii"]
)
@pytest.mark.parametrize("preserve_axes", [False, True])
@pytest.mark.parametrize("chunked", [False])
@pytest.mark.parametrize(
"compressor, lossless",
[
(tiledb.ZstdFilter(level=0), True),
# WEBP is not supported for Grayscale images
],
)
def test_nifti_converter_roundtrip(
tmp_path, preserve_axes, chunked, compressor, lossless, filename
):
# For lossy WEBP we cannot use random generated images as they have so much noise
input_path = str(get_path(filename))
tiledb_path = str(tmp_path / "to_tiledb")
output_path = str(tmp_path / "from_tiledb.nii")

NiftiConverter.to_tiledb(
input_path,
tiledb_path,
preserve_axes=preserve_axes,
chunked=chunked,
compressor=compressor,
log=False,
)
# Store it back to PNG
NiftiConverter.from_tiledb(tiledb_path, output_path)
compare_nifti_images(input_path, output_path)
91 changes: 71 additions & 20 deletions tiledb/bioimg/converters/nifti.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import nibabel as nib
import numpy as np
from nibabel import Nifti1Image
from nibabel.analyze import _dtdefs
from numpy._typing import NDArray

from tiledb import VFS, Config, Ctx
Expand Down Expand Up @@ -56,7 +57,11 @@ def __init__(
self._binary_header = base64.b64encode(
self._nib_image.header.binaryblock
).decode("utf-8")
self._mode = "".join(self._nib_image.dataobj.dtype.names)
self._mode = (
"".join(self._nib_image.dataobj.dtype.names)
if self._nib_image.dataobj.dtype.names is not None
else ""
)

def __enter__(self) -> NiftiReader:
return self
Expand Down Expand Up @@ -100,10 +105,40 @@ def image_metadata(self) -> Dict[str, Any]:

@property
def axes(self) -> Axes:
if self._mode == "L":
axes = Axes(["X", "Y", "Z"])
header_dict = self.nifti1_hdr_2_dict()
# The 0-index holds information about the number of dimensions
# according the spec https://nifti.nimh.nih.gov/pub/dist/src/niftilib/nifti1.h
dims_number = header_dict["dim"][0]
if dims_number == 4:
# According to standard the 4th dimension corresponds to 'T' time
# but in special cases can be degnerate into channels
if header_dict["dim"][dims_number] == 1:
# The time dimension does not correspond to time
if self._mode == "RGB" or self._mode == "RGBA":
# [..., ..., ..., 1, 3] or [..., ..., ..., 1, 4]
axes = Axes(["X", "Y", "Z", "T", "C"])
else:
# The image is single-channel with 1 value in Temporal dimension
# instead of channel. So we map T to be channel.
# [..., ..., ..., 1]
axes = Axes(["X", "Y", "Z", "C"])
else:
# The time dimension does correspond to time
axes = Axes(["X", "Y", "Z", "T"])
elif dims_number < 4:
# Only spatial dimensions
if self._mode == "RGB" or self._mode == "RGBA":
axes = Axes(["X", "Y", "Z", "C"])
else:
axes = Axes(["X", "Y", "Z"])
else:
axes = Axes(["X", "Y", "Z", "C"])
# Has more dimensions that belong to spatial-temporal unknown attributes
# TODO: investigate sample images of this format.
if self._mode == "RGB" or self._mode == "RGBA":
axes = Axes(["X", "Y", "Z", "C"])
else:
axes = Axes(["X", "Y", "Z"])

self._logger.debug(f"Reader axes: {axes}")
return axes

Expand All @@ -124,7 +159,6 @@ def channels(self) -> Sequence[str]:
"G": "GREEN",
"B": "BLUE",
"A": "ALPHA",
"L": "GRAYSCALE",
}
# Use list comprehension to convert the short form to full form
rgb_full = [color_map[color] for color in self._mode]
Expand All @@ -139,12 +173,11 @@ def level_count(self) -> int:
def level_dtype(self, level: int = 0) -> np.dtype:
header_dict = self.nifti1_hdr_2_dict()

# Check the header first
if (dtype := header_dict["data_type"].dtype) == np.dtype("S10"):
dtype = self.get_dtype_from_code(header_dict["datatype"])
if dtype == np.dtype([("R", "u1"), ("G", "u1"), ("B", "u1")]):
dtype = np.uint8

# TODO: Compare with the dtype of fields
# dict(self._nib_image.dataobj.dtype.fields)

self._logger.debug(f"Level {level} dtype: {dtype}")
return dtype

Expand All @@ -153,15 +186,17 @@ def level_shape(self, level: int = 0) -> Tuple[int, ...]:
return ()

original_shape = self._nib_image.shape
fields = self._nib_image.dataobj.dtype.fields
if len(fields) == 3:
# RGB convert the shape from to stack 3 channels
l_shape = (*original_shape[:-1], 3)
elif len(fields) == 4:
# RGBA
l_shape = (*original_shape[:-1], 4)
if (fields := self._nib_image.dataobj.dtype.fields) is not None:
if len(fields) == 3:
# RGB convert the shape from to stack 3 channels
l_shape = (*original_shape, 3)
elif len(fields) == 4:
# RGBA
l_shape = (*original_shape, 4)
else:
# Grayscale
l_shape = original_shape
else:
# Grayscale
l_shape = original_shape
self._logger.debug(f"Level {level} shape: {l_shape}")
return l_shape
Expand Down Expand Up @@ -221,6 +256,13 @@ def nifti1_hdr_2_dict(self) -> Dict[str, Any]:
for field in structured_header_arr.dtype.names
}

# Function to find and return the third value based on the first value
def get_dtype_from_code(self, dtype_code: int) -> np.dtype:
for item in _dtdefs:
if item[0] == dtype_code: # Check if the first value matches the input code
return item[2] # Return the third value (dtype)
return None # Return None if the code is not foun

@staticmethod
def _serialize_header(header_dict: Mapping[str, Any]) -> Dict[str, Any]:
serialized_header = {
Expand Down Expand Up @@ -265,9 +307,13 @@ def compute_level_metadata(
def write_group_metadata(self, metadata: Mapping[str, Any]) -> None:
self._group_metadata = json.loads(metadata["json_write_kwargs"])

def _structured_dtype(self) -> np.dtype:
def _structured_dtype(self) -> Optional[np.dtype]:
if self._original_mode == "RGB":
return np.dtype([("R", "u1"), ("G", "u1"), ("B", "u1")])
elif self._original_mode == "RGBA":
return np.dtype([("R", "u1"), ("G", "u1"), ("B", "u1"), ("A", "u1")])
else:
return None

def write_level_image(
self,
Expand All @@ -278,9 +324,14 @@ def write_level_image(
binaryblock=base64.b64decode(self._group_metadata["binaryblock"])
)
contiguous_image = np.ascontiguousarray(image)
structured_arr = contiguous_image.view(dtype=self._structured_dtype()).reshape(
*image.shape[:-1]
structured_arr = contiguous_image.view(
dtype=self._structured_dtype() if self._structured_dtype() else image.dtype
)
if len(image.shape) > 3:
# If temporal is 1 and extra dim for channels RGB/RGBA
if image.shape[3] == 1 and (image.shape[4] == 3 or 4):
structured_arr = structured_arr.reshape(*image.shape[:4])

nib_image = self._writer(
structured_arr, header=header, affine=header.get_best_affine()
)
Expand Down
13 changes: 3 additions & 10 deletions tiledb/bioimg/openslide.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import json

import tiledb
from tiledb import Config, Ctx, TileDBError
from tiledb import Config, Ctx
from tiledb.highlevel import _get_ctx

from . import ATTR_NAME
Expand Down Expand Up @@ -84,7 +84,7 @@ def levels(self) -> Sequence[int]:

@property
def dimensions(self) -> Tuple[int, ...]:
"""A (width, height, depth - (if exists)) tuple for level 0 of the slide."""
"""A (width, height) tuple for level 0 of the slide."""
return self._levels[0].dimensions

@property
Expand Down Expand Up @@ -196,14 +196,7 @@ def dimensions(self) -> Tuple[int, ...]:
dims = list(a.domain)
width = a.shape[dims.index(a.dim("X"))]
height = a.shape[dims.index(a.dim("Y"))]
try:
depth = a.shape[dims.index(a.dim("Z"))]
# The Z dim does not exist
except TileDBError:
depth = None
d1, d2 = width // self._pixel_depth, height
dimensions = (d1, d2) if depth is None else (d1, d2, depth)
return dimensions
return width // self._pixel_depth, height

@property
def properties(self) -> Mapping[str, Any]:
Expand Down

0 comments on commit 3745388

Please sign in to comment.