Skip to content

Commit d3e39e1

Browse files
authored
new: Implement Llama3 (#24)
* Add LLamaForCusalLM * Add docstrings * Add tests * Expose llama config and model * fix: linting with ruff * Add tokenization and working generation * fix linting and tests * type fixes * Update docstrings and readme * ruff is happy
1 parent c9a740a commit d3e39e1

File tree

12 files changed

+1219
-13
lines changed

12 files changed

+1219
-13
lines changed

README.md

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,39 @@ pip install -e .
3737

3838
## Usage
3939

40+
### LLaMA inference
41+
42+
```python
43+
from jaxgarden import LlamaConfig, LlamaForCausalLM, Tokenizer
44+
from flax import nnx
45+
46+
47+
# HF repo id of the LLaMA variant that you want to use
48+
model_id = "meta-llama/Llama-3.2-1B"
49+
50+
# initialize the LLaMA architecture
51+
config = LlamaConfig()
52+
model = LlamaForCausalLM(config, rngs=nnx.Rngs(0))
53+
54+
# This is a one-liner to download HF checkpoint from HuggingFace Hub,
55+
# convert it to jaxgarden format,
56+
# save it in an Orbax checkpoint,
57+
# and then remove the HF checkpoint.
58+
model.from_hf(model_id)
59+
60+
# this works just like `transformers.AutoTokenizer`,
61+
# but without the dependency of the whole `transformers` library.
62+
# Instead, we simply extend `tokenizers` package and add some cnvenience code for JAX.
63+
tokenizer = Tokenizer.from_pretrained(model_id)
64+
65+
text = "The meaning of life is"
66+
model_inputs = tokenizer.encode(text)
67+
output = model.generate(**model_inputs, max_length=20, do_sample=True)
68+
output_text = tokenizer.decode(output)
69+
print(output_text)
70+
```
71+
72+
4073
### MultiHeadAttention Module (Flax NNX)
4174

4275
```python

examples/llama_inference_example.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from flax import nnx
2+
3+
from jaxgarden import LlamaConfig, LlamaForCausalLM, Tokenizer
4+
5+
if __name__ == "__main__":
6+
# initialize a config object (with defaults for 1B varient)
7+
# other varients to be added.
8+
config = LlamaConfig()
9+
model = LlamaForCausalLM(config, rngs=nnx.Rngs(0))
10+
model_id = "meta-llama/Llama-3.2-1B"
11+
12+
# this will download HF checkpoint from HuggingFace Hub,
13+
# convert it to jaxgarden format,
14+
# save it in an Orbax checkpoint,
15+
# and then remove the HF checkpoint.
16+
# If you didn't set your HF token globally,
17+
# you may need to pass your token as an argument to this method.
18+
model.from_hf(model_id, force_download=True)
19+
20+
# this works just like `transformers.AutoTokenizer`,
21+
# but without the dependency of the whole `transformers` library.
22+
# Instead, we simply extend `tokenizers` package and add some cnvenience code for JAX.
23+
tokenizer = Tokenizer.from_pretrained(model_id)
24+
25+
text = "The meaning of life is"
26+
model_inputs = tokenizer.encode(text)
27+
output = model.generate(**model_inputs, max_length=20, do_sample=True)
28+
output_text = tokenizer.decode(output)
29+
print(output, output.shape)
30+
print(output_text)

jaxgarden/__init__.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,15 @@
44
from jaxgarden.functional.attention import dot_product_attention
55
from jaxgarden.models.base import BaseConfig, BaseModel
66
from jaxgarden.models.generation_utils import GenerationMixin
7+
from jaxgarden.models.llama import (
8+
LlamaAttention,
9+
LlamaConfig,
10+
LlamaForCausalLM,
11+
LlamaMLP,
12+
LlamaRMSNorm,
13+
LlamaRotaryEmbedding,
14+
LlamaTransformerBlock,
15+
)
716
from jaxgarden.models.modernbert import (
817
ModernBertAttention,
918
ModernBertEmbeddings,
@@ -12,6 +21,7 @@
1221
ModernBertLayer,
1322
ModernBertMLP,
1423
)
24+
from jaxgarden.tokenization import Tokenizer
1525

1626
__all__ = [
1727
# Base classes
@@ -20,6 +30,13 @@
2030
# Mixins
2131
"GenerationMixin",
2232
# Models
33+
"LlamaAttention",
34+
"LlamaConfig",
35+
"LlamaForCausalLM",
36+
"LlamaMLP",
37+
"LlamaRMSNorm",
38+
"LlamaRotaryEmbedding",
39+
"LlamaTransformerBlock",
2340
"ModernBERTEncoder",
2441
"ModernBERTForMaskedLM",
2542
"ModernBertAttention",
@@ -28,6 +45,8 @@
2845
"ModernBertMLP",
2946
# Attention modules
3047
"MultiHeadAttention",
48+
# tokenization
49+
"Tokenizer",
3150
# Functional interfaces
3251
"dot_product_attention",
3352
]

jaxgarden/models/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
11
from jaxgarden.models.base import BaseConfig, BaseModel
22
from jaxgarden.models.generation_utils import GenerationMixin
3+
from jaxgarden.models.llama import (
4+
LlamaAttention,
5+
LlamaConfig,
6+
LlamaForCausalLM,
7+
LlamaMLP,
8+
LlamaRMSNorm,
9+
LlamaRotaryEmbedding,
10+
LlamaTransformerBlock,
11+
)
312
from jaxgarden.models.modernbert import (
413
ModernBertAttention,
514
ModernBertEmbeddings,
@@ -13,6 +22,13 @@
1322
"BaseConfig",
1423
"BaseModel",
1524
"GenerationMixin",
25+
"LlamaAttention",
26+
"LlamaConfig",
27+
"LlamaForCausalLM",
28+
"LlamaMLP",
29+
"LlamaRMSNorm",
30+
"LlamaRotaryEmbedding",
31+
"LlamaTransformerBlock",
1632
"ModernBERTEncoder",
1733
"ModernBERTForMaskedLM",
1834
"ModernBertAttention",

jaxgarden/models/base.py

Lines changed: 96 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import logging
12
import os
3+
import shutil
24
from collections.abc import Iterator
35
from dataclasses import dataclass, field
46
from pathlib import Path
@@ -11,6 +13,9 @@
1113
from huggingface_hub import snapshot_download
1214
from safetensors import safe_open
1315

16+
# Set up logging
17+
logger = logging.getLogger(__name__)
18+
1419
DEFAULT_PARAMS_FILE = "jaxgarden_state"
1520

1621

@@ -64,11 +69,16 @@ def __init__(
6469
self.rngs = rngs
6570

6671
@property
67-
def state(self) -> dict[str, jnp.ndarray]:
68-
"""Splits state from the graph and returns it.
72+
def state(self) -> nnx.State:
73+
"""Splits state from the graph and returns it"""
74+
return nnx.split(self, nnx.Param, ...)[1]
75+
76+
@property
77+
def state_dict(self) -> dict[str, jnp.ndarray]:
78+
"""Splits state from the graph and returns it as a dictionary.
6979
7080
It can be used for serialization with orbax."""
71-
state = nnx.split(self, nnx.Param, ...)[1]
81+
state = self.state
7282
pure_dict_state = nnx.to_pure_dict(state)
7383
return pure_dict_state
7484

@@ -78,7 +88,7 @@ def save(self, path: str) -> None:
7888
Args:
7989
path: The directory path to save the model state to.
8090
"""
81-
state = self.state
91+
state = self.state_dict
8292
checkpointer = ocp.StandardCheckpointer()
8393
checkpointer.save(os.path.join(path, DEFAULT_PARAMS_FILE), state)
8494
checkpointer.wait_until_finished()
@@ -97,20 +107,30 @@ def load(self, path: str) -> nnx.Module:
97107
return nnx.merge(graphdef, abstract_state)
98108

99109
@staticmethod
100-
def download_from_hf(repo_id: str, local_dir: str) -> None:
110+
def download_from_hf(
111+
repo_id: str, local_dir: str, token: str | None = None, force_download: bool = False
112+
) -> None:
101113
"""Downloads the model from the Hugging Face Hub.
102114
103115
Args:
104116
repo_id: The repository ID of the model to download.
105117
local_dir: The local directory to save the model to.
106118
"""
107-
snapshot_download(repo_id, local_dir=local_dir)
119+
logger.info(f"Attempting to download {repo_id} from Hugging Face Hub to {local_dir}.")
120+
try:
121+
snapshot_download(
122+
repo_id, local_dir=local_dir, token=token, force_download=force_download
123+
)
124+
logger.info(f"Successfully downloaded {repo_id} to {local_dir}.")
125+
except Exception as e:
126+
logger.error(f"Failed to download {repo_id}: {e}")
127+
raise
108128

109129
@staticmethod
110-
def load_safetensors(path_to_model_weights: str) -> Iterator[tuple[Any, Any]]:
130+
def iter_safetensors(path_to_model_weights: str) -> Iterator[tuple[Any, Any]]:
111131
"""Helper function to lazily load params from safetensors file.
112132
113-
Use this static method to load weights for conversion tasks.
133+
Use this static method to iterate over weights for conversion tasks.
114134
115135
Args:
116136
model_path_to_params: Path to directory containing .safetensors files."""
@@ -121,5 +141,72 @@ def load_safetensors(path_to_model_weights: str) -> Iterator[tuple[Any, Any]]:
121141

122142
for file in safetensors_files:
123143
with safe_open(file, framework="jax", device="cpu") as f:
124-
for key in f:
144+
for key in f.keys(): # noqa: SIM118
125145
yield (key, f.get_tensor(key))
146+
147+
def from_hf(
148+
self,
149+
model_repo_or_id: str,
150+
token: str | None = None,
151+
force_download: bool = False,
152+
save_in_orbax: bool = True,
153+
remove_hf_after_conversion: bool = True,
154+
) -> None:
155+
"""Downloads the model from the Hugging Face Hub and returns a new instance of the model.
156+
157+
It can also save the converted weights in an Orbax checkpoint
158+
and removes the original HF checkpoint after conversion.
159+
160+
Args:
161+
model_repo_or_id: The repository ID or name of the model to download.
162+
token: The token to use for authentication with the Hugging Face Hub.
163+
save_in_orbax: Whether to save the converted weights in an Orbax checkpoint.
164+
remove_hf_after_conversion: Whether to remove the downloaded HuggingFace checkpoint
165+
after conversion.
166+
"""
167+
logger.info(f"Starting from_hf process for model: {model_repo_or_id}")
168+
local_dir = os.path.join(
169+
os.path.expanduser("~"), ".jaxgarden", "hf_models", *model_repo_or_id.split("/")
170+
)
171+
save_dir = local_dir.replace("hf_models", "models")
172+
if os.path.exists(save_dir):
173+
if force_download:
174+
logger.warn(f"Removing {save_dir} because force_download is set to True")
175+
shutil.rmtree(save_dir)
176+
else:
177+
raise RuntimeError(
178+
f"Path {save_dir} already exists."
179+
+ " Set force_download to Tru to run conversion again."
180+
)
181+
182+
logger.debug(f"Local Hugging Face model directory set to: {local_dir}")
183+
184+
BaseModel.download_from_hf(
185+
model_repo_or_id, local_dir, token=token, force_download=force_download
186+
)
187+
logger.info(f"Initiating weight iteration from safetensors in {local_dir}")
188+
weights = BaseModel.iter_safetensors(local_dir)
189+
state = self.state
190+
logger.info("Running weight conversion...")
191+
self.convert_weights_from_hf(state, weights)
192+
logger.info("Weight conversion finished. Updating model state...")
193+
nnx.update(self, state)
194+
logger.warn("Model state successfully updated with converted weights.")
195+
196+
if remove_hf_after_conversion:
197+
logger.warn(f"Removing HuggingFace checkpoint from {local_dir}...")
198+
shutil.rmtree(local_dir)
199+
200+
if save_in_orbax:
201+
logger.warn(f")Saving Orbax checkpoint in {save_dir}.")
202+
self.save(save_dir)
203+
204+
logger.warn(f"from_hf process completed for {model_repo_or_id}.")
205+
206+
def convert_weights_from_hf(self, state: nnx.State, weights: Iterator[tuple[Any, Any]]) -> None:
207+
"""Convert weights from Hugging Face Hub to the model's state.
208+
209+
This method should be implemented in downstream classes
210+
to support conversion from HuggingFace format.
211+
"""
212+
raise NotImplementedError("This model does not support conversion from HuggingFace yet.")

jaxgarden/models/generation_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,7 @@ def scan_step(carry: dict, _: Any) -> tuple[dict, None]:
363363
def generate(
364364
self: "GenerationMixin",
365365
input_ids: jnp.ndarray,
366+
attention_mask: jnp.ndarray | None = None,
366367
max_length: int = 20,
367368
temperature: float = 1.0,
368369
top_k: int | None = None,

0 commit comments

Comments
 (0)