diff --git a/python/mlc_llm/op/moe_misc.py b/python/mlc_llm/op/moe_misc.py index e97ef94fff..19bf10381f 100644 --- a/python/mlc_llm/op/moe_misc.py +++ b/python/mlc_llm/op/moe_misc.py @@ -385,11 +385,11 @@ def scatter_output(x: Tensor, indices: Tensor) -> Tensor: The output of MoE experts with shape [batch_size * num_experts_per_tok, hidden_size]. """ dtype = x.dtype + _, hidden_size = x.shape @T.prim_func(private=True) def _func(var_x: T.handle, var_indices: T.handle, var_out: T.handle): T.func_attr({"tir.noalias": True}) - hidden_size = T.int64() indices_len = T.int64() x = T.match_buffer(var_x, [indices_len, hidden_size], dtype) indices = T.match_buffer(var_indices, [indices_len], "int32")