Skip to content

Commit

Permalink
[Model] Use static hidden size in mixtral scatter_output
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 committed Mar 14, 2024
1 parent 2872f70 commit 8132bbb
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion python/mlc_llm/op/moe_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,11 +385,11 @@ def scatter_output(x: Tensor, indices: Tensor) -> Tensor:
The output of MoE experts with shape [batch_size * num_experts_per_tok, hidden_size].
"""
dtype = x.dtype
_, hidden_size = x.shape

@T.prim_func(private=True)
def _func(var_x: T.handle, var_indices: T.handle, var_out: T.handle):
T.func_attr({"tir.noalias": True})
hidden_size = T.int64()
indices_len = T.int64()
x = T.match_buffer(var_x, [indices_len, hidden_size], dtype)
indices = T.match_buffer(var_indices, [indices_len], "int32")
Expand Down

0 comments on commit 8132bbb

Please sign in to comment.