Skip to content

Qualcomm AI Engine Direct - Observer Fix and remove unused passes #6225

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions backends/qualcomm/_passes/annotate_quant_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,12 @@ class AnnotateQuantAttrs(ExportPass):
generated after quatization process.
"""

def __init__(self, edge_program: torch.export.ExportedProgram):
def __init__(
self, edge_program: torch.export.ExportedProgram, skip_advanced_requat: bool
):
super(AnnotateQuantAttrs, self).__init__()
self.edge_program = edge_program
self.skip_advanced_requant = skip_advanced_requat

def _annotate_source_nodes(
self, quant_node: torch.fx.Node, quant_attrs: Dict[str, Any]
Expand Down Expand Up @@ -68,9 +71,26 @@ def _annotate_requant(self, n):

# TODO: Store multiple pairs of requantize attributes when we have an op builder
# that has multiple outputs that requires quant attributes.
if q_attrs["dtype"] != dq_attrs["dtype"]:
dq_attrs[QCOM_ENCODING] = q_attrs[QCOM_ENCODING]
n.args[0].meta[QCOM_REQUANTIZE] = dq_attrs
if self.skip_advanced_requant:
if q_attrs["dtype"] != dq_attrs["dtype"]:
dq_attrs[QCOM_ENCODING] = q_attrs[QCOM_ENCODING]
n.args[0].meta[QCOM_REQUANTIZE] = dq_attrs
else:
# When dtype is the same but other specs such as scale and offset are different,
# insert requant to improve accuracy.
# Users can turn this feature off if any inference speed drop is observed.
if any(
q_attrs[attr] != dq_attrs[attr]
for attr in [
"scale",
"zero_point",
"quant_min",
"quant_max",
"dtype",
]
):
dq_attrs[QCOM_ENCODING] = q_attrs[QCOM_ENCODING]
n.args[0].meta[QCOM_REQUANTIZE] = dq_attrs

# Dequant all the fold_quant parameters back to fp32.
# If an operation is not supported by QNN and got fallback, it will expect a fp32 param.
Expand Down
2 changes: 0 additions & 2 deletions backends/qualcomm/qnn_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager

import torch # noqa: F401
from executorch.backends.qualcomm._passes.convert_to_linear import ConvertToLinear
from executorch.backends.qualcomm._passes.fuse_consecutive_transpose import (
FuseConsecutiveTranspose,
)
Expand Down Expand Up @@ -49,7 +48,6 @@ def preprocess(
# QNN Delegate Specific Passes
qnn_compiler_passes = PassManager(
passes=[
ConvertToLinear(),
InsertRequantize(edge_program),
InsertIOQDQ(edge_program),
LayoutTransform(edge_program, insert_permute=True),
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/quantizer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def get_ptq_per_channel_quant_config(
quant_min=torch.iinfo(act_dtype).min,
quant_max=torch.iinfo(act_dtype).max,
qscheme=torch.per_tensor_affine,
observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args),
observer_or_fake_quant_ctr=MovingAverageMinMaxObserver.with_args(**extra_args),
)

weight_quantization_spec = QuantizationSpec(
Expand Down
1 change: 1 addition & 0 deletions backends/qualcomm/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
QCOM_ZERO_POINT = "zero_point"
QCOM_ZERO_POINTS = "zero_points"
QCOM_PASS_EXPAND_BROADCAST_SHAPE = "expand_broadcast_shape"
QCOM_PASS_SKIP_ADVANCED_REQUANT = "skip_advanced_requant"

# constants in backends/qualcomm/tests
QCOM_ANNOTATION = "annotation"
Expand Down
5 changes: 4 additions & 1 deletion backends/qualcomm/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
)
from executorch.backends.qualcomm.utils.constants import (
QCOM_PASS_EXPAND_BROADCAST_SHAPE,
QCOM_PASS_SKIP_ADVANCED_REQUANT,
QCOM_QNN_COMPILE_SPEC,
)

Expand Down Expand Up @@ -305,7 +306,9 @@ def _transform(
ConvertBmmToMatmul()(graph_module)
ConvertInterpolateWithUpsample2D()(graph_module)
I64toI32(edge_program)(graph_module)
AnnotateQuantAttrs(edge_program)(graph_module)
AnnotateQuantAttrs(
edge_program, QCOM_PASS_SKIP_ADVANCED_REQUANT in custom_pass_config
)(graph_module)
AnnotateAndQuantScalar(edge_program)(graph_module)
AnnotateDecomposed(edge_program)(graph_module)
FoldQDQ()(graph_module)
Expand Down
Loading