diff --git a/backends/arm/operators/op_add.py b/backends/arm/operators/op_add.py index 7a022b54395..2165adf49ed 100644 --- a/backends/arm/operators/op_add.py +++ b/backends/arm/operators/op_add.py @@ -50,7 +50,7 @@ def define_node( validate_valid_dtype( self.target, [*inputs, output], - [ts.DType.INT8, ts.DType.INT32], + [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32], output.tosa_spec, ) diff --git a/backends/arm/operators/op_view.py b/backends/arm/operators/op_view.py index 1e8c06b691f..01e791e7324 100644 --- a/backends/arm/operators/op_view.py +++ b/backends/arm/operators/op_view.py @@ -44,7 +44,7 @@ def define_node( validate_valid_dtype( self.target, [inputs[0], output], - [ts.DType.INT8, ts.DType.INT32, ts.DType.FP32, ts.DType.BOOL], + [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32, ts.DType.BOOL], output.tosa_spec, ) diff --git a/backends/arm/quantizer/arm_quantizer.py b/backends/arm/quantizer/arm_quantizer.py index 9fa15568cc4..e60cb667c13 100644 --- a/backends/arm/quantizer/arm_quantizer.py +++ b/backends/arm/quantizer/arm_quantizer.py @@ -145,6 +145,86 @@ def get_symmetric_quantization_config( return quantization_config +@functools.lru_cache +def get_symmetric_a16w8_quantization_config( + is_per_channel: bool = True, + is_qat: bool = False, + is_dynamic: bool = False, + weight_qmin: int = -127, + weight_qmax: int = 127, +): + """ + 16A8W quantization config: 16-bit activations, 8-bit weights. + + This configuration provides better accuracy than 8A8W while maintaining + reasonable memory usage through 8-bit weights. + + Args: + is_per_channel: Whether to use per-channel quantization for weights + is_qat: Whether this is for Quantization Aware Training + is_dynamic: Whether to use dynamic quantization + weight_qmin: Minimum quantization value for weights + weight_qmax: Maximum quantization value for weights + + Returns: + QuantizationConfig with 16-bit activations and 8-bit weights + """ + extra_args: Dict[str, Any] = {"eps": 2**-12} + + # Setup observer/fake-quant for 16-bit activations + if is_qat: + if is_dynamic: + act_observer_or_fake_quant_ctr = FakeQuantize + dynamic_quant_observer = MovingAverageMinMaxObserver.with_args( + averaging_constant=1 + ) + extra_args["observer"] = dynamic_quant_observer + else: + act_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize # type: ignore[assignment] + else: + if is_dynamic: + act_observer_or_fake_quant_ctr = PlaceholderObserver # type: ignore[assignment] + else: + # HistogramObserver works well for 16-bit range + act_observer_or_fake_quant_ctr = HistogramObserver # type: ignore[assignment] + + # 16-bit activation quantization spec + act_quantization_spec = QuantizationSpec( + dtype=torch.int16, + quant_min=torch.iinfo(torch.int16).min, # -32768 + quant_max=torch.iinfo(torch.int16).max, # 32767 + qscheme=torch.per_tensor_symmetric, + is_dynamic=is_dynamic, + observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args( + **extra_args, + ), + ) + + # Instead of reconstructing quantization_config, just clone and update as needed + # Clone the quantization_config from get_symmetric_quantization_config and update activation spec + base_config = get_symmetric_quantization_config( + is_per_channel=is_per_channel, + is_qat=is_qat, + is_dynamic=is_dynamic, + ) + # Replace activation quantization spec with 16-bit version + if is_dynamic: + quantization_config = QuantizationConfig( + act_quantization_spec, # 16-bit input activations + None, + base_config.weight, # 8-bit weights from base config + None, + ) + else: + quantization_config = QuantizationConfig( + act_quantization_spec, # 16-bit input activations + act_quantization_spec, # 16-bit output activations + base_config.weight, # 8-bit weights from base config + None, + ) + return quantization_config + + NodeFilterType = Callable[[Node], bool] """Type for a Node Filter used by annotators. A Node filter is a function that takes a Node and returns whether the node should be annotated or not. diff --git a/backends/arm/test/ops/test_add.py b/backends/arm/test/ops/test_add.py index 6bf3830d038..25dc8c27ce8 100644 --- a/backends/arm/test/ops/test_add.py +++ b/backends/arm/test/ops/test_add.py @@ -10,6 +10,10 @@ import pytest import torch from executorch.backends.arm.quantizer import arm_quantizer +from executorch.backends.arm.quantizer.arm_quantizer import ( + get_symmetric_a16w8_quantization_config, + TOSAQuantizer, +) from executorch.backends.arm.test import common, conftest from executorch.backends.arm.test.tester.test_pipeline import ( EthosU55PipelineINT, @@ -216,3 +220,46 @@ def test_add_tensor_vgf_INT(test_data: input_t1): tosa_version="TOSA-1.0+INT", ) pipeline.run() + + +def get_symmetric_a16w8_add_quantizer(u55_config=False, per_channel_quantization=False): + tosa_version = conftest.get_option("tosa_version") + tosa_profiles = { + "1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"), + } + + quantizer = TOSAQuantizer(tosa_profiles[tosa_version]) + quantizer.set_global( + get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization) + ) + + return Quantize( + quantizer, + get_symmetric_a16w8_quantization_config( + is_per_channel=per_channel_quantization + ), + ) + + +@common.parametrize("test_data", Add.test_data) +def test_add_tensor_16a8w_tosa_INT(test_data: input_t1): + """Test add operation with 16A8W quantization (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = TosaPipelineINT[input_t1]( + Add(), + test_data(), + aten_op, + exir_op=[], + per_channel_quantization=per_channel_quantization, + use_to_edge_transform_and_lower=True, + tosa_extensions=["int16"], + ) + + pipeline.change_args( + "quantize", + get_symmetric_a16w8_add_quantizer( + per_channel_quantization=per_channel_quantization + ), + ) + pipeline.run() diff --git a/backends/arm/test/ops/test_linear.py b/backends/arm/test/ops/test_linear.py index 57ce490dae8..1d7981c3155 100644 --- a/backends/arm/test/ops/test_linear.py +++ b/backends/arm/test/ops/test_linear.py @@ -11,7 +11,11 @@ import pytest import torch -from executorch.backends.arm.test import common +from executorch.backends.arm.quantizer.arm_quantizer import ( + get_symmetric_a16w8_quantization_config, + TOSAQuantizer, +) +from executorch.backends.arm.test import common, conftest from executorch.backends.arm.test.tester.test_pipeline import ( EthosU55PipelineINT, @@ -20,6 +24,8 @@ TosaPipelineINT, VgfPipeline, ) +from executorch.backends.arm.tosa_specification import TosaSpecification +from executorch.backends.xnnpack.test.tester import Quantize aten_op = "torch.ops.aten.linear.default" @@ -143,7 +149,6 @@ def test_linear_tosa_FP(test_data: torch.Tensor): pipeline.run() -@pytest.mark.flaky(reruns=5) # TODO: Investigate flakyness. @common.parametrize("test_data", test_data_rank1_INT | test_data_rank4_INT) def test_linear_tosa_INT(test_data: torch.Tensor): test_data, out_features, has_bias, per_channel_quantization = test_data() @@ -258,3 +263,52 @@ def test_linear_vgf_INT(test_data: torch.Tensor): per_channel_quantization=per_channel_quantization, ) pipeline.run() + +def get_symmetric_a16w8_linear_quantizer(u55_config=False, per_channel_quantization=False): + tosa_version = conftest.get_option("tosa_version") + tosa_profiles = { + "1.0": TosaSpecification.create_from_string( + "TOSA-1.0+INT+int16" + ), + } + + quantizer = TOSAQuantizer(tosa_profiles[tosa_version]) + quantizer.set_global(get_symmetric_a16w8_quantization_config( + is_per_channel=per_channel_quantization + )) + quantizer.set_module_type( + torch.nn.Linear, get_symmetric_a16w8_quantization_config( + is_per_channel=per_channel_quantization + ) + ) + + return Quantize(quantizer, get_symmetric_a16w8_quantization_config( + is_per_channel=per_channel_quantization + )) + +@common.parametrize("test_data", test_data_rank1_INT, test_data_rank4_INT) +def test_linear_16a8w_tosa_INT(test_data: torch.Tensor): + """Test linear operation with 16A8W quantization (16-bit activations, 8-bit weights)""" + test_data, out_features, has_bias, per_channel_quantization = test_data() + in_features = test_data.shape[-1] + + # Create pipeline with custom 16A8W quantization config + pipeline = TosaPipelineINT[input_t1]( + Linear( + in_features=in_features, + out_features=out_features, + bias=has_bias, + ), + (test_data,), + aten_op, + exir_op=[], + per_channel_quantization=per_channel_quantization, + use_to_edge_transform_and_lower=True, + tosa_extensions=["int16"], + ) + + pipeline.change_args("quantize", get_symmetric_a16w8_linear_quantizer( + per_channel_quantization=per_channel_quantization + )) + # Run the pipeline + pipeline.run() diff --git a/backends/arm/test/targets.bzl b/backends/arm/test/targets.bzl index acb27f13798..405f1bbf081 100644 --- a/backends/arm/test/targets.bzl +++ b/backends/arm/test/targets.bzl @@ -13,6 +13,7 @@ def define_arm_tests(): # Operators test_files += [ + "ops/test_add.py", "ops/test_avg_pool2d.py", "ops/test_linear.py", "ops/test_slice.py",