Skip to content

Commit 78f9ad4

Browse files
committed
Always check hash
1 parent 9acc3ee commit 78f9ad4

17 files changed

+17
-17
lines changed

torchvision/models/alexnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,6 @@ def alexnet(*, weights: Optional[AlexNet_Weights] = None, progress: bool = True,
114114
model = AlexNet(**kwargs)
115115

116116
if weights is not None:
117-
model.load_state_dict(weights.get_state_dict(progress=progress))
117+
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
118118

119119
return model

torchvision/models/convnext.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def _convnext(
189189
model = ConvNeXt(block_setting, stochastic_depth_prob=stochastic_depth_prob, **kwargs)
190190

191191
if weights is not None:
192-
model.load_state_dict(weights.get_state_dict(progress=progress))
192+
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
193193

194194
return model
195195

torchvision/models/densenet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def _load_state_dict(model: nn.Module, weights: WeightsEnum, progress: bool) ->
227227
r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$"
228228
)
229229

230-
state_dict = weights.get_state_dict(progress=progress)
230+
state_dict = weights.get_state_dict(progress=progress, check_hash=True)
231231
for key in list(state_dict.keys()):
232232
res = pattern.match(key)
233233
if res:

torchvision/models/efficientnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ def _efficientnet(
357357
model = EfficientNet(inverted_residual_setting, dropout, last_channel=last_channel, **kwargs)
358358

359359
if weights is not None:
360-
model.load_state_dict(weights.get_state_dict(progress=progress))
360+
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
361361

362362
return model
363363

torchvision/models/googlenet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ def googlenet(*, weights: Optional[GoogLeNet_Weights] = None, progress: bool = T
332332
model = GoogLeNet(**kwargs)
333333

334334
if weights is not None:
335-
model.load_state_dict(weights.get_state_dict(progress=progress))
335+
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
336336
if not original_aux_logits:
337337
model.aux_logits = False
338338
model.aux1 = None # type: ignore[assignment]

torchvision/models/inception.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,7 @@ def inception_v3(*, weights: Optional[Inception_V3_Weights] = None, progress: bo
470470
model = Inception3(**kwargs)
471471

472472
if weights is not None:
473-
model.load_state_dict(weights.get_state_dict(progress=progress))
473+
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
474474
if not original_aux_logits:
475475
model.aux_logits = False
476476
model.AuxLogits = None

torchvision/models/maxvit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -763,7 +763,7 @@ def _maxvit(
763763
)
764764

765765
if weights is not None:
766-
model.load_state_dict(weights.get_state_dict(progress=progress))
766+
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
767767

768768
return model
769769

torchvision/models/mnasnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ def _mnasnet(alpha: float, weights: Optional[WeightsEnum], progress: bool, **kwa
317317
model = MNASNet(alpha, **kwargs)
318318

319319
if weights:
320-
model.load_state_dict(weights.get_state_dict(progress=progress))
320+
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
321321

322322
return model
323323

torchvision/models/mobilenetv2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,6 @@ def mobilenet_v2(
255255
model = MobileNetV2(**kwargs)
256256

257257
if weights is not None:
258-
model.load_state_dict(weights.get_state_dict(progress=progress))
258+
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
259259

260260
return model

torchvision/models/mobilenetv3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ def _mobilenet_v3(
282282
model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs)
283283

284284
if weights is not None:
285-
model.load_state_dict(weights.get_state_dict(progress=progress))
285+
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
286286

287287
return model
288288

0 commit comments

Comments
 (0)