diff --git a/tutorials/quantize_vit/run_vit_b.py b/tutorials/quantize_vit/run_vit_b.py index a7fd78f9b2..dae7dde703 100644 --- a/tutorials/quantize_vit/run_vit_b.py +++ b/tutorials/quantize_vit/run_vit_b.py @@ -1,10 +1,11 @@ import torch -import torchvision.models.vision_transformer as models from torchao.utils import benchmark_model, profiler_runner +from torchvision import models + torch.set_float32_matmul_precision("high") # Load Vision Transformer model -model = models.vit_b_16(pretrained=True) +model = models.vit_b_16(weights=models.ViT_B_16_Weights.IMAGENET1K_V1) # Set the model to evaluation mode model.eval().cuda().to(torch.bfloat16) diff --git a/tutorials/quantize_vit/run_vit_b_quant.py b/tutorials/quantize_vit/run_vit_b_quant.py index 0396a9dffd..8239c82423 100644 --- a/tutorials/quantize_vit/run_vit_b_quant.py +++ b/tutorials/quantize_vit/run_vit_b_quant.py @@ -1,11 +1,12 @@ import torch import torchao -import torchvision.models.vision_transformer as models from torchao.utils import benchmark_model, profiler_runner +from torchvision import models + torch.set_float32_matmul_precision("high") # Load Vision Transformer model -model = models.vit_b_16(pretrained=True) +model = models.vit_b_16(weights=models.ViT_B_16_Weights.IMAGENET1K_V1) # Set the model to evaluation mode model.eval().cuda().to(torch.bfloat16)