Skip to content

Commit 6ecc71f

Browse files
authored
Merge pull request #24 from grok-ai/develop
Version 0.1.1
2 parents f97aa83 + 4a8ecb2 commit 6ecc71f

File tree

3 files changed

+34
-1
lines changed

3 files changed

+34
-1
lines changed

src/nn_core/callbacks.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,4 +53,6 @@ def on_save_checkpoint(
5353
) -> None:
5454
if self._is_nnlogger(trainer):
5555
trainer.logger.on_save_checkpoint(trainer=trainer, pl_module=pl_module, checkpoint=checkpoint)
56-
checkpoint[METADATA_KEY] = trainer.datamodule.metadata
56+
metadata = getattr(trainer.datamodule, "metadata", None)
57+
if metadata is not None:
58+
checkpoint[METADATA_KEY] = metadata

src/nn_core/common/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ def load_envs(env_file: Optional[str] = None) -> None:
5151
:param env_file: the file that defines the environment variables to use. If None
5252
it searches for a `.env` file in the project.
5353
"""
54+
if env_file is None:
55+
env_file = dotenv.find_dotenv(usecwd=True)
5456
dotenv.load_dotenv(dotenv_path=env_file, override=True)
5557

5658

src/nn_core/serialization.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import collections
12
import importlib
23
import inspect
34
import logging
@@ -17,6 +18,7 @@
1718

1819
pylogger = logging.getLogger(__name__)
1920

21+
from typing import Mapping
2022

2123
_METADATA_MODULE_KEY = f"{METADATA_KEY}_module"
2224
_METADATA_CLASS_KEY = f"{METADATA_KEY}_class"
@@ -124,14 +126,41 @@ def extract_checkpoint(ckpt_file: Path) -> Path:
124126
yield Path(tmp_dir)
125127

126128

129+
def _substistute(dictionary, substitute_values: Dict[str, str], substitute_keys: Dict[str, str] = {}):
130+
if not isinstance(dictionary, Mapping):
131+
if isinstance(dictionary, collections.Hashable):
132+
if substitute_values is not None and dictionary in substitute_values:
133+
return substitute_values[dictionary]
134+
elif substitute_keys is not None and dictionary in substitute_keys:
135+
return substitute_keys[dictionary]
136+
else:
137+
return dictionary
138+
return dictionary
139+
140+
return {
141+
_substistute(key, substitute_values=substitute_values, substitute_keys=substitute_keys,): _substistute(
142+
value,
143+
substitute_values=substitute_values,
144+
substitute_keys=substitute_keys,
145+
)
146+
for key, value in dictionary.items()
147+
}
148+
149+
127150
def load_model(
128151
module_class: Type[pl.LightningModule],
129152
checkpoint_path: Path,
130153
map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None,
154+
substitute_keys: Optional[Dict[str, str]] = None,
155+
substitute_values: Optional[Dict[str, str]] = None,
131156
):
132157
# Lightning checkpoints end with .ckpt, ours with .ckpt.zip
133158
if checkpoint_path.name.endswith(".ckpt.zip"):
134159
checkpoint = NNCheckpointIO.load(path=checkpoint_path, map_location=map_location)
160+
161+
if substitute_values is not None:
162+
checkpoint = _substistute(checkpoint, substitute_values=substitute_values, substitute_keys=substitute_keys)
163+
135164
return module_class._load_model_state(checkpoint=checkpoint, metadata=checkpoint.get("metadata", None))
136165
else:
137166
pylogger.warning(f"Loading a legacy checkpoint (from vanilla PyTorch Lightning): '{checkpoint_path}'")

0 commit comments

Comments
 (0)