From f47b773f91ad6037d1296f82503ee49c2ebb6ec6 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Mon, 23 Sep 2024 18:51:38 -0700 Subject: [PATCH] Add temporary workaround to recover the perf for quantized vit under torch.compile Summary: Recently we found a perf drop in quantized vit due to https://github.com/pytorch/ao/issues/898#issuecomment-2364540055 This PR add a temp fix until we figure out the longer term fix. I think ideally we should figure out why the tensor subclass check failed in torch.compile (https://github.com/pytorch/pytorch/blob/e4d294221b140fdbb49a64f297bc60c9fcc2f80e/torch/nn/modules/activation.py#L1286) and fix that Test Plan: python tutorials/quantize_vit/run_vit_b_quant.py Reviewers: Subscribers: Tasks: Tags: --- tutorials/quantize_vit/run_vit_b_quant.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tutorials/quantize_vit/run_vit_b_quant.py b/tutorials/quantize_vit/run_vit_b_quant.py index 5c30762099..06113bcd68 100644 --- a/tutorials/quantize_vit/run_vit_b_quant.py +++ b/tutorials/quantize_vit/run_vit_b_quant.py @@ -36,6 +36,9 @@ if not TORCH_VERSION_AT_LEAST_2_5: unwrap_tensor_subclass(model) +# temporary workaround to recover the perf with quantized model under torch.compile +torch.backends.mha.set_fastpath_enabled(False) + model = torch.compile(model, mode='max-autotune') # Must run with no_grad when optimizing for inference