diff --git a/tests/data/nifti/example4d.nii b/tests/data/nifti/example4d.nii new file mode 100644 index 0000000..1401e91 Binary files /dev/null and b/tests/data/nifti/example4d.nii differ diff --git a/tests/data/nifti/functional.nii b/tests/data/nifti/functional.nii new file mode 100644 index 0000000..2768d4d Binary files /dev/null and b/tests/data/nifti/functional.nii differ diff --git a/tests/data/nifti/standard.nii b/tests/data/nifti/standard.nii new file mode 100644 index 0000000..d685a25 Binary files /dev/null and b/tests/data/nifti/standard.nii differ diff --git a/tests/integration/converters/test_nifti.py b/tests/integration/converters/test_nifti.py new file mode 100644 index 0000000..cbaaabf --- /dev/null +++ b/tests/integration/converters/test_nifti.py @@ -0,0 +1,60 @@ +import nibabel as nib +import numpy as np +import pytest + +import tiledb +from tests import get_path +from tiledb.bioimg.converters.nifti import NiftiConverter + + +def compare_nifti_images(file1, file2): + img1 = nib.load(file1) + img2 = nib.load(file2) + + # Compare the headers (metadata) + if img1.header != img2.header: + return False + + # Compare the affine matrices (spatial information) + if not np.array_equal(img1.affine, img2.affine): + return False + + # Compare the image data (voxel data) + data1 = img1.get_fdata() + data2 = img2.get_fdata() + if not np.array_equal(data1, data2): + return False + return True + + +@pytest.mark.parametrize( + "filename", ["nifti/example4d.nii", "nifti/functional.nii", "nifti/standard.nii"] +) +@pytest.mark.parametrize("preserve_axes", [False, True]) +@pytest.mark.parametrize("chunked", [False]) +@pytest.mark.parametrize( + "compressor, lossless", + [ + (tiledb.ZstdFilter(level=0), True), + # WEBP is not supported for Grayscale images + ], +) +def test_nifti_converter_roundtrip( + tmp_path, preserve_axes, chunked, compressor, lossless, filename +): + # For lossy WEBP we cannot use random generated images as they have so much noise + input_path = str(get_path(filename)) + tiledb_path = str(tmp_path / "to_tiledb") + output_path = str(tmp_path / "from_tiledb.nii") + + NiftiConverter.to_tiledb( + input_path, + tiledb_path, + preserve_axes=preserve_axes, + chunked=chunked, + compressor=compressor, + log=False, + ) + # Store it back to PNG + NiftiConverter.from_tiledb(tiledb_path, output_path) + compare_nifti_images(input_path, output_path) diff --git a/tiledb/bioimg/converters/nifti.py b/tiledb/bioimg/converters/nifti.py index 4212e23..5b416a5 100644 --- a/tiledb/bioimg/converters/nifti.py +++ b/tiledb/bioimg/converters/nifti.py @@ -17,6 +17,7 @@ import nibabel as nib import numpy as np from nibabel import Nifti1Image +from nibabel.analyze import _dtdefs from numpy._typing import NDArray from tiledb import VFS, Config, Ctx @@ -56,7 +57,11 @@ def __init__( self._binary_header = base64.b64encode( self._nib_image.header.binaryblock ).decode("utf-8") - self._mode = "".join(self._nib_image.dataobj.dtype.names) + self._mode = ( + "".join(self._nib_image.dataobj.dtype.names) + if self._nib_image.dataobj.dtype.names is not None + else "" + ) def __enter__(self) -> NiftiReader: return self @@ -100,10 +105,40 @@ def image_metadata(self) -> Dict[str, Any]: @property def axes(self) -> Axes: - if self._mode == "L": - axes = Axes(["X", "Y", "Z"]) + header_dict = self.nifti1_hdr_2_dict() + # The 0-index holds information about the number of dimensions + # according the spec https://nifti.nimh.nih.gov/pub/dist/src/niftilib/nifti1.h + dims_number = header_dict["dim"][0] + if dims_number == 4: + # According to standard the 4th dimension corresponds to 'T' time + # but in special cases can be degnerate into channels + if header_dict["dim"][dims_number] == 1: + # The time dimension does not correspond to time + if self._mode == "RGB" or self._mode == "RGBA": + # [..., ..., ..., 1, 3] or [..., ..., ..., 1, 4] + axes = Axes(["X", "Y", "Z", "T", "C"]) + else: + # The image is single-channel with 1 value in Temporal dimension + # instead of channel. So we map T to be channel. + # [..., ..., ..., 1] + axes = Axes(["X", "Y", "Z", "C"]) + else: + # The time dimension does correspond to time + axes = Axes(["X", "Y", "Z", "T"]) + elif dims_number < 4: + # Only spatial dimensions + if self._mode == "RGB" or self._mode == "RGBA": + axes = Axes(["X", "Y", "Z", "C"]) + else: + axes = Axes(["X", "Y", "Z"]) else: - axes = Axes(["X", "Y", "Z", "C"]) + # Has more dimensions that belong to spatial-temporal unknown attributes + # TODO: investigate sample images of this format. + if self._mode == "RGB" or self._mode == "RGBA": + axes = Axes(["X", "Y", "Z", "C"]) + else: + axes = Axes(["X", "Y", "Z"]) + self._logger.debug(f"Reader axes: {axes}") return axes @@ -124,7 +159,6 @@ def channels(self) -> Sequence[str]: "G": "GREEN", "B": "BLUE", "A": "ALPHA", - "L": "GRAYSCALE", } # Use list comprehension to convert the short form to full form rgb_full = [color_map[color] for color in self._mode] @@ -139,12 +173,11 @@ def level_count(self) -> int: def level_dtype(self, level: int = 0) -> np.dtype: header_dict = self.nifti1_hdr_2_dict() - # Check the header first - if (dtype := header_dict["data_type"].dtype) == np.dtype("S10"): + dtype = self.get_dtype_from_code(header_dict["datatype"]) + if dtype == np.dtype([("R", "u1"), ("G", "u1"), ("B", "u1")]): dtype = np.uint8 - # TODO: Compare with the dtype of fields - # dict(self._nib_image.dataobj.dtype.fields) + self._logger.debug(f"Level {level} dtype: {dtype}") return dtype @@ -153,15 +186,17 @@ def level_shape(self, level: int = 0) -> Tuple[int, ...]: return () original_shape = self._nib_image.shape - fields = self._nib_image.dataobj.dtype.fields - if len(fields) == 3: - # RGB convert the shape from to stack 3 channels - l_shape = (*original_shape[:-1], 3) - elif len(fields) == 4: - # RGBA - l_shape = (*original_shape[:-1], 4) + if (fields := self._nib_image.dataobj.dtype.fields) is not None: + if len(fields) == 3: + # RGB convert the shape from to stack 3 channels + l_shape = (*original_shape, 3) + elif len(fields) == 4: + # RGBA + l_shape = (*original_shape, 4) + else: + # Grayscale + l_shape = original_shape else: - # Grayscale l_shape = original_shape self._logger.debug(f"Level {level} shape: {l_shape}") return l_shape @@ -221,6 +256,13 @@ def nifti1_hdr_2_dict(self) -> Dict[str, Any]: for field in structured_header_arr.dtype.names } + # Function to find and return the third value based on the first value + def get_dtype_from_code(self, dtype_code: int) -> np.dtype: + for item in _dtdefs: + if item[0] == dtype_code: # Check if the first value matches the input code + return item[2] # Return the third value (dtype) + return None # Return None if the code is not foun + @staticmethod def _serialize_header(header_dict: Mapping[str, Any]) -> Dict[str, Any]: serialized_header = { @@ -265,9 +307,13 @@ def compute_level_metadata( def write_group_metadata(self, metadata: Mapping[str, Any]) -> None: self._group_metadata = json.loads(metadata["json_write_kwargs"]) - def _structured_dtype(self) -> np.dtype: + def _structured_dtype(self) -> Optional[np.dtype]: if self._original_mode == "RGB": return np.dtype([("R", "u1"), ("G", "u1"), ("B", "u1")]) + elif self._original_mode == "RGBA": + return np.dtype([("R", "u1"), ("G", "u1"), ("B", "u1"), ("A", "u1")]) + else: + return None def write_level_image( self, @@ -278,9 +324,14 @@ def write_level_image( binaryblock=base64.b64decode(self._group_metadata["binaryblock"]) ) contiguous_image = np.ascontiguousarray(image) - structured_arr = contiguous_image.view(dtype=self._structured_dtype()).reshape( - *image.shape[:-1] + structured_arr = contiguous_image.view( + dtype=self._structured_dtype() if self._structured_dtype() else image.dtype ) + if len(image.shape) > 3: + # If temporal is 1 and extra dim for channels RGB/RGBA + if image.shape[3] == 1 and (image.shape[4] == 3 or 4): + structured_arr = structured_arr.reshape(*image.shape[:4]) + nib_image = self._writer( structured_arr, header=header, affine=header.get_best_affine() ) diff --git a/tiledb/bioimg/openslide.py b/tiledb/bioimg/openslide.py index b96d45b..c4edeba 100644 --- a/tiledb/bioimg/openslide.py +++ b/tiledb/bioimg/openslide.py @@ -14,7 +14,7 @@ import json import tiledb -from tiledb import Config, Ctx, TileDBError +from tiledb import Config, Ctx from tiledb.highlevel import _get_ctx from . import ATTR_NAME @@ -84,7 +84,7 @@ def levels(self) -> Sequence[int]: @property def dimensions(self) -> Tuple[int, ...]: - """A (width, height, depth - (if exists)) tuple for level 0 of the slide.""" + """A (width, height) tuple for level 0 of the slide.""" return self._levels[0].dimensions @property @@ -196,14 +196,7 @@ def dimensions(self) -> Tuple[int, ...]: dims = list(a.domain) width = a.shape[dims.index(a.dim("X"))] height = a.shape[dims.index(a.dim("Y"))] - try: - depth = a.shape[dims.index(a.dim("Z"))] - # The Z dim does not exist - except TileDBError: - depth = None - d1, d2 = width // self._pixel_depth, height - dimensions = (d1, d2) if depth is None else (d1, d2, depth) - return dimensions + return width // self._pixel_depth, height @property def properties(self) -> Mapping[str, Any]: