@@ -95,6 +95,7 @@ class Float8Tensor(TorchAOBaseTensor):
95
95
96
96
tensor_data_names = ["qdata" , "scale" ]
97
97
tensor_attribute_names = []
98
+ optional_tensor_data_names = ["test_only_data" ]
98
99
optional_tensor_attribute_names = [
99
100
"block_size" ,
100
101
"mm_config" ,
@@ -103,19 +104,22 @@ class Float8Tensor(TorchAOBaseTensor):
103
104
"act_quant_kwargs" ,
104
105
"kernel_preference" ,
105
106
"dtype" ,
107
+ "new_optional_attr" ,
106
108
]
107
109
108
110
def __new__ (
109
111
cls ,
110
112
qdata : torch .Tensor ,
111
113
scale : torch .Tensor ,
114
+ test_only_data : Optional [torch .Tensor ] = None ,
112
115
block_size : Optional [List [int ]] = None ,
113
116
mm_config : Optional [Float8MMConfig ] = None ,
114
117
hp_value_lb : Optional [float ] = None ,
115
118
hp_value_ub : Optional [float ] = None ,
116
119
act_quant_kwargs : Optional [QuantizeTensorToFloat8Kwargs ] = None ,
117
120
kernel_preference : KernelPreference = KernelPreference .AUTO ,
118
121
dtype : Optional [torch .dtype ] = None ,
122
+ new_optional_attr : Optional [int ] = None ,
119
123
):
120
124
shape = qdata .shape
121
125
kwargs = {}
@@ -128,22 +132,26 @@ def __init__(
128
132
self ,
129
133
qdata : torch .Tensor ,
130
134
scale : torch .Tensor ,
135
+ test_only_data : Optional [torch .Tensor ] = None ,
131
136
block_size : Optional [List [int ]] = None ,
132
137
mm_config : Optional [Float8MMConfig ] = None ,
133
138
hp_value_lb : Optional [float ] = None ,
134
139
hp_value_ub : Optional [float ] = None ,
135
140
act_quant_kwargs : Optional [QuantizeTensorToFloat8Kwargs ] = None ,
136
141
kernel_preference : KernelPreference = KernelPreference .AUTO ,
137
142
dtype : Optional [torch .dtype ] = None ,
143
+ new_optional_attr : Optional [int ] = None ,
138
144
):
139
145
self .qdata = qdata
140
146
self .scale = scale
147
+ self .test_only_data = test_only_data
141
148
self .block_size = block_size
142
149
self .mm_config = mm_config
143
150
self .hp_value_lb = hp_value_lb
144
151
self .hp_value_ub = hp_value_ub
145
152
self .act_quant_kwargs = act_quant_kwargs
146
153
self .kernel_preference = kernel_preference
154
+ self .new_optional_attr = new_optional_attr
147
155
148
156
def __repr__ (self ):
149
157
return (
0 commit comments