You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When I try to LoRA DeepSeek-V3's MoE layer, I encounter the following error:
【2025-04-10 16:16:52】 File "/home/sharele/Megatron-LM/megatron/core/transformer/moe/moe_layer.py", line 155, in forward
【2025-04-10 16:16:52】 output, mlp_bias = custom_forward(hidden_states)
【2025-04-10 16:16:52】 File "/home/sharele/Megatron-LM/megatron/core/transformer/moe/moe_layer.py", line 141, in custom_forward
【2025-04-10 16:16:52】 (dispatched_input, tokens_per_expert) = self.token_dispatcher.token_permutation(
【2025-04-10 16:16:52】 File "/home/sharele/Megatron-LM/megatron/core/transformer/moe/token_dispatcher.py", line 535, in token_permutation
【2025-04-10 16:16:52】 self.shared_experts.linear_fc1_forward_and_act(global_input_tokens)
【2025-04-10 16:16:52】 File "/home/sharele/Megatron-LM/megatron/core/transformer/moe/shared_experts.py", line 140, in linear_fc1_forward_and_act
【2025-04-10 16:16:52】 intermediate_parallel, bias_parallel = self.linear_fc1(self.cached_fc1_input)
【2025-04-10 16:16:52】 File "/home/sharele/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
【2025-04-10 16:16:52】 return self._call_impl(*args, **kwargs)
【2025-04-10 16:16:52】 File "/home/sharele/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
【2025-04-10 16:16:52】 return forward_call(*args, **kwargs)
【2025-04-10 16:16:52】 File "/home/sharele/NeMo/nemo/collections/llm/peft/lora.py", line 47, in forward
【2025-04-10 16:16:52】 return linear_output + adapter_output, bias
【2025-04-10 16:16:52】RuntimeError: The size of tensor a (8112) must match the size of tensor b (64896) at non-singleton dimension 0
I found that setting the default value of moe_shared_expert_overlap to True in DeepSeekConfig will disable share_expert's TP-related communication in Megatron-LM,as shown in the relevant code below in megatron/core/transformer/moe/shared_experts.py
ifself.config.moe_shared_expert_overlap:
# disable TP related AG/RS communications in the linear moduleforlinearin [self.linear_fc1, self.linear_fc2]:
ifhasattr(linear, 'parallel_mode'):
# TELinearlinear.parallel_mode=Noneelse:
# MCore legacy Linearlinear.explicit_expert_comm=True
But in NeMo, the output_features of xxxColumnParallelLinear are expanded along the TP dimension.
Therefore, I think it is best to set the default value of this configuration value to False.
Or solve this problem fundamentally, that is, make a judgment when generating adapter attributes:
defis_share_expert_linear(fqn):
""" Return whether the current base module is an share expert linear module. """returnre.match(r'.*mlp\.shared_experts\.linear_fc[1-2]$', fqn) isnotNonedefget_adapter_attributes_from_linear(m: nn.Module):
""" Return input_is_parallel, in_features, out_feature attributes based on implementation of the base layer. """disable_sequence_parallel_comm=notm.config.sequence_paralleltp_size=parallel_state.get_tensor_model_parallel_world_size()
# check if open overlap for share expertmoe_shared_expert_overlap= (
Trueifis_share_expert_linear(m) andm.config.moe_shared_expert_overlapelseFalse
)
ifHAVE_TEandisinstance(m, TEColumnParallelLinear) orisinstance(m, TELayerNormColumnParallelLinear):
input_is_parallel=False# m.in_features and m.out_features are divided by tp_size already,# but in_features and out_features passed to ParallelLinearAdapter are not.in_features=m.in_featuresout_features= (m.out_features*tp_size) ifnotmoe_shared_expert_overlapelsem.out_featuresifisinstance(m, TELayerNormColumnParallelLinear):
# LoRA is applied after layernorm, so layernorm output must be returnedm.return_layernorm_output=True# perf optimization for LoRA + SPifhasattr(m, "ub_overlap_ag"):
ub_overlap_ag=m.ub_overlap_agelifhasattr(m, "ub_overlap_ag_fprop"):
ub_overlap_ag=m.ub_overlap_ag_fpropelse:
ub_overlap_ag=Falseifm.config.sequence_parallelandnotub_overlap_ag:
m.return_layernorm_output_gathered=Truete_version=packaging.version.Version(version("transformer-engine"))
ifte_version>=packaging.version.Version("1.5.0dev") and (
notgetattr(m.config, "tp_comm_overlap", False)
orgetattr(m.config, "tp_comm_overlap_disable_qkv", False)
):
# TE 1.5 introduces the option `return_layernorm_output_gathered`, so the all gather# in the forward method is not needed, so disable sp communications# unless TP communication overlap is useddisable_sequence_parallel_comm=TrueelifHAVE_TEandisinstance(m, TERowParallelLinear):
input_is_parallel=Truein_features= (m.out_features*tp_size) ifnotmoe_shared_expert_overlapelsem.out_featuresout_features=m.out_featureselifHAVE_TEandisinstance(m, TELinear): # parallel_mode="duplicated"input_is_parallel=Falsein_features=m.in_featuresout_features=m.out_featureselifisinstance(m, ColumnParallelLinear):
input_is_parallel=Falsein_features=m.input_sizeout_features=m.output_sizeelifisinstance(m, RowParallelLinear):
input_is_parallel=Truein_features=m.input_sizeout_features=m.output_sizeelse:
raiseNotImplementedError(f"Layer type is unrecognized for LoRA: {type(m)}")
returninput_is_parallel, in_features, out_features, disable_sequence_parallel_comm
The text was updated successfully, but these errors were encountered:
Hi @ShareLer
Thank you for the issue. I think your solution with in_features = (m.out_features * tp_size) if not moe_shared_expert_overlap else m.out_features is great! Would you like to make a PR for this change?
When I try to LoRA DeepSeek-V3's MoE layer, I encounter the following error:
I found that setting the default value of moe_shared_expert_overlap to True in DeepSeekConfig will disable share_expert's TP-related communication in Megatron-LM,as shown in the relevant code below in megatron/core/transformer/moe/shared_experts.py
But in NeMo, the output_features of xxxColumnParallelLinear are expanded along the TP dimension.
Therefore, I think it is best to set the default value of this configuration value to False.
Or solve this problem fundamentally, that is, make a judgment when generating adapter attributes:
The text was updated successfully, but these errors were encountered: