From acc6f6da8148237c23706fc3866c559d474ce81d Mon Sep 17 00:00:00 2001 From: winskuo-quic Date: Mon, 14 Oct 2024 18:03:57 +0800 Subject: [PATCH] Observer Fix and remove unused passes --- .../qualcomm/_passes/annotate_quant_attrs.py | 28 ++++++++++++++++--- backends/qualcomm/qnn_preprocess.py | 2 -- backends/qualcomm/quantizer/utils.py | 2 +- backends/qualcomm/utils/constants.py | 1 + backends/qualcomm/utils/utils.py | 5 +++- 5 files changed, 30 insertions(+), 8 deletions(-) diff --git a/backends/qualcomm/_passes/annotate_quant_attrs.py b/backends/qualcomm/_passes/annotate_quant_attrs.py index 0dc39d2a4de..632e67569f7 100644 --- a/backends/qualcomm/_passes/annotate_quant_attrs.py +++ b/backends/qualcomm/_passes/annotate_quant_attrs.py @@ -27,9 +27,12 @@ class AnnotateQuantAttrs(ExportPass): generated after quatization process. """ - def __init__(self, edge_program: torch.export.ExportedProgram): + def __init__( + self, edge_program: torch.export.ExportedProgram, skip_advanced_requat: bool + ): super(AnnotateQuantAttrs, self).__init__() self.edge_program = edge_program + self.skip_advanced_requant = skip_advanced_requat def _annotate_source_nodes( self, quant_node: torch.fx.Node, quant_attrs: Dict[str, Any] @@ -68,9 +71,26 @@ def _annotate_requant(self, n): # TODO: Store multiple pairs of requantize attributes when we have an op builder # that has multiple outputs that requires quant attributes. - if q_attrs["dtype"] != dq_attrs["dtype"]: - dq_attrs[QCOM_ENCODING] = q_attrs[QCOM_ENCODING] - n.args[0].meta[QCOM_REQUANTIZE] = dq_attrs + if self.skip_advanced_requant: + if q_attrs["dtype"] != dq_attrs["dtype"]: + dq_attrs[QCOM_ENCODING] = q_attrs[QCOM_ENCODING] + n.args[0].meta[QCOM_REQUANTIZE] = dq_attrs + else: + # When dtype is the same but other specs such as scale and offset are different, + # insert requant to improve accuracy. + # Users can turn this feature off if any inference speed drop is observed. + if any( + q_attrs[attr] != dq_attrs[attr] + for attr in [ + "scale", + "zero_point", + "quant_min", + "quant_max", + "dtype", + ] + ): + dq_attrs[QCOM_ENCODING] = q_attrs[QCOM_ENCODING] + n.args[0].meta[QCOM_REQUANTIZE] = dq_attrs # Dequant all the fold_quant parameters back to fp32. # If an operation is not supported by QNN and got fallback, it will expect a fp32 param. diff --git a/backends/qualcomm/qnn_preprocess.py b/backends/qualcomm/qnn_preprocess.py index 417934acbd4..f13d3fb55ae 100644 --- a/backends/qualcomm/qnn_preprocess.py +++ b/backends/qualcomm/qnn_preprocess.py @@ -11,7 +11,6 @@ import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch # noqa: F401 -from executorch.backends.qualcomm._passes.convert_to_linear import ConvertToLinear from executorch.backends.qualcomm._passes.fuse_consecutive_transpose import ( FuseConsecutiveTranspose, ) @@ -49,7 +48,6 @@ def preprocess( # QNN Delegate Specific Passes qnn_compiler_passes = PassManager( passes=[ - ConvertToLinear(), InsertRequantize(edge_program), InsertIOQDQ(edge_program), LayoutTransform(edge_program, insert_permute=True), diff --git a/backends/qualcomm/quantizer/utils.py b/backends/qualcomm/quantizer/utils.py index d1ea35fa190..46a048c36b7 100644 --- a/backends/qualcomm/quantizer/utils.py +++ b/backends/qualcomm/quantizer/utils.py @@ -364,7 +364,7 @@ def get_ptq_per_channel_quant_config( quant_min=torch.iinfo(act_dtype).min, quant_max=torch.iinfo(act_dtype).max, qscheme=torch.per_tensor_affine, - observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + observer_or_fake_quant_ctr=MovingAverageMinMaxObserver.with_args(**extra_args), ) weight_quantization_spec = QuantizationSpec( diff --git a/backends/qualcomm/utils/constants.py b/backends/qualcomm/utils/constants.py index 8a37b2bd8ca..c54770e5423 100644 --- a/backends/qualcomm/utils/constants.py +++ b/backends/qualcomm/utils/constants.py @@ -26,6 +26,7 @@ QCOM_ZERO_POINT = "zero_point" QCOM_ZERO_POINTS = "zero_points" QCOM_PASS_EXPAND_BROADCAST_SHAPE = "expand_broadcast_shape" +QCOM_PASS_SKIP_ADVANCED_REQUANT = "skip_advanced_requant" # constants in backends/qualcomm/tests QCOM_ANNOTATION = "annotation" diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py index d93f7fcb4bc..298664e2c96 100644 --- a/backends/qualcomm/utils/utils.py +++ b/backends/qualcomm/utils/utils.py @@ -69,6 +69,7 @@ ) from executorch.backends.qualcomm.utils.constants import ( QCOM_PASS_EXPAND_BROADCAST_SHAPE, + QCOM_PASS_SKIP_ADVANCED_REQUANT, QCOM_QNN_COMPILE_SPEC, ) @@ -305,7 +306,9 @@ def _transform( ConvertBmmToMatmul()(graph_module) ConvertInterpolateWithUpsample2D()(graph_module) I64toI32(edge_program)(graph_module) - AnnotateQuantAttrs(edge_program)(graph_module) + AnnotateQuantAttrs( + edge_program, QCOM_PASS_SKIP_ADVANCED_REQUANT in custom_pass_config + )(graph_module) AnnotateAndQuantScalar(edge_program)(graph_module) AnnotateDecomposed(edge_program)(graph_module) FoldQDQ()(graph_module)