diff --git a/src/nifreeze/cli/run.py b/src/nifreeze/cli/run.py index df493cb1..5b8366b3 100644 --- a/src/nifreeze/cli/run.py +++ b/src/nifreeze/cli/run.py @@ -69,9 +69,11 @@ def main(argv=None) -> None: prev_model: Estimator | None = None for _model in args.models: + single_fit = _model.lower().startswith("single") estimator: Estimator = Estimator( - _model, + _model.lower().replace("single", ""), prev=prev_model, + single_fit=single_fit, ) prev_model = estimator diff --git a/src/nifreeze/estimator.py b/src/nifreeze/estimator.py index a27a69bf..59edf206 100644 --- a/src/nifreeze/estimator.py +++ b/src/nifreeze/estimator.py @@ -69,7 +69,7 @@ def run(self, dataset: DatasetT, **kwargs) -> DatasetT: class Estimator: """Estimates rigid-body head-motion and distortions derived from eddy-currents.""" - __slots__ = ("_model", "_strategy", "_prev", "_model_kwargs", "_align_kwargs") + __slots__ = ("_model", "_single_fit", "_strategy", "_prev", "_model_kwargs", "_align_kwargs") def __init__( self, @@ -77,11 +77,13 @@ def __init__( strategy: str = "random", prev: Estimator | Filter | None = None, model_kwargs: dict | None = None, + single_fit: bool = False, **kwargs, ): self._model = model self._prev = prev self._strategy = strategy + self._single_fit = single_fit self._model_kwargs = model_kwargs or {} self._align_kwargs = kwargs or {} @@ -120,6 +122,9 @@ def run(self, dataset: DatasetT, **kwargs) -> Self: **self._model_kwargs, ) + if self._single_fit: + self._model.fit_predict(None, njobs=n_jobs) + kwargs["num_threads"] = kwargs.pop("omp_nthreads", None) or kwargs.pop("num_threads", None) kwargs = self._align_kwargs | kwargs diff --git a/src/nifreeze/model/base.py b/src/nifreeze/model/base.py index d127468b..2bfae049 100644 --- a/src/nifreeze/model/base.py +++ b/src/nifreeze/model/base.py @@ -87,46 +87,59 @@ class BaseModel: """ - __slots__ = ("_dataset",) + __slots__ = ("_dataset", "_locked_fit") def __init__(self, dataset, **kwargs): """Base initialization.""" + self._locked_fit = None self._dataset = dataset # Warn if mask not present if dataset.brainmask is None: warn(mask_absence_warn_msg, stacklevel=2) @abstractmethod - def fit_predict(self, index, **kwargs) -> np.ndarray: - """Fit and predict the indicate index of the dataset (abstract signature).""" + def fit_predict(self, index: int | None = None, **kwargs) -> np.ndarray: + """ + Fit and predict the indicated index of the dataset (abstract signature). + + If ``index`` is ``None``, then the model is executed in *single-fit mode* meaning + that it will be run only once in all the data available. + Please note that all the predictions of this model will suffer from data leakage + from the original volume. + + Parameters + ---------- + index : :obj:`int` or ``None`` + The index to predict. + If ``None``, no prediction will be executed. + + """ raise NotImplementedError("Cannot call fit_predict() on a BaseModel instance.") class TrivialModel(BaseModel): """A trivial model that returns a given map always.""" - __slots__ = ("_predicted",) - def __init__(self, dataset, predicted=None, **kwargs): """Implement object initialization.""" super().__init__(dataset, **kwargs) - self._predicted = ( + self._locked_fit = ( predicted if predicted is not None # Infer from dataset if not provided at initialization else getattr(dataset, "reference", getattr(dataset, "bzero", None)) ) - if self._predicted is None: + if self._locked_fit is None: raise TypeError("This model requires the predicted map at initialization") def fit_predict(self, *_, **kwargs): """Return the reference map.""" # No need to check fit (if not fitted, has raised already) - return self._predicted + return self._locked_fit class ExpectationModel(BaseModel): @@ -139,7 +152,7 @@ def __init__(self, dataset, stat="median", **kwargs): super().__init__(dataset, **kwargs) self._stat = stat - def fit_predict(self, index: int, **kwargs): + def fit_predict(self, index: int | None = None, **kwargs): """ Return the expectation map. @@ -149,12 +162,20 @@ def fit_predict(self, index: int, **kwargs): The volume index that is left-out in fitting, and then predicted. """ + + if self._locked_fit is not None: + return self._locked_fit + # Select the summary statistic avg_func = getattr(np, kwargs.pop("stat", self._stat)) # Create index mask index_mask = np.ones(len(self._dataset), dtype=bool) - index_mask[index] = False - # Calculate the average - return avg_func(self._dataset[index_mask][0], axis=-1) + if index is not None: + index_mask[index] = False + # Calculate the average + return avg_func(self._dataset[index_mask][0], axis=-1) + + self._locked_fit = avg_func(self._dataset[index_mask][0], axis=-1) + return self._locked_fit diff --git a/src/nifreeze/model/dmri.py b/src/nifreeze/model/dmri.py index f8041de0..4f203fa1 100644 --- a/src/nifreeze/model/dmri.py +++ b/src/nifreeze/model/dmri.py @@ -77,13 +77,21 @@ def __init__(self, dataset: DWI, **kwargs): super().__init__(dataset, **kwargs) - def _fit(self, index, n_jobs=None, **kwargs): + def _fit(self, index: int | None = None, n_jobs=None, **kwargs): """Fit the model chunk-by-chunk asynchronously""" + n_jobs = n_jobs or 1 + if self._locked_fit is not None: + return n_jobs + brainmask = self._dataset.brainmask idxmask = np.ones(len(self._dataset), dtype=bool) - idxmask[index] = False + + if index is not None: + idxmask[index] = False + else: + self._locked_fit = True data, _, gtab = self._dataset[idxmask] # Select voxels within mask or just unravel 3D if no mask @@ -122,7 +130,7 @@ def _fit(self, index, n_jobs=None, **kwargs): self._model = None # Preempt further actions on the model return n_jobs - def fit_predict(self, index: int, **kwargs): + def fit_predict(self, index: int | None = None, **kwargs): """ Predict asynchronously chunk-by-chunk the diffusion signal. @@ -133,8 +141,14 @@ def fit_predict(self, index: int, **kwargs): """ - n_models = self._fit(index, **kwargs) - kwargs.pop("n_jobs") + n_models = self._fit( + index, + n_jobs=max(kwargs.pop("n_jobs", None) or 1, kwargs.pop("njobs", None) or 1), + **kwargs, + ) + + if index is None: + return None brainmask = self._dataset.brainmask gradient = self._dataset.gradients[:, index] @@ -151,7 +165,6 @@ def fit_predict(self, index: int, **kwargs): if n_models == 1: predicted, _ = _exec_predict(self._model, **(kwargs | {"gtab": gradient, "S0": S0})) else: - print(n_models, S0) S0 = np.array_split(S0, n_models) if S0 is not None else np.full(n_models, None) predicted = [None] * n_models @@ -221,9 +234,12 @@ def __init__( self._th_high = th_high self._detrend = detrend - def fit_predict(self, index, *_, **kwargs): + def fit_predict(self, index: int | None = None, *_, **kwargs): """Return the average map.""" + if index is None: + raise RuntimeError(f"Model {__class__.__name__} does not allow locking.") + bvalues = self._dataset.gradients[:, -1] bcenter = bvalues[index] diff --git a/src/nifreeze/model/pet.py b/src/nifreeze/model/pet.py index 2a911315..1a857c27 100644 --- a/src/nifreeze/model/pet.py +++ b/src/nifreeze/model/pet.py @@ -112,7 +112,7 @@ def fit(self, data, **kwargs): self._coeff = np.array([r[0] for r in results]) - def predict(self, index=None, **kwargs): + def predict(self, index: int | None = None, **kwargs): """Return the corrected volume using B-spline interpolation.""" from scipy.interpolate import BSpline