Skip to content

Commit bc2c83e

Browse files
authored
[reland] Refactor TorchAOBaseTensor for better BC (#2793) (#2855)
Summary: After this PR, tensors inheriting from TorchAOBaseTensor will have better support BC, that is if they add some optional tensor data attribute or optional non-tensor attribute, we will still have BC without any additional changes. More Details: The BC story we are looking at is that, after we land some tensor, e.g. Int4Tensor, Float8Tensor, future changes should only add optional Tensor data attributes and optional non-Tensor attributes to the Tensor (other bigger changes will require a version bump, we need to add that too). The current TorchAOBaseTensor doesn’t support this very well. also see #2840 for a real test that adds both an optional tensor and optional non-tensor attribute to Float8Tensor, and the BC test in https://github.com/pytorch/ao/blob/main/test/integration/test_load_and_run_checkpoint.py that tests Float8Tensor does not fail. Docs for current TorchAOBaseTensor: https://github.com/pytorch/ao/blob/e6b38bb0e1477ae6aaca0a3d30de70598be43290/torchao/utils.py#L726-L731 `tensor_data_names` (List[str]): list of names of all requires tensor_data, order should match the `__init__` list of tensor subclass `optional_tensor_data_names` (List[str]): it's optional to define this field to have the additional boilerplate functions been implemented for you, but this will be need if there are some optional Tensor attributes, when defined, this will be a list of names of Tensors that can be optional `tensor_attribute_names` (List[str]): list of names of non-Tensor attributes, order should match the `__init__` list of tensor subclass, following all the `tensor_data_names` arguments and `optional_tensor_data_names` Problems: current optional_tensor_data_names is not truly optional, since it is followed by tensor_attribute_names which contains both required and optional attributes. So if we add a tensor data attribute to Tensor, it will break BC. Here are a few options: ``` class Int4Tensor(TorchAOBaseTensor): tensor_data_names = ["qdata", "scale", "zero_point"] optional_tensor_data_names = ["act_scale"] tensor_attribute_names = ["block_size", "shape", "_demo_only_optional_attr"] def __init__(self, qdata, scale, zero_point, act_scale=None, block_size=None, shape=None, _demo_only_optional_attr=None): ... # for BC def __setstate__(self, state): torch._utils._set_obj_state(self, state) if "act_scale" not in self.__dict__: self.act_scale = None ``` ``` class Int4Tensor(TorchAOBaseTensor): tensor_data_names = ["qdata", "scale", "zero_point"] optional_tensor_data_names = ["act_scale"] required_tensor_attribute_names = ["block_size", "shape"] optional_tensor_attribute_names = ["_demo_only_optional_attr"] def __init__(self, qdata, scale, zero_point, block_size, shape, act_scale=None, _demo_only_optional_attr = None): ... # for BC def __setstate__(self, state): torch._utils._set_obj_state(self, state) if "act_scale" not in self.__dict__: self.act_scale = None ``` ``` class Int4Tensor(TorchAOBaseTensor): tensor_data_names = ["qdata", "scale", "zero_point"] tensor_attribute_names = ["block_size", "shape", "_demo_only_optional_attr"] optional_tensor_data_names = ["act_scale"] def __init__(self, qdata, scale, zero_point, block_size, shape, _demo_only_optional_attr = None, act_scale = None): ... # for BC def __setstate__(self, state): torch._utils._set_obj_state(self, state) if "act_scale" not in self.__dict__: self.act_scale = None ``` Test Plan: python test/integration/test_load_and_run_checkpoint.py Reviewers: Subscribers: Tasks: Tags:
1 parent 98e406d commit bc2c83e

File tree

8 files changed

+309
-110
lines changed

8 files changed

+309
-110
lines changed

test/prototype/mx_formats/test_mx_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -907,7 +907,7 @@ def test_nvfp4_swizzled_scales_serialization():
907907
tensor_list, ctx = original_tensor.__tensor_flatten__()
908908

909909
# Verify swizzled flag is preserved in context
910-
assert NVFP4Tensor.tensor_attribute_names[2] == "_is_swizzled_scales"
910+
assert NVFP4Tensor.optional_tensor_attribute_names[0] == "_is_swizzled_scales"
911911
assert ctx[2] == True
912912

913913
# Test deserialization

test/prototype/mx_formats/test_nvfp4_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ def test_nvfp4_swizzled_scales_serialization():
307307
tensor_list, ctx = original_tensor.__tensor_flatten__()
308308

309309
# Verify swizzled flag is preserved in context
310-
assert NVFP4Tensor.tensor_attribute_names[2] == "_is_swizzled_scales"
310+
assert NVFP4Tensor.optional_tensor_attribute_names[0] == "_is_swizzled_scales"
311311
assert ctx[2] == True
312312

313313
# Test deserialization

test/test_utils.py

Lines changed: 126 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,18 @@ def __init__(self, data):
106106
l.weight = torch.nn.Parameter(MyTensor(l.weight))
107107

108108
def _test_default_impls_helper(self, lp_tensor, lp_tensor_for_copy):
109+
# get `all_tensor_data_names` and `all_tensor_attribute_names`
110+
all_tensor_data_names = lp_tensor.tensor_data_names.copy()
111+
if hasattr(lp_tensor, "optional_tensor_data_names"):
112+
for tensor_data_name in lp_tensor.optional_tensor_data_names:
113+
if getattr(lp_tensor, tensor_data_name) is not None:
114+
all_tensor_data_names.append(tensor_data_name)
115+
all_tensor_attribute_names = lp_tensor.tensor_attribute_names.copy()
116+
if hasattr(lp_tensor, "optional_tensor_attribute_names"):
117+
for tensor_attribute_name in lp_tensor.optional_tensor_attribute_names:
118+
if getattr(lp_tensor, tensor_attribute_name) is not None:
119+
all_tensor_attribute_names.append(tensor_attribute_name)
120+
109121
# test __tensor_flatten__ and __tensor_unflatten__
110122
tensor_data_names, tensor_attributes = lp_tensor.__tensor_flatten__()
111123
tensor_data_dict = {
@@ -116,6 +128,19 @@ def _test_default_impls_helper(self, lp_tensor, lp_tensor_for_copy):
116128
reconstructed = type(lp_tensor).__tensor_unflatten__(
117129
tensor_data_dict, tensor_attributes, outer_size, outer_stride
118130
)
131+
for tensor_data_name in all_tensor_data_names:
132+
self.assertTrue(
133+
torch.equal(
134+
getattr(lp_tensor, tensor_data_name),
135+
getattr(reconstructed, tensor_data_name),
136+
)
137+
)
138+
for tensor_attribute_name in all_tensor_attribute_names:
139+
self.assertEqual(
140+
getattr(lp_tensor, tensor_attribute_name),
141+
getattr(reconstructed, tensor_attribute_name),
142+
)
143+
119144
self.assertTrue(torch.equal(lp_tensor.qdata, reconstructed.qdata))
120145
self.assertEqual(lp_tensor.attr, reconstructed.attr)
121146

@@ -129,52 +154,81 @@ def _test_default_impls_helper(self, lp_tensor, lp_tensor_for_copy):
129154
# __repr__
130155
_ = str(lp_tensor)
131156

132-
# other ops
157+
# op test: detach
133158
lp_tensor = lp_tensor.detach()
134-
# explicitly testing aten.alias
159+
# op test: alias
135160
lp_tensor = torch.ops.aten.alias(lp_tensor)
136-
lp_tensor = lp_tensor.clone()
137-
# get all tensor_data_names for both
161+
162+
# op test: clone
163+
lp_tensor_clone = lp_tensor.clone()
164+
165+
for tensor_data_name in all_tensor_data_names:
166+
self.assertTrue(
167+
torch.equal(
168+
getattr(lp_tensor_clone, tensor_data_name),
169+
getattr(lp_tensor, tensor_data_name),
170+
)
171+
)
172+
for tensor_attribute_name in all_tensor_attribute_names:
173+
self.assertEqual(
174+
getattr(lp_tensor_clone, tensor_attribute_name),
175+
getattr(lp_tensor, tensor_attribute_name),
176+
)
177+
178+
# op test: transpose
138179
# non optional and valid optional tensors
139-
tensor_data_names = lp_tensor.tensor_data_names.copy()
140-
if hasattr(lp_tensor, "optional_tensor_data_names"):
141-
for tensor_data_name in lp_tensor.optional_tensor_data_names:
142-
if getattr(lp_tensor, tensor_data_name) is not None:
143-
tensor_data_names.append(tensor_data_name)
144180

145181
# for each of the tensor data, we try to
146182
# make it non-contiguous and then use
147183
# lp_tensor.contiguous() call to make sure
148184
# contiguous() works
149-
for tensor_data_name in tensor_data_names:
185+
for tensor_data_name in all_tensor_data_names:
150186
tensor = getattr(lp_tensor, tensor_data_name)
151187
# making qdata not contiguous
152188
tensor = tensor.transpose(0, 1).contiguous()
153189
tensor = tensor.transpose(0, 1)
154190
setattr(lp_tensor, tensor_data_name, tensor)
155191
self.assertFalse(getattr(lp_tensor, tensor_data_name).is_contiguous())
156-
lp_tensor = lp_tensor.contiguous()
157-
# making sure contiguous call works
158-
self.assertTrue(getattr(lp_tensor, tensor_data_name).is_contiguous())
159192

160-
# copy_
193+
lp_tensor_t = lp_tensor.contiguous()
194+
195+
# making sure contiguous call works
196+
for tensor_data_name in all_tensor_data_names:
197+
self.assertTrue(getattr(lp_tensor_t, tensor_data_name).is_contiguous())
198+
199+
# making sure transpose does not change attributes
200+
for tensor_attribute_name in all_tensor_attribute_names:
201+
self.assertEqual(
202+
getattr(lp_tensor_t, tensor_attribute_name),
203+
getattr(lp_tensor, tensor_attribute_name),
204+
)
205+
206+
# op test: copy_
161207
# making sure that initially tensor values are not the same so we can test copy_
162208
self.assertNotEqual(lp_tensor.qdata[0][0], lp_tensor_for_copy.qdata[0][0])
163209
# copy_ requires the attributes to be the same
164-
for tensor_attr_name in lp_tensor.tensor_attribute_names:
210+
for tensor_attribute_name in all_tensor_attribute_names:
165211
self.assertEqual(
166-
getattr(lp_tensor, tensor_attr_name),
167-
getattr(lp_tensor_for_copy, tensor_attr_name),
212+
getattr(lp_tensor_for_copy, tensor_attribute_name),
213+
getattr(lp_tensor, tensor_attribute_name),
168214
)
215+
169216
lp_tensor.copy_(lp_tensor_for_copy)
170217
# after copy_, the tensor values should match
171-
for tensor_data_name in tensor_data_names:
218+
for tensor_data_name in all_tensor_data_names:
172219
self.assertTrue(
173220
torch.equal(
174221
getattr(lp_tensor, tensor_data_name),
175222
getattr(lp_tensor_for_copy, tensor_data_name),
176223
)
177224
)
225+
# after copy_, the tensor attributes still matches
226+
# copy_ requires the attributes to be the same
227+
for tensor_attribute_name in all_tensor_attribute_names:
228+
self.assertEqual(
229+
getattr(lp_tensor_for_copy, tensor_attribute_name),
230+
getattr(lp_tensor, tensor_attribute_name),
231+
)
178232

179233
@skip_if_no_cuda()
180234
def test_default_impls(self):
@@ -186,60 +240,103 @@ class MyTensor(TorchAOBaseTensor):
186240
tensor_data_names = ["qdata"]
187241
tensor_attribute_names = ["attr", "device"]
188242

189-
def __new__(cls, qdata, attr, device=None):
243+
def __new__(cls, qdata, attr, device):
190244
shape = qdata.shape
191245
if device is None:
192246
device = qdata.device
193247
kwargs = {"device": device}
194248
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
195249

196-
def __init__(self, qdata, attr, device=None):
250+
def __init__(self, qdata, attr, device):
197251
self.qdata = qdata
198252
self.attr = attr
199253

200254
l = torch.nn.Linear(2, 3)
201-
l.weight = torch.nn.Parameter(MyTensor(l.weight, "attr"))
255+
l.weight = torch.nn.Parameter(MyTensor(l.weight, "attr", None))
202256
lp_tensor = l.weight
203257

204258
another_tensor = torch.nn.Linear(2, 3).weight
205259
# attribute has to be the same
206-
lp_tensor_for_copy = MyTensor(another_tensor, "attr")
260+
lp_tensor_for_copy = MyTensor(another_tensor, "attr", None)
207261
self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy)
208262

209263
@skip_if_no_cuda()
210264
def test_default_impls_with_optional_data(self):
211265
class MyTensorWithOptionalData(TorchAOBaseTensor):
212266
tensor_data_names = ["qdata"]
213-
optional_tensor_data_names = ["zero_point"]
214267
tensor_attribute_names = ["attr", "device"]
268+
optional_tensor_data_names = ["zero_point"]
215269

216-
def __new__(cls, qdata, zero_point=None, attr=1.0, device=None):
270+
def __new__(cls, qdata, attr, device, zero_point=None):
217271
shape = qdata.shape
218272
if device is None:
219273
device = qdata.device
220274
kwargs = {"device": device}
221275
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
222276

223-
def __init__(self, qdata, zero_point=None, attr=1.0, device=None):
277+
def __init__(self, qdata, attr, device, zero_point=None):
224278
self.qdata = qdata
279+
self.attr = attr
225280
self.zero_point = zero_point
281+
282+
# test both the optional Tensor is None
283+
# and not None
284+
l = torch.nn.Linear(2, 3)
285+
lp_tensor = MyTensorWithOptionalData(l.weight, "attr", None, None)
286+
l = torch.nn.Linear(2, 3)
287+
lp_tensor_for_copy = MyTensorWithOptionalData(l.weight, "attr", None, None)
288+
self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy)
289+
290+
l = torch.nn.Linear(2, 3)
291+
lp_tensor = MyTensorWithOptionalData(
292+
l.weight, "attr", None, torch.zeros_like(l.weight)
293+
)
294+
l = torch.nn.Linear(2, 3)
295+
lp_tensor_for_copy = MyTensorWithOptionalData(
296+
l.weight, "attr", None, torch.zeros_like(l.weight)
297+
)
298+
self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy)
299+
300+
@skip_if_no_cuda()
301+
def test_default_impls_with_optional_attr(self):
302+
class MyTensorWithOptionalData(TorchAOBaseTensor):
303+
tensor_data_names = ["qdata"]
304+
tensor_attribute_names = ["attr", "device"]
305+
optional_tensor_data_names = ["zero_point"]
306+
optional_tensor_attribute_names = ["optional_attr"]
307+
308+
def __new__(cls, qdata, attr, device, zero_point=None, optional_attr=None):
309+
shape = qdata.shape
310+
if device is None:
311+
device = qdata.device
312+
kwargs = {"device": device}
313+
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
314+
315+
def __init__(
316+
self, qdata, attr, device, zero_point=None, optional_attr=None
317+
):
318+
self.qdata = qdata
226319
self.attr = attr
320+
self.zero_point = zero_point
321+
self.optional_attr = optional_attr
227322

228323
# test both the optional Tensor is None
229324
# and not None
230325
l = torch.nn.Linear(2, 3)
231-
lp_tensor = MyTensorWithOptionalData(l.weight, None, "attr")
326+
lp_tensor = MyTensorWithOptionalData(l.weight, "attr", None, zero_point=None)
232327
l = torch.nn.Linear(2, 3)
233-
lp_tensor_for_copy = MyTensorWithOptionalData(l.weight, None, "attr")
328+
lp_tensor_for_copy = MyTensorWithOptionalData(
329+
l.weight, "attr", None, zero_point=None
330+
)
234331
self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy)
235332

236333
l = torch.nn.Linear(2, 3)
237334
lp_tensor = MyTensorWithOptionalData(
238-
l.weight, torch.zeros_like(l.weight), "attr"
335+
l.weight, "attr", None, zero_point=None, optional_attr="value"
239336
)
240337
l = torch.nn.Linear(2, 3)
241338
lp_tensor_for_copy = MyTensorWithOptionalData(
242-
l.weight, torch.zeros_like(l.weight), "attr"
339+
l.weight, "attr", None, zero_point=None, optional_attr="value"
243340
)
244341
self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy)
245342

torchao/prototype/mx_formats/nvfp4_tensor.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,12 @@ class NVFP4Tensor(TorchAOBaseTensor):
7979
"""
8080

8181
tensor_data_names = ["qdata", "_scale_e4m3"]
82-
optional_tensor_data_names = ["_per_tensor_scale", "_act_per_tensor_scale"]
8382
tensor_attribute_names = [
8483
"_block_size",
8584
"_orig_dtype",
85+
]
86+
optional_tensor_data_names = ["_per_tensor_scale", "_act_per_tensor_scale"]
87+
optional_tensor_attribute_names = [
8688
"_is_swizzled_scales",
8789
"use_triton_kernel",
8890
"act_quant_kwargs",
@@ -92,10 +94,10 @@ def __new__(
9294
cls,
9395
qdata,
9496
blockwise_scales,
95-
per_tensor_scale,
96-
act_per_tensor_scale,
9797
block_size,
9898
orig_dtype,
99+
per_tensor_scale,
100+
act_per_tensor_scale,
99101
is_swizzled_scales=False,
100102
use_triton_kernel=False,
101103
act_quant_kwargs=None,
@@ -116,13 +118,13 @@ def __new__(
116118
requires_grad=False,
117119
)
118120

119-
self._scale_e4m3 = blockwise_scales
120-
self._is_swizzled_scales = is_swizzled_scales
121-
self._per_tensor_scale = per_tensor_scale
122-
self._act_per_tensor_scale = act_per_tensor_scale
123121
self.qdata = qdata
122+
self._scale_e4m3 = blockwise_scales
124123
self._block_size = block_size
125124
self._orig_dtype = orig_dtype
125+
self._per_tensor_scale = per_tensor_scale
126+
self._act_per_tensor_scale = act_per_tensor_scale
127+
self._is_swizzled_scales = is_swizzled_scales
126128
self.use_triton_kernel = use_triton_kernel
127129
self.act_quant_kwargs = act_quant_kwargs
128130
return self
@@ -184,10 +186,10 @@ def to_nvfp4(
184186
return NVFP4Tensor(
185187
data_lp,
186188
blockwise_scales,
187-
per_tensor_scale,
188-
act_per_tensor_scale,
189189
block_size,
190190
data_hp.dtype,
191+
per_tensor_scale,
192+
act_per_tensor_scale,
191193
is_swizzled_scales,
192194
use_triton_kernel,
193195
act_quant_kwargs,
@@ -312,10 +314,10 @@ def nvfp4_to_copy(func, types, args, kwargs):
312314
res = NVFP4Tensor(
313315
tensor.qdata,
314316
tensor._scale_e4m3,
315-
tensor._per_tensor_scale,
316-
tensor._act_per_tensor_scale,
317317
tensor._block_size,
318318
dtype,
319+
tensor._per_tensor_scale,
320+
tensor._act_per_tensor_scale,
319321
tensor._is_swizzled_scales,
320322
tensor.use_triton_kernel,
321323
tensor.act_quant_kwargs,
@@ -513,10 +515,10 @@ def nvfp4_slice(func, types, args, kwargs):
513515
result = NVFP4Tensor(
514516
sliced_data,
515517
sliced_scale,
516-
x._per_tensor_scale,
517-
x._act_per_tensor_scale,
518518
x._block_size,
519519
x._orig_dtype,
520+
x._per_tensor_scale,
521+
x._act_per_tensor_scale,
520522
x._is_swizzled_scales,
521523
x.use_triton_kernel,
522524
x.act_quant_kwargs,
@@ -532,10 +534,10 @@ def nvfp4_t(func, types, args, kwargs):
532534
new = NVFP4Tensor(
533535
old.qdata.t(),
534536
old._scale_e4m3,
535-
old._per_tensor_scale,
536-
old._act_per_tensor_scale,
537537
old._block_size,
538538
old._orig_dtype,
539+
old._per_tensor_scale,
540+
old._act_per_tensor_scale,
539541
old._is_swizzled_scales,
540542
old.use_triton_kernel,
541543
old.act_quant_kwargs,
@@ -552,10 +554,10 @@ def nvfp4_view_op(func, types, args, kwargs):
552554
return NVFP4Tensor(
553555
new_data,
554556
args[0]._scale_e4m3,
555-
args[0]._per_tensor_scale,
556-
args[0]._act_per_tensor_scale,
557557
args[0]._block_size,
558558
args[0]._orig_dtype,
559+
args[0]._per_tensor_scale,
560+
args[0]._act_per_tensor_scale,
559561
args[0]._is_swizzled_scales,
560562
args[0].use_triton_kernel,
561563
args[0].act_quant_kwargs,

torchao/quantization/autoquant.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ def do_autoquant_bench(op, *args, **kwargs):
344344
graph = torch.cuda.CUDAGraph()
345345
with torch.cuda.graph(graph, stream=stream):
346346
op(*args, **kwargs)
347-
if torch_version_at_least("2.8.0"):
347+
if torch_version_at_least("2.9.0.dev"):
348348
from statistics import median
349349

350350
res = benchmarker.benchmark_gpu(

0 commit comments

Comments
 (0)