From 1851c3a41da630b44c62cb5152bd3afff9d0ed7c Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 30 Apr 2025 13:47:46 +0000 Subject: [PATCH 01/14] init Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/cpu_offload.py | 121 +++--------------- .../pytorch/tensor/float8_blockwise_tensor.py | 7 +- .../pytorch/tensor/float8_tensor.py | 69 +++++++--- .../pytorch/tensor/mxfp8_tensor.py | 5 +- .../pytorch/tensor/quantized_tensor.py | 16 +++ 5 files changed, 91 insertions(+), 127 deletions(-) diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index 814e699557..dfc274134d 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -9,8 +9,6 @@ import torch -from .tensor.float8_tensor import Float8Tensor - __all__ = ["get_cpu_offload_context"] CPUOffloadEnabled = False @@ -21,17 +19,13 @@ def mark_activation_offload(*tensors): for tensor in tensors: if tensor is None: continue - if type(tensor) in [torch.Tensor, torch.nn.Parameter]: + if isinstance(tensor, torch.Tensor): tensor.activation_offloading = True else: data_tensors = tensor.get_data_tensors() for tensor in data_tensors: if tensor is not None: tensor.activation_offloading = True - # This is a hack to force clear the tensor after it is offloaded. - # It is needed, because .*TensorBase classes are saved in the ctx, - # and they contain the reference to their data tensors. - tensor.needs_force_clear = True def is_cpu_offload_enabled() -> bool: @@ -238,14 +232,11 @@ def on_group_commit_backward(self): def offload(src_tensor, pin_memory=True): """Offload.""" - cpu_backup = torch.empty( - src_tensor.size(), - dtype=src_tensor.dtype, - layout=src_tensor.layout, + cpu_backup = torch.empty_like( + src_tensor, device="cpu", pin_memory=pin_memory, ) - cpu_backup.copy_(src_tensor, non_blocking=pin_memory) state = (src_tensor.device, cpu_backup) return state @@ -309,9 +300,6 @@ def __init__( self.num_layers = num_model_group # Data Structure to maintain reference to activation tensors self.tensor_tag_to_buf = {} - # Data structure to hold the FP8/MXFP8 tensor objects - self.fp8_tensor_object_map = {} - self.float8_transpose_cache_valid = {} # Tracking the number of layers offloaded self.offloaded_group_count = 0 # Core data structure that decides the window for offloading @@ -342,8 +330,6 @@ def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: ), ) - is_quantized_tensor = callable(getattr(tensor, "prepare_for_saving", None)) - if not torch_stray_tensor: # obtain a unique tensor tag @@ -352,36 +338,13 @@ def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: assert tensor_tag not in self.tensor_tag_to_state - if is_quantized_tensor: - tensor_list, _ = tensor.prepare_for_saving() - - self.tensor_tag_to_state[tensor_tag] = [] - self.tensor_tag_to_buf[tensor_tag] = [] + self.tensor_tag_to_state[tensor_tag] = tensor - self.fp8_tensor_object_map[tensor_tag] = tensor - if isinstance(tensor, Float8Tensor): - self.float8_transpose_cache_valid[tensor_tag] = getattr( - tensor, "_transpose_invalid" - ) - else: - tensor_list = [tensor] - - for t in tensor_list: - if is_quantized_tensor: - self.tensor_tag_to_state[tensor_tag].append(t) - else: - self.tensor_tag_to_state[tensor_tag] = t - - if ( - self.current_group < self.num_offload_group - and self.tensor_need_offloading_checker(t) - ): - if is_quantized_tensor: - self.tensor_tag_to_buf[tensor_tag].append(t) - # Need to clear the internal data reference for the quantized tensors - tensor.clear() - else: - self.tensor_tag_to_buf[tensor_tag] = t + if ( + self.current_group < self.num_offload_group + and self.tensor_need_offloading_checker(tensor) + ): + self.tensor_tag_to_buf[tensor_tag] = tensor else: tensor_tag = (-1, self.torch_tensor_count) self.torch_tensor_count += 1 @@ -393,12 +356,6 @@ def tensor_pop(self, tensor_tag, **kwargs): """Tensor pop.""" assert tensor_tag in self.tensor_tag_to_state tensor = self.tensor_tag_to_state.pop(tensor_tag) - - # Handling the quantized tensor case specially here - if isinstance(tensor, list): - self.fp8_tensor_object_map[tensor_tag].restore_from_saved(tensor) - tensor = self.fp8_tensor_object_map.pop(tensor_tag) - self.tensor_tag_to_buf.pop(tensor_tag, None) # the tensor should have been copied back in on_group_commit_backward() @@ -414,36 +371,11 @@ def bulk_offload_group(self, group_to_offload): if group_id == group_to_offload: assert not isinstance(state, tuple) - is_quantized_tensor = isinstance(state, list) - - if is_quantized_tensor: - tensor_list = state - self.tensor_tag_to_state[tensor_tag] = [] - else: - tensor_list = [state] - - for tensor_on_device in tensor_list: - # `tensor_offloaded` is a hacky way of dealing with columnwise-only - # quantized tensors for CPU offloading. The complication is due to - # the `rowwise_data` being `None`. The offloading checker incorrectly - # returns `False` and the entire `state` ([None, columnwise_tensor]) - # is added to the tensor tag state dict. A better design would change - # how quantized tensors are kept track of in the offload handler. - # Currently at every stage it is ensured that a quantized tensor is a - # list whereas a non-quantized tensor is standalone object, which is - # not good! TODO(@sanandaraj5597) - tensor_offloaded = False - # if offload, return the reference to cpu copy - if self.tensor_need_offloading_checker(tensor_on_device): - tensor_offloaded = True - state = SynchronizedGroupOffloadHandler.offload(tensor_on_device) - if is_quantized_tensor: - if tensor_offloaded: - self.tensor_tag_to_state[tensor_tag].append(state) - else: - self.tensor_tag_to_state[tensor_tag].append(tensor_on_device) - else: - self.tensor_tag_to_state[tensor_tag] = state + tensor_on_device = state + # if offload, return the reference to cpu copy + if self.tensor_need_offloading_checker(tensor_on_device): + state = SynchronizedGroupOffloadHandler.offload(tensor_on_device) + self.tensor_tag_to_state[tensor_tag] = state def synchronize_on_group_commit_forward(self, current_group): """Synchronize on group commit forward.""" @@ -463,14 +395,8 @@ def synchronize_on_group_commit_forward(self, current_group): torch.cuda.current_stream().wait_stream(self.d2h_stream) # Time to free the activation memory after usage - for tensor_tag, tensor_buf in self.tensor_tag_to_buf.items(): + for tensor_tag, _ in self.tensor_tag_to_buf.items(): if tensor_tag[0] == self.offloaded_group_count: - if hasattr(tensor_buf, "needs_force_clear"): - # Need to clear activation tensor - sometimes references persist in the code. - # This is the case for example with the Float8TensorBase class, - # which is saved directly inside the ctx while its internal tensors are - # saved inside save_for_backward. - tensor_buf.data = torch.Tensor() # Release the pointer to the tensor self.tensor_tag_to_buf[tensor_tag] = None @@ -500,23 +426,6 @@ def bulk_reload_group(self, group_to_reload): if isinstance(state, tuple): recovered_tensor = SynchronizedGroupOffloadHandler.reload(state) self.tensor_tag_to_state[tensor_label] = recovered_tensor - elif isinstance(state, list): - tensor_list = [] - for state_tuple in state: - if isinstance(state_tuple, tuple): - tensor_list.append( - SynchronizedGroupOffloadHandler.reload(state_tuple) - ) - else: - tensor_list.append(state_tuple) - _ = self.fp8_tensor_object_map[tensor_label].restore_from_saved(tensor_list) - if isinstance(self.fp8_tensor_object_map[tensor_label], Float8Tensor): - self.fp8_tensor_object_map[tensor_label]._transpose_invalid = ( - self.float8_transpose_cache_valid.pop(tensor_label) - ) - self.tensor_tag_to_state[tensor_label] = self.fp8_tensor_object_map.pop( - tensor_label - ) def on_group_commit_backward(self): # first decrement the current group. diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 7e101b2612..fcc0f51bc6 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -179,6 +179,7 @@ def make_empty( dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, requires_grad: bool = False, + pin_memory: bool = False, ) -> Float8BlockwiseQTensor: """Construct quantized tensor with uninitialized data""" if device is None: @@ -188,12 +189,13 @@ def make_empty( data = None scale_inv = None if self.rowwise_usage: - data = torch.empty(shape, dtype=torch.uint8, device=device) + data = torch.empty(shape, dtype=torch.uint8, device=device, pin_memory=pin_memory) scale_shape = self.get_scale_shape(shape, columnwise=False) scale_inv = torch.empty( scale_shape, dtype=torch.float32, device=device, + pin_memory=pin_memory, ) # Allocate FP8 data transpose if needed @@ -201,13 +203,14 @@ def make_empty( columnwise_scale_inv = None if self.columnwise_usage: columnwise_data = torch.empty( - self.get_columnwise_shape(shape), dtype=torch.uint8, device=device + self.get_columnwise_shape(shape), dtype=torch.uint8, device=device, pin_memory=pin_memory ) columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True) columnwise_scale_inv = torch.empty( columnwise_scale_shape, dtype=torch.float32, device=device, + pin_memory=pin_memory, ) # Construct FP8 tensor diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index a37eb4f632..63463e2a86 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -6,6 +6,7 @@ from __future__ import annotations from typing import Optional, Tuple, Iterable import warnings +import math import torch import transformer_engine_torch as tex @@ -95,24 +96,32 @@ def make_empty( dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, requires_grad: bool = False, + pin_memory: bool = False, + rowwise: bool = None, + columnwise: bool = None, ) -> Float8Tensor: + rowwise = rowwise if rowwise is not None else self.rowwise_usage + columnwise = columnwise if columnwise is not None else self.columnwise_usage # Canonicalize tensor attributes if device is None: device = torch.device("cuda") # Allocate FP8 data - data = torch.empty(shape, dtype=torch.uint8, device=device) + data = None + if rowwise: + data = torch.empty(shape, dtype=torch.uint8, device=device, pin_memory=pin_memory) + + transpose_shape = shape[-1:] + shape[:-1] # Allocate FP8 data transpose if needed data_transpose = None - if self.columnwise_usage: - inner_dim = data.size(-1) + if columnwise: data_transpose = torch.empty( - inner_dim, - data.numel() // inner_dim, + transpose_shape, dtype=torch.uint8, device=device, + pin_memory=pin_memory, ) # Construct FP8 tensor @@ -120,7 +129,7 @@ def make_empty( shape=shape, dtype=dtype, data=data, - fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=device), + fp8_scale_inv=torch.empty((), dtype=torch.float32, device=device), fp8_dtype=self.dtype, requires_grad=requires_grad, data_transpose=data_transpose, @@ -250,24 +259,37 @@ def make_empty( dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, requires_grad: bool = False, + pin_memory: bool = False, + rowwise: bool = None, + columnwise: bool = None, ) -> Float8Tensor: + rowwise = rowwise if rowwise is not None else self.rowwise_usage + columnwise = columnwise if columnwise is not None else self.columnwise_usage # Canonicalize tensor attributes if device is None: device = torch.device("cuda") # Allocate FP8 data - data = torch.empty(shape, dtype=torch.uint8, device=device) + data = None + if rowwise: + data = torch.empty(shape, dtype=torch.uint8, device=device, pin_memory=pin_memory) + + transpose_shape = None + if columnwise: + if len(shape) >= 2: + transpose_shape = shape[-1:] + shape[:-1] + else: + transpose_shape = shape # Allocate FP8 data transpose if needed data_transpose = None - if self.columnwise_usage: - inner_dim = data.size(-1) + if columnwise: data_transpose = torch.empty( - inner_dim, - data.numel() // inner_dim, + transpose_shape, dtype=torch.uint8, device=device, + pin_memory=pin_memory, ) # Construct FP8 tensor @@ -275,7 +297,7 @@ def make_empty( shape=shape, dtype=dtype, data=data, - fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=device), + fp8_scale_inv=torch.empty((), dtype=torch.float32, device=device), fp8_dtype=self.dtype, requires_grad=requires_grad, data_transpose=data_transpose, @@ -307,7 +329,7 @@ def create_tensor_from_data( if internal: return Float8TensorBase( data=data, - fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=data.device), + fp8_scale_inv=torch.empty((), dtype=torch.float32, device=data.device), fp8_dtype=self.dtype, requires_grad=requires_grad, data_transpose=None, @@ -317,7 +339,7 @@ def create_tensor_from_data( shape=data.shape, dtype=fake_dtype, data=data, - fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=data.device), + fp8_scale_inv=torch.empty((), dtype=torch.float32, device=data.device), fp8_dtype=self.dtype, requires_grad=requires_grad, data_transpose=None, @@ -602,11 +624,22 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): dst, src = args[0], args[1] # Just copy FP8 attrs if copying between Float8Tensors if isinstance(src, Float8Tensor) and isinstance(dst, Float8Tensor): - dst._data.copy_(src._data.detach()) - dst._scale_inv.copy_(src._scale_inv.view(dst._scale_inv.size())) - if src._transpose is not None or dst._transpose is not None: + def copy_tensor(src, dst, tensor_name): + src_is_none = getattr(src, tensor_name) is None + dst_is_none = getattr(dst, tensor_name) is None + if src_is_none and dst_is_none: + return + elif src_is_none and not dst_is_none: + setattr(dst, tensor_name, None) + elif not src_is_none and not dst_is_none: + getattr(src, tensor_name).copy_(getattr(dst, tensor_name)) + copy_tensor(src, dst, "_data") + copy_tensor(src, dst, "_scale_inv") + if src._transpose is not None and dst._data is not None: dst._create_transpose() - return dst + copy_tensor(src, dst, "_transpose") + return + elif func in _ops_to_preserve_subclass_in_fsdp2: # Ops in the _ops_to_preserve_subclass_in_fsdp2 are recommened to return the same class instance to work fine with the torch fsdp2 warnings.warn( diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index d2124f8e1e..6f37110f80 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -83,6 +83,7 @@ def make_empty( dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, requires_grad: bool = False, + pin_memory: bool = False, ) -> MXFP8Tensor: # Canonicalize tensor attributes @@ -98,12 +99,13 @@ def make_empty( ) # Allocate FP8 data - data = torch.empty(shape, dtype=torch.uint8, device=device) + data = torch.empty(shape, dtype=torch.uint8, device=device, pin_memory=pin_memory) scale_inv = torch.zeros( round_up_to_nearest_multiple(math.prod(shape[:-1]), 128), round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4), dtype=torch.uint8, device=device, + pin_memory=pin_memory, ) # Allocate FP8 data transpose if needed @@ -116,6 +118,7 @@ def make_empty( round_up_to_nearest_multiple(shape[-1], 128), dtype=torch.uint8, device=device, + pin_memory=pin_memory, ) # Construct FP8 tensor diff --git a/transformer_engine/pytorch/tensor/quantized_tensor.py b/transformer_engine/pytorch/tensor/quantized_tensor.py index aa433e58bc..4f50b1a50f 100644 --- a/transformer_engine/pytorch/tensor/quantized_tensor.py +++ b/transformer_engine/pytorch/tensor/quantized_tensor.py @@ -148,6 +148,9 @@ def make_empty( *, dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, + pin_memory: bool = False, + rowwise: bool = True, + columnwise: bool = True, ) -> QuantizedTensor: """Construct quantized tensor with uninitialized data""" @@ -363,6 +366,19 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): # View op if func == torch.ops.aten.view.default: raise NotImplementedError("{cls.__name__} class does not support tensor views") + + # Empty like op + if func == torch.ops.aten.empty_like.default: + tensor = args[0] + quantizer = tensor._quantizer # TODO - pgadzinski look if this makes sense + dtype = kwargs.get("dtype", tensor.dtype) + device = kwargs.get("device", tensor.device) + shape = kwargs.get("shape", tensor.shape) + pin_memory = kwargs.get("pin_memory", False) + rowwise = (getattr(tensor, "_data", None) is not None) or (getattr(tensor, "_rowwise_data", None) is not None) + columnwise = (getattr(tensor, "_transpose", None) is not None) or (getattr(tensor, "_columnwise_data", None) is not None) + + return quantizer.make_empty(shape, dtype=dtype, device=device, pin_memory=pin_memory, rowwise=rowwise, columnwise=columnwise) def maybe_unwrap(arg): if isinstance(arg, QuantizedTensor): From 657cbbe3b387379856e560dc7ba6431f070da387 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 30 Apr 2025 14:57:43 +0000 Subject: [PATCH 02/14] offloading Signed-off-by: Pawel Gadzinski --- tests/pytorch/test_cpu_offloading.py | 2 +- transformer_engine/pytorch/cpu_offload.py | 1 - .../pytorch/tensor/float8_blockwise_tensor.py | 33 ++++++++++--- .../pytorch/tensor/float8_tensor.py | 48 +++++++++++-------- .../pytorch/tensor/mxfp8_tensor.py | 30 ++++++++++-- .../pytorch/tensor/quantized_tensor.py | 19 ++++---- 6 files changed, 90 insertions(+), 43 deletions(-) diff --git a/tests/pytorch/test_cpu_offloading.py b/tests/pytorch/test_cpu_offloading.py index ab4b7634b8..f8d6a11826 100644 --- a/tests/pytorch/test_cpu_offloading.py +++ b/tests/pytorch/test_cpu_offloading.py @@ -17,7 +17,7 @@ fp8_recipes = [ None, # non-fp8 - # recipe.MXFP8BlockScaling(), - scale inverse tensors offloading doest not work yet + recipe.MXFP8BlockScaling(), recipe.Float8CurrentScaling(), recipe.DelayedScaling(), ] diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index dfc274134d..956c848d91 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -331,7 +331,6 @@ def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: ) if not torch_stray_tensor: - # obtain a unique tensor tag tensor_tag = (self.current_group, self.tensor_count_current_group) self.tensor_count_current_group += 1 diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index fcc0f51bc6..db07dc6f85 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -178,8 +178,7 @@ def make_empty( *, dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, - requires_grad: bool = False, - pin_memory: bool = False, + requires_grad: bool = False ) -> Float8BlockwiseQTensor: """Construct quantized tensor with uninitialized data""" if device is None: @@ -189,13 +188,12 @@ def make_empty( data = None scale_inv = None if self.rowwise_usage: - data = torch.empty(shape, dtype=torch.uint8, device=device, pin_memory=pin_memory) + data = torch.empty(shape, dtype=torch.uint8, device=device) scale_shape = self.get_scale_shape(shape, columnwise=False) scale_inv = torch.empty( scale_shape, dtype=torch.float32, device=device, - pin_memory=pin_memory, ) # Allocate FP8 data transpose if needed @@ -203,14 +201,13 @@ def make_empty( columnwise_scale_inv = None if self.columnwise_usage: columnwise_data = torch.empty( - self.get_columnwise_shape(shape), dtype=torch.uint8, device=device, pin_memory=pin_memory + self.get_columnwise_shape(shape), dtype=torch.uint8, device=device ) columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True) columnwise_scale_inv = torch.empty( columnwise_scale_shape, dtype=torch.float32, device=device, - pin_memory=pin_memory, ) # Construct FP8 tensor @@ -380,6 +377,30 @@ def clone(self) -> Float8BlockwiseQTensor: }, ) + def empty_like(self, *args, **kwargs): + """Create a new empty tensor with the same shape and type as this tensor""" + new_rowwise_data = torch.empty_like(self._rowwise_data, *args, **kwargs) \ + if self._rowwise_data is not None else None + new_columnwise_data = torch.empty_like(self._columnwise_data, *args, **kwargs) \ + if self._columnwise_data is not None else None + rowwise_scale_inv = torch.empty_like(self._rowwise_scale_inv, *args, **kwargs) \ + if self._rowwise_scale_inv is not None else None + new_columnwise_scale_inv = torch.empty_like(self._columnwise_scale_inv, *args, **kwargs) \ + if self._columnwise_scale_inv is not None else None + + return Float8BlockwiseQTensor( + shape=self.shape, + dtype=self.dtype, + fp8_dtype=self._fp8_dtype, + rowwise_data=new_rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + columnwise_data=new_columnwise_data, + columnwise_scale_inv=new_columnwise_scale_inv, + quantizer=self._quantizer, + is_2D_scaled=self._is_2D_scaled, + requires_grad=self.requires_grad, + ) + def view(self, *shape: Tuple[int]) -> Float8BlockwiseQTensor: # pylint: disable=missing-function-docstring return _ViewFunc.apply(self, shape) diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 63463e2a86..a0ab3f5d62 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -95,13 +95,8 @@ def make_empty( *, dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, - requires_grad: bool = False, - pin_memory: bool = False, - rowwise: bool = None, - columnwise: bool = None, + requires_grad: bool = False ) -> Float8Tensor: - rowwise = rowwise if rowwise is not None else self.rowwise_usage - columnwise = columnwise if columnwise is not None else self.columnwise_usage # Canonicalize tensor attributes if device is None: @@ -109,19 +104,18 @@ def make_empty( # Allocate FP8 data data = None - if rowwise: - data = torch.empty(shape, dtype=torch.uint8, device=device, pin_memory=pin_memory) + if self.rowwise_usage: + data = torch.empty(shape, dtype=torch.uint8, device=device) transpose_shape = shape[-1:] + shape[:-1] # Allocate FP8 data transpose if needed data_transpose = None - if columnwise: + if self.columnwise_usage: data_transpose = torch.empty( transpose_shape, dtype=torch.uint8, device=device, - pin_memory=pin_memory, ) # Construct FP8 tensor @@ -258,13 +252,9 @@ def make_empty( *, dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, - requires_grad: bool = False, - pin_memory: bool = False, - rowwise: bool = None, - columnwise: bool = None, + requires_grad: bool = False ) -> Float8Tensor: - rowwise = rowwise if rowwise is not None else self.rowwise_usage - columnwise = columnwise if columnwise is not None else self.columnwise_usage + # Canonicalize tensor attributes if device is None: @@ -272,11 +262,11 @@ def make_empty( # Allocate FP8 data data = None - if rowwise: - data = torch.empty(shape, dtype=torch.uint8, device=device, pin_memory=pin_memory) + if self.rowwise_usage: + data = torch.empty(shape, dtype=torch.uint8, device=device) transpose_shape = None - if columnwise: + if self.columnwise_usage: if len(shape) >= 2: transpose_shape = shape[-1:] + shape[:-1] else: @@ -284,12 +274,11 @@ def make_empty( # Allocate FP8 data transpose if needed data_transpose = None - if columnwise: + if self.columnwise_usage: data_transpose = torch.empty( transpose_shape, dtype=torch.uint8, device=device, - pin_memory=pin_memory, ) # Construct FP8 tensor @@ -496,6 +485,23 @@ def clone(self) -> Float8Tensor: }, ) + def empty_like(self, *args, **kwargs): + """Create a new empty tensor with the same shape and type as this tensor""" + new_data = torch.empty_like(self._data, *args, **kwargs) \ + if self._data is not None else None + new_transpose = torch.empty_like(self._transpose, *args, **kwargs) \ + if self._transpose is not None else None + new_scale_inv = torch.empty_like(self._scale_inv, *args, **kwargs) + return Float8Tensor( + shape=self.shape, + dtype=self.dtype, + data=new_data, + fp8_scale_inv=new_scale_inv, + fp8_dtype=self._fp8_dtype, + data_transpose=new_transpose, + quantizer=self._quantizer, + ) + def view(self, *shape: Tuple[int]) -> Float8Tensor: # pylint: disable=missing-function-docstring return _ViewFunc.apply(self, shape) diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 6f37110f80..e991b589ef 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -82,8 +82,7 @@ def make_empty( *, dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, - requires_grad: bool = False, - pin_memory: bool = False, + requires_grad: bool = False ) -> MXFP8Tensor: # Canonicalize tensor attributes @@ -99,13 +98,12 @@ def make_empty( ) # Allocate FP8 data - data = torch.empty(shape, dtype=torch.uint8, device=device, pin_memory=pin_memory) + data = torch.empty(shape, dtype=torch.uint8, device=device) scale_inv = torch.zeros( round_up_to_nearest_multiple(math.prod(shape[:-1]), 128), round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4), dtype=torch.uint8, device=device, - pin_memory=pin_memory, ) # Allocate FP8 data transpose if needed @@ -118,7 +116,6 @@ def make_empty( round_up_to_nearest_multiple(shape[-1], 128), dtype=torch.uint8, device=device, - pin_memory=pin_memory, ) # Construct FP8 tensor @@ -279,6 +276,29 @@ def clone(self) -> MXFP8Tensor: "columnwise_data": columnwise_data, }, ) + + + def empty_like(self, *args, **kwargs): + """Create a new empty tensor with the same shape and type as this tensor""" + new_rowwise_data = torch.empty_like(self._rowwise_data, *args, **kwargs) \ + if self._rowwise_data is not None else None + new_columnwise_data = torch.empty_like(self._columnwise_data, *args, **kwargs) \ + if self._columnwise_data is not None else None + new_rowwise_scale_inv = torch.empty_like(self._rowwise_scale_inv, *args, **kwargs) \ + if self._rowwise_scale_inv is not None else None + new_columnwise_scale_inv = torch.empty_like(self._columnwise_scale_inv, *args, **kwargs) \ + if self._columnwise_scale_inv is not None else None + return MXFP8Tensor( + shape=self.shape, + dtype=self.dtype, + rowwise_data=new_rowwise_data, + rowwise_scale_inv=new_rowwise_scale_inv, + fp8_dtype=self._fp8_dtype, + columnwise_data=new_columnwise_data, + columnwise_scale_inv=new_columnwise_scale_inv, + quantizer=self._quantizer, + ) + def view(self, *shape: Tuple[int]) -> MXFP8Tensor: # pylint: disable=missing-function-docstring diff --git a/transformer_engine/pytorch/tensor/quantized_tensor.py b/transformer_engine/pytorch/tensor/quantized_tensor.py index 4f50b1a50f..b7722d983e 100644 --- a/transformer_engine/pytorch/tensor/quantized_tensor.py +++ b/transformer_engine/pytorch/tensor/quantized_tensor.py @@ -314,6 +314,15 @@ def update_usage( def clear(self): """Deallocate this tensor's memory. Typically not needed and must be used carefully""" + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement clear function" + ) + + def empty_like(self, *args, **kwargs): + """Create a new empty tensor with the same shape and type as this tensor""" + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement empty_like function" + ) def __repr__(self, *, tensor_contents=None) -> str: return f"{self.__class__.__name__}(data={self.dequantize(dtype=self.dtype)})" @@ -370,15 +379,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): # Empty like op if func == torch.ops.aten.empty_like.default: tensor = args[0] - quantizer = tensor._quantizer # TODO - pgadzinski look if this makes sense - dtype = kwargs.get("dtype", tensor.dtype) - device = kwargs.get("device", tensor.device) - shape = kwargs.get("shape", tensor.shape) - pin_memory = kwargs.get("pin_memory", False) - rowwise = (getattr(tensor, "_data", None) is not None) or (getattr(tensor, "_rowwise_data", None) is not None) - columnwise = (getattr(tensor, "_transpose", None) is not None) or (getattr(tensor, "_columnwise_data", None) is not None) - - return quantizer.make_empty(shape, dtype=dtype, device=device, pin_memory=pin_memory, rowwise=rowwise, columnwise=columnwise) + return tensor.empty_like(*args[1:], **kwargs) def maybe_unwrap(arg): if isinstance(arg, QuantizedTensor): From 44f6494edbbb18733afb7279e77b53b4a17d8035 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 30 Apr 2025 15:00:55 +0000 Subject: [PATCH 03/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/cpu_offload.py | 5 ++- .../pytorch/tensor/float8_blockwise_tensor.py | 32 +++++++++++++------ .../pytorch/tensor/float8_tensor.py | 19 ++++++----- .../pytorch/tensor/mxfp8_tensor.py | 32 ++++++++++++------- .../pytorch/tensor/quantized_tensor.py | 2 +- 5 files changed, 57 insertions(+), 33 deletions(-) diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index 956c848d91..d470e370d6 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -339,9 +339,8 @@ def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: self.tensor_tag_to_state[tensor_tag] = tensor - if ( - self.current_group < self.num_offload_group - and self.tensor_need_offloading_checker(tensor) + if self.current_group < self.num_offload_group and self.tensor_need_offloading_checker( + tensor ): self.tensor_tag_to_buf[tensor_tag] = tensor else: diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index db07dc6f85..89099f57a0 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -178,7 +178,7 @@ def make_empty( *, dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, - requires_grad: bool = False + requires_grad: bool = False, ) -> Float8BlockwiseQTensor: """Construct quantized tensor with uninitialized data""" if device is None: @@ -379,15 +379,27 @@ def clone(self) -> Float8BlockwiseQTensor: def empty_like(self, *args, **kwargs): """Create a new empty tensor with the same shape and type as this tensor""" - new_rowwise_data = torch.empty_like(self._rowwise_data, *args, **kwargs) \ - if self._rowwise_data is not None else None - new_columnwise_data = torch.empty_like(self._columnwise_data, *args, **kwargs) \ - if self._columnwise_data is not None else None - rowwise_scale_inv = torch.empty_like(self._rowwise_scale_inv, *args, **kwargs) \ - if self._rowwise_scale_inv is not None else None - new_columnwise_scale_inv = torch.empty_like(self._columnwise_scale_inv, *args, **kwargs) \ - if self._columnwise_scale_inv is not None else None - + new_rowwise_data = ( + torch.empty_like(self._rowwise_data, *args, **kwargs) + if self._rowwise_data is not None + else None + ) + new_columnwise_data = ( + torch.empty_like(self._columnwise_data, *args, **kwargs) + if self._columnwise_data is not None + else None + ) + rowwise_scale_inv = ( + torch.empty_like(self._rowwise_scale_inv, *args, **kwargs) + if self._rowwise_scale_inv is not None + else None + ) + new_columnwise_scale_inv = ( + torch.empty_like(self._columnwise_scale_inv, *args, **kwargs) + if self._columnwise_scale_inv is not None + else None + ) + return Float8BlockwiseQTensor( shape=self.shape, dtype=self.dtype, diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index a0ab3f5d62..fe31da5777 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -95,7 +95,7 @@ def make_empty( *, dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, - requires_grad: bool = False + requires_grad: bool = False, ) -> Float8Tensor: # Canonicalize tensor attributes @@ -252,10 +252,9 @@ def make_empty( *, dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, - requires_grad: bool = False + requires_grad: bool = False, ) -> Float8Tensor: - # Canonicalize tensor attributes if device is None: device = torch.device("cuda") @@ -487,10 +486,12 @@ def clone(self) -> Float8Tensor: def empty_like(self, *args, **kwargs): """Create a new empty tensor with the same shape and type as this tensor""" - new_data = torch.empty_like(self._data, *args, **kwargs) \ - if self._data is not None else None - new_transpose = torch.empty_like(self._transpose, *args, **kwargs) \ - if self._transpose is not None else None + new_data = torch.empty_like(self._data, *args, **kwargs) if self._data is not None else None + new_transpose = ( + torch.empty_like(self._transpose, *args, **kwargs) + if self._transpose is not None + else None + ) new_scale_inv = torch.empty_like(self._scale_inv, *args, **kwargs) return Float8Tensor( shape=self.shape, @@ -630,6 +631,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): dst, src = args[0], args[1] # Just copy FP8 attrs if copying between Float8Tensors if isinstance(src, Float8Tensor) and isinstance(dst, Float8Tensor): + def copy_tensor(src, dst, tensor_name): src_is_none = getattr(src, tensor_name) is None dst_is_none = getattr(dst, tensor_name) is None @@ -639,13 +641,14 @@ def copy_tensor(src, dst, tensor_name): setattr(dst, tensor_name, None) elif not src_is_none and not dst_is_none: getattr(src, tensor_name).copy_(getattr(dst, tensor_name)) + copy_tensor(src, dst, "_data") copy_tensor(src, dst, "_scale_inv") if src._transpose is not None and dst._data is not None: dst._create_transpose() copy_tensor(src, dst, "_transpose") return - + elif func in _ops_to_preserve_subclass_in_fsdp2: # Ops in the _ops_to_preserve_subclass_in_fsdp2 are recommened to return the same class instance to work fine with the torch fsdp2 warnings.warn( diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index e991b589ef..08f7952c31 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -82,7 +82,7 @@ def make_empty( *, dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, - requires_grad: bool = False + requires_grad: bool = False, ) -> MXFP8Tensor: # Canonicalize tensor attributes @@ -276,18 +276,29 @@ def clone(self) -> MXFP8Tensor: "columnwise_data": columnwise_data, }, ) - def empty_like(self, *args, **kwargs): """Create a new empty tensor with the same shape and type as this tensor""" - new_rowwise_data = torch.empty_like(self._rowwise_data, *args, **kwargs) \ - if self._rowwise_data is not None else None - new_columnwise_data = torch.empty_like(self._columnwise_data, *args, **kwargs) \ - if self._columnwise_data is not None else None - new_rowwise_scale_inv = torch.empty_like(self._rowwise_scale_inv, *args, **kwargs) \ - if self._rowwise_scale_inv is not None else None - new_columnwise_scale_inv = torch.empty_like(self._columnwise_scale_inv, *args, **kwargs) \ - if self._columnwise_scale_inv is not None else None + new_rowwise_data = ( + torch.empty_like(self._rowwise_data, *args, **kwargs) + if self._rowwise_data is not None + else None + ) + new_columnwise_data = ( + torch.empty_like(self._columnwise_data, *args, **kwargs) + if self._columnwise_data is not None + else None + ) + new_rowwise_scale_inv = ( + torch.empty_like(self._rowwise_scale_inv, *args, **kwargs) + if self._rowwise_scale_inv is not None + else None + ) + new_columnwise_scale_inv = ( + torch.empty_like(self._columnwise_scale_inv, *args, **kwargs) + if self._columnwise_scale_inv is not None + else None + ) return MXFP8Tensor( shape=self.shape, dtype=self.dtype, @@ -299,7 +310,6 @@ def empty_like(self, *args, **kwargs): quantizer=self._quantizer, ) - def view(self, *shape: Tuple[int]) -> MXFP8Tensor: # pylint: disable=missing-function-docstring return _ViewFunc.apply(self, shape) diff --git a/transformer_engine/pytorch/tensor/quantized_tensor.py b/transformer_engine/pytorch/tensor/quantized_tensor.py index b7722d983e..8b02f861af 100644 --- a/transformer_engine/pytorch/tensor/quantized_tensor.py +++ b/transformer_engine/pytorch/tensor/quantized_tensor.py @@ -375,7 +375,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): # View op if func == torch.ops.aten.view.default: raise NotImplementedError("{cls.__name__} class does not support tensor views") - + # Empty like op if func == torch.ops.aten.empty_like.default: tensor = args[0] From 728b174d63c24272a3e1ddc60f196cfae2507ec4 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 30 Apr 2025 15:02:56 +0000 Subject: [PATCH 04/14] fixes Signed-off-by: Pawel Gadzinski --- .../pytorch/tensor/float8_tensor.py | 34 +++++++------------ .../pytorch/tensor/quantized_tensor.py | 3 -- 2 files changed, 13 insertions(+), 24 deletions(-) diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index fe31da5777..8ec99ae52d 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -103,17 +103,17 @@ def make_empty( device = torch.device("cuda") # Allocate FP8 data - data = None - if self.rowwise_usage: - data = torch.empty(shape, dtype=torch.uint8, device=device) + data = torch.empty(shape, dtype=torch.uint8, device=device) transpose_shape = shape[-1:] + shape[:-1] # Allocate FP8 data transpose if needed data_transpose = None if self.columnwise_usage: + inner_dim = data.size(-1) data_transpose = torch.empty( - transpose_shape, + inner_dim, + data.numel() // inner_dim, dtype=torch.uint8, device=device, ) @@ -123,7 +123,7 @@ def make_empty( shape=shape, dtype=dtype, data=data, - fp8_scale_inv=torch.empty((), dtype=torch.float32, device=device), + fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=device), fp8_dtype=self.dtype, requires_grad=requires_grad, data_transpose=data_transpose, @@ -260,22 +260,15 @@ def make_empty( device = torch.device("cuda") # Allocate FP8 data - data = None - if self.rowwise_usage: - data = torch.empty(shape, dtype=torch.uint8, device=device) - - transpose_shape = None - if self.columnwise_usage: - if len(shape) >= 2: - transpose_shape = shape[-1:] + shape[:-1] - else: - transpose_shape = shape + data = torch.empty(shape, dtype=torch.uint8, device=device) # Allocate FP8 data transpose if needed data_transpose = None if self.columnwise_usage: + inner_dim = data.size(-1) data_transpose = torch.empty( - transpose_shape, + inner_dim, + data.numel() // inner_dim, dtype=torch.uint8, device=device, ) @@ -285,7 +278,7 @@ def make_empty( shape=shape, dtype=dtype, data=data, - fp8_scale_inv=torch.empty((), dtype=torch.float32, device=device), + fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=device), fp8_dtype=self.dtype, requires_grad=requires_grad, data_transpose=data_transpose, @@ -317,7 +310,7 @@ def create_tensor_from_data( if internal: return Float8TensorBase( data=data, - fp8_scale_inv=torch.empty((), dtype=torch.float32, device=data.device), + fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=data.device), fp8_dtype=self.dtype, requires_grad=requires_grad, data_transpose=None, @@ -327,7 +320,7 @@ def create_tensor_from_data( shape=data.shape, dtype=fake_dtype, data=data, - fp8_scale_inv=torch.empty((), dtype=torch.float32, device=data.device), + fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=data.device), fp8_dtype=self.dtype, requires_grad=requires_grad, data_transpose=None, @@ -647,8 +640,7 @@ def copy_tensor(src, dst, tensor_name): if src._transpose is not None and dst._data is not None: dst._create_transpose() copy_tensor(src, dst, "_transpose") - return - + return dst elif func in _ops_to_preserve_subclass_in_fsdp2: # Ops in the _ops_to_preserve_subclass_in_fsdp2 are recommened to return the same class instance to work fine with the torch fsdp2 warnings.warn( diff --git a/transformer_engine/pytorch/tensor/quantized_tensor.py b/transformer_engine/pytorch/tensor/quantized_tensor.py index 8b02f861af..c80c4c01b3 100644 --- a/transformer_engine/pytorch/tensor/quantized_tensor.py +++ b/transformer_engine/pytorch/tensor/quantized_tensor.py @@ -148,9 +148,6 @@ def make_empty( *, dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, - pin_memory: bool = False, - rowwise: bool = True, - columnwise: bool = True, ) -> QuantizedTensor: """Construct quantized tensor with uninitialized data""" From 57e78698c174e636b5831b5c08b8a79b1b53cfb1 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 30 Apr 2025 15:30:53 +0000 Subject: [PATCH 05/14] all types Signed-off-by: Pawel Gadzinski --- tests/pytorch/test_cpu_offloading.py | 4 +++ .../pytorch/tensor/float8_blockwise_tensor.py | 13 +++++++++ .../pytorch/tensor/float8_tensor.py | 28 ++++++------------- .../pytorch/tensor/mxfp8_tensor.py | 14 ++++++++++ 4 files changed, 40 insertions(+), 19 deletions(-) diff --git a/tests/pytorch/test_cpu_offloading.py b/tests/pytorch/test_cpu_offloading.py index f8d6a11826..64d4847674 100644 --- a/tests/pytorch/test_cpu_offloading.py +++ b/tests/pytorch/test_cpu_offloading.py @@ -14,12 +14,14 @@ # Check if FP8 is supported fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() +fp8_block_available, reason_for_no_fp8_block = FP8GlobalStateManager.is_fp8_block_scaling_available() fp8_recipes = [ None, # non-fp8 recipe.MXFP8BlockScaling(), recipe.Float8CurrentScaling(), recipe.DelayedScaling(), + recipe.Float8BlockScaling(), ] SIZE = 512 @@ -124,6 +126,8 @@ def test_cpu_offload(fp8_recipe, model_key) -> None: if fp8_recipe is not None: if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) + if fp8_recipe.float8_block_scaling() and not fp8_block_available: + pytest.skip(reason_for_no_fp8_block) without_offloading = _measure_memory_between_forward_and_backward( models_list, fp8_recipe, False diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 89099f57a0..129448c028 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -444,6 +444,19 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): " (scales and columnwise data untouched)." ) return Float8BlockwiseQTensor.make_like(tensor) + + if func == torch.ops.aten.copy_.default: + dst, src = args[0], args[1] + if isinstance(src, Float8BlockwiseQTensor) and isinstance(dst, Float8BlockwiseQTensor): + if dst._rowwise_data is not None: + dst._rowwise_data.copy_(src._rowwise_data) + if dst._rowwise_scale_inv is not None: + dst._rowwise_scale_inv.copy_(src._rowwise_scale_inv) + if dst._columnwise_data is not None: + dst._columnwise_data.copy_(src._columnwise_data) + if dst._columnwise_scale_inv is not None: + dst._columnwise_scale_inv.copy_(src._columnwise_scale_inv) + return dst # Default case return super().__torch_dispatch__(func, types, args, kwargs) diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 8ec99ae52d..dab318e3b1 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -6,7 +6,6 @@ from __future__ import annotations from typing import Optional, Tuple, Iterable import warnings -import math import torch import transformer_engine_torch as tex @@ -105,8 +104,6 @@ def make_empty( # Allocate FP8 data data = torch.empty(shape, dtype=torch.uint8, device=device) - transpose_shape = shape[-1:] + shape[:-1] - # Allocate FP8 data transpose if needed data_transpose = None if self.columnwise_usage: @@ -624,22 +621,15 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): dst, src = args[0], args[1] # Just copy FP8 attrs if copying between Float8Tensors if isinstance(src, Float8Tensor) and isinstance(dst, Float8Tensor): - - def copy_tensor(src, dst, tensor_name): - src_is_none = getattr(src, tensor_name) is None - dst_is_none = getattr(dst, tensor_name) is None - if src_is_none and dst_is_none: - return - elif src_is_none and not dst_is_none: - setattr(dst, tensor_name, None) - elif not src_is_none and not dst_is_none: - getattr(src, tensor_name).copy_(getattr(dst, tensor_name)) - - copy_tensor(src, dst, "_data") - copy_tensor(src, dst, "_scale_inv") - if src._transpose is not None and dst._data is not None: - dst._create_transpose() - copy_tensor(src, dst, "_transpose") + if dst._data is not None: + dst._data.copy_(src._data) + if dst._scale_inv is not None: + dst._scale_inv.copy_(src._scale_inv) + if dst._transpose is not None and not dst._transpose_invalid: + if not src._transpose_invalid: + dst._transpose.copy_(src._transpose) + else: + dst._create_transpose() return dst elif func in _ops_to_preserve_subclass_in_fsdp2: # Ops in the _ops_to_preserve_subclass_in_fsdp2 are recommened to return the same class instance to work fine with the torch fsdp2 diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 08f7952c31..bf120f1de7 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -363,6 +363,20 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): fp8_dtype=tensor._fp8_dtype, ) + if func == torch.ops.aten.copy_.default: + dst, src = args[0], args[1] + # Just copy FP8 attrs if copying between Float8Tensors + if isinstance(src, MXFP8Tensor) and isinstance(dst, MXFP8Tensor): + if dst._rowwise_data is not None: + dst._rowwise_data.copy_(src._rowwise_data) + if dst._rowwise_scale_inv is not None: + dst._rowwise_scale_inv.copy_(src._rowwise_scale_inv) + if dst._columnwise_data is not None: + dst._columnwise_data.copy_(src._columnwise_data) + if dst._columnwise_scale_inv is not None: + dst._columnwise_scale_inv.copy_(src._columnwise_scale_inv) + return dst + # Default case return super().__torch_dispatch__(func, types, args, kwargs) From ad68f9192996711fe7b24f1dbeec9c876bd57734 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 30 Apr 2025 15:31:31 +0000 Subject: [PATCH 06/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/test_cpu_offloading.py | 4 +++- transformer_engine/pytorch/tensor/float8_blockwise_tensor.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/pytorch/test_cpu_offloading.py b/tests/pytorch/test_cpu_offloading.py index 64d4847674..7764511374 100644 --- a/tests/pytorch/test_cpu_offloading.py +++ b/tests/pytorch/test_cpu_offloading.py @@ -14,7 +14,9 @@ # Check if FP8 is supported fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() -fp8_block_available, reason_for_no_fp8_block = FP8GlobalStateManager.is_fp8_block_scaling_available() +fp8_block_available, reason_for_no_fp8_block = ( + FP8GlobalStateManager.is_fp8_block_scaling_available() +) fp8_recipes = [ None, # non-fp8 diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 129448c028..d9f76f0508 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -444,7 +444,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): " (scales and columnwise data untouched)." ) return Float8BlockwiseQTensor.make_like(tensor) - + if func == torch.ops.aten.copy_.default: dst, src = args[0], args[1] if isinstance(src, Float8BlockwiseQTensor) and isinstance(dst, Float8BlockwiseQTensor): From 001b23cc66ab6a0f6a8644454681bc8ed6714787 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 30 Apr 2025 16:43:14 +0000 Subject: [PATCH 07/14] typo Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/tensor/float8_blockwise_tensor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index d9f76f0508..455d34d33c 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -389,7 +389,7 @@ def empty_like(self, *args, **kwargs): if self._columnwise_data is not None else None ) - rowwise_scale_inv = ( + new_rowwise_scale_inv = ( torch.empty_like(self._rowwise_scale_inv, *args, **kwargs) if self._rowwise_scale_inv is not None else None @@ -405,7 +405,7 @@ def empty_like(self, *args, **kwargs): dtype=self.dtype, fp8_dtype=self._fp8_dtype, rowwise_data=new_rowwise_data, - rowwise_scale_inv=rowwise_scale_inv, + rowwise_scale_inv=new_rowwise_scale_inv, columnwise_data=new_columnwise_data, columnwise_scale_inv=new_columnwise_scale_inv, quantizer=self._quantizer, From 8b00c6f7234932005e72e7b212204f9e3c5ce252 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Mon, 5 May 2025 10:19:19 +0000 Subject: [PATCH 08/14] fix Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/optimizers/fused_adam.py | 2 +- .../pytorch/tensor/float8_blockwise_tensor.py | 7 +++++++ transformer_engine/pytorch/tensor/float8_tensor.py | 7 +++++++ transformer_engine/pytorch/tensor/mxfp8_tensor.py | 8 +++++++- 4 files changed, 22 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/optimizers/fused_adam.py b/transformer_engine/pytorch/optimizers/fused_adam.py index 18f7e2031a..3ec24f98de 100644 --- a/transformer_engine/pytorch/optimizers/fused_adam.py +++ b/transformer_engine/pytorch/optimizers/fused_adam.py @@ -374,7 +374,7 @@ def _initialize_state( if store_param_remainders: data = torch.zeros_like(param, dtype=torch.int16) else: - data = torch.empty_like(param, dtype=dtype) + data = torch.empty_like(param.dequantize(), dtype=dtype) if zero_buffer: data.zero_() diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 455d34d33c..21d282d9b2 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -457,6 +457,13 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): if dst._columnwise_scale_inv is not None: dst._columnwise_scale_inv.copy_(src._columnwise_scale_inv) return dst + elif func == torch.ops.aten.is_pinned.default: + if args[0]._rowwise_data is not None: + return args[0]._rowwise_data.is_pinned() + elif args[0]._columnwise_data is not None: + return args[0]._columnwise_data.is_pinned() + else: + raise RuntimeError("Cannot check if pinned for Float8BlockwiseQTensor with no data.") # Default case return super().__torch_dispatch__(func, types, args, kwargs) diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index dab318e3b1..af144fcc9c 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -617,6 +617,13 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): return cls.detach(args[0]) if func == torch.ops.aten.clone.default: return cls.clone(args[0]) + if func == torch.ops.aten.is_pinned.default: + if args[0]._data is not None: + return args[0]._data.is_pinned() + elif args[0]._transpose is not None: + return args[0]._transpose.is_pinned() + else: + raise RuntimeError("Cannot check if pinned for Float8Tensor with no data and no transpose.") if func == torch.ops.aten.copy_.default: dst, src = args[0], args[1] # Just copy FP8 attrs if copying between Float8Tensors diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index bf120f1de7..9d745d0ed7 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -376,7 +376,13 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): if dst._columnwise_scale_inv is not None: dst._columnwise_scale_inv.copy_(src._columnwise_scale_inv) return dst - + elif func == torch.ops.aten.is_pinned.default: + if args[0]._rowwise_data is not None: + return args[0]._rowwise_data.is_pinned() + elif args[0]._columnwise_data is not None: + return args[0]._columnwise_data.is_pinned() + else: + raise RuntimeError("Cannot check if pinned for MXFP8Tensor with no data.") # Default case return super().__torch_dispatch__(func, types, args, kwargs) From 4c8a06d120bd5f0ccc4020a8d722bfad659b6d80 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 5 May 2025 10:19:47 +0000 Subject: [PATCH 09/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/tensor/float8_blockwise_tensor.py | 4 +++- transformer_engine/pytorch/tensor/float8_tensor.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 21d282d9b2..c44ba260e3 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -463,7 +463,9 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): elif args[0]._columnwise_data is not None: return args[0]._columnwise_data.is_pinned() else: - raise RuntimeError("Cannot check if pinned for Float8BlockwiseQTensor with no data.") + raise RuntimeError( + "Cannot check if pinned for Float8BlockwiseQTensor with no data." + ) # Default case return super().__torch_dispatch__(func, types, args, kwargs) diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index af144fcc9c..eb827dbcad 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -623,7 +623,9 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): elif args[0]._transpose is not None: return args[0]._transpose.is_pinned() else: - raise RuntimeError("Cannot check if pinned for Float8Tensor with no data and no transpose.") + raise RuntimeError( + "Cannot check if pinned for Float8Tensor with no data and no transpose." + ) if func == torch.ops.aten.copy_.default: dst, src = args[0], args[1] # Just copy FP8 attrs if copying between Float8Tensors From bc60990a4303789ed83fd01081888c41321ac6a1 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 8 May 2025 13:44:31 +0000 Subject: [PATCH 10/14] fix Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/optimizers/fused_adam.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/optimizers/fused_adam.py b/transformer_engine/pytorch/optimizers/fused_adam.py index 3ec24f98de..0fafe2fc66 100644 --- a/transformer_engine/pytorch/optimizers/fused_adam.py +++ b/transformer_engine/pytorch/optimizers/fused_adam.py @@ -374,7 +374,7 @@ def _initialize_state( if store_param_remainders: data = torch.zeros_like(param, dtype=torch.int16) else: - data = torch.empty_like(param.dequantize(), dtype=dtype) + data = torch.empty_like(param.detach().dequantize(), dtype=dtype) if zero_buffer: data.zero_() From 8bf27c19514027aff8c9a6801572b8a74969346d Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 8 May 2025 14:31:24 +0000 Subject: [PATCH 11/14] lint fix Signed-off-by: Pawel Gadzinski --- .../pytorch/tensor/float8_blockwise_tensor.py | 9 ++++----- transformer_engine/pytorch/tensor/float8_tensor.py | 9 ++++----- transformer_engine/pytorch/tensor/mxfp8_tensor.py | 5 ++--- 3 files changed, 10 insertions(+), 13 deletions(-) diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index ad3dc0fc56..46d7252363 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -408,12 +408,11 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): elif func == torch.ops.aten.is_pinned.default: if args[0]._rowwise_data is not None: return args[0]._rowwise_data.is_pinned() - elif args[0]._columnwise_data is not None: + if args[0]._columnwise_data is not None: return args[0]._columnwise_data.is_pinned() - else: - raise RuntimeError( - "Cannot check if pinned for Float8BlockwiseQTensor with no data." - ) + raise RuntimeError( + "Cannot check if pinned for Float8BlockwiseQTensor with no data." + ) # Default case return super().__torch_dispatch__(func, types, args, kwargs) diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 9c9ee91fca..8481b7717c 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -583,12 +583,11 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): if func == torch.ops.aten.is_pinned.default: if args[0]._data is not None: return args[0]._data.is_pinned() - elif args[0]._transpose is not None: + if args[0]._transpose is not None: return args[0]._transpose.is_pinned() - else: - raise RuntimeError( - "Cannot check if pinned for Float8Tensor with no data and no transpose." - ) + raise RuntimeError( + "Cannot check if pinned for Float8Tensor with no data and no transpose." + ) if func == torch.ops.aten.copy_.default: dst, src = args[0], args[1] # Just copy FP8 attrs if copying between Float8Tensors diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 53b5235ee7..3469c543bd 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -334,10 +334,9 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): elif func == torch.ops.aten.is_pinned.default: if args[0]._rowwise_data is not None: return args[0]._rowwise_data.is_pinned() - elif args[0]._columnwise_data is not None: + if args[0]._columnwise_data is not None: return args[0]._columnwise_data.is_pinned() - else: - raise RuntimeError("Cannot check if pinned for MXFP8Tensor with no data.") + raise RuntimeError("Cannot check if pinned for MXFP8Tensor with no data.") # Default case return super().__torch_dispatch__(func, types, args, kwargs) From ac93bb4aaed3c7e3553f26f27f23a74cc539f9c0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 8 May 2025 14:32:05 +0000 Subject: [PATCH 12/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/tensor/float8_blockwise_tensor.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 46d7252363..207be9d986 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -410,9 +410,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): return args[0]._rowwise_data.is_pinned() if args[0]._columnwise_data is not None: return args[0]._columnwise_data.is_pinned() - raise RuntimeError( - "Cannot check if pinned for Float8BlockwiseQTensor with no data." - ) + raise RuntimeError("Cannot check if pinned for Float8BlockwiseQTensor with no data.") # Default case return super().__torch_dispatch__(func, types, args, kwargs) From da1bbf996261f83c2f3adcfece0a7fe430ece2ac Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 9 May 2025 13:55:56 +0200 Subject: [PATCH 13/14] fixes Signed-off-by: Pawel Gadzinski --- tests/pytorch/test_cpu_offloading.py | 10 +++++----- transformer_engine/pytorch/cpu_offload.py | 2 +- .../tensor/_internal/float8_blockwise_tensor_base.py | 4 +++- .../pytorch/tensor/_internal/float8_tensor_base.py | 4 +++- .../pytorch/tensor/_internal/mxfp8_tensor_base.py | 4 +++- 5 files changed, 15 insertions(+), 9 deletions(-) diff --git a/tests/pytorch/test_cpu_offloading.py b/tests/pytorch/test_cpu_offloading.py index 7764511374..2b84529809 100644 --- a/tests/pytorch/test_cpu_offloading.py +++ b/tests/pytorch/test_cpu_offloading.py @@ -26,10 +26,10 @@ recipe.Float8BlockScaling(), ] -SIZE = 512 +SIZE = 64 NUM_HEADS = 8 NUM_LAYERS = 5 -EPSILON = 0.1 +EPSILON = 0.05 # Flash attention saves some internal tensor for the backward pass # that cannot be offloaded to CPU. @@ -48,7 +48,7 @@ SIZE, NUM_HEADS, params_dtype=torch.bfloat16 ), "transformer_layer": lambda: te.TransformerLayer( - SIZE, SIZE, NUM_HEADS, params_dtype=torch.bfloat16, hidden_dropout=0.0 + SIZE, SIZE, NUM_HEADS, params_dtype=torch.bfloat16, hidden_dropout=0.0 ), } @@ -97,7 +97,8 @@ def _measure_memory_between_forward_and_backward(models, fp8_recipe, cpu_offload ), offload_context: tensor = model(tensor) tensor = sync_function(tensor) - + + import gc; gc.collect() max_mem_used = torch.cuda.memory_allocated() / (1024**2) torch.cuda.synchronize() @@ -119,7 +120,6 @@ def test_cpu_offload(fp8_recipe, model_key) -> None: the difference being the size of the FP8 cache that is not offloaded to the CPU. We also expect this memory consumption to be smaller than in scenario (1). """ - model_cls = model_types[model_key] models_list = [model_cls() for _ in range(NUM_LAYERS)] diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index d470e370d6..6e8c170ff2 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -22,7 +22,7 @@ def mark_activation_offload(*tensors): if isinstance(tensor, torch.Tensor): tensor.activation_offloading = True else: - data_tensors = tensor.get_data_tensors() + data_tensors = tensor.get_data_tensors(scaling_factors=True) for tensor in data_tensors: if tensor is not None: tensor.activation_offloading = True diff --git a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py index 42a6181716..1dc2c9007f 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py @@ -112,8 +112,10 @@ def restore_from_saved( self._columnwise_scale_inv = tensors[3] return tensors[4:] - def get_data_tensors(self): + def get_data_tensors(self, scaling_factors=False): """Get this Tensor's data.""" + if scaling_factors: + return self._rowwise_data, self._columnwise_data, self._rowwise_scale_inv, self._columnwise_scale_inv return self._rowwise_data, self._columnwise_data def _transpose_dq_columnwise_output(self, columnwise_dq: torch.Tensor) -> torch.Tensor: diff --git a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py index 4124511cd8..d022e93a8a 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py @@ -128,8 +128,10 @@ def restore_from_saved( self._scale_inv = tensors[2] return tensors[3:] - def get_data_tensors(self): + def get_data_tensors(self, scaling_factors=False): """Get this Tensor's data.""" + if scaling_factors: + return self._data, self._transpose, self._scale_inv return self._data, self._transpose def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor: diff --git a/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py index ae00a4d72b..826fc7c591 100644 --- a/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py @@ -131,8 +131,10 @@ def restore_from_saved( self._columnwise_scale_inv = tensors[3] return tensors[4:] - def get_data_tensors(self): + def get_data_tensors(self, scaling_factors=False): """Get this Tensor's data.""" + if scaling_factors: + return self._rowwise_data, self._columnwise_data, self._rowwise_scale_inv, self._columnwise_scale_inv return self._rowwise_data, self._columnwise_data def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor: From b6012b13094135079edb70a298400078c1ae5390 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 9 May 2025 11:58:11 +0000 Subject: [PATCH 14/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/test_cpu_offloading.py | 8 +++++--- .../tensor/_internal/float8_blockwise_tensor_base.py | 7 ++++++- .../pytorch/tensor/_internal/mxfp8_tensor_base.py | 7 ++++++- 3 files changed, 17 insertions(+), 5 deletions(-) diff --git a/tests/pytorch/test_cpu_offloading.py b/tests/pytorch/test_cpu_offloading.py index 2b84529809..968fe2527b 100644 --- a/tests/pytorch/test_cpu_offloading.py +++ b/tests/pytorch/test_cpu_offloading.py @@ -48,7 +48,7 @@ SIZE, NUM_HEADS, params_dtype=torch.bfloat16 ), "transformer_layer": lambda: te.TransformerLayer( - SIZE, SIZE, NUM_HEADS, params_dtype=torch.bfloat16, hidden_dropout=0.0 + SIZE, SIZE, NUM_HEADS, params_dtype=torch.bfloat16, hidden_dropout=0.0 ), } @@ -97,8 +97,10 @@ def _measure_memory_between_forward_and_backward(models, fp8_recipe, cpu_offload ), offload_context: tensor = model(tensor) tensor = sync_function(tensor) - - import gc; gc.collect() + + import gc + + gc.collect() max_mem_used = torch.cuda.memory_allocated() / (1024**2) torch.cuda.synchronize() diff --git a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py index 1dc2c9007f..c7fdb6fc00 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py @@ -115,7 +115,12 @@ def restore_from_saved( def get_data_tensors(self, scaling_factors=False): """Get this Tensor's data.""" if scaling_factors: - return self._rowwise_data, self._columnwise_data, self._rowwise_scale_inv, self._columnwise_scale_inv + return ( + self._rowwise_data, + self._columnwise_data, + self._rowwise_scale_inv, + self._columnwise_scale_inv, + ) return self._rowwise_data, self._columnwise_data def _transpose_dq_columnwise_output(self, columnwise_dq: torch.Tensor) -> torch.Tensor: diff --git a/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py index 826fc7c591..23035346ce 100644 --- a/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py @@ -134,7 +134,12 @@ def restore_from_saved( def get_data_tensors(self, scaling_factors=False): """Get this Tensor's data.""" if scaling_factors: - return self._rowwise_data, self._columnwise_data, self._rowwise_scale_inv, self._columnwise_scale_inv + return ( + self._rowwise_data, + self._columnwise_data, + self._rowwise_scale_inv, + self._columnwise_scale_inv, + ) return self._rowwise_data, self._columnwise_data def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor: