Skip to content

Commit 7a6b62d

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 4293d32 commit 7a6b62d

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,12 @@ def restore_from_saved(
115115
def get_data_tensors(self, scaling_factors=False):
116116
"""Get this Tensor's data."""
117117
if scaling_factors:
118-
return self._rowwise_data, self._columnwise_data, self._rowwise_scale_inv, self._columnwise_scale_inv
118+
return (
119+
self._rowwise_data,
120+
self._columnwise_data,
121+
self._rowwise_scale_inv,
122+
self._columnwise_scale_inv,
123+
)
119124
return self._rowwise_data, self._columnwise_data
120125

121126
def _transpose_dq_columnwise_output(self, columnwise_dq: torch.Tensor) -> torch.Tensor:

transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,12 @@ def restore_from_saved(
134134
def get_data_tensors(self, scaling_factors=False):
135135
"""Get this Tensor's data."""
136136
if scaling_factors:
137-
return self._rowwise_data, self._columnwise_data, self._rowwise_scale_inv, self._columnwise_scale_inv
137+
return (
138+
self._rowwise_data,
139+
self._columnwise_data,
140+
self._rowwise_scale_inv,
141+
self._columnwise_scale_inv,
142+
)
138143
return self._rowwise_data, self._columnwise_data
139144

140145
def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor:

0 commit comments

Comments
 (0)