Skip to content

Commit 1825fed

Browse files
authored
[SLM] GPTJ Multi-GPU support (#3070)
This PR supports TP function of GPTJ Model and fix minor typo of OlMo Model.
1 parent 9a33772 commit 1825fed

File tree

2 files changed

+35
-5
lines changed

2 files changed

+35
-5
lines changed

python/mlc_llm/model/gpt_j/gpt_j_model.py

+33-3
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from mlc_llm import op as op_ext
1414
from mlc_llm.nn import PagedKVCache, RopeMode
1515
from mlc_llm.support import logging
16+
from mlc_llm.support import tensor_parallel as tp
1617
from mlc_llm.support.config import ConfigBase
1718
from mlc_llm.support.style import bold
1819

@@ -57,6 +58,9 @@ def __post_init__(self):
5758
"`context_window_size`, `max_position_embeddings` or `max_sequence_length` is "
5859
"provided in `config.json`."
5960
)
61+
if self.head_dim == 0:
62+
self.head_dim = self.n_embd // self.n_head
63+
assert self.head_dim * self.n_head == self.n_embd
6064
if self.prefill_chunk_size == 0:
6165
logger.info(
6266
"%s defaults to %d",
@@ -72,7 +76,6 @@ def __post_init__(self):
7276
min(self.context_window_size, 8192),
7377
)
7478
self.prefill_chunk_size = min(self.context_window_size, 8192)
75-
assert self.tensor_parallel_shards == 1, "GPTJ currently does not support sharding."
7679

7780

7881
# pylint: disable=invalid-name,missing-docstring
@@ -82,7 +85,7 @@ class GPTJAttention(nn.Module): # pylint: disable=too-many-instance-attributes
8285
def __init__(self, config: GPTJConfig):
8386
self.embed_dim = config.n_embd
8487
self.num_heads = config.n_head // config.tensor_parallel_shards
85-
self.head_dim = self.embed_dim // self.num_heads
88+
self.head_dim = config.head_dim
8689
self.max_position_embeddings = config.context_window_size
8790
self.rope_theta = 10000
8891
self.rotary_dim = config.rotary_dim
@@ -140,14 +143,41 @@ def __init__(self, config: GPTJConfig):
140143
self.attn = GPTJAttention(config)
141144
self.mlp = GPTJMLP(config)
142145

146+
def _set_tp():
147+
def _set(layer, hint):
148+
layer.attrs["shard_strategy"] = hint
149+
150+
hd = config.head_dim
151+
q = self.attn.num_heads * hd
152+
k = self.attn.num_heads * hd
153+
v = self.attn.num_heads * hd
154+
_set(
155+
self.attn.c_attn.weight,
156+
tp.ShardSingleDim("_shard_qkv_weight", dim=0, segs=[q, k, v]),
157+
)
158+
_set(self.attn.out_proj.weight, tp.ShardSingleDim("_shard_o", dim=1))
159+
_set(
160+
self.mlp.fc_in.weight,
161+
tp.ShardSingleDim("_shard_c_fc_weight", dim=0),
162+
)
163+
_set(self.mlp.fc_out.weight, tp.ShardSingleDim("_shard_mlp_c_proj", dim=1))
164+
165+
self.tensor_parallel_shards = config.tensor_parallel_shards
166+
_set_tp()
167+
143168
def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):
144169
residual = hidden_states
145170
hidden_states = self.ln_1(hidden_states)
146171
attn_output = self.attn(hidden_states, paged_kv_cache, layer_id)
147172
feed_forward_hidden_states = self.mlp(hidden_states)
148-
hidden_states = attn_output + feed_forward_hidden_states + residual
173+
hidden_states = self._apply_residual(attn_output + feed_forward_hidden_states, residual)
149174
return hidden_states
150175

176+
def _apply_residual(self, out, residual):
177+
if self.tensor_parallel_shards > 1:
178+
return op.ccl_allreduce(out, "sum") + residual
179+
return out + residual
180+
151181

152182
class GPTJModel(nn.Module):
153183
def __init__(self, config: GPTJConfig):

python/mlc_llm/model/olmo/olmo_model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def __post_init__(self): # pylint: disable=too-many-branches
102102
raise ValueError(f"'clip_qkv'({self.clip_qkv}) should be non-negative")
103103

104104

105-
class OLMoEebedding(nn.Embedding):
105+
class OLMoEmbedding(nn.Embedding):
106106
"""The embedding module that can be shared with the final lm_head. From Qwen2Embedding."""
107107

108108
def lm_head_forward(self, x: nn.Tensor):
@@ -248,7 +248,7 @@ def forward( # pylint: disable=missing-function-docstring
248248
class OLMoModel(nn.Module): # pylint: disable=missing-class-docstring
249249
def __init__(self, config: OLMoConfig):
250250
assert config.hidden_size % config.num_attention_heads == 0
251-
self.embed_tokens = OLMoEebedding(config.vocab_size, config.hidden_size)
251+
self.embed_tokens = OLMoEmbedding(config.vocab_size, config.hidden_size)
252252
self.layers = nn.ModuleList(
253253
[OLMoDecoderLayer(config) for _ in range(config.num_hidden_layers)]
254254
)

0 commit comments

Comments
 (0)