Skip to content

Commit 9a33772

Browse files
authored
[Model] Add support for Nemotron architecture (#3069)
This PR adds support for Nemotron architecture, and is in reference to #2901 [Request for Nemotron-Mini-4B-Instruct] Based on my analysis of the Nemotron architecture in the huggingface repository, it appears to share similarities with the Llama architecture, but with the following key distinctions: - The activation function used in the MLP is `relu2` (squared ReLU). - The MLP includes `up_proj` and `down_proj`, but does not have a `gate_proj` as seen in Llama. - It uses `layernorm1p`, and the normalization layer incorporates a bias term. - The architecture employs a `partial_rotary_factor`, which is similar to the approach used in the Phi architecture.
1 parent 8a1bfd6 commit 9a33772

File tree

8 files changed

+734
-0
lines changed

8 files changed

+734
-0
lines changed

python/mlc_llm/conversation_template/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
llama,
2020
llava,
2121
mistral,
22+
nemotron,
2223
oasst,
2324
olmo,
2425
orion,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
"""nemotron default templates"""
2+
3+
from mlc_llm.protocol.conversation_protocol import Conversation, MessagePlaceholders
4+
5+
from .registry import ConvTemplateRegistry
6+
7+
# Nemotron template
8+
# https://huggingface.co/nvidia/Nemotron-Mini-4B-Instruct/blob/6a417790c444fd65a3da6a5c8821de6afc9654a6/tokenizer_config.json#L8030
9+
ConvTemplateRegistry.register_conv_template(
10+
Conversation(
11+
name="nemotron",
12+
system_template=("<extra_id_0>System\n" f"{MessagePlaceholders.SYSTEM.value}\n\n"),
13+
system_message="",
14+
roles={
15+
"user": "<extra_id_1>User",
16+
"assistant": "<extra_id_1>Assistant",
17+
"tool": "<extra_id_1>Tool",
18+
},
19+
seps=["\n"],
20+
role_content_sep="\n",
21+
role_empty_sep="\n",
22+
stop_str=["</s>"],
23+
stop_token_ids=[3],
24+
system_prefix_token_ids=[2],
25+
add_role_after_system_message=True,
26+
)
27+
)

python/mlc_llm/interface/gen_config.py

+1
Original file line numberDiff line numberDiff line change
@@ -310,4 +310,5 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b
310310
"deepseek_v2",
311311
"deepseek",
312312
"olmo",
313+
"nemotron",
313314
}

python/mlc_llm/model/model.py

+17
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from .minicpm import minicpm_loader, minicpm_model, minicpm_quantization
3030
from .mistral import mistral_loader, mistral_model, mistral_quantization
3131
from .mixtral import mixtral_loader, mixtral_model, mixtral_quantization
32+
from .nemotron import nemotron_loader, nemotron_model, nemotron_quantization
3233
from .olmo import olmo_loader, olmo_model, olmo_quantization
3334
from .orion import orion_loader, orion_model, orion_quantization
3435
from .phi import phi_loader, phi_model, phi_quantization
@@ -565,4 +566,20 @@ class Model:
565566
"per-tensor-quant": olmo_quantization.per_tensor_quant,
566567
},
567568
),
569+
"nemotron": Model(
570+
name="nemotron",
571+
model=nemotron_model.NemotronForCausalLM,
572+
config=nemotron_model.NemotronConfig,
573+
source={
574+
"huggingface-torch": nemotron_loader.huggingface,
575+
"huggingface-safetensor": nemotron_loader.huggingface,
576+
},
577+
quantize={
578+
"no-quant": nemotron_quantization.no_quant,
579+
"group-quant": nemotron_quantization.group_quant,
580+
"ft-quant": nemotron_quantization.ft_quant,
581+
"awq": nemotron_quantization.awq_quant,
582+
"per-tensor-quant": nemotron_quantization.per_tensor_quant,
583+
},
584+
),
568585
}

python/mlc_llm/model/nemotron/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
"""
2+
This file specifies how MLC's Nemotron parameter maps from other formats, for example HuggingFace
3+
PyTorch, HuggingFace safetensors.
4+
"""
5+
6+
import functools
7+
8+
import numpy as np
9+
10+
from mlc_llm.loader import ExternMapping
11+
from mlc_llm.quantization import Quantization
12+
13+
from .nemotron_model import NemotronConfig, NemotronForCausalLM
14+
15+
16+
def huggingface(model_config: NemotronConfig, quantization: Quantization) -> ExternMapping:
17+
"""Returns a parameter mapping that maps from the names of MLC LLM parameters to
18+
the names of HuggingFace PyTorch parameters.
19+
20+
Parameters
21+
----------
22+
model_config : NemotronConfig
23+
The configuration of the Nemotron model.
24+
25+
quantization : Quantization
26+
The quantization configuration.
27+
28+
Returns
29+
-------
30+
param_map : ExternMapping
31+
The parameter mapping from MLC to HuggingFace PyTorch.
32+
"""
33+
model = NemotronForCausalLM(model_config)
34+
if quantization is not None:
35+
model.to(quantization.model_dtype)
36+
_, _named_params, _ = model.export_tvm( # type: ignore[misc]
37+
spec=model.get_default_spec(),
38+
allow_extern=True,
39+
)
40+
named_parameters = dict(_named_params)
41+
42+
mapping = ExternMapping()
43+
44+
for i in range(model_config.num_hidden_layers):
45+
# Add QKV in self attention
46+
attn = f"model.layers.{i}.self_attn"
47+
mlc_name = f"{attn}.qkv_proj.weight"
48+
mlc_param = named_parameters[mlc_name]
49+
mapping.add_mapping(
50+
mlc_name,
51+
[
52+
f"{attn}.q_proj.weight",
53+
f"{attn}.k_proj.weight",
54+
f"{attn}.v_proj.weight",
55+
],
56+
functools.partial(
57+
lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype),
58+
dtype=mlc_param.dtype,
59+
),
60+
)
61+
62+
# inv_freq is not used in the model
63+
mapping.add_unused(f"{attn}.rotary_emb.inv_freq")
64+
65+
for mlc_name, mlc_param in named_parameters.items():
66+
if mlc_name not in mapping.param_map:
67+
mapping.add_mapping(
68+
mlc_name,
69+
[mlc_name],
70+
functools.partial(
71+
lambda x, dtype: x.astype(dtype),
72+
dtype=mlc_param.dtype,
73+
),
74+
)
75+
76+
return mapping

0 commit comments

Comments
 (0)