@@ -730,6 +730,7 @@ def _pipeline_model(
730
730
batch : Optional [In ],
731
731
context : TrainPipelineContext ,
732
732
pipelined_forward : Type [PipelinedForward ] = PipelinedForward ,
733
+ custom_dist_stream : Optional [torch .Stream ] = None ,
733
734
) -> None :
734
735
(
735
736
self ._pipelined_modules ,
@@ -740,7 +741,11 @@ def _pipeline_model(
740
741
) = _rewrite_model (
741
742
model = self ._model ,
742
743
context = context ,
743
- dist_stream = self ._data_dist_stream ,
744
+ dist_stream = (
745
+ self ._data_dist_stream
746
+ if custom_dist_stream is None
747
+ else custom_dist_stream
748
+ ),
744
749
default_stream = torch .get_device_module (self ._device ).current_stream (),
745
750
batch = batch ,
746
751
apply_jit = self ._apply_jit ,
@@ -768,6 +773,7 @@ def _init_pipelined_modules(
768
773
batch : In ,
769
774
context : TrainPipelineContext ,
770
775
pipelined_forward : Type [PipelinedForward ] = PipelinedForward ,
776
+ custom_dist_stream : Optional [torch .Stream ] = None ,
771
777
) -> None :
772
778
"""
773
779
Retrieves the pipelined modules after overriding their forwards, initializes the
@@ -779,7 +785,7 @@ def _init_pipelined_modules(
779
785
self .start_sparse_data_dist (batch , context )
780
786
return
781
787
782
- self ._pipeline_model (batch , context , pipelined_forward )
788
+ self ._pipeline_model (batch , context , pipelined_forward , custom_dist_stream )
783
789
784
790
def copy_batch_to_gpu (
785
791
self ,
@@ -1452,7 +1458,7 @@ def __init__(
1452
1458
else None
1453
1459
)
1454
1460
self ._default_stream : Optional [torch .Stream ] = (
1455
- (torch .get_device_module (self ._device ).Stream ())
1461
+ (torch .get_device_module (self ._device ).current_stream ())
1456
1462
if self ._device .type in ["cuda" , "mtia" ]
1457
1463
else None
1458
1464
)
@@ -1476,6 +1482,7 @@ def _fill_pipeline(self, dataloader_iter: Iterator[In]) -> None:
1476
1482
self ._context ,
1477
1483
# pyre-ignore
1478
1484
self ._pipelined_forward_type ,
1485
+ self ._prefetch_stream ,
1479
1486
)
1480
1487
self ._start_sparse_data_dist (self ._batch_i )
1481
1488
self ._wait_sparse_data_dist ()
0 commit comments