diff --git a/torchrec/distributed/train_pipeline/train_pipelines.py b/torchrec/distributed/train_pipeline/train_pipelines.py index 9afcc4a4c..3833bc996 100644 --- a/torchrec/distributed/train_pipeline/train_pipelines.py +++ b/torchrec/distributed/train_pipeline/train_pipelines.py @@ -730,6 +730,7 @@ def _pipeline_model( batch: Optional[In], context: TrainPipelineContext, pipelined_forward: Type[PipelinedForward] = PipelinedForward, + custom_dist_stream: Optional[torch.Stream] = None, ) -> None: ( self._pipelined_modules, @@ -740,7 +741,11 @@ def _pipeline_model( ) = _rewrite_model( model=self._model, context=context, - dist_stream=self._data_dist_stream, + dist_stream=( + self._data_dist_stream + if custom_dist_stream is None + else custom_dist_stream + ), default_stream=torch.get_device_module(self._device).current_stream(), batch=batch, apply_jit=self._apply_jit, @@ -768,6 +773,7 @@ def _init_pipelined_modules( batch: In, context: TrainPipelineContext, pipelined_forward: Type[PipelinedForward] = PipelinedForward, + custom_dist_stream: Optional[torch.Stream] = None, ) -> None: """ Retrieves the pipelined modules after overriding their forwards, initializes the @@ -779,7 +785,7 @@ def _init_pipelined_modules( self.start_sparse_data_dist(batch, context) return - self._pipeline_model(batch, context, pipelined_forward) + self._pipeline_model(batch, context, pipelined_forward, custom_dist_stream) def copy_batch_to_gpu( self, @@ -1452,7 +1458,7 @@ def __init__( else None ) self._default_stream: Optional[torch.Stream] = ( - (torch.get_device_module(self._device).Stream()) + (torch.get_device_module(self._device).current_stream()) if self._device.type in ["cuda", "mtia"] else None ) @@ -1476,6 +1482,7 @@ def _fill_pipeline(self, dataloader_iter: Iterator[In]) -> None: self._context, # pyre-ignore self._pipelined_forward_type, + self._prefetch_stream, ) self._start_sparse_data_dist(self._batch_i) self._wait_sparse_data_dist()