Skip to content

Add 16A8W support and test for add operation #13568

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/arm/operators/op_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def define_node(
validate_valid_dtype(
self.target,
[*inputs, output],
[ts.DType.INT8, ts.DType.INT32],
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32],
output.tosa_spec,
)

Expand Down
2 changes: 1 addition & 1 deletion backends/arm/operators/op_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def define_node(
validate_valid_dtype(
self.target,
[inputs[0], output],
[ts.DType.INT8, ts.DType.INT32, ts.DType.FP32, ts.DType.BOOL],
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32, ts.DType.BOOL],
output.tosa_spec,
)

Expand Down
80 changes: 80 additions & 0 deletions backends/arm/quantizer/arm_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,86 @@ def get_symmetric_quantization_config(
return quantization_config


@functools.lru_cache
def get_symmetric_a16w8_quantization_config(
is_per_channel: bool = True,
is_qat: bool = False,
is_dynamic: bool = False,
weight_qmin: int = -127,
weight_qmax: int = 127,
):
"""
16A8W quantization config: 16-bit activations, 8-bit weights.

This configuration provides better accuracy than 8A8W while maintaining
reasonable memory usage through 8-bit weights.

Args:
is_per_channel: Whether to use per-channel quantization for weights
is_qat: Whether this is for Quantization Aware Training
is_dynamic: Whether to use dynamic quantization
weight_qmin: Minimum quantization value for weights
weight_qmax: Maximum quantization value for weights

Returns:
QuantizationConfig with 16-bit activations and 8-bit weights
"""
extra_args: Dict[str, Any] = {"eps": 2**-12}

# Setup observer/fake-quant for 16-bit activations
if is_qat:
if is_dynamic:
act_observer_or_fake_quant_ctr = FakeQuantize
dynamic_quant_observer = MovingAverageMinMaxObserver.with_args(
averaging_constant=1
)
extra_args["observer"] = dynamic_quant_observer
else:
act_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize # type: ignore[assignment]
else:
if is_dynamic:
act_observer_or_fake_quant_ctr = PlaceholderObserver # type: ignore[assignment]
else:
# HistogramObserver works well for 16-bit range
act_observer_or_fake_quant_ctr = HistogramObserver # type: ignore[assignment]

# 16-bit activation quantization spec
act_quantization_spec = QuantizationSpec(
dtype=torch.int16,
quant_min=torch.iinfo(torch.int16).min, # -32768
quant_max=torch.iinfo(torch.int16).max, # 32767
qscheme=torch.per_tensor_symmetric,
is_dynamic=is_dynamic,
observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(
**extra_args,
),
)

# Instead of reconstructing quantization_config, just clone and update as needed
# Clone the quantization_config from get_symmetric_quantization_config and update activation spec
base_config = get_symmetric_quantization_config(
is_per_channel=is_per_channel,
is_qat=is_qat,
is_dynamic=is_dynamic,
)
# Replace activation quantization spec with 16-bit version
if is_dynamic:
quantization_config = QuantizationConfig(
act_quantization_spec, # 16-bit input activations
None,
base_config.weight, # 8-bit weights from base config
None,
)
else:
quantization_config = QuantizationConfig(
act_quantization_spec, # 16-bit input activations
act_quantization_spec, # 16-bit output activations
base_config.weight, # 8-bit weights from base config
None,
)
return quantization_config


NodeFilterType = Callable[[Node], bool]
"""Type for a Node Filter used by annotators. A Node filter is a function that takes
a Node and returns whether the node should be annotated or not.
Expand Down
47 changes: 47 additions & 0 deletions backends/arm/test/ops/test_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
import pytest
import torch
from executorch.backends.arm.quantizer import arm_quantizer
from executorch.backends.arm.quantizer.arm_quantizer import (
get_symmetric_a16w8_quantization_config,
TOSAQuantizer,
)
from executorch.backends.arm.test import common, conftest
from executorch.backends.arm.test.tester.test_pipeline import (
EthosU55PipelineINT,
Expand Down Expand Up @@ -216,3 +220,46 @@ def test_add_tensor_vgf_INT(test_data: input_t1):
tosa_version="TOSA-1.0+INT",
)
pipeline.run()


def get_symmetric_a16w8_add_quantizer(u55_config=False, per_channel_quantization=False):
tosa_version = conftest.get_option("tosa_version")
tosa_profiles = {
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
}

quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
quantizer.set_global(
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
)

return Quantize(
quantizer,
get_symmetric_a16w8_quantization_config(
is_per_channel=per_channel_quantization
),
)


@common.parametrize("test_data", Add.test_data)
def test_add_tensor_16a8w_tosa_INT(test_data: input_t1):
"""Test add operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
per_channel_quantization = False

pipeline = TosaPipelineINT[input_t1](
Add(),
test_data(),
aten_op,
exir_op=[],
per_channel_quantization=per_channel_quantization,
use_to_edge_transform_and_lower=True,
tosa_extensions=["int16"],
)

pipeline.change_args(
"quantize",
get_symmetric_a16w8_add_quantizer(
per_channel_quantization=per_channel_quantization
),
)
pipeline.run()
58 changes: 56 additions & 2 deletions backends/arm/test/ops/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
import pytest

import torch
from executorch.backends.arm.test import common
from executorch.backends.arm.quantizer.arm_quantizer import (
get_symmetric_a16w8_quantization_config,
TOSAQuantizer,
)
from executorch.backends.arm.test import common, conftest

from executorch.backends.arm.test.tester.test_pipeline import (
EthosU55PipelineINT,
Expand All @@ -20,6 +24,8 @@
TosaPipelineINT,
VgfPipeline,
)
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.backends.xnnpack.test.tester import Quantize

aten_op = "torch.ops.aten.linear.default"

Expand Down Expand Up @@ -143,7 +149,6 @@ def test_linear_tosa_FP(test_data: torch.Tensor):
pipeline.run()


@pytest.mark.flaky(reruns=5) # TODO: Investigate flakyness.
@common.parametrize("test_data", test_data_rank1_INT | test_data_rank4_INT)
def test_linear_tosa_INT(test_data: torch.Tensor):
test_data, out_features, has_bias, per_channel_quantization = test_data()
Expand Down Expand Up @@ -258,3 +263,52 @@ def test_linear_vgf_INT(test_data: torch.Tensor):
per_channel_quantization=per_channel_quantization,
)
pipeline.run()

def get_symmetric_a16w8_linear_quantizer(u55_config=False, per_channel_quantization=False):
tosa_version = conftest.get_option("tosa_version")
tosa_profiles = {
"1.0": TosaSpecification.create_from_string(
"TOSA-1.0+INT+int16"
),
}

quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
quantizer.set_global(get_symmetric_a16w8_quantization_config(
is_per_channel=per_channel_quantization
))
quantizer.set_module_type(
torch.nn.Linear, get_symmetric_a16w8_quantization_config(
is_per_channel=per_channel_quantization
)
)

return Quantize(quantizer, get_symmetric_a16w8_quantization_config(
is_per_channel=per_channel_quantization
))

@common.parametrize("test_data", test_data_rank1_INT, test_data_rank4_INT)
def test_linear_16a8w_tosa_INT(test_data: torch.Tensor):
"""Test linear operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
test_data, out_features, has_bias, per_channel_quantization = test_data()
in_features = test_data.shape[-1]

# Create pipeline with custom 16A8W quantization config
pipeline = TosaPipelineINT[input_t1](
Linear(
in_features=in_features,
out_features=out_features,
bias=has_bias,
),
(test_data,),
aten_op,
exir_op=[],
per_channel_quantization=per_channel_quantization,
use_to_edge_transform_and_lower=True,
tosa_extensions=["int16"],
)

pipeline.change_args("quantize", get_symmetric_a16w8_linear_quantizer(
per_channel_quantization=per_channel_quantization
))
# Run the pipeline
pipeline.run()
1 change: 1 addition & 0 deletions backends/arm/test/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def define_arm_tests():

# Operators
test_files += [
"ops/test_add.py",
"ops/test_avg_pool2d.py",
"ops/test_linear.py",
"ops/test_slice.py",
Expand Down
Loading