Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions torchvision/models/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ def verify(cls, obj: Any) -> Any:
)
return obj

def get_state_dict(self, progress: bool) -> Mapping[str, Any]:
return load_state_dict_from_url(self.url, progress=progress)
def get_state_dict(self, *args: Any, **kwargs: Any) -> Mapping[str, Any]:
return load_state_dict_from_url(self.url, *args, **kwargs)

def __repr__(self) -> str:
return f"{self.__class__.__name__}.{self._name_}"
Expand Down
2 changes: 1 addition & 1 deletion torchvision/models/alexnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,6 @@ def alexnet(*, weights: Optional[AlexNet_Weights] = None, progress: bool = True,
model = AlexNet(**kwargs)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))

return model
2 changes: 1 addition & 1 deletion torchvision/models/convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def _convnext(
model = ConvNeXt(block_setting, stochastic_depth_prob=stochastic_depth_prob, **kwargs)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))

return model

Expand Down
2 changes: 1 addition & 1 deletion torchvision/models/densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def _load_state_dict(model: nn.Module, weights: WeightsEnum, progress: bool) ->
r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$"
)

state_dict = weights.get_state_dict(progress=progress)
state_dict = weights.get_state_dict(progress=progress, check_hash=True)
for key in list(state_dict.keys()):
res = pattern.match(key)
if res:
Expand Down
6 changes: 3 additions & 3 deletions torchvision/models/detection/faster_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ def fasterrcnn_resnet50_fpn(
model = FasterRCNN(backbone, num_classes=num_classes, **kwargs)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
if weights == FasterRCNN_ResNet50_FPN_Weights.COCO_V1:
overwrite_eps(model, 0.0)

Expand Down Expand Up @@ -653,7 +653,7 @@ def fasterrcnn_resnet50_fpn_v2(
)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))

return model

Expand Down Expand Up @@ -694,7 +694,7 @@ def _fasterrcnn_mobilenet_v3_large_fpn(
)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))

return model

Expand Down
2 changes: 1 addition & 1 deletion torchvision/models/detection/fcos.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,6 +766,6 @@ def fcos_resnet50_fpn(
model = FCOS(backbone, num_classes, **kwargs)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))

return model
2 changes: 1 addition & 1 deletion torchvision/models/detection/keypoint_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ def keypointrcnn_resnet50_fpn(
model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
if weights == KeypointRCNN_ResNet50_FPN_Weights.COCO_V1:
overwrite_eps(model, 0.0)

Expand Down
4 changes: 2 additions & 2 deletions torchvision/models/detection/mask_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ def maskrcnn_resnet50_fpn(
model = MaskRCNN(backbone, num_classes=num_classes, **kwargs)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
if weights == MaskRCNN_ResNet50_FPN_Weights.COCO_V1:
overwrite_eps(model, 0.0)

Expand Down Expand Up @@ -582,6 +582,6 @@ def maskrcnn_resnet50_fpn_v2(
)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))

return model
4 changes: 2 additions & 2 deletions torchvision/models/detection/retinanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,7 +815,7 @@ def retinanet_resnet50_fpn(
model = RetinaNet(backbone, num_classes, **kwargs)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
if weights == RetinaNet_ResNet50_FPN_Weights.COCO_V1:
overwrite_eps(model, 0.0)

Expand Down Expand Up @@ -894,6 +894,6 @@ def retinanet_resnet50_fpn_v2(
model = RetinaNet(backbone, num_classes, anchor_generator=anchor_generator, head=head, **kwargs)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))

return model
2 changes: 1 addition & 1 deletion torchvision/models/detection/ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,6 @@ def ssd300_vgg16(
model = SSD(backbone, anchor_generator, (300, 300), num_classes, **kwargs)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))

return model
2 changes: 1 addition & 1 deletion torchvision/models/detection/ssdlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,6 @@ def ssdlite320_mobilenet_v3_large(
)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))

return model
2 changes: 1 addition & 1 deletion torchvision/models/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def _efficientnet(
model = EfficientNet(inverted_residual_setting, dropout, last_channel=last_channel, **kwargs)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))

return model

Expand Down
2 changes: 1 addition & 1 deletion torchvision/models/googlenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def googlenet(*, weights: Optional[GoogLeNet_Weights] = None, progress: bool = T
model = GoogLeNet(**kwargs)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
if not original_aux_logits:
model.aux_logits = False
model.aux1 = None # type: ignore[assignment]
Expand Down
2 changes: 1 addition & 1 deletion torchvision/models/inception.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ def inception_v3(*, weights: Optional[Inception_V3_Weights] = None, progress: bo
model = Inception3(**kwargs)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
if not original_aux_logits:
model.aux_logits = False
model.AuxLogits = None
Expand Down
2 changes: 1 addition & 1 deletion torchvision/models/maxvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,7 +763,7 @@ def _maxvit(
)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))

return model

Expand Down
2 changes: 1 addition & 1 deletion torchvision/models/mnasnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def _mnasnet(alpha: float, weights: Optional[WeightsEnum], progress: bool, **kwa
model = MNASNet(alpha, **kwargs)

if weights:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))

return model

Expand Down
2 changes: 1 addition & 1 deletion torchvision/models/mobilenetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,6 @@ def mobilenet_v2(
model = MobileNetV2(**kwargs)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))

return model
2 changes: 1 addition & 1 deletion torchvision/models/mobilenetv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def _mobilenet_v3(
model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))

return model

Expand Down
2 changes: 1 addition & 1 deletion torchvision/models/optical_flow/raft.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,7 +818,7 @@ def _raft(
)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))

return model

Expand Down
2 changes: 1 addition & 1 deletion torchvision/models/quantization/googlenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def googlenet(
quantize_model(model, backend)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
if not original_aux_logits:
model.aux_logits = False
model.aux1 = None # type: ignore[assignment]
Expand Down
2 changes: 1 addition & 1 deletion torchvision/models/quantization/inception.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def inception_v3(
if quantize and not original_aux_logits:
model.aux_logits = False
model.AuxLogits = None
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
if not quantize and not original_aux_logits:
model.aux_logits = False
model.AuxLogits = None
Expand Down
2 changes: 1 addition & 1 deletion torchvision/models/quantization/mobilenetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,6 @@ def mobilenet_v2(
quantize_model(model, backend)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))

return model
2 changes: 1 addition & 1 deletion torchvision/models/quantization/mobilenetv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def _mobilenet_v3_model(
torch.ao.quantization.prepare_qat(model, inplace=True)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))

if quantize:
torch.ao.quantization.convert(model, inplace=True)
Expand Down
2 changes: 1 addition & 1 deletion torchvision/models/quantization/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def _resnet(
quantize_model(model, backend)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))

return model

Expand Down
2 changes: 1 addition & 1 deletion torchvision/models/quantization/shufflenetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def _shufflenetv2(
quantize_model(model, backend)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))

return model

Expand Down
2 changes: 1 addition & 1 deletion torchvision/models/regnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ def _regnet(
model = RegNet(block_params, norm_layer=norm_layer, **kwargs)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))

return model

Expand Down
2 changes: 1 addition & 1 deletion torchvision/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def _resnet(
model = ResNet(block, layers, **kwargs)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))

return model

Expand Down
6 changes: 3 additions & 3 deletions torchvision/models/segmentation/deeplabv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def deeplabv3_resnet50(
model = _deeplabv3_resnet(backbone, num_classes, aux_loss)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))

return model

Expand Down Expand Up @@ -331,7 +331,7 @@ def deeplabv3_resnet101(
model = _deeplabv3_resnet(backbone, num_classes, aux_loss)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))

return model

Expand Down Expand Up @@ -385,6 +385,6 @@ def deeplabv3_mobilenet_v3_large(
model = _deeplabv3_mobilenetv3(backbone, num_classes, aux_loss)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))

return model
4 changes: 2 additions & 2 deletions torchvision/models/segmentation/fcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def fcn_resnet50(
model = _fcn_resnet(backbone, num_classes, aux_loss)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))

return model

Expand Down Expand Up @@ -227,6 +227,6 @@ def fcn_resnet101(
model = _fcn_resnet(backbone, num_classes, aux_loss)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))

return model
2 changes: 1 addition & 1 deletion torchvision/models/segmentation/lraspp.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,6 @@ def lraspp_mobilenet_v3_large(
model = _lraspp_mobilenetv3(backbone, num_classes)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))

return model
2 changes: 1 addition & 1 deletion torchvision/models/shufflenetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def _shufflenetv2(
model = ShuffleNetV2(*args, **kwargs)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))

return model

Expand Down
2 changes: 1 addition & 1 deletion torchvision/models/squeezenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def _squeezenet(
model = SqueezeNet(version, **kwargs)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))

return model

Expand Down
2 changes: 1 addition & 1 deletion torchvision/models/swin_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,7 +639,7 @@ def _swin_transformer(
)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))

return model

Expand Down
2 changes: 1 addition & 1 deletion torchvision/models/vgg.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def _vgg(cfg: str, batch_norm: bool, weights: Optional[WeightsEnum], progress: b
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model


Expand Down
2 changes: 1 addition & 1 deletion torchvision/models/video/mvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ def _mvit(
)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))

return model

Expand Down
2 changes: 1 addition & 1 deletion torchvision/models/video/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def _video_resnet(
model = VideoResNet(block, conv_makers, layers, stem, **kwargs)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))

return model

Expand Down
2 changes: 1 addition & 1 deletion torchvision/models/video/s3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,6 @@ def s3d(*, weights: Optional[S3D_Weights] = None, progress: bool = True, **kwarg
model = S3D(**kwargs)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))

return model
2 changes: 1 addition & 1 deletion torchvision/models/video/swin_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ def _swin_transformer3d(
)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))

return model

Expand Down
Loading