Skip to content

Commit e0e6446

Browse files
Pooja Agarwalfacebook-github-bot
authored andcommitted
Revert D69125073: Add row based sharding support for FeaturedProcessedEBC
Differential Revision: D69125073 Original commit changeset: 0cc5bd49f9bf Original Phabricator Diff: D69125073 fbshipit-source-id: c6cf2407e40d4274dc0fffdcb6f394ff1cfb412e
1 parent af6f7bf commit e0e6446

File tree

9 files changed

+26
-330
lines changed

9 files changed

+26
-330
lines changed

torchrec/distributed/fp_embeddingbag.py

Lines changed: 9 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,7 @@
88
# pyre-strict
99

1010
from functools import partial
11-
from typing import (
12-
Any,
13-
Dict,
14-
Iterator,
15-
List,
16-
Mapping,
17-
Optional,
18-
Tuple,
19-
Type,
20-
TypeVar,
21-
Union,
22-
)
11+
from typing import Any, Dict, Iterator, List, Optional, Type, Union
2312

2413
import torch
2514
from torch import nn
@@ -42,20 +31,14 @@
4231
ShardingEnv,
4332
ShardingType,
4433
)
45-
from torchrec.distributed.utils import (
46-
append_prefix,
47-
init_parameters,
48-
modify_input_for_feature_processor,
49-
)
34+
from torchrec.distributed.utils import append_prefix, init_parameters
5035
from torchrec.modules.feature_processor_ import FeatureProcessorsCollection
5136
from torchrec.modules.fp_embedding_modules import (
5237
apply_feature_processors_to_kjt,
5338
FeatureProcessedEmbeddingBagCollection,
5439
)
5540
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
5641

57-
_T = TypeVar("_T")
58-
5942

6043
def param_dp_sync(kt: KeyedTensor, no_op_tensor: torch.Tensor) -> KeyedTensor:
6144
kt._values.add_(no_op_tensor)
@@ -91,16 +74,6 @@ def __init__(
9174
)
9275
)
9376

94-
self._row_wise_sharded: bool = False
95-
for param_sharding in table_name_to_parameter_sharding.values():
96-
if param_sharding.sharding_type in [
97-
ShardingType.ROW_WISE.value,
98-
ShardingType.TABLE_ROW_WISE.value,
99-
ShardingType.GRID_SHARD.value,
100-
]:
101-
self._row_wise_sharded = True
102-
break
103-
10477
self._lookups: List[nn.Module] = self._embedding_bag_collection._lookups
10578

10679
self._is_collection: bool = False
@@ -123,11 +96,6 @@ def __init__(
12396
def input_dist(
12497
self, ctx: EmbeddingBagCollectionContext, features: KeyedJaggedTensor
12598
) -> Awaitable[Awaitable[KJTList]]:
126-
if not self.is_pipelined and self._row_wise_sharded:
127-
# transform input to support row based sharding when not pipelined
128-
modify_input_for_feature_processor(
129-
features, self._feature_processors, self._is_collection
130-
)
13199
return self._embedding_bag_collection.input_dist(ctx, features)
132100

133101
def apply_feature_processors_to_kjt_list(self, dist_input: KJTList) -> KJTList:
@@ -137,7 +105,10 @@ def apply_feature_processors_to_kjt_list(self, dist_input: KJTList) -> KJTList:
137105
kjt_list.append(self._feature_processors(features))
138106
else:
139107
kjt_list.append(
140-
apply_feature_processors_to_kjt(features, self._feature_processors)
108+
apply_feature_processors_to_kjt(
109+
features,
110+
self._feature_processors,
111+
)
141112
)
142113
return KJTList(kjt_list)
143114

@@ -146,6 +117,7 @@ def compute(
146117
ctx: EmbeddingBagCollectionContext,
147118
dist_input: KJTList,
148119
) -> List[torch.Tensor]:
120+
149121
fp_features = self.apply_feature_processors_to_kjt_list(dist_input)
150122
return self._embedding_bag_collection.compute(ctx, fp_features)
151123

@@ -191,18 +163,6 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]:
191163
if "_embedding_bag_collection" in fqn:
192164
yield append_prefix(prefix, fqn)
193165

194-
def preprocess_input(
195-
self, args: List[_T], kwargs: Mapping[str, _T]
196-
) -> Tuple[List[_T], Mapping[str, _T]]:
197-
for x in args + list(kwargs.values()):
198-
if isinstance(x, KeyedJaggedTensor):
199-
modify_input_for_feature_processor(
200-
features=x,
201-
feature_processors=self._feature_processors,
202-
is_collection=self._is_collection,
203-
)
204-
return args, kwargs
205-
206166

207167
class FeatureProcessedEmbeddingBagCollectionSharder(
208168
BaseEmbeddingSharder[FeatureProcessedEmbeddingBagCollection]
@@ -228,6 +188,7 @@ def shard(
228188
device: Optional[torch.device] = None,
229189
module_fqn: Optional[str] = None,
230190
) -> ShardedFeatureProcessedEmbeddingBagCollection:
191+
231192
if device is None:
232193
device = torch.device("cuda")
233194

@@ -264,14 +225,12 @@ def sharding_types(self, compute_device_type: str) -> List[str]:
264225
if compute_device_type in {"mtia"}:
265226
return [ShardingType.TABLE_WISE.value, ShardingType.COLUMN_WISE.value]
266227

228+
# No row wise because position weighted FP and RW don't play well together.
267229
types = [
268230
ShardingType.DATA_PARALLEL.value,
269231
ShardingType.TABLE_WISE.value,
270232
ShardingType.COLUMN_WISE.value,
271233
ShardingType.TABLE_COLUMN_WISE.value,
272-
ShardingType.TABLE_ROW_WISE.value,
273-
ShardingType.ROW_WISE.value,
274-
ShardingType.GRID_SHARD.value,
275234
]
276235

277236
return types

torchrec/distributed/tests/test_fp_embeddingbag.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ class ShardedEmbeddingBagCollectionParallelTest(MultiProcessTestBase):
231231
def test_sharding_ebc(
232232
self, set_gradient_division: bool, use_dmp: bool, use_fp_collection: bool
233233
) -> None:
234+
234235
import hypothesis
235236

236237
# don't need to test entire matrix

torchrec/distributed/tests/test_fp_embeddingbag_utils.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,7 @@ def forward(self, kjt: KeyedJaggedTensor) -> Tuple[torch.Tensor, torch.Tensor]:
8686
pred = torch.cat(
8787
[
8888
fp_ebc_out[key]
89-
for key in [
90-
"feature_0",
91-
"feature_1",
92-
"feature_2",
93-
"feature_3",
94-
]
89+
for key in ["feature_0", "feature_1", "feature_2", "feature_3"]
9590
],
9691
dim=1,
9792
)

torchrec/distributed/train_pipeline/tests/test_train_pipelines.py

Lines changed: 1 addition & 164 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,7 @@
2323
from torch._dynamo.utils import counters
2424
from torch.fx._symbolic_trace import is_fx_tracing
2525
from torchrec.distributed import DistributedModelParallel
26-
from torchrec.distributed.embedding_types import (
27-
EmbeddingComputeKernel,
28-
EmbeddingTableConfig,
29-
)
26+
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
3027
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
3128
from torchrec.distributed.fp_embeddingbag import (
3229
FeatureProcessedEmbeddingBagCollectionSharder,
@@ -35,13 +32,8 @@
3532
from torchrec.distributed.model_parallel import DMPCollection
3633
from torchrec.distributed.sharding_plan import (
3734
construct_module_sharding_plan,
38-
row_wise,
3935
table_wise,
4036
)
41-
from torchrec.distributed.test_utils.multi_process import (
42-
MultiProcessContext,
43-
MultiProcessTestBase,
44-
)
4537
from torchrec.distributed.test_utils.test_model import (
4638
ModelInput,
4739
TestEBCSharder,
@@ -350,161 +342,6 @@ def test_equal_to_non_pipelined_with_input_transformer(self) -> None:
350342
torch.testing.assert_close(pred_gpu.cpu(), pred)
351343

352344

353-
def fp_ebc(
354-
rank: int,
355-
world_size: int,
356-
tables: List[EmbeddingTableConfig],
357-
weighted_tables: List[EmbeddingTableConfig],
358-
data: List[Tuple[ModelInput, List[ModelInput]]],
359-
backend: str = "nccl",
360-
local_size: Optional[int] = None,
361-
) -> None:
362-
with MultiProcessContext(rank, world_size, backend, local_size) as ctx:
363-
assert ctx.pg is not None
364-
sharder = cast(
365-
ModuleSharder[nn.Module],
366-
FeatureProcessedEmbeddingBagCollectionSharder(),
367-
)
368-
369-
class DummyWrapper(nn.Module):
370-
def __init__(self, sparse_arch):
371-
super().__init__()
372-
self.m = sparse_arch
373-
374-
def forward(self, model_input) -> Tuple[torch.Tensor, torch.Tensor]:
375-
return self.m(model_input.idlist_features)
376-
377-
max_feature_lengths = [10, 10, 12, 12]
378-
sparse_arch = DummyWrapper(
379-
create_module_and_freeze(
380-
tables=tables, # pyre-ignore[6]
381-
device=ctx.device,
382-
use_fp_collection=False,
383-
max_feature_lengths=max_feature_lengths,
384-
)
385-
)
386-
387-
# compute_kernel = EmbeddingComputeKernel.FUSED.value
388-
module_sharding_plan = construct_module_sharding_plan(
389-
sparse_arch.m._fp_ebc,
390-
per_param_sharding={
391-
"table_0": row_wise(),
392-
"table_1": row_wise(),
393-
"table_2": row_wise(),
394-
"table_3": row_wise(),
395-
},
396-
world_size=2,
397-
device_type=ctx.device.type,
398-
sharder=sharder,
399-
)
400-
sharded_sparse_arch_pipeline = DistributedModelParallel(
401-
module=copy.deepcopy(sparse_arch),
402-
plan=ShardingPlan({"m._fp_ebc": module_sharding_plan}),
403-
env=ShardingEnv.from_process_group(ctx.pg), # pyre-ignore[6]
404-
sharders=[sharder],
405-
device=ctx.device,
406-
)
407-
sharded_sparse_arch_no_pipeline = DistributedModelParallel(
408-
module=copy.deepcopy(sparse_arch),
409-
plan=ShardingPlan({"m._fp_ebc": module_sharding_plan}),
410-
env=ShardingEnv.from_process_group(ctx.pg), # pyre-ignore[6]
411-
sharders=[sharder],
412-
device=ctx.device,
413-
)
414-
415-
batches = []
416-
for d in data:
417-
batches.append(d[1][ctx.rank].to(ctx.device))
418-
dataloader = iter(batches)
419-
420-
optimizer_no_pipeline = optim.SGD(
421-
sharded_sparse_arch_no_pipeline.parameters(), lr=0.1
422-
)
423-
optimizer_pipeline = optim.SGD(
424-
sharded_sparse_arch_pipeline.parameters(), lr=0.1
425-
)
426-
427-
pipeline = TrainPipelineSparseDist(
428-
sharded_sparse_arch_pipeline,
429-
optimizer_pipeline,
430-
ctx.device,
431-
)
432-
433-
for batch in batches[:-2]:
434-
batch = batch.to(ctx.device)
435-
optimizer_no_pipeline.zero_grad()
436-
loss, pred = sharded_sparse_arch_no_pipeline(batch)
437-
loss.backward()
438-
optimizer_no_pipeline.step()
439-
440-
pred_pipeline = pipeline.progress(dataloader)
441-
torch.testing.assert_close(pred_pipeline.cpu(), pred.cpu())
442-
443-
444-
class TrainPipelineGPUTest(MultiProcessTestBase):
445-
def setUp(self, backend: str = "nccl") -> None:
446-
super().setUp()
447-
448-
self.pipeline_class = TrainPipelineSparseDist
449-
num_features = 4
450-
num_weighted_features = 4
451-
self.tables = [
452-
EmbeddingBagConfig(
453-
num_embeddings=(i + 1) * 100,
454-
embedding_dim=(i + 1) * 4,
455-
name="table_" + str(i),
456-
feature_names=["feature_" + str(i)],
457-
)
458-
for i in range(num_features)
459-
]
460-
self.weighted_tables = [
461-
EmbeddingBagConfig(
462-
num_embeddings=(i + 1) * 100,
463-
embedding_dim=(i + 1) * 4,
464-
name="weighted_table_" + str(i),
465-
feature_names=["weighted_feature_" + str(i)],
466-
)
467-
for i in range(num_weighted_features)
468-
]
469-
470-
self.backend = backend
471-
if torch.cuda.is_available():
472-
self.device = torch.device("cuda")
473-
else:
474-
self.device = torch.device("cpu")
475-
476-
if self.backend == "nccl" and self.device == torch.device("cpu"):
477-
self.skipTest("NCCL not supported on CPUs.")
478-
479-
def _generate_data(
480-
self,
481-
num_batches: int = 5,
482-
batch_size: int = 1,
483-
max_feature_lengths: Optional[List[int]] = None,
484-
) -> List[Tuple[ModelInput, List[ModelInput]]]:
485-
return [
486-
ModelInput.generate(
487-
tables=self.tables,
488-
weighted_tables=self.weighted_tables,
489-
batch_size=batch_size,
490-
world_size=2,
491-
num_float_features=10,
492-
max_feature_lengths=max_feature_lengths,
493-
)
494-
for i in range(num_batches)
495-
]
496-
497-
def test_fp_ebc_rw(self) -> None:
498-
data = self._generate_data(max_feature_lengths=[10, 10, 12, 12])
499-
self._run_multi_process_test(
500-
callable=fp_ebc,
501-
world_size=2,
502-
tables=self.tables,
503-
weighted_tables=self.weighted_tables,
504-
data=data,
505-
)
506-
507-
508345
class TrainPipelineSparseDistTest(TrainPipelineSparseDistTestBase):
509346
# pyre-fixme[56]: Pyre was not able to infer the type of argument
510347
@unittest.skipIf(

torchrec/distributed/train_pipeline/tests/test_train_pipelines_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def setUp(self) -> None:
4040
self.pg = init_distributed_single_host(backend=backend, rank=0, world_size=1)
4141

4242
num_features = 4
43-
num_weighted_features = 4
43+
num_weighted_features = 2
4444
self.tables = [
4545
EmbeddingBagConfig(
4646
num_embeddings=(i + 1) * 100,

torchrec/distributed/train_pipeline/utils.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,6 @@ def _start_data_dist(
147147
# and this info was done in the _rewrite_model by tracing the
148148
# entire model to get the arg_info_list
149149
args, kwargs = forward.args.build_args_kwargs(batch)
150-
args, kwargs = module.preprocess_input(args, kwargs)
151150

152151
# Start input distribution.
153152
module_ctx = module.create_context()
@@ -380,8 +379,6 @@ def _rewrite_model( # noqa C901
380379
logger.info(f"Module '{node.target}' will be pipelined")
381380
child = sharded_modules[node.target]
382381
original_forwards.append(child.forward)
383-
# Set pipelining flag on the child module
384-
child.is_pipelined = True
385382
# pyre-ignore[8] Incompatible attribute type
386383
child.forward = pipelined_forward(
387384
node.target,

0 commit comments

Comments
 (0)