Skip to content

ENH: Re-allow "locking" of models with first fit #120

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/nifreeze/cli/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,11 @@ def main(argv=None) -> None:

prev_model: Estimator | None = None
for _model in args.models:
single_fit = _model.lower().startswith("single")
estimator: Estimator = Estimator(
_model,
_model.lower().replace("single", ""),
prev=prev_model,
single_fit=single_fit,
)
prev_model = estimator

Expand Down
7 changes: 6 additions & 1 deletion src/nifreeze/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,19 +69,21 @@ def run(self, dataset: DatasetT, **kwargs) -> DatasetT:
class Estimator:
"""Estimates rigid-body head-motion and distortions derived from eddy-currents."""

__slots__ = ("_model", "_strategy", "_prev", "_model_kwargs", "_align_kwargs")
__slots__ = ("_model", "_single_fit", "_strategy", "_prev", "_model_kwargs", "_align_kwargs")

def __init__(
self,
model: BaseModel | str,
strategy: str = "random",
prev: Estimator | Filter | None = None,
model_kwargs: dict | None = None,
single_fit: bool = False,
**kwargs,
):
self._model = model
self._prev = prev
self._strategy = strategy
self._single_fit = single_fit
self._model_kwargs = model_kwargs or {}
self._align_kwargs = kwargs or {}

Expand Down Expand Up @@ -120,6 +122,9 @@ def run(self, dataset: DatasetT, **kwargs) -> Self:
**self._model_kwargs,
)

if self._single_fit:
self._model.fit_predict(None, njobs=n_jobs)

kwargs["num_threads"] = kwargs.pop("omp_nthreads", None) or kwargs.pop("num_threads", None)
kwargs = self._align_kwargs | kwargs

Expand Down
45 changes: 33 additions & 12 deletions src/nifreeze/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,46 +87,59 @@ class BaseModel:

"""

__slots__ = ("_dataset",)
__slots__ = ("_dataset", "_locked_fit")

def __init__(self, dataset, **kwargs):
"""Base initialization."""

self._locked_fit = None
self._dataset = dataset
# Warn if mask not present
if dataset.brainmask is None:
warn(mask_absence_warn_msg, stacklevel=2)

@abstractmethod
def fit_predict(self, index, **kwargs) -> np.ndarray:
"""Fit and predict the indicate index of the dataset (abstract signature)."""
def fit_predict(self, index: int | None = None, **kwargs) -> np.ndarray:
"""
Fit and predict the indicated index of the dataset (abstract signature).

If ``index`` is ``None``, then the model is executed in *single-fit mode* meaning
that it will be run only once in all the data available.
Please note that all the predictions of this model will suffer from data leakage
from the original volume.

Parameters
----------
index : :obj:`int` or ``None``
The index to predict.
If ``None``, no prediction will be executed.

"""
raise NotImplementedError("Cannot call fit_predict() on a BaseModel instance.")


class TrivialModel(BaseModel):
"""A trivial model that returns a given map always."""

__slots__ = ("_predicted",)

def __init__(self, dataset, predicted=None, **kwargs):
"""Implement object initialization."""

super().__init__(dataset, **kwargs)
self._predicted = (
self._locked_fit = (
predicted
if predicted is not None
# Infer from dataset if not provided at initialization
else getattr(dataset, "reference", getattr(dataset, "bzero", None))
)

if self._predicted is None:
if self._locked_fit is None:
raise TypeError("This model requires the predicted map at initialization")

def fit_predict(self, *_, **kwargs):
"""Return the reference map."""

# No need to check fit (if not fitted, has raised already)
return self._predicted
return self._locked_fit


class ExpectationModel(BaseModel):
Expand All @@ -139,7 +152,7 @@ def __init__(self, dataset, stat="median", **kwargs):
super().__init__(dataset, **kwargs)
self._stat = stat

def fit_predict(self, index: int, **kwargs):
def fit_predict(self, index: int | None = None, **kwargs):
"""
Return the expectation map.

Expand All @@ -149,12 +162,20 @@ def fit_predict(self, index: int, **kwargs):
The volume index that is left-out in fitting, and then predicted.

"""

if self._locked_fit is not None:
return self._locked_fit

# Select the summary statistic
avg_func = getattr(np, kwargs.pop("stat", self._stat))

# Create index mask
index_mask = np.ones(len(self._dataset), dtype=bool)
index_mask[index] = False

# Calculate the average
return avg_func(self._dataset[index_mask][0], axis=-1)
if index is not None:
index_mask[index] = False
# Calculate the average
return avg_func(self._dataset[index_mask][0], axis=-1)

self._locked_fit = avg_func(self._dataset[index_mask][0], axis=-1)
return self._locked_fit
30 changes: 23 additions & 7 deletions src/nifreeze/model/dmri.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,21 @@ def __init__(self, dataset: DWI, **kwargs):

super().__init__(dataset, **kwargs)

def _fit(self, index, n_jobs=None, **kwargs):
def _fit(self, index: int | None = None, n_jobs=None, **kwargs):
"""Fit the model chunk-by-chunk asynchronously"""

n_jobs = n_jobs or 1

if self._locked_fit is not None:
return n_jobs

brainmask = self._dataset.brainmask
idxmask = np.ones(len(self._dataset), dtype=bool)
idxmask[index] = False

if index is not None:
idxmask[index] = False
else:
self._locked_fit = True

data, _, gtab = self._dataset[idxmask]
# Select voxels within mask or just unravel 3D if no mask
Expand Down Expand Up @@ -122,7 +130,7 @@ def _fit(self, index, n_jobs=None, **kwargs):
self._model = None # Preempt further actions on the model
return n_jobs

def fit_predict(self, index: int, **kwargs):
def fit_predict(self, index: int | None = None, **kwargs):
"""
Predict asynchronously chunk-by-chunk the diffusion signal.

Expand All @@ -133,8 +141,14 @@ def fit_predict(self, index: int, **kwargs):

"""

n_models = self._fit(index, **kwargs)
kwargs.pop("n_jobs")
n_models = self._fit(
index,
n_jobs=max(kwargs.pop("n_jobs", None) or 1, kwargs.pop("njobs", None) or 1),
**kwargs,
)

if index is None:
return None

brainmask = self._dataset.brainmask
gradient = self._dataset.gradients[:, index]
Expand All @@ -151,7 +165,6 @@ def fit_predict(self, index: int, **kwargs):
if n_models == 1:
predicted, _ = _exec_predict(self._model, **(kwargs | {"gtab": gradient, "S0": S0}))
else:
print(n_models, S0)
S0 = np.array_split(S0, n_models) if S0 is not None else np.full(n_models, None)

predicted = [None] * n_models
Expand Down Expand Up @@ -221,9 +234,12 @@ def __init__(
self._th_high = th_high
self._detrend = detrend

def fit_predict(self, index, *_, **kwargs):
def fit_predict(self, index: int | None = None, *_, **kwargs):
"""Return the average map."""

if index is None:
raise RuntimeError(f"Model {__class__.__name__} does not allow locking.")

bvalues = self._dataset.gradients[:, -1]
bcenter = bvalues[index]

Expand Down
2 changes: 1 addition & 1 deletion src/nifreeze/model/pet.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def fit(self, data, **kwargs):

self._coeff = np.array([r[0] for r in results])

def predict(self, index=None, **kwargs):
def predict(self, index: int | None = None, **kwargs):
"""Return the corrected volume using B-spline interpolation."""
from scipy.interpolate import BSpline

Expand Down
Loading