Skip to content

Commit bbf24cd

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 1851c3a commit bbf24cd

File tree

4 files changed

+27
-12
lines changed

4 files changed

+27
-12
lines changed

transformer_engine/pytorch/cpu_offload.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -340,9 +340,8 @@ def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any:
340340

341341
self.tensor_tag_to_state[tensor_tag] = tensor
342342

343-
if (
344-
self.current_group < self.num_offload_group
345-
and self.tensor_need_offloading_checker(tensor)
343+
if self.current_group < self.num_offload_group and self.tensor_need_offloading_checker(
344+
tensor
346345
):
347346
self.tensor_tag_to_buf[tensor_tag] = tensor
348347
else:

transformer_engine/pytorch/tensor/float8_blockwise_tensor.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,10 @@ def make_empty(
203203
columnwise_scale_inv = None
204204
if self.columnwise_usage:
205205
columnwise_data = torch.empty(
206-
self.get_columnwise_shape(shape), dtype=torch.uint8, device=device, pin_memory=pin_memory
206+
self.get_columnwise_shape(shape),
207+
dtype=torch.uint8,
208+
device=device,
209+
pin_memory=pin_memory,
207210
)
208211
columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True)
209212
columnwise_scale_inv = torch.empty(

transformer_engine/pytorch/tensor/float8_tensor.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def make_empty(
9797
device: Optional[torch.device] = None,
9898
requires_grad: bool = False,
9999
pin_memory: bool = False,
100-
rowwise: bool = None,
100+
rowwise: bool = None,
101101
columnwise: bool = None,
102102
) -> Float8Tensor:
103103
rowwise = rowwise if rowwise is not None else self.rowwise_usage
@@ -624,6 +624,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None):
624624
dst, src = args[0], args[1]
625625
# Just copy FP8 attrs if copying between Float8Tensors
626626
if isinstance(src, Float8Tensor) and isinstance(dst, Float8Tensor):
627+
627628
def copy_tensor(src, dst, tensor_name):
628629
src_is_none = getattr(src, tensor_name) is None
629630
dst_is_none = getattr(dst, tensor_name) is None
@@ -633,13 +634,14 @@ def copy_tensor(src, dst, tensor_name):
633634
setattr(dst, tensor_name, None)
634635
elif not src_is_none and not dst_is_none:
635636
getattr(src, tensor_name).copy_(getattr(dst, tensor_name))
637+
636638
copy_tensor(src, dst, "_data")
637639
copy_tensor(src, dst, "_scale_inv")
638640
if src._transpose is not None and dst._data is not None:
639641
dst._create_transpose()
640642
copy_tensor(src, dst, "_transpose")
641643
return
642-
644+
643645
elif func in _ops_to_preserve_subclass_in_fsdp2:
644646
# Ops in the _ops_to_preserve_subclass_in_fsdp2 are recommened to return the same class instance to work fine with the torch fsdp2
645647
warnings.warn(

transformer_engine/pytorch/tensor/quantized_tensor.py

+17-6
Original file line numberDiff line numberDiff line change
@@ -366,19 +366,30 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None):
366366
# View op
367367
if func == torch.ops.aten.view.default:
368368
raise NotImplementedError("{cls.__name__} class does not support tensor views")
369-
369+
370370
# Empty like op
371371
if func == torch.ops.aten.empty_like.default:
372372
tensor = args[0]
373-
quantizer = tensor._quantizer # TODO - pgadzinski look if this makes sense
373+
quantizer = tensor._quantizer # TODO - pgadzinski look if this makes sense
374374
dtype = kwargs.get("dtype", tensor.dtype)
375375
device = kwargs.get("device", tensor.device)
376376
shape = kwargs.get("shape", tensor.shape)
377377
pin_memory = kwargs.get("pin_memory", False)
378-
rowwise = (getattr(tensor, "_data", None) is not None) or (getattr(tensor, "_rowwise_data", None) is not None)
379-
columnwise = (getattr(tensor, "_transpose", None) is not None) or (getattr(tensor, "_columnwise_data", None) is not None)
380-
381-
return quantizer.make_empty(shape, dtype=dtype, device=device, pin_memory=pin_memory, rowwise=rowwise, columnwise=columnwise)
378+
rowwise = (getattr(tensor, "_data", None) is not None) or (
379+
getattr(tensor, "_rowwise_data", None) is not None
380+
)
381+
columnwise = (getattr(tensor, "_transpose", None) is not None) or (
382+
getattr(tensor, "_columnwise_data", None) is not None
383+
)
384+
385+
return quantizer.make_empty(
386+
shape,
387+
dtype=dtype,
388+
device=device,
389+
pin_memory=pin_memory,
390+
rowwise=rowwise,
391+
columnwise=columnwise,
392+
)
382393

383394
def maybe_unwrap(arg):
384395
if isinstance(arg, QuantizedTensor):

0 commit comments

Comments
 (0)