Skip to content

Commit

Permalink
Adding extra nifti test samples
Browse files Browse the repository at this point in the history
  • Loading branch information
ktsitsi committed Oct 11, 2024
1 parent 3745388 commit b7d54af
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 24 deletions.
Binary file added tests/data/nifti/anatomical.nii
Binary file not shown.
43 changes: 28 additions & 15 deletions tests/integration/converters/test_nifti.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,35 +7,43 @@
from tiledb.bioimg.converters.nifti import NiftiConverter


def compare_nifti_images(file1, file2):
def compare_nifti_images(file1, file2, scaled_test):
img1 = nib.load(file1)
img2 = nib.load(file2)

# Compare the headers (metadata)
if img1.header != img2.header:
return False

# Compare the affine matrices (spatial information)
if not np.array_equal(img1.affine, img2.affine):
return False
assert np.array_equal(img1.affine, img2.affine)

# Compare the image data (voxel data)
data1 = img1.get_fdata()
data2 = img2.get_fdata()
if not np.array_equal(data1, data2):
return False
return True
data1 = np.array(img1.dataobj, dtype=img1.get_data_dtype())
data2 = np.array(img2.dataobj, dtype=img2.get_data_dtype())

assert np.array_equal(data1, data2)

# Compare the image data scaled (voxel data)
if scaled_test:
data_sc = img1.get_fdata()
data_sc_2 = img2.get_fdata()

assert np.array_equal(data_sc, data_sc_2)


@pytest.mark.parametrize(
"filename", ["nifti/example4d.nii", "nifti/functional.nii", "nifti/standard.nii"]
"filename",
[
"nifti/example4d.nii",
"nifti/functional.nii",
"nifti/standard.nii",
"nifti/visiblehuman.nii",
"nifti/anatomical.nii",
],
)
@pytest.mark.parametrize("preserve_axes", [False, True])
@pytest.mark.parametrize("chunked", [False])
@pytest.mark.parametrize(
"compressor, lossless",
[
(tiledb.ZstdFilter(level=0), True),
(tiledb.ZstdFilter(level=0), False),
# WEBP is not supported for Grayscale images
],
)
Expand All @@ -57,4 +65,9 @@ def test_nifti_converter_roundtrip(
)
# Store it back to PNG
NiftiConverter.from_tiledb(tiledb_path, output_path)
compare_nifti_images(input_path, output_path)
# The dtype of this image is complex and nibabel breaks originally
compare_nifti_images(
input_path,
output_path,
scaled_test=False if filename == "nifti/visiblehuman.nii" else True,
)
34 changes: 25 additions & 9 deletions tiledb/bioimg/converters/nifti.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@
from .base import ImageConverterMixin


# Function to find and return the third value based on the first value
def get_dtype_from_code(dtype_code: int) -> Optional[np.dtype]:
for item in _dtdefs:
if item[0] == dtype_code: # Check if the first value matches the input code
return item[2] # Return the third value (dtype)
return None # Return None if the code is not found


class NiftiReader:
_logger: logging.Logger

Expand Down Expand Up @@ -83,7 +91,12 @@ def logger(self) -> Optional[logging.Logger]:

@property
def group_metadata(self) -> Dict[str, Any]:
writer_kwargs = dict(metadata=self._metadata, binaryblock=self._binary_header)
writer_kwargs = dict(
metadata=self._metadata,
binaryblock=self._binary_header,
slope=self._nib_image.dataobj.slope,
inter=self._nib_image.dataobj.inter,
)
self._logger.debug(f"Group metadata: {writer_kwargs}")
return {"json_write_kwargs": json.dumps(writer_kwargs)}

Expand Down Expand Up @@ -173,7 +186,7 @@ def level_count(self) -> int:
def level_dtype(self, level: int = 0) -> np.dtype:
header_dict = self.nifti1_hdr_2_dict()

dtype = self.get_dtype_from_code(header_dict["datatype"])
dtype = get_dtype_from_code(header_dict["datatype"])
if dtype == np.dtype([("R", "u1"), ("G", "u1"), ("B", "u1")]):
dtype = np.uint8
# TODO: Compare with the dtype of fields
Expand Down Expand Up @@ -218,8 +231,14 @@ def level_image(
self._metadata["original_mode"] = self._mode
raw_data_contiguous = np.ascontiguousarray(unscaled_img)
numerical_data = np.frombuffer(raw_data_contiguous, dtype=self.level_dtype())
# Account endianness
numerical_data = numerical_data.view(
numerical_data.dtype.newbyteorder(self._nib_image.header.endianness)
)
numerical_data = numerical_data.reshape(self.level_shape())

# Bug! data might have slope and inter and header not contain them.

if tile is None:
return numerical_data
else:
Expand Down Expand Up @@ -256,13 +275,6 @@ def nifti1_hdr_2_dict(self) -> Dict[str, Any]:
for field in structured_header_arr.dtype.names
}

# Function to find and return the third value based on the first value
def get_dtype_from_code(self, dtype_code: int) -> np.dtype:
for item in _dtdefs:
if item[0] == dtype_code: # Check if the first value matches the input code
return item[2] # Return the third value (dtype)
return None # Return None if the code is not foun

@staticmethod
def _serialize_header(header_dict: Mapping[str, Any]) -> Dict[str, Any]:
serialized_header = {
Expand Down Expand Up @@ -335,6 +347,10 @@ def write_level_image(
nib_image = self._writer(
structured_arr, header=header, affine=header.get_best_affine()
)

nib_image.header.set_slope_inter(
self._group_metadata["slope"], self._group_metadata["inter"]
)
nib.save(nib_image, self._output_path)

def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
Expand Down

0 comments on commit b7d54af

Please sign in to comment.