Skip to content

Memory Overhead and TPS Degradation with Increased Expert Parallelism (EP) #1624

@AnWang-AI

Description

@AnWang-AI

Bug description

When scaling Expert Parallelism (EP) from 1 to 4 on a system with 8×H20 GPUs (use_grouped_mm=True, num_experts=128, top_k=8) on the Llama4 model, we observe:

  • Memory usage​ increases from ​84.57 GiB (EP=1)​​ to ​91.51 GiB (EP=4)​.
  • ​Throughput (TPS)​​ decreases from ​2,125 (EP=1)​​ to ​1,121 (EP=4)​.

Versions

llama4 model:

"3bx128e": TransformerModelArgs(
    dim=2048,
    n_layers=48,
    n_heads=32,
    n_kv_heads=4,
    multiple_of=8,
    ffn_dim_multiplier=1.2657,
    rope_theta=500000,
    max_seq_len=4096,
    moe_args=MoEArgs(
        num_experts=128,
        top_k=8,
        use_grouped_mm=True,
    ),
    interleave_moe_layer_step=1,
),

toml file:


[job]
dump_folder = "./outputs"
description = "Llama 4 Scout training"

[profiling]
enable_profiling = false
save_traces_folder = "profile_trace"
profile_freq = 100

[metrics]
log_freq = 1
enable_tensorboard = false
save_tb_folder = "tb"

[model]
name = "llama4"
flavor = "3bx128e"
hf_assets_path = "./assets/hf/Llama-4-Scout-17B-16E"
# converters = ["float8"]

[optimizer]
name = "AdamW"
lr = 4e-3
eps = 1e-15

[lr_scheduler]
warmup_steps = 600
min_lr_factor = 0.1

[training]
local_batch_size = 1
seq_len = 8192
max_norm = 1.0  # grad norm clipping
steps = 3000
dataset = "c4"

[parallelism]
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1
tensor_parallel_degree = 1
enable_async_tensor_parallel = false
pipeline_parallel_degree = 1
context_parallel_degree = 1
expert_parallel_degree = 4
expert_tensor_parallel_degree = 1

[checkpoint]
enable_checkpoint = false
folder = "checkpoint"
interval = 500
last_save_model_only = true
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = "full" # ["none", "selective", "full"]

[compile]
enable=false
components = ["model", "loss"]

[float8]
enable_fsdp_float8_all_gather = false
precompute_float8_dynamic_scale_for_fsdp = false
filter_fqns = ["output", "router.gate"]

Metadata

Metadata

Assignees

No one assigned

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions