Skip to content

feat: general fsdp2 on non-MoE models + HF TP plan #352

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
May 23, 2025

Conversation

yuki-666
Copy link
Collaborator

@yuki-666 yuki-666 commented May 12, 2025

What does this PR do ?

  1. Support FSDP2 on non-MoE models.
  2. Support Hugging Face TP plan.
  3. The priority for using which parallel plan is custom-parallel-plan > opt-parallel-plan (which we implemented for certain models in FSDP2) > hf-tp-plan (HF's _tp_plan).

Convergence test on LlamaForCausalLM, Qwen2ForCausalLM, Qwen3ForCausalLM, Gemma2ForCausalLM, Gemma3ForCausalLM, Phi3ForCausalLM run well.

Convergence Test Detail

Llama-3.1-8B-Instruct (LlamaForCausalLM)
FSDP2-tp8-opt_plan vs FSDP2-tp8-hf_tp_plan
image

Qwen2ForCausalLM / Qwen3ForCausalLM

Qwen2.5-7B-Instruct
(Qwen2ForCausalLM)
FSDP2-tp4-opt_plan vs FSDP2-tp4-hf_tp_plan
Qwen3-0.6B
(Qwen3ForCausalLM)
FSDP1 vs FSDP2-tp1
image image

Gemma2ForCausalLM / Gemma3ForCausalLM

gemma-2-9b-it
(Gemma2ForCausalLM)
FSDP1 vs FSDP2-tp1 vs FSDP2-tp4-hf_tp_plan
gemma-3-1b-it
(Gemma3ForCausalLM)
FSDP1 vs FSDP2-tp1
image image

Issues

Closes #156

Usage

  • You can potentially add a usage example below
# Add a code snippet demonstrating how to use this 

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

  • ...

@yuki-666 yuki-666 force-pushed the yukih/fsdp2-general branch 14 times, most recently from d5ad00d to 8e2e6f4 Compare May 20, 2025 09:59
@github-actions github-actions bot added the documentation Improvements or additions to documentation label May 20, 2025
@yuki-666 yuki-666 added the CI:L1 Run doctests, unit tests, and functional tests label May 20, 2025
@yuki-666 yuki-666 added the CI:docs Run doctest label May 20, 2025
@yuki-666
Copy link
Collaborator Author

File another issue #413 to trace FSDP2 for MoE models.

  1. Qwen3-30B-A3B is obviously slower than Qwen3-32B, especially on the refit process or when using hf-tp-plan with dtensor tp > 1.
  2. DeepseekV2ForCausalLM using fsdp2 will fail on the following error on model.layers.0.self_attn.rotary_emb.cos_cached, said v.shape=torch.Size([2048, 64]) and self.reference_model_buffers[k].shape=torch.Size([163840, 64]) shape mismatch in self.use_reference_model().

@yuki-666 yuki-666 changed the title feat: general fsdp2 feat: general fsdp2 on non-MoE models + HF TP plan May 20, 2025
@yuki-666 yuki-666 marked this pull request as ready for review May 20, 2025 13:59
@yuki-666 yuki-666 force-pushed the yukih/fsdp2-general branch from fc6cc49 to 05d8cfe Compare May 21, 2025 02:53
@yuki-666 yuki-666 added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels May 21, 2025
@yuki-666 yuki-666 force-pushed the yukih/fsdp2-general branch 3 times, most recently from 08cce8c to 0dc55cc Compare May 22, 2025 06:59
@yuki-666 yuki-666 added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels May 22, 2025
@yuki-666
Copy link
Collaborator Author

@yuki-666 what were the parameters of your gemma2 run? I can't seem to get it to run correctly:

uv run examples/run_grpo_math.py policy.model_name=google/gemma-2-2b-it logger.wandb_enabled=True cluster.gpus_per_node=8 +policy.generation.vllm_cfg.load_format=auto

image

@terrykong Thanks very much for pointing out this!

I tested with almost the same script as yours before this commit fdb565c.
After this commit, load_format of vllm during training is default set to dummy, only specific models will change it through nemo_rl/models/huggingface/common.py. The param policy.generation.vllm_cfg.load_format is removed from yaml and has no effect even if we pass it.

It is fixed now, and other models won't be affect since they don't need special handle on load_format.
image

@yuki-666
Copy link
Collaborator Author

Thanks @jgerh , have updated from your suggestions.

yuki-666 added 9 commits May 23, 2025 17:46
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
@yuki-666 yuki-666 force-pushed the yukih/fsdp2-general branch from 0dc55cc to 72a8f35 Compare May 23, 2025 09:46
@yuki-666 yuki-666 added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels May 23, 2025
@terrykong
Copy link
Collaborator

Thanks for the quick fix @yuki-666 . Gemma2 seems to be okay now from a quick run:

image

@parthchadha parthchadha added this pull request to the merge queue May 23, 2025
Merged via the queue into main with commit 3db05c1 May 23, 2025
21 of 23 checks passed
@parthchadha parthchadha deleted the yukih/fsdp2-general branch May 23, 2025 22:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CI:docs Run doctest CI:L1 Run doctests, unit tests, and functional tests documentation Improvements or additions to documentation
Projects
None yet
Development

Successfully merging this pull request may close these issues.

support FSDP2 generally
6 participants