|
| 1 | +import collections |
1 | 2 | import importlib
|
2 | 3 | import inspect
|
3 | 4 | import logging
|
|
17 | 18 |
|
18 | 19 | pylogger = logging.getLogger(__name__)
|
19 | 20 |
|
| 21 | +from typing import Mapping |
20 | 22 |
|
21 | 23 | _METADATA_MODULE_KEY = f"{METADATA_KEY}_module"
|
22 | 24 | _METADATA_CLASS_KEY = f"{METADATA_KEY}_class"
|
@@ -124,14 +126,41 @@ def extract_checkpoint(ckpt_file: Path) -> Path:
|
124 | 126 | yield Path(tmp_dir)
|
125 | 127 |
|
126 | 128 |
|
| 129 | +def _substistute(dictionary, substitute_values: Dict[str, str], substitute_keys: Dict[str, str] = {}): |
| 130 | + if not isinstance(dictionary, Mapping): |
| 131 | + if isinstance(dictionary, collections.Hashable): |
| 132 | + if substitute_values is not None and dictionary in substitute_values: |
| 133 | + return substitute_values[dictionary] |
| 134 | + elif substitute_keys is not None and dictionary in substitute_keys: |
| 135 | + return substitute_keys[dictionary] |
| 136 | + else: |
| 137 | + return dictionary |
| 138 | + return dictionary |
| 139 | + |
| 140 | + return { |
| 141 | + _substistute(key, substitute_values=substitute_values, substitute_keys=substitute_keys,): _substistute( |
| 142 | + value, |
| 143 | + substitute_values=substitute_values, |
| 144 | + substitute_keys=substitute_keys, |
| 145 | + ) |
| 146 | + for key, value in dictionary.items() |
| 147 | + } |
| 148 | + |
| 149 | + |
127 | 150 | def load_model(
|
128 | 151 | module_class: Type[pl.LightningModule],
|
129 | 152 | checkpoint_path: Path,
|
130 | 153 | map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None,
|
| 154 | + substitute_keys: Optional[Dict[str, str]] = None, |
| 155 | + substitute_values: Optional[Dict[str, str]] = None, |
131 | 156 | ):
|
132 | 157 | # Lightning checkpoints end with .ckpt, ours with .ckpt.zip
|
133 | 158 | if checkpoint_path.name.endswith(".ckpt.zip"):
|
134 | 159 | checkpoint = NNCheckpointIO.load(path=checkpoint_path, map_location=map_location)
|
| 160 | + |
| 161 | + if substitute_values is not None: |
| 162 | + checkpoint = _substistute(checkpoint, substitute_values=substitute_values, substitute_keys=substitute_keys) |
| 163 | + |
135 | 164 | return module_class._load_model_state(checkpoint=checkpoint, metadata=checkpoint.get("metadata", None))
|
136 | 165 | else:
|
137 | 166 | pylogger.warning(f"Loading a legacy checkpoint (from vanilla PyTorch Lightning): '{checkpoint_path}'")
|
|
0 commit comments