Skip to content

Commit 77e4a69

Browse files
committed
[test only] testing adding optioanl tensor arg to float8 tensor
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2840, branch: jerryzh168/stack/33
1 parent a9ffa50 commit 77e4a69

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

torchao/quantization/quantize_/workflows/float8/float8_tensor.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ class Float8Tensor(TorchAOBaseTensor):
9595

9696
tensor_data_names = ["qdata", "scale"]
9797
tensor_attribute_names = []
98+
optional_tensor_data_names = ["test_only_data"]
9899
optional_tensor_attribute_names = [
99100
"block_size",
100101
"mm_config",
@@ -103,19 +104,22 @@ class Float8Tensor(TorchAOBaseTensor):
103104
"act_quant_kwargs",
104105
"kernel_preference",
105106
"dtype",
107+
"new_optional_attr",
106108
]
107109

108110
def __new__(
109111
cls,
110112
qdata: torch.Tensor,
111113
scale: torch.Tensor,
114+
test_only_data: Optional[torch.Tensor] = None,
112115
block_size: Optional[List[int]] = None,
113116
mm_config: Optional[Float8MMConfig] = None,
114117
hp_value_lb: Optional[float] = None,
115118
hp_value_ub: Optional[float] = None,
116119
act_quant_kwargs: Optional[QuantizeTensorToFloat8Kwargs] = None,
117120
kernel_preference: KernelPreference = KernelPreference.AUTO,
118121
dtype: Optional[torch.dtype] = None,
122+
new_optional_attr: Optional[int] = None,
119123
):
120124
shape = qdata.shape
121125
kwargs = {}
@@ -128,22 +132,26 @@ def __init__(
128132
self,
129133
qdata: torch.Tensor,
130134
scale: torch.Tensor,
135+
test_only_data: Optional[torch.Tensor] = None,
131136
block_size: Optional[List[int]] = None,
132137
mm_config: Optional[Float8MMConfig] = None,
133138
hp_value_lb: Optional[float] = None,
134139
hp_value_ub: Optional[float] = None,
135140
act_quant_kwargs: Optional[QuantizeTensorToFloat8Kwargs] = None,
136141
kernel_preference: KernelPreference = KernelPreference.AUTO,
137142
dtype: Optional[torch.dtype] = None,
143+
new_optional_attr: Optional[int] = None,
138144
):
139145
self.qdata = qdata
140146
self.scale = scale
147+
self.test_only_data = test_only_data
141148
self.block_size = block_size
142149
self.mm_config = mm_config
143150
self.hp_value_lb = hp_value_lb
144151
self.hp_value_ub = hp_value_ub
145152
self.act_quant_kwargs = act_quant_kwargs
146153
self.kernel_preference = kernel_preference
154+
self.new_optional_attr = new_optional_attr
147155

148156
def __repr__(self):
149157
return (

0 commit comments

Comments
 (0)