13
13
from mlc_llm import op as op_ext
14
14
from mlc_llm .nn import PagedKVCache , RopeMode
15
15
from mlc_llm .support import logging
16
+ from mlc_llm .support import tensor_parallel as tp
16
17
from mlc_llm .support .config import ConfigBase
17
18
from mlc_llm .support .style import bold
18
19
@@ -57,6 +58,9 @@ def __post_init__(self):
57
58
"`context_window_size`, `max_position_embeddings` or `max_sequence_length` is "
58
59
"provided in `config.json`."
59
60
)
61
+ if self .head_dim == 0 :
62
+ self .head_dim = self .n_embd // self .n_head
63
+ assert self .head_dim * self .n_head == self .n_embd
60
64
if self .prefill_chunk_size == 0 :
61
65
logger .info (
62
66
"%s defaults to %d" ,
@@ -72,7 +76,6 @@ def __post_init__(self):
72
76
min (self .context_window_size , 8192 ),
73
77
)
74
78
self .prefill_chunk_size = min (self .context_window_size , 8192 )
75
- assert self .tensor_parallel_shards == 1 , "GPTJ currently does not support sharding."
76
79
77
80
78
81
# pylint: disable=invalid-name,missing-docstring
@@ -82,7 +85,7 @@ class GPTJAttention(nn.Module): # pylint: disable=too-many-instance-attributes
82
85
def __init__ (self , config : GPTJConfig ):
83
86
self .embed_dim = config .n_embd
84
87
self .num_heads = config .n_head // config .tensor_parallel_shards
85
- self .head_dim = self . embed_dim // self . num_heads
88
+ self .head_dim = config . head_dim
86
89
self .max_position_embeddings = config .context_window_size
87
90
self .rope_theta = 10000
88
91
self .rotary_dim = config .rotary_dim
@@ -140,14 +143,41 @@ def __init__(self, config: GPTJConfig):
140
143
self .attn = GPTJAttention (config )
141
144
self .mlp = GPTJMLP (config )
142
145
146
+ def _set_tp ():
147
+ def _set (layer , hint ):
148
+ layer .attrs ["shard_strategy" ] = hint
149
+
150
+ hd = config .head_dim
151
+ q = self .attn .num_heads * hd
152
+ k = self .attn .num_heads * hd
153
+ v = self .attn .num_heads * hd
154
+ _set (
155
+ self .attn .c_attn .weight ,
156
+ tp .ShardSingleDim ("_shard_qkv_weight" , dim = 0 , segs = [q , k , v ]),
157
+ )
158
+ _set (self .attn .out_proj .weight , tp .ShardSingleDim ("_shard_o" , dim = 1 ))
159
+ _set (
160
+ self .mlp .fc_in .weight ,
161
+ tp .ShardSingleDim ("_shard_c_fc_weight" , dim = 0 ),
162
+ )
163
+ _set (self .mlp .fc_out .weight , tp .ShardSingleDim ("_shard_mlp_c_proj" , dim = 1 ))
164
+
165
+ self .tensor_parallel_shards = config .tensor_parallel_shards
166
+ _set_tp ()
167
+
143
168
def forward (self , hidden_states : Tensor , paged_kv_cache : PagedKVCache , layer_id : int ):
144
169
residual = hidden_states
145
170
hidden_states = self .ln_1 (hidden_states )
146
171
attn_output = self .attn (hidden_states , paged_kv_cache , layer_id )
147
172
feed_forward_hidden_states = self .mlp (hidden_states )
148
- hidden_states = attn_output + feed_forward_hidden_states + residual
173
+ hidden_states = self . _apply_residual ( attn_output + feed_forward_hidden_states , residual )
149
174
return hidden_states
150
175
176
+ def _apply_residual (self , out , residual ):
177
+ if self .tensor_parallel_shards > 1 :
178
+ return op .ccl_allreduce (out , "sum" ) + residual
179
+ return out + residual
180
+
151
181
152
182
class GPTJModel (nn .Module ):
153
183
def __init__ (self , config : GPTJConfig ):
0 commit comments