Skip to content

ENH: Add option to infer CIFTI-2 intent codes #932

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
101 changes: 95 additions & 6 deletions nibabel/cifti2/cifti2.py
Original file line number Diff line number Diff line change
@@ -25,6 +25,7 @@
from ..nifti1 import Nifti1Extensions
from ..nifti2 import Nifti2Image, Nifti2Header
from ..arrayproxy import reshape_dataobj
from ..volumeutils import Recoder
from warnings import warn


@@ -89,6 +90,53 @@ class Cifti2HeaderError(Exception):
'CIFTI_STRUCTURE_THALAMUS_LEFT',
'CIFTI_STRUCTURE_THALAMUS_RIGHT')

# "Standard CIFTI Mapping Combinations" within CIFTI-2 spec
# https://www.nitrc.org/forum/attachment.php?attachid=341&group_id=454&forum_id=1955
CIFTI_CODES = Recoder((
('.dconn.nii', 'NIFTI_INTENT_CONNECTIVITY_DENSE', (
'CIFTI_INDEX_TYPE_BRAIN_MODELS', 'CIFTI_INDEX_TYPE_BRAIN_MODELS',
)),
('.dtseries.nii', 'NIFTI_INTENT_CONNECTIVITY_DENSE_SERIES', (
'CIFTI_INDEX_TYPE_SERIES', 'CIFTI_INDEX_TYPE_BRAIN_MODELS',
)),
('.pconn.nii', 'NIFTI_INTENT_CONNECTIVITY_PARCELLATED', (
'CIFTI_INDEX_TYPE_PARCELS', 'CIFTI_INDEX_TYPE_PARCELS',
)),
('.ptseries.nii', 'NIFTI_INTENT_CONNECTIVITY_PARCELLATED_SERIES', (
'CIFTI_INDEX_TYPE_SERIES', 'CIFTI_INDEX_TYPE_PARCELS',
)),
('.dscalar.nii', 'NIFTI_INTENT_CONNECTIVITY_DENSE_SCALARS', (
'CIFTI_INDEX_TYPE_SCALARS', 'CIFTI_INDEX_TYPE_BRAIN_MODELS',
)),
('.dlabel.nii', 'NIFTI_INTENT_CONNECTIVITY_DENSE_LABELS', (
'CIFTI_INDEX_TYPE_LABELS', 'CIFTI_INDEX_TYPE_BRAIN_MODELS',
)),
('.pscalar.nii', 'NIFTI_INTENT_CONNECTIVITY_PARCELLATED_SCALAR', (
'CIFTI_INDEX_TYPE_SCALARS', 'CIFTI_INDEX_TYPE_PARCELS',
)),
('.pdconn.nii', 'NIFTI_INTENT_CONNECTIVITY_PARCELLATED_DENSE', (
'CIFTI_INDEX_TYPE_BRAIN_MODELS', 'CIFTI_INDEX_TYPE_PARCELS',
)),
('.dpconn.nii', 'NIFTI_INTENT_CONNECTIVITY_DENSE_PARCELLATED', (
'CIFTI_INDEX_TYPE_PARCELS', 'CIFTI_INDEX_TYPE_BRAIN_MODELS',
)),
('.pconnseries.nii', 'NIFTI_INTENT_CONNECTIVITY_PARCELLATED_PARCELLATED_SERIES', (
'CIFTI_INDEX_TYPE_PARCELS', 'CIFTI_INDEX_TYPE_PARCELS', 'CIFTI_INDEX_TYPE_SERIES',
)),
('.pconnscalar.nii', 'NIFTI_INTENT_CONNECTIVITY_PARCELLATED_PARCELLATED_SCALAR', (
'CIFTI_INDEX_TYPE_PARCELS', 'CIFTI_INDEX_TYPE_PARCELS', 'CIFTI_INDEX_TYPE_SCALARS',
)),
('.dfan.nii', 'NIFTI_INTENT_CONNECTIVITY_DENSE_SERIES', (
'CIFTI_INDEX_TYPE_SCALARS', 'CIFTI_INDEX_TYPE_BRAIN_MODELS',
)),
('.dfibersamp.nii', 'NIFTI_INTENT_CONNECTIVITY_UNKNOWN', (
'CIFTI_INDEX_TYPE_SCALARS', 'CIFTI_INDEX_TYPE_SCALARS', 'CIFTI_INDEX_TYPE_BRAIN_MODELS',
)),
('.dfansamp.nii', 'NIFTI_INTENT_CONNECTIVITY_UNKNOWN', (
'CIFTI_INDEX_TYPE_SCALARS', 'CIFTI_INDEX_TYPE_SCALARS', 'CIFTI_INDEX_TYPE_BRAIN_MODELS',
)),
), fields=('extension', 'niistring', 'map_types'))


def _value_if_klass(val, klass):
if val is None or isinstance(val, klass):
@@ -1466,11 +1514,7 @@ def to_file_map(self, file_map=None):
raise ValueError(
f"Dataobj shape {self._dataobj.shape} does not match shape "
f"expected from CIFTI-2 header {self.header.matrix.get_data_shape()}")
# if intent code is not set, default to unknown CIFTI
if header.get_intent()[0] == 'none':
header.set_intent('NIFTI_INTENT_CONNECTIVITY_UNKNOWN')
data = reshape_dataobj(self.dataobj,
(1, 1, 1, 1) + self.dataobj.shape)
data = reshape_dataobj(self.dataobj, (1, 1, 1, 1) + self.dataobj.shape)
# If qform not set, reset pixdim values so Nifti2 does not complain
if header['qform_code'] == 0:
header['pixdim'][:4] = 1
@@ -1501,14 +1545,59 @@ def update_headers(self):
>>> img.shape == (2, 3, 4)
True
"""
self._nifti_header.set_data_shape((1, 1, 1, 1) + self._dataobj.shape)
header = self._nifti_header
header.set_data_shape((1, 1, 1, 1) + self._dataobj.shape)
# if intent code is not set, default to unknown
if header.get_intent()[0] == 'none':
header.set_intent('NIFTI_INTENT_CONNECTIVITY_UNKNOWN')

def get_data_dtype(self):
return self._nifti_header.get_data_dtype()

def set_data_dtype(self, dtype):
self._nifti_header.set_data_dtype(dtype)

def to_filename(self, filename, validate=True):
"""
Ensures NIfTI header intent code is set prior to saving.

Parameters
----------
validate : boolean, optional
If ``True``, infer and validate CIFTI type based on MatrixIndicesMap values.
This includes the setting of the relevant intent code within the NIfTI header.
If validation fails, a UserWarning is issued and saving continues.
"""
if validate:
# Determine CIFTI type via index maps
from .parse_cifti2 import intent_codes

matrix = self.header.matrix
map_types = tuple(
matrix.get_index_map(idx).indices_map_to_data_type for idx
in sorted(matrix.mapped_indices)
)
try:
expected_intent = CIFTI_CODES.niistring[map_types]
expected_ext = CIFTI_CODES.extension[map_types]
except KeyError: # unknown
expected_intent = "NIFTI_INTENT_CONNECTIVITY_UNKNOWN"
expected_ext = None
warn(
"No information found for matrix containing the following index maps:"
f"{map_types}, defaulting to unknown."
)

orig_intent = self._nifti_header.get_intent()[0]
if expected_intent != intent_codes.niistring[orig_intent]:
warn(
f"Expected NIfTI intent: {expected_intent} has been automatically set."
)
self._nifti_header.set_intent(expected_intent)
if expected_ext is not None and not filename.endswith(expected_ext):
warn(f"Filename does not end with expected extension: {expected_ext}")
super().to_filename(filename)


load = Cifti2Image.from_filename
save = Cifti2Image.instance_to_filename
15 changes: 15 additions & 0 deletions nibabel/cifti2/tests/test_cifti2.py
Original file line number Diff line number Diff line change
@@ -427,3 +427,18 @@ def make_imaker(self, arr, header=None, ni_header=None):
)
header.matrix.append(mim)
return lambda: self.image_maker(arr.copy(), header, ni_header)

def validate_filenames(self, imaker, params, validate=False):
super().validate_filenames(imaker, params, validate=validate)

def validate_mmap_parameter(self, imaker, params, validate=False):
super().validate_mmap_parameter(imaker, params, validate=validate)

def validate_to_bytes(self, imaker, params, validate=False):
super().validate_to_bytes(imaker, params, validate=validate)

def validate_from_bytes(self, imaker, params, validate=False):
super().validate_from_bytes(imaker, params, validate=validate)

def validate_to_from_bytes(self, imaker, params, validate=False):
super().validate_to_from_bytes(imaker, params, validate=validate)
2 changes: 1 addition & 1 deletion nibabel/cifti2/tests/test_cifti2io_axes.py
Original file line number Diff line number Diff line change
@@ -91,7 +91,7 @@ def check_rewrite(arr, axes, extension='.nii'):
custom extension to use
"""
(fd, name) = tempfile.mkstemp(extension)
cifti2.Cifti2Image(arr, header=axes).to_filename(name)
cifti2.Cifti2Image(arr, header=axes).to_filename(name, validate=False)
img = nib.load(name)
arr2 = img.get_fdata()
assert np.allclose(arr, arr2)
4 changes: 2 additions & 2 deletions nibabel/cifti2/tests/test_cifti2io_header.py
Original file line number Diff line number Diff line change
@@ -83,7 +83,7 @@ def test_readwritedata():
with InTemporaryDirectory():
for name in datafiles:
img = ci.load(name)
ci.save(img, 'test.nii')
ci.save(img, 'test.nii', validate=False)
img2 = ci.load('test.nii')
assert len(img.header.matrix) == len(img2.header.matrix)
# Order should be preserved in load/save
@@ -109,7 +109,7 @@ def test_nibabel_readwritedata():
with InTemporaryDirectory():
for name in datafiles:
img = nib.load(name)
nib.save(img, 'test.nii')
nib.save(img, 'test.nii', validate=False)
img2 = nib.load('test.nii')
assert len(img.header.matrix) == len(img2.header.matrix)
# Order should be preserved in load/save
57 changes: 41 additions & 16 deletions nibabel/cifti2/tests/test_new_cifti2.py
Original file line number Diff line number Diff line change
@@ -7,12 +7,11 @@
scratch.
"""
import numpy as np

import nibabel as nib
from nibabel import cifti2 as ci
from nibabel.tmpdirs import InTemporaryDirectory

import pytest

from ...testing import (
clear_and_catch_warnings, error_warnings, suppress_warnings, assert_array_equal)

@@ -237,7 +236,6 @@ def test_dtseries():
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(13, 10)
img = ci.Cifti2Image(data, hdr)
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_DENSE_SERIES')

with InTemporaryDirectory():
ci.save(img, 'test.dtseries.nii')
@@ -281,7 +279,6 @@ def test_dlabel():
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(2, 10)
img = ci.Cifti2Image(data, hdr)
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_DENSE_LABELS')

with InTemporaryDirectory():
ci.save(img, 'test.dlabel.nii')
@@ -301,7 +298,6 @@ def test_dconn():
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(10, 10)
img = ci.Cifti2Image(data, hdr)
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_DENSE')

with InTemporaryDirectory():
ci.save(img, 'test.dconn.nii')
@@ -323,7 +319,6 @@ def test_ptseries():
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(13, 4)
img = ci.Cifti2Image(data, hdr)
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_PARCELLATED_SERIES')

with InTemporaryDirectory():
ci.save(img, 'test.ptseries.nii')
@@ -345,7 +340,6 @@ def test_pscalar():
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(2, 4)
img = ci.Cifti2Image(data, hdr)
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_PARCELLATED_SCALAR')

with InTemporaryDirectory():
ci.save(img, 'test.pscalar.nii')
@@ -367,7 +361,6 @@ def test_pdconn():
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(10, 4)
img = ci.Cifti2Image(data, hdr)
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_PARCELLATED_DENSE')

with InTemporaryDirectory():
ci.save(img, 'test.pdconn.nii')
@@ -389,7 +382,6 @@ def test_dpconn():
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(4, 10)
img = ci.Cifti2Image(data, hdr)
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_DENSE_PARCELLATED')

with InTemporaryDirectory():
ci.save(img, 'test.dpconn.nii')
@@ -413,7 +405,7 @@ def test_plabel():
img = ci.Cifti2Image(data, hdr)

with InTemporaryDirectory():
ci.save(img, 'test.plabel.nii')
ci.save(img, 'test.plabel.nii', validate=False)
img2 = ci.load('test.plabel.nii')
assert img.nifti_header.get_intent()[0] == 'ConnUnknown'
assert isinstance(img2, ci.Cifti2Image)
@@ -430,7 +422,6 @@ def test_pconn():
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(4, 4)
img = ci.Cifti2Image(data, hdr)
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_PARCELLATED')

with InTemporaryDirectory():
ci.save(img, 'test.pconn.nii')
@@ -453,8 +444,6 @@ def test_pconnseries():
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(4, 4, 13)
img = ci.Cifti2Image(data, hdr)
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_PARCELLATED_'
'PARCELLATED_SERIES')

with InTemporaryDirectory():
ci.save(img, 'test.pconnseries.nii')
@@ -478,8 +467,6 @@ def test_pconnscalar():
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(4, 4, 2)
img = ci.Cifti2Image(data, hdr)
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_PARCELLATED_'
'PARCELLATED_SCALAR')

with InTemporaryDirectory():
ci.save(img, 'test.pconnscalar.nii')
@@ -517,7 +504,45 @@ def test_wrong_shape():
ci.Cifti2Image(data, hdr)
with suppress_warnings():
img = ci.Cifti2Image(data, hdr)

with pytest.raises(ValueError):
img.to_file_map()


def test_cifti_validation():
# flip label / brain_model index maps
geometry_map = create_geometry_map((0, ))
label_map = create_label_map((1, ))
matrix = ci.Cifti2Matrix()
matrix.append(geometry_map)
matrix.append(label_map)
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(10, 2)
img = ci.Cifti2Image(data, hdr)
# flipped index maps will warn
with InTemporaryDirectory(), pytest.warns(UserWarning):
ci.save(img, 'test.dlabel.nii')

label_map = create_label_map((0, ))
geometry_map = create_geometry_map((1, ))
matrix = ci.Cifti2Matrix()
matrix.append(label_map)
matrix.append(geometry_map)
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(2, 10)
img = ci.Cifti2Image(data, hdr)

with InTemporaryDirectory():
ci.save(img, 'test.validate.nii', validate=False)
ci.save(img, 'test.dlabel.nii')

img2 = nib.load('test.dlabel.nii')
img3 = nib.load('test.validate.nii')
assert img2.nifti_header.get_intent()[0] == 'ConnDenseLabel'
assert img3.nifti_header.get_intent()[0] == 'ConnUnknown'
assert isinstance(img2, ci.Cifti2Image)
assert isinstance(img3, ci.Cifti2Image)
assert_array_equal(img2.get_fdata(), data)
check_label_map(img2.header.matrix.get_index_map(0))
check_geometry_map(img2.header.matrix.get_index_map(1))
del img2, img3
6 changes: 3 additions & 3 deletions nibabel/filebasedimages.py
Original file line number Diff line number Diff line change
@@ -315,7 +315,7 @@ def filespec_to_file_map(klass, filespec):
def filespec_to_files(klass, filespec):
return klass.filespec_to_file_map(filespec)

def to_filename(self, filename):
def to_filename(self, filename, **kwargs):
""" Write image to files implied by filename string

Parameters
@@ -381,7 +381,7 @@ def make_file_map(klass, mapping=None):
load = from_filename

@classmethod
def instance_to_filename(klass, img, filename):
def instance_to_filename(klass, img, filename, **kwargs):
""" Save `img` in our own format, to name implied by `filename`

This is a class method
@@ -394,7 +394,7 @@ def instance_to_filename(klass, img, filename):
Filename, implying name to which to save image.
"""
img = klass.from_image(img)
img.to_filename(filename)
img.to_filename(filename, **kwargs)

@classmethod
def from_image(klass, img):
6 changes: 3 additions & 3 deletions nibabel/loadsave.py
Original file line number Diff line number Diff line change
@@ -78,7 +78,7 @@ def guessed_image_type(filename):
raise ImageFileError(f'Cannot work out file type of "{filename}"')


def save(img, filename):
def save(img, filename, **kwargs):
""" Save an image to file adapting format to `filename`

Parameters
@@ -96,7 +96,7 @@ def save(img, filename):

# Save the type as expected
try:
img.to_filename(filename)
img.to_filename(filename, **kwargs)
except ImageFileError:
pass
else:
@@ -144,7 +144,7 @@ def save(img, filename):
# Here, we either have a klass or a converted image.
if converted is None:
converted = klass.from_image(img)
converted.to_filename(filename)
converted.to_filename(filename, **kwargs)


@deprecate_with_version('read_img_data deprecated. '
Loading