Skip to content

Commit 43f1473

Browse files
optimiseafacebook-github-bot
authored andcommitted
Add MTIA into sharding plan and estimator (#3310)
Summary: Pull Request resolved: #3310 as title Differential Revision: D80758637
1 parent e0e6446 commit 43f1473

File tree

3 files changed

+19
-7
lines changed

3 files changed

+19
-7
lines changed

torchrec/distributed/planner/planners.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@
6767
ShardingType,
6868
ShardMetadata,
6969
)
70-
from torchrec.distributed.utils import none_throws
70+
from torchrec.distributed.utils import get_device_type, none_throws
7171

7272
logger: logging.Logger = logging.getLogger(__name__)
7373

@@ -178,10 +178,11 @@ def __init__(
178178
heuristical_storage_reservation_percentage: float = 0.15,
179179
) -> None:
180180
if topology is None:
181+
compute_device = get_device_type()
181182
topology = Topology(
182183
local_world_size=get_local_size(),
183184
world_size=dist.get_world_size(),
184-
compute_device="cuda" if torch.cuda.is_available() else "cpu",
185+
compute_device=compute_device,
185186
)
186187
self._topology: Topology = topology
187188
self._batch_size: int = batch_size if batch_size else BATCH_SIZE
@@ -624,7 +625,8 @@ def __init__(
624625
List[Callable[[List[ShardingOption]], List[ShardingOption]]]
625626
] = None,
626627
) -> None:
627-
default_device = "cuda" if torch.cuda.is_available() else "cpu"
628+
default_device = get_device_type()
629+
628630
if topology_groups is None:
629631
topology_groups = {
630632
default_device: Topology(

torchrec/distributed/sharding_plan.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
ShardingType,
4343
ShardMetadata,
4444
)
45-
from torchrec.distributed.utils import none_throws
45+
from torchrec.distributed.utils import get_device_type, none_throws
4646

4747

4848
def get_default_sharders() -> List[ModuleSharder[nn.Module]]:
@@ -620,7 +620,7 @@ def column_wise(
620620
ranks (Optional[List[int]]): Ranks to place columns. Required if size_per_rank is None.
621621
size_per_rank (Optional[List[int]]): List specifying the number of columns per rank.
622622
If provided, the columns will be distributed according to these sizes.
623-
device_types (Optional[List[str]]): List of device types (e.g., "cpu", "cuda") for each shard.
623+
device_types (Optional[List[str]]): List of device types (e.g., "cpu", "cuda", "mtia") for each shard.
624624
Used to specify different device placements for different shards.
625625
626626
Returns:
@@ -651,7 +651,7 @@ def _parameter_sharding_generator(
651651
param: The parameter tensor to be sharded.
652652
local_size: Number of devices in the local process group.
653653
world_size: Total number of devices across all process groups.
654-
device_type: Type of device (e.g., "cuda", "cpu").
654+
device_type: Type of device (e.g., "cuda", "cpu", "mtia").
655655
sharder: The module sharder instance.
656656
657657
Returns:
@@ -895,7 +895,7 @@ def construct_module_sharding_plan(
895895
)
896896
"""
897897
if device_type is None:
898-
device_type = "cuda" if torch.cuda.is_available() else "cpu"
898+
device_type = get_device_type()
899899
if sharder is None:
900900
sharder = get_module_to_default_sharders().get(type(module), None)
901901
assert (

torchrec/distributed/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,16 @@
5050
"""
5151

5252

53+
def get_device_type() -> str:
54+
if torch.cuda.is_available():
55+
device_type = "cuda"
56+
elif torch.mtia.is_available():
57+
device_type = "mtia"
58+
else:
59+
device_type = "cpu"
60+
return device_type
61+
62+
5363
def get_class_name(obj: object) -> str:
5464
if obj is None:
5565
return "None"

0 commit comments

Comments
 (0)