Skip to content

Commit 4293d32

Browse files
committed
fix
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
1 parent ac93bb4 commit 4293d32

File tree

5 files changed

+10
-5
lines changed

5 files changed

+10
-5
lines changed

tests/pytorch/test_cpu_offloading.py

-1
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,6 @@ def test_cpu_offload(fp8_recipe, model_key) -> None:
119119
the difference being the size of the FP8 cache that is not offloaded to the CPU.
120120
We also expect this memory consumption to be smaller than in scenario (1).
121121
"""
122-
123122
model_cls = model_types[model_key]
124123
models_list = [model_cls() for _ in range(NUM_LAYERS)]
125124

transformer_engine/pytorch/cpu_offload.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def mark_activation_offload(*tensors):
2222
if isinstance(tensor, torch.Tensor):
2323
tensor.activation_offloading = True
2424
else:
25-
data_tensors = tensor.get_data_tensors()
25+
data_tensors = tensor.get_data_tensors(scaling_factors=True)
2626
for tensor in data_tensors:
2727
if tensor is not None:
2828
tensor.activation_offloading = True

transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,10 @@ def restore_from_saved(
112112
self._columnwise_scale_inv = tensors[3]
113113
return tensors[4:]
114114

115-
def get_data_tensors(self):
115+
def get_data_tensors(self, scaling_factors=False):
116116
"""Get this Tensor's data."""
117+
if scaling_factors:
118+
return self._rowwise_data, self._columnwise_data, self._rowwise_scale_inv, self._columnwise_scale_inv
117119
return self._rowwise_data, self._columnwise_data
118120

119121
def _transpose_dq_columnwise_output(self, columnwise_dq: torch.Tensor) -> torch.Tensor:

transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,10 @@ def restore_from_saved(
128128
self._scale_inv = tensors[2]
129129
return tensors[3:]
130130

131-
def get_data_tensors(self):
131+
def get_data_tensors(self, scaling_factors=False):
132132
"""Get this Tensor's data."""
133+
if scaling_factors:
134+
return self._data, self._transpose, self._scale_inv
133135
return self._data, self._transpose
134136

135137
def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor:

transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,10 @@ def restore_from_saved(
131131
self._columnwise_scale_inv = tensors[3]
132132
return tensors[4:]
133133

134-
def get_data_tensors(self):
134+
def get_data_tensors(self, scaling_factors=False):
135135
"""Get this Tensor's data."""
136+
if scaling_factors:
137+
return self._rowwise_data, self._columnwise_data, self._rowwise_scale_inv, self._columnwise_scale_inv
136138
return self._rowwise_data, self._columnwise_data
137139

138140
def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor:

0 commit comments

Comments
 (0)