1
+ import logging
1
2
import os
3
+ import shutil
2
4
from collections .abc import Iterator
3
5
from dataclasses import dataclass , field
4
6
from pathlib import Path
11
13
from huggingface_hub import snapshot_download
12
14
from safetensors import safe_open
13
15
16
+ # Set up logging
17
+ logger = logging .getLogger (__name__ )
18
+
14
19
DEFAULT_PARAMS_FILE = "jaxgarden_state"
15
20
16
21
@@ -64,11 +69,16 @@ def __init__(
64
69
self .rngs = rngs
65
70
66
71
@property
67
- def state (self ) -> dict [str , jnp .ndarray ]:
68
- """Splits state from the graph and returns it.
72
+ def state (self ) -> nnx .State :
73
+ """Splits state from the graph and returns it"""
74
+ return nnx .split (self , nnx .Param , ...)[1 ]
75
+
76
+ @property
77
+ def state_dict (self ) -> dict [str , jnp .ndarray ]:
78
+ """Splits state from the graph and returns it as a dictionary.
69
79
70
80
It can be used for serialization with orbax."""
71
- state = nnx . split ( self , nnx . Param , ...)[ 1 ]
81
+ state = self . state
72
82
pure_dict_state = nnx .to_pure_dict (state )
73
83
return pure_dict_state
74
84
@@ -78,7 +88,7 @@ def save(self, path: str) -> None:
78
88
Args:
79
89
path: The directory path to save the model state to.
80
90
"""
81
- state = self .state
91
+ state = self .state_dict
82
92
checkpointer = ocp .StandardCheckpointer ()
83
93
checkpointer .save (os .path .join (path , DEFAULT_PARAMS_FILE ), state )
84
94
checkpointer .wait_until_finished ()
@@ -97,20 +107,30 @@ def load(self, path: str) -> nnx.Module:
97
107
return nnx .merge (graphdef , abstract_state )
98
108
99
109
@staticmethod
100
- def download_from_hf (repo_id : str , local_dir : str ) -> None :
110
+ def download_from_hf (
111
+ repo_id : str , local_dir : str , token : str | None = None , force_download : bool = False
112
+ ) -> None :
101
113
"""Downloads the model from the Hugging Face Hub.
102
114
103
115
Args:
104
116
repo_id: The repository ID of the model to download.
105
117
local_dir: The local directory to save the model to.
106
118
"""
107
- snapshot_download (repo_id , local_dir = local_dir )
119
+ logger .info (f"Attempting to download { repo_id } from Hugging Face Hub to { local_dir } ." )
120
+ try :
121
+ snapshot_download (
122
+ repo_id , local_dir = local_dir , token = token , force_download = force_download
123
+ )
124
+ logger .info (f"Successfully downloaded { repo_id } to { local_dir } ." )
125
+ except Exception as e :
126
+ logger .error (f"Failed to download { repo_id } : { e } " )
127
+ raise
108
128
109
129
@staticmethod
110
- def load_safetensors (path_to_model_weights : str ) -> Iterator [tuple [Any , Any ]]:
130
+ def iter_safetensors (path_to_model_weights : str ) -> Iterator [tuple [Any , Any ]]:
111
131
"""Helper function to lazily load params from safetensors file.
112
132
113
- Use this static method to load weights for conversion tasks.
133
+ Use this static method to iterate over weights for conversion tasks.
114
134
115
135
Args:
116
136
model_path_to_params: Path to directory containing .safetensors files."""
@@ -121,5 +141,72 @@ def load_safetensors(path_to_model_weights: str) -> Iterator[tuple[Any, Any]]:
121
141
122
142
for file in safetensors_files :
123
143
with safe_open (file , framework = "jax" , device = "cpu" ) as f :
124
- for key in f :
144
+ for key in f . keys (): # noqa: SIM118
125
145
yield (key , f .get_tensor (key ))
146
+
147
+ def from_hf (
148
+ self ,
149
+ model_repo_or_id : str ,
150
+ token : str | None = None ,
151
+ force_download : bool = False ,
152
+ save_in_orbax : bool = True ,
153
+ remove_hf_after_conversion : bool = True ,
154
+ ) -> None :
155
+ """Downloads the model from the Hugging Face Hub and returns a new instance of the model.
156
+
157
+ It can also save the converted weights in an Orbax checkpoint
158
+ and removes the original HF checkpoint after conversion.
159
+
160
+ Args:
161
+ model_repo_or_id: The repository ID or name of the model to download.
162
+ token: The token to use for authentication with the Hugging Face Hub.
163
+ save_in_orbax: Whether to save the converted weights in an Orbax checkpoint.
164
+ remove_hf_after_conversion: Whether to remove the downloaded HuggingFace checkpoint
165
+ after conversion.
166
+ """
167
+ logger .info (f"Starting from_hf process for model: { model_repo_or_id } " )
168
+ local_dir = os .path .join (
169
+ os .path .expanduser ("~" ), ".jaxgarden" , "hf_models" , * model_repo_or_id .split ("/" )
170
+ )
171
+ save_dir = local_dir .replace ("hf_models" , "models" )
172
+ if os .path .exists (save_dir ):
173
+ if force_download :
174
+ logger .warn (f"Removing { save_dir } because force_download is set to True" )
175
+ shutil .rmtree (save_dir )
176
+ else :
177
+ raise RuntimeError (
178
+ f"Path { save_dir } already exists."
179
+ + " Set force_download to Tru to run conversion again."
180
+ )
181
+
182
+ logger .debug (f"Local Hugging Face model directory set to: { local_dir } " )
183
+
184
+ BaseModel .download_from_hf (
185
+ model_repo_or_id , local_dir , token = token , force_download = force_download
186
+ )
187
+ logger .info (f"Initiating weight iteration from safetensors in { local_dir } " )
188
+ weights = BaseModel .iter_safetensors (local_dir )
189
+ state = self .state
190
+ logger .info ("Running weight conversion..." )
191
+ self .convert_weights_from_hf (state , weights )
192
+ logger .info ("Weight conversion finished. Updating model state..." )
193
+ nnx .update (self , state )
194
+ logger .warn ("Model state successfully updated with converted weights." )
195
+
196
+ if remove_hf_after_conversion :
197
+ logger .warn (f"Removing HuggingFace checkpoint from { local_dir } ..." )
198
+ shutil .rmtree (local_dir )
199
+
200
+ if save_in_orbax :
201
+ logger .warn (f")Saving Orbax checkpoint in { save_dir } ." )
202
+ self .save (save_dir )
203
+
204
+ logger .warn (f"from_hf process completed for { model_repo_or_id } ." )
205
+
206
+ def convert_weights_from_hf (self , state : nnx .State , weights : Iterator [tuple [Any , Any ]]) -> None :
207
+ """Convert weights from Hugging Face Hub to the model's state.
208
+
209
+ This method should be implemented in downstream classes
210
+ to support conversion from HuggingFace format.
211
+ """
212
+ raise NotImplementedError ("This model does not support conversion from HuggingFace yet." )
0 commit comments