Skip to content

Commit 8d17716

Browse files
peterfu0facebook-github-bot
authored andcommitted
Fix prefetch train pipeline (#3306)
Summary: there are 2 bugs * a random stream is used as default stream * prefetch stream is not waited by pipelined_forward Differential Revision: D80268824
1 parent e23007b commit 8d17716

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

torchrec/distributed/train_pipeline/train_pipelines.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -730,6 +730,7 @@ def _pipeline_model(
730730
batch: Optional[In],
731731
context: TrainPipelineContext,
732732
pipelined_forward: Type[PipelinedForward] = PipelinedForward,
733+
custom_dist_stream: Optional[torch.Stream] = None,
733734
) -> None:
734735
(
735736
self._pipelined_modules,
@@ -740,7 +741,11 @@ def _pipeline_model(
740741
) = _rewrite_model(
741742
model=self._model,
742743
context=context,
743-
dist_stream=self._data_dist_stream,
744+
dist_stream=(
745+
self._data_dist_stream
746+
if custom_dist_stream is None
747+
else custom_dist_stream
748+
),
744749
default_stream=torch.get_device_module(self._device).current_stream(),
745750
batch=batch,
746751
apply_jit=self._apply_jit,
@@ -768,6 +773,7 @@ def _init_pipelined_modules(
768773
batch: In,
769774
context: TrainPipelineContext,
770775
pipelined_forward: Type[PipelinedForward] = PipelinedForward,
776+
custom_dist_stream: Optional[torch.Stream] = None,
771777
) -> None:
772778
"""
773779
Retrieves the pipelined modules after overriding their forwards, initializes the
@@ -779,7 +785,7 @@ def _init_pipelined_modules(
779785
self.start_sparse_data_dist(batch, context)
780786
return
781787

782-
self._pipeline_model(batch, context, pipelined_forward)
788+
self._pipeline_model(batch, context, pipelined_forward, custom_dist_stream)
783789

784790
def copy_batch_to_gpu(
785791
self,
@@ -1452,7 +1458,7 @@ def __init__(
14521458
else None
14531459
)
14541460
self._default_stream: Optional[torch.Stream] = (
1455-
(torch.get_device_module(self._device).Stream())
1461+
(torch.get_device_module(self._device).current_stream())
14561462
if self._device.type in ["cuda", "mtia"]
14571463
else None
14581464
)
@@ -1476,6 +1482,7 @@ def _fill_pipeline(self, dataloader_iter: Iterator[In]) -> None:
14761482
self._context,
14771483
# pyre-ignore
14781484
self._pipelined_forward_type,
1485+
self._prefetch_stream,
14791486
)
14801487
self._start_sparse_data_dist(self._batch_i)
14811488
self._wait_sparse_data_dist()

0 commit comments

Comments
 (0)