From 308af038f171c34e2c05eb93e158fe3eeb1b876e Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Mon, 3 Jun 2024 17:52:04 -0700 Subject: [PATCH] Update old pretrained TorchVision API in ao tutorials (#313) Summary: Pull Request resolved: https://github.com/pytorch/ao/pull/313 For TorchVision models, pretrained parameters have been deprecated in favor of "Multi-weight support API" - see https://pytorch.org/vision/0.15/models.html Differential Revision: D58117114 --- tutorials/quantize_vit/run_vit_b.py | 5 +++-- tutorials/quantize_vit/run_vit_b_quant.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/tutorials/quantize_vit/run_vit_b.py b/tutorials/quantize_vit/run_vit_b.py index a7fd78f9b2..dae7dde703 100644 --- a/tutorials/quantize_vit/run_vit_b.py +++ b/tutorials/quantize_vit/run_vit_b.py @@ -1,10 +1,11 @@ import torch -import torchvision.models.vision_transformer as models from torchao.utils import benchmark_model, profiler_runner +from torchvision import models + torch.set_float32_matmul_precision("high") # Load Vision Transformer model -model = models.vit_b_16(pretrained=True) +model = models.vit_b_16(weights=models.ViT_B_16_Weights.IMAGENET1K_V1) # Set the model to evaluation mode model.eval().cuda().to(torch.bfloat16) diff --git a/tutorials/quantize_vit/run_vit_b_quant.py b/tutorials/quantize_vit/run_vit_b_quant.py index 0396a9dffd..8239c82423 100644 --- a/tutorials/quantize_vit/run_vit_b_quant.py +++ b/tutorials/quantize_vit/run_vit_b_quant.py @@ -1,11 +1,12 @@ import torch import torchao -import torchvision.models.vision_transformer as models from torchao.utils import benchmark_model, profiler_runner +from torchvision import models + torch.set_float32_matmul_precision("high") # Load Vision Transformer model -model = models.vit_b_16(pretrained=True) +model = models.vit_b_16(weights=models.ViT_B_16_Weights.IMAGENET1K_V1) # Set the model to evaluation mode model.eval().cuda().to(torch.bfloat16)