Skip to content

Commit

Permalink
Refactor executive summary workflow (#721)
Browse files Browse the repository at this point in the history
  • Loading branch information
tsalo authored Dec 19, 2022
1 parent cb2563a commit 786533d
Show file tree
Hide file tree
Showing 8 changed files with 249 additions and 452 deletions.
15 changes: 12 additions & 3 deletions xcp_d/interfaces/surfplotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
File,
SimpleInterface,
TraitedSpec,
isdefined,
traits,
)

Expand Down Expand Up @@ -84,10 +85,12 @@ class _PlotSVGDataInputSpec(BaseInterfaceInputSpec):
mandatory=True,
desc="TSV file with filtered motion parameters.",
)
TR = traits.Float(default_value=1, desc="Repetition time")

# Optional inputs
mask = File(exists=True, mandatory=False, desc="Bold mask")
tmask = File(exists=True, mandatory=False, desc="Temporal mask")
seg_data = File(exists=True, mandatory=False, desc="Segmentation file")
TR = traits.Float(default_value=1, desc="Repetition time")
dummy_scans = traits.Int(
0,
usedefault=True,
Expand Down Expand Up @@ -128,16 +131,22 @@ def _run_interface(self, runtime):
use_ext=False,
)

mask_file = self.inputs.mask
mask_file = mask_file if isdefined(mask_file) else None

segmentation_file = self.inputs.seg_data
segmentation_file = segmentation_file if isdefined(segmentation_file) else None

self._results["before_process"], self._results["after_process"] = plot_svgx(
preprocessed_file=self.inputs.rawdata,
residuals_file=self.inputs.regressed_data,
denoised_file=self.inputs.residual_data,
tmask=self.inputs.tmask,
dummy_scans=self.inputs.dummy_scans,
TR=self.inputs.TR,
mask=self.inputs.mask,
mask=mask_file,
filtered_motion=self.inputs.filtered_motion,
seg_data=self.inputs.seg_data,
seg_data=segmentation_file,
processed_filename=after_process_fn,
unprocessed_filename=before_process_fn,
)
Expand Down
157 changes: 39 additions & 118 deletions xcp_d/utils/bids.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,26 @@

LOGGER = logging.getLogger("nipype.utils")

# TODO: Add and test fsaverage.
DEFAULT_ALLOWED_SPACES = {
"cifti": ["fsLR"],
"nifti": [
"MNI152NLin6Asym",
"MNI152NLin2009cAsym",
"MNIInfant",
],
}
INPUT_TYPE_ALLOWED_SPACES = {
"nibabies": {
"cifti": ["fsLR"],
"nifti": [
"MNIInfant",
"MNI152NLin6Asym",
"MNI152NLin2009cAsym",
],
},
}


class BIDSError(ValueError):
"""A generic error related to BIDS datasets.
Expand Down Expand Up @@ -158,25 +178,6 @@ def collect_data(
derivatives=True,
config=["bids", "derivatives"],
)
# TODO: Add and test fsaverage.
default_allowed_spaces = {
"cifti": ["fsLR"],
"nifti": [
"MNI152NLin6Asym",
"MNI152NLin2009cAsym",
"MNIInfant",
],
}
input_type_allowed_spaces = {
"nibabies": {
"cifti": ["fsLR"],
"nifti": [
"MNIInfant",
"MNI152NLin6Asym",
"MNI152NLin2009cAsym",
],
},
}

queries = {
# all preprocessed BOLD files in the right space/resolution/density
Expand Down Expand Up @@ -233,9 +234,10 @@ def collect_data(
# but we'll grab the first one with available data if they did.
allowed_spaces = ensure_list(queries["bold"]["space"])
else:
allowed_spaces = input_type_allowed_spaces.get(input_type, default_allowed_spaces)[
"cifti" if cifti else "nifti"
]
allowed_spaces = INPUT_TYPE_ALLOWED_SPACES.get(
input_type,
DEFAULT_ALLOWED_SPACES,
)["cifti" if cifti else "nifti"]

for space in allowed_spaces:
queries["bold"]["space"] = space
Expand All @@ -255,35 +257,24 @@ def collect_data(
else:
# Select the best *volumetric* space, based on available nifti BOLD files.
# This space will be used in the executive summary and T1w/T2w workflows.
temp_bold_query = queries["bold"].copy()
temp_bold_query["extension"] = ".nii.gz"
temp_allowed_spaces = input_type_allowed_spaces.get(
temp_query = queries["t1w_to_template_xform"].copy()
temp_allowed_spaces = INPUT_TYPE_ALLOWED_SPACES.get(
input_type,
default_allowed_spaces,
DEFAULT_ALLOWED_SPACES,
)["nifti"]

for space in temp_allowed_spaces:
temp_bold_query["space"] = space
nifti_bold_data = layout.get(**temp_bold_query)
if nifti_bold_data:
temp_query["to"] = space
transform_files = layout.get(**temp_query)
if transform_files:
queries["t1w_to_template_xform"]["to"] = space
queries["template_to_t1w_xform"]["from"] = space
break

if input_type in ("hcp", "dcan"):
temp_allowed_spaces = ["MNI152NLin6Asym"]
# HCP and DCAN files don't have nifti BOLD data, we will use the boldref
temp_bold_query["desc"] = None
temp_bold_query["suffix"] = "boldref"
temp_bold_query["space"] = "MNI152NLin6Asym"
queries["t1w_to_template_xform"]["to"] = "MNI152NLin2009cAsym"
queries["template_to_t1w_xform"]["from"] = "MNI152NLin2009cAsym"
nifti_bold_data = layout.get(**temp_bold_query)

if not nifti_bold_data:
if not transform_files:
allowed_space_str = ", ".join(temp_allowed_spaces)
raise FileNotFoundError(
f"No nifti BOLD data found in allowed spaces ({allowed_space_str})"
f"No nifti transforms found to allowed spaces ({allowed_space_str})"
)

# Grab the first (and presumably best) density and resolution if there are multiple.
Expand Down Expand Up @@ -542,7 +533,7 @@ def collect_run_data(layout, input_type, bold_file, cifti=False):
if "RepetitionTime" not in metadata["bold_metadata"].keys():
metadata["bold_metadata"]["RepetitionTime"] = _get_tr(bold_file)

if not cifti and input_type not in ("hcp", "dcan"):
if not cifti:
run_data["boldref"] = layout.get_nearest(
bids_file.path,
strict=False,
Expand All @@ -561,21 +552,17 @@ def collect_run_data(layout, input_type, bold_file, cifti=False):
to="scanner",
suffix="xfm",
)

elif not cifti:
else:
allowed_nifti_spaces = INPUT_TYPE_ALLOWED_SPACES.get(
input_type,
DEFAULT_ALLOWED_SPACES,
)["nifti"]
run_data["boldref"] = layout.get_nearest(
bids_file.path,
strict=False,
space=allowed_nifti_spaces,
suffix="boldref",
)
run_data["boldmask"] = layout.get(
return_type="file",
suffix="mask",
desc="brain",
)
run_data["t1w_to_native_xform"] = layout.get(
return_type="file", datatype="anat", suffix="xfm", to="MNI152NLin2009cAsym"
)

LOGGER.debug(
f"Collected run data for {bold_file}:\n"
Expand Down Expand Up @@ -745,72 +732,6 @@ def _get_tr(img):
raise RuntimeError("Could not extract TR - unknown data structure type")


def find_nifti_bold_files(bold_file, template_to_t1w):
"""Find nifti bold and boldref files associated with a given input file.
Parameters
----------
bold_file : str
Path to the preprocessed BOLD file that XCPD will denoise elsewhere.
If this is a cifti file, then the appropriate nifti file will be determined based on
entities in this file, as well as the space and, potentially, cohort in the
template_to_t1w file.
When this is a nifti file, it is returned without modification.
template_to_t1w : str
The transform from standard space to T1w space.
This is used to determine the volumetric template when bold_file is a cifti file.
When bold_file is a nifti file, this is not used.
Returns
-------
nifti_bold_file : str
Path to the volumetric (nifti) preprocessed BOLD file.
nifti_boldref_file : str
Path to the volumetric (nifti) BOLD reference file associated with nifti_bold_file.
"""
import glob
import os
import re

# Get the nifti reference file
if bold_file.endswith(".nii.gz"):
nifti_bold_file = bold_file
nifti_boldref_file = bold_file.split("desc-preproc_bold.nii.gz")[0] + "boldref.nii.gz"
if not os.path.isfile(nifti_boldref_file):
raise FileNotFoundError(f"boldref file not found: {nifti_boldref_file}")

else: # Get the cifti reference file
# Infer the volumetric space from the transform
nifti_template = re.findall("from-([a-zA-Z0-9+]+)", os.path.basename(template_to_t1w))[0]
if "+" in nifti_template:
nifti_template, cohort = nifti_template.split("+")
search_substring = f"space-{nifti_template}_cohort-{cohort}"
else:
search_substring = f"space-{nifti_template}"

bb_file_prefix = bold_file.split("space-fsLR_den-91k_bold.dtseries.nii")[0]

# Find the appropriate _bold file.
bold_search_str = bb_file_prefix + search_substring + "*preproc_bold.nii.gz"
nifti_bold_file = sorted(glob.glob(bold_search_str))
if len(nifti_bold_file) > 1:
LOGGER.warn(f"More than one nifti bold file found: {', '.join(nifti_bold_file)}")
elif len(nifti_bold_file) == 0:
raise FileNotFoundError(f"bold file not found: {bold_search_str}")
nifti_bold_file = nifti_bold_file[0]

# Find the associated _boldref file.
boldref_search_str = bb_file_prefix + search_substring + "*boldref.nii.gz"
nifti_boldref_file = sorted(glob.glob(boldref_search_str))
if len(nifti_boldref_file) > 1:
LOGGER.warn(f"More than one nifti boldref found: {', '.join(nifti_boldref_file)}")
elif len(nifti_boldref_file) == 0:
raise FileNotFoundError(f"boldref file not found: {boldref_search_str}")
nifti_boldref_file = nifti_boldref_file[0]

return nifti_bold_file, nifti_boldref_file


def get_freesurfer_dir(fmri_dir):
"""Find FreeSurfer derivatives associated with preprocessing pipeline.
Expand Down
10 changes: 7 additions & 3 deletions xcp_d/utils/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,11 @@ def plot_svgx(
unprocessed_figure = plt.figure(constrained_layout=True, figsize=(22.5, 30))
grid = mgs.GridSpec(5, 1, wspace=0.0, hspace=0.05, height_ratios=[1, 1, 0.2, 2.5, 1])
confoundplotx(
time_series=DVARS_timeseries, grid_spec_ts=grid[0], TR=TR, ylabel="DVARS", hide_x=True
time_series=DVARS_timeseries,
grid_spec_ts=grid[0],
TR=TR,
ylabel="DVARS",
hide_x=True,
)
confoundplotx(
time_series=unprocessed_data_timeseries,
Expand Down Expand Up @@ -794,6 +798,7 @@ def plot_carpet(
atlaslabels : numpy.ndarray, optional
A 3D array of integer labels from an atlas, resampled into ``img`` space.
Required if ``func`` is a NIfTI image.
Unused if ``func`` is a CIFTI.
detrend : bool, optional
Detrend and standardize the data prior to plotting.
size : tuple, optional
Expand All @@ -807,8 +812,7 @@ def plot_carpet(
are .png, .pdf, .svg. If output_file is not None, the plot
is saved to a file, and the display is closed.
legend : bool
Whether to render the average functional series with ``atlaslabels`` as
overlay.
Whether to render the average functional series with ``atlaslabels`` as overlay.
TR : float, optional
Specify the TR, if specified it uses this value. If left as None,
# of frames is plotted instead of time.
Expand Down
9 changes: 5 additions & 4 deletions xcp_d/workflow/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ def init_subject_wf(
"t1w",
"t1w_mask", # not used by cifti workflow
"t1w_seg",
"template_to_t1w_xform",
"template_to_t1w_xform", # not used by cifti workflow
"t1w_to_template_xform",
# surface files
"lh_pial_surf",
Expand Down Expand Up @@ -537,14 +537,15 @@ def init_subject_wf(
(inputnode, bold_postproc_wf, [
('t1w', 'inputnode.t1w'),
('t1w_seg', 'inputnode.t1seg'),
('template_to_t1w_xform', 'inputnode.template_to_t1w'),
]),
])
if not cifti:
workflow.connect([
(inputnode, bold_postproc_wf, [('t1w_mask', 'inputnode.t1w_mask')]),
(inputnode, bold_postproc_wf, [
('t1w_mask', 'inputnode.t1w_mask'),
('template_to_t1w_xform', 'inputnode.template_to_t1w'),
]),
])

# fmt:on

# fmt:off
Expand Down
Loading

0 comments on commit 786533d

Please sign in to comment.