Skip to content

Commit 08cce8c

Browse files
committed
update doc and fix type
Signed-off-by: Yuki Huang <yukih@nvidia.com>
1 parent b6880b9 commit 08cce8c

File tree

3 files changed

+65
-47
lines changed

3 files changed

+65
-47
lines changed
Lines changed: 13 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,30 @@
11
# FSDP2 Parallel Plan
22

3-
This guide outlines the parallelization strategy for FSDP2 training in NeMo-RL.
3+
This guide outlines the parallelization strategy for Fully Sharded Data Parallel version 2 (FSDP2) training in NeMo RL.
44

55
## Fallback Priority
66

7-
Three parallelization approaches are supported, with the following fallback priority.
7+
NeMo RL supports three parallelization strategies, applied in the following order of fallback priority:
88

9-
**Custom Parallel Plan**
9+
### 1. Custom Parallel Plan
1010

11-
User-defined custom parallel plans take precedence when available.
11+
Your user-defined custom parallel plans always take precedence when available. For detailed implementation and usage, refer to the [Custom Parallel Plan Example](#custom-parallel-plan-example).
1212

13-
For implementation details and usage guidelines, please refer to [Custom Parallel Plan Example](#custom-parallel-plan-example).
13+
### 2. Optimized Parallel Plan
1414

15-
**Optimized Parallel Plan**
15+
Optimized parallel plans are available for specific model architectures. They may offer superior performance compared to Hugging Face's tensor parallel implementation. This approach is used if no custom parallel plan is specified and the model class supports optimized parallelization.
1616

17-
Optimized parallel plans are available for specific model architectures and may offer superior performance compared to the Hugging Face tensor parallel implementation.
17+
### 3. Hugging Face Tensor Parallel Plan
1818

19-
This approach is used when no custom parallel plan is specified and the model class supports optimized parallelization.
20-
21-
**Hugging Face Tensor Parallel Plan**
22-
23-
Hugging Face provides tensor parallelism for most models through `._tp_plan`.
24-
25-
It serves as the default when neither custom nor optimized parallel plans are available.
19+
The Hugging Face tensor parallel plan is the default. It's available for most models via `._tp_plan` and is used when neither a custom nor an optimized parallel plan is available.
2620

2721
## Custom Parallel Plan Example
2822

29-
Custom parallel plan should be defined in a file, exemplified by `examples/custom_parallel.py`.
30-
31-
To implement the custom parallel plan, configure `policy.dtensor_cfg.custom_parallel_plan=examples.custom_parallel.custom_parallel_plan`.
32-
33-
```python
34-
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
35-
from torch.distributed.tensor.placement_types import Replicate, Shard
23+
A custom parallel plan should be defined in a separate file, such as the example provided in `examples/custom_parallel.py`.
3624

25+
To implement the custom parallel plan, either update the value of `custom_parallel_plan` in the `yaml` file directly, or pass the override via the command line. For example:
3726

38-
custom_parallel_plan = {
39-
"model.embed_tokens": RowwiseParallel(input_layouts=Replicate()),
40-
"model.layers.*.self_attn.q_proj": ColwiseParallel(),
41-
"model.layers.*.self_attn.k_proj": ColwiseParallel(),
42-
"model.layers.*.self_attn.v_proj": ColwiseParallel(),
43-
"model.layers.*.self_attn.o_proj": RowwiseParallel(),
44-
"model.layers.*.mlp.up_proj": ColwiseParallel(),
45-
"model.layers.*.mlp.gate_proj": ColwiseParallel(),
46-
"model.layers.*.mlp.down_proj": RowwiseParallel(),
47-
"lm_head": ColwiseParallel(output_layouts=Shard(-1), use_local_output=False),
48-
}
27+
```bash
28+
uv run examples/run_grpo_math.py \
29+
policy.dtensor_cfg.custom_parallel_plan=examples.custom_parallel.custom_parallel_plan
4930
```

examples/custom_parallel.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
2+
from torch.distributed.tensor.placement_types import Replicate, Shard
3+
4+
custom_parallel_plan = {
5+
"model.embed_tokens": RowwiseParallel(input_layouts=Replicate()),
6+
"model.layers.*.self_attn.q_proj": ColwiseParallel(),
7+
"model.layers.*.self_attn.k_proj": ColwiseParallel(),
8+
"model.layers.*.self_attn.v_proj": ColwiseParallel(),
9+
"model.layers.*.self_attn.o_proj": RowwiseParallel(),
10+
"model.layers.*.mlp.up_proj": ColwiseParallel(),
11+
"model.layers.*.mlp.gate_proj": ColwiseParallel(),
12+
"model.layers.*.mlp.down_proj": RowwiseParallel(),
13+
"lm_head": ColwiseParallel(output_layouts=Shard(-1), use_local_output=False),
14+
}

nemo_rl/models/dtensor/parallelize.py

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from functools import lru_cache
1616
from types import FunctionType
17-
from typing import Callable, Union
17+
from typing import Callable, Optional, Union
1818

1919
import torch
2020
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
@@ -25,6 +25,7 @@
2525
from torch.distributed.tensor import DTensor
2626
from torch.distributed.tensor.parallel import (
2727
ColwiseParallel,
28+
ParallelStyle,
2829
PrepareModuleInput,
2930
PrepareModuleOutput,
3031
RowwiseParallel,
@@ -254,7 +255,9 @@ def _prepare_input_fn(sequence_sharding, mod, inputs, device_mesh):
254255
return base_model_tp_plan
255256

256257

257-
PARALLIZE_FUNCTIONS: dict[type[torch.nn.Module], Callable[..., torch.nn.Module]] = {
258+
PARALLIZE_FUNCTIONS: dict[
259+
type[torch.nn.Module], Callable[..., dict[str, ParallelStyle]]
260+
] = {
258261
Qwen2ForCausalLM: _parallelize_qwen,
259262
Qwen3ForCausalLM: _parallelize_qwen,
260263
LlamaForCausalLM: _parallelize_llama,
@@ -292,7 +295,21 @@ def translate_parallel_style(style: str):
292295
def get_hf_tp_plan(model):
293296
"""Get the Hugging Face tensor parallel plan from the model.
294297
298+
This function:
299+
- Retrieves TP strategies from model class, instance, and inner model levels.
300+
- Handles special cases for `embed_tokens` and `lm_head` for speed up.
301+
- Converts string-based parallel styles to DTensor parallelization strategies.
302+
295303
Taken and modified from: https://github.com/NVIDIA/NeMo/blob/6c6169db01bcca73ae8ad3ac35242fadbb9a78ba/nemo/lightning/pytorch/strategies/utils.py#L532
304+
305+
Args:
306+
model: A Hugging Face model instance
307+
308+
Returns:
309+
dict: A dictionary mapping model component paths to their parallelization strategies
310+
311+
Raises:
312+
AssertionError: If no TP plan is found
296313
"""
297314
model_cls = type(model)
298315
if model_cls == Gemma3ForConditionalGeneration:
@@ -317,7 +334,8 @@ def get_hf_tp_plan(model):
317334
)
318335

319336
assert len(hf_tp_plan) > 0, (
320-
f"Hugging Face tp plan is not supported for {model_cls}, please set dtensor_cfg.tensor_parallel_size to 1 or provide a custom parallel plan."
337+
f"Hugging Face tp plan is not supported for {model_cls}, please set dtensor_cfg.tensor_parallel_size to 1 or provide a custom_parallel_plan. "
338+
"The usage example of custom_parallel_plan can refer to `docs/design-docs/fsdp2-parallel-plan.md`."
321339
)
322340

323341
# hf tp plan not contain embed_tokens, we add it and set to rowwise_rep
@@ -344,26 +362,31 @@ def get_hf_tp_plan(model):
344362

345363

346364
def _parallelize_model(
347-
model: Union[Qwen2ForCausalLM, LlamaForCausalLM],
365+
model: Union[
366+
Qwen2ForCausalLM,
367+
LlamaForCausalLM,
368+
Gemma3ForCausalLM,
369+
Gemma3ForConditionalGeneration,
370+
],
348371
dp_mesh: DeviceMesh,
349372
tp_mesh: DeviceMesh,
350373
param_dtype: torch.dtype,
351374
sequence_parallel: bool = False,
352375
activation_checkpointing: bool = False,
353376
cpu_offload: bool = False,
354-
custom_parallel_plan: Union[dict, str] = None,
377+
custom_parallel_plan: Optional[Union[dict, str]] = None,
355378
):
356379
"""Parallelize a model using DTensor.
357380
358381
Args:
359-
model (Union[Qwen2ForCausalLM, LlamaForCausalLM]): The model to parallelize.
360-
dp_mesh (DeviceMesh): Device mesh for data parallelism.
361-
tp_mesh (DeviceMesh): Device mesh for tensor parallelism.
362-
param_dtype (torch.dtype): Data type for model parameters.
363-
sequence_parallel (bool, optional): Whether to use sequence parallelism. Defaults to False.
364-
activation_checkpointing (bool, optional): Whether to use activation checkpointing. Defaults to False.
365-
cpu_offload (bool, optional): Whether to enable cpu offloading for FSDP. Defaults to False.
366-
custom_parallel_plan (Union[dict, str], optional): Custom parallel plan for the model. Defaults to None.
382+
model: The model to parallelize.
383+
dp_mesh: Device mesh for data parallelism.
384+
tp_mesh: Device mesh for tensor parallelism.
385+
param_dtype: Data type for model parameters.
386+
sequence_parallel: Whether to use sequence parallelism. Defaults to False.
387+
activation_checkpointing: Whether to use activation checkpointing. Defaults to False.
388+
cpu_offload: Whether to enable cpu offloading for FSDP. Defaults to False.
389+
custom_parallel_plan: Custom parallel plan for the model. Defaults to None.
367390
If it's a dict, it will be used as the parallel plan directly.
368391
If it's a string, it must be a path that points to a dict or a function that returns a dict.
369392
The usage example can refer to `docs/design-docs/fsdp2-parallel-plan.md`.
@@ -376,11 +399,11 @@ def _parallelize_model(
376399
"""
377400
model_cls = type(model)
378401
if model_cls == Gemma3ForConditionalGeneration:
379-
layers = model.language_model.model.layers
402+
layers: torch.nn.ModuleList = model.language_model.model.layers # type: ignore
380403
num_attention_heads = model.config.text_config.num_attention_heads
381404
num_key_value_heads = model.config.text_config.num_key_value_heads
382405
else:
383-
layers = model.model.layers
406+
layers: torch.nn.ModuleList = model.model.layers # type: ignore
384407
num_attention_heads = model.config.num_attention_heads
385408
num_key_value_heads = model.config.num_key_value_heads
386409

0 commit comments

Comments
 (0)