Skip to content

The default configuration value of DeepseekV3 causes fail when do expert's lora #12960

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
ShareLer opened this issue Apr 10, 2025 · 2 comments

Comments

@ShareLer
Copy link

ShareLer commented Apr 10, 2025

When I try to LoRA DeepSeek-V3's MoE layer, I encounter the following error:

【2025-04-10 16:16:52】 File "/home/sharele/Megatron-LM/megatron/core/transformer/moe/moe_layer.py", line 155, in forward
【2025-04-10 16:16:52】 output, mlp_bias = custom_forward(hidden_states)
【2025-04-10 16:16:52】 File "/home/sharele/Megatron-LM/megatron/core/transformer/moe/moe_layer.py", line 141, in custom_forward
【2025-04-10 16:16:52】 (dispatched_input, tokens_per_expert) = self.token_dispatcher.token_permutation(
【2025-04-10 16:16:52】 File "/home/sharele/Megatron-LM/megatron/core/transformer/moe/token_dispatcher.py", line 535, in token_permutation
【2025-04-10 16:16:52】 self.shared_experts.linear_fc1_forward_and_act(global_input_tokens)
【2025-04-10 16:16:52】 File "/home/sharele/Megatron-LM/megatron/core/transformer/moe/shared_experts.py", line 140, in linear_fc1_forward_and_act
【2025-04-10 16:16:52】 intermediate_parallel, bias_parallel = self.linear_fc1(self.cached_fc1_input)
【2025-04-10 16:16:52】 File "/home/sharele/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
【2025-04-10 16:16:52】 return self._call_impl(*args, **kwargs)
【2025-04-10 16:16:52】 File "/home/sharele/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
【2025-04-10 16:16:52】 return forward_call(*args, **kwargs)
【2025-04-10 16:16:52】 File "/home/sharele/NeMo/nemo/collections/llm/peft/lora.py", line 47, in forward
【2025-04-10 16:16:52】 return linear_output + adapter_output, bias
【2025-04-10 16:16:52】RuntimeError: The size of tensor a (8112) must match the size of tensor b (64896) at non-singleton dimension 0

I found that setting the default value of moe_shared_expert_overlap to True in DeepSeekConfig will disable share_expert's TP-related communication in Megatron-LM,as shown in the relevant code below in megatron/core/transformer/moe/shared_experts.py

        if self.config.moe_shared_expert_overlap:
            # disable TP related AG/RS communications in the linear module
            for linear in [self.linear_fc1, self.linear_fc2]:
                if hasattr(linear, 'parallel_mode'):
                    # TELinear
                    linear.parallel_mode = None
                else:
                    # MCore legacy Linear
                    linear.explicit_expert_comm = True

But in NeMo, the output_features of xxxColumnParallelLinear are expanded along the TP dimension.

Therefore, I think it is best to set the default value of this configuration value to False.
Or solve this problem fundamentally, that is, make a judgment when generating adapter attributes:

def is_share_expert_linear(fqn):
    """
    Return whether the current base module is an share expert linear module.
    """
    return re.match(r'.*mlp\.shared_experts\.linear_fc[1-2]$', fqn) is not None

def get_adapter_attributes_from_linear(m: nn.Module):
    """
    Return input_is_parallel, in_features, out_feature attributes based on implementation of the base layer.
    """
    disable_sequence_parallel_comm = not m.config.sequence_parallel
    tp_size = parallel_state.get_tensor_model_parallel_world_size()

    # check if open overlap for share expert
    moe_shared_expert_overlap = (
        True
        if is_share_expert_linear(m) and m.config.moe_shared_expert_overlap
        else False
    )

    if HAVE_TE and isinstance(m, TEColumnParallelLinear) or isinstance(m, TELayerNormColumnParallelLinear):
        input_is_parallel = False
        # m.in_features and m.out_features are divided by tp_size already,
        # but in_features and out_features passed to ParallelLinearAdapter are not.
        in_features = m.in_features
        out_features = (m.out_features * tp_size) if not moe_shared_expert_overlap else m.out_features

        if isinstance(m, TELayerNormColumnParallelLinear):
            # LoRA is applied after layernorm, so layernorm output must be returned
            m.return_layernorm_output = True
            # perf optimization for LoRA + SP
            if hasattr(m, "ub_overlap_ag"):
                ub_overlap_ag = m.ub_overlap_ag
            elif hasattr(m, "ub_overlap_ag_fprop"):
                ub_overlap_ag = m.ub_overlap_ag_fprop
            else:
                ub_overlap_ag = False
            if m.config.sequence_parallel and not ub_overlap_ag:
                m.return_layernorm_output_gathered = True
                te_version = packaging.version.Version(version("transformer-engine"))
                if te_version >= packaging.version.Version("1.5.0dev") and (
                    not getattr(m.config, "tp_comm_overlap", False)
                    or getattr(m.config, "tp_comm_overlap_disable_qkv", False)
                ):
                    # TE 1.5 introduces the option `return_layernorm_output_gathered`, so the all gather
                    # in the forward method is not needed, so disable sp communications
                    # unless TP communication overlap is used
                    disable_sequence_parallel_comm = True
    elif HAVE_TE and isinstance(m, TERowParallelLinear):
        input_is_parallel = True
        in_features = (m.out_features * tp_size) if not moe_shared_expert_overlap else m.out_features
        out_features = m.out_features
    elif HAVE_TE and isinstance(m, TELinear):  # parallel_mode="duplicated"
        input_is_parallel = False
        in_features = m.in_features
        out_features = m.out_features
    elif isinstance(m, ColumnParallelLinear):
        input_is_parallel = False
        in_features = m.input_size
        out_features = m.output_size
    elif isinstance(m, RowParallelLinear):
        input_is_parallel = True
        in_features = m.input_size
        out_features = m.output_size
    else:
        raise NotImplementedError(f"Layer type is unrecognized for LoRA: {type(m)}")

    return input_is_parallel, in_features, out_features, disable_sequence_parallel_comm
@akoumpa
Copy link
Member

akoumpa commented Apr 15, 2025

CC @cuichenx

@cuichenx
Copy link
Collaborator

Hi @ShareLer
Thank you for the issue. I think your solution with in_features = (m.out_features * tp_size) if not moe_shared_expert_overlap else m.out_features is great! Would you like to make a PR for this change?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants