Skip to content

Commit 139e53a

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] Check sha256 of weights (#7219)
Summary: Co-authored-by: Nicolas Hug <[email protected]> Reviewed By: vmoens Differential Revision: D45523936 fbshipit-source-id: 3febf3a0f410cc1af38cfac91d18c7e83213bd4f
1 parent c5fcd2f commit 139e53a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+49
-49
lines changed

torchvision/models/_api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,8 @@ def verify(cls, obj: Any) -> Any:
8585
)
8686
return obj
8787

88-
def get_state_dict(self, progress: bool) -> Mapping[str, Any]:
89-
return load_state_dict_from_url(self.url, progress=progress)
88+
def get_state_dict(self, *args: Any, **kwargs: Any) -> Mapping[str, Any]:
89+
return load_state_dict_from_url(self.url, *args, **kwargs)
9090

9191
def __repr__(self) -> str:
9292
return f"{self.__class__.__name__}.{self._name_}"

torchvision/models/alexnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,6 @@ def alexnet(*, weights: Optional[AlexNet_Weights] = None, progress: bool = True,
114114
model = AlexNet(**kwargs)
115115

116116
if weights is not None:
117-
model.load_state_dict(weights.get_state_dict(progress=progress))
117+
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
118118

119119
return model

torchvision/models/convnext.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def _convnext(
189189
model = ConvNeXt(block_setting, stochastic_depth_prob=stochastic_depth_prob, **kwargs)
190190

191191
if weights is not None:
192-
model.load_state_dict(weights.get_state_dict(progress=progress))
192+
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
193193

194194
return model
195195

torchvision/models/densenet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def _load_state_dict(model: nn.Module, weights: WeightsEnum, progress: bool) ->
227227
r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$"
228228
)
229229

230-
state_dict = weights.get_state_dict(progress=progress)
230+
state_dict = weights.get_state_dict(progress=progress, check_hash=True)
231231
for key in list(state_dict.keys()):
232232
res = pattern.match(key)
233233
if res:

torchvision/models/detection/faster_rcnn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -571,7 +571,7 @@ def fasterrcnn_resnet50_fpn(
571571
model = FasterRCNN(backbone, num_classes=num_classes, **kwargs)
572572

573573
if weights is not None:
574-
model.load_state_dict(weights.get_state_dict(progress=progress))
574+
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
575575
if weights == FasterRCNN_ResNet50_FPN_Weights.COCO_V1:
576576
overwrite_eps(model, 0.0)
577577

@@ -653,7 +653,7 @@ def fasterrcnn_resnet50_fpn_v2(
653653
)
654654

655655
if weights is not None:
656-
model.load_state_dict(weights.get_state_dict(progress=progress))
656+
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
657657

658658
return model
659659

@@ -694,7 +694,7 @@ def _fasterrcnn_mobilenet_v3_large_fpn(
694694
)
695695

696696
if weights is not None:
697-
model.load_state_dict(weights.get_state_dict(progress=progress))
697+
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
698698

699699
return model
700700

torchvision/models/detection/fcos.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -766,6 +766,6 @@ def fcos_resnet50_fpn(
766766
model = FCOS(backbone, num_classes, **kwargs)
767767

768768
if weights is not None:
769-
model.load_state_dict(weights.get_state_dict(progress=progress))
769+
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
770770

771771
return model

torchvision/models/detection/keypoint_rcnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,7 @@ def keypointrcnn_resnet50_fpn(
465465
model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs)
466466

467467
if weights is not None:
468-
model.load_state_dict(weights.get_state_dict(progress=progress))
468+
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
469469
if weights == KeypointRCNN_ResNet50_FPN_Weights.COCO_V1:
470470
overwrite_eps(model, 0.0)
471471

torchvision/models/detection/mask_rcnn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,7 @@ def maskrcnn_resnet50_fpn(
501501
model = MaskRCNN(backbone, num_classes=num_classes, **kwargs)
502502

503503
if weights is not None:
504-
model.load_state_dict(weights.get_state_dict(progress=progress))
504+
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
505505
if weights == MaskRCNN_ResNet50_FPN_Weights.COCO_V1:
506506
overwrite_eps(model, 0.0)
507507

@@ -582,6 +582,6 @@ def maskrcnn_resnet50_fpn_v2(
582582
)
583583

584584
if weights is not None:
585-
model.load_state_dict(weights.get_state_dict(progress=progress))
585+
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
586586

587587
return model

torchvision/models/detection/retinanet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -815,7 +815,7 @@ def retinanet_resnet50_fpn(
815815
model = RetinaNet(backbone, num_classes, **kwargs)
816816

817817
if weights is not None:
818-
model.load_state_dict(weights.get_state_dict(progress=progress))
818+
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
819819
if weights == RetinaNet_ResNet50_FPN_Weights.COCO_V1:
820820
overwrite_eps(model, 0.0)
821821

@@ -894,6 +894,6 @@ def retinanet_resnet50_fpn_v2(
894894
model = RetinaNet(backbone, num_classes, anchor_generator=anchor_generator, head=head, **kwargs)
895895

896896
if weights is not None:
897-
model.load_state_dict(weights.get_state_dict(progress=progress))
897+
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
898898

899899
return model

torchvision/models/detection/ssd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -677,6 +677,6 @@ def ssd300_vgg16(
677677
model = SSD(backbone, anchor_generator, (300, 300), num_classes, **kwargs)
678678

679679
if weights is not None:
680-
model.load_state_dict(weights.get_state_dict(progress=progress))
680+
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
681681

682682
return model

0 commit comments

Comments
 (0)