From 77e4a69eb174989563b4d96959afab4ef82b3658 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Thu, 21 Aug 2025 13:49:30 -0700 Subject: [PATCH] [test only] testing adding optioanl tensor arg to float8 tensor Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: stack-info: PR: https://github.com/pytorch/ao/pull/2840, branch: jerryzh168/stack/33 --- .../quantize_/workflows/float8/float8_tensor.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py index baf6d493df..7fea5a0271 100644 --- a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py +++ b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py @@ -95,6 +95,7 @@ class Float8Tensor(TorchAOBaseTensor): tensor_data_names = ["qdata", "scale"] tensor_attribute_names = [] + optional_tensor_data_names = ["test_only_data"] optional_tensor_attribute_names = [ "block_size", "mm_config", @@ -103,12 +104,14 @@ class Float8Tensor(TorchAOBaseTensor): "act_quant_kwargs", "kernel_preference", "dtype", + "new_optional_attr", ] def __new__( cls, qdata: torch.Tensor, scale: torch.Tensor, + test_only_data: Optional[torch.Tensor] = None, block_size: Optional[List[int]] = None, mm_config: Optional[Float8MMConfig] = None, hp_value_lb: Optional[float] = None, @@ -116,6 +119,7 @@ def __new__( act_quant_kwargs: Optional[QuantizeTensorToFloat8Kwargs] = None, kernel_preference: KernelPreference = KernelPreference.AUTO, dtype: Optional[torch.dtype] = None, + new_optional_attr: Optional[int] = None, ): shape = qdata.shape kwargs = {} @@ -128,6 +132,7 @@ def __init__( self, qdata: torch.Tensor, scale: torch.Tensor, + test_only_data: Optional[torch.Tensor] = None, block_size: Optional[List[int]] = None, mm_config: Optional[Float8MMConfig] = None, hp_value_lb: Optional[float] = None, @@ -135,15 +140,18 @@ def __init__( act_quant_kwargs: Optional[QuantizeTensorToFloat8Kwargs] = None, kernel_preference: KernelPreference = KernelPreference.AUTO, dtype: Optional[torch.dtype] = None, + new_optional_attr: Optional[int] = None, ): self.qdata = qdata self.scale = scale + self.test_only_data = test_only_data self.block_size = block_size self.mm_config = mm_config self.hp_value_lb = hp_value_lb self.hp_value_ub = hp_value_ub self.act_quant_kwargs = act_quant_kwargs self.kernel_preference = kernel_preference + self.new_optional_attr = new_optional_attr def __repr__(self): return (