Skip to content

Commit 708d87a

Browse files
committed
Fix ViT SAM weight compat as weights at URL changed to not use repr layer. Fix #825. Tweak optim test.
1 parent acd6c68 commit 708d87a

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

tests/test_optim.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ def test_sgd(optimizer):
320320
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3, momentum=1)
321321
)
322322
_test_basic_cases(
323-
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3, momentum=1, weight_decay=1)
323+
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3, momentum=1, weight_decay=.1)
324324
)
325325
_test_rosenbrock(
326326
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)

timm/models/vision_transformer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -683,7 +683,8 @@ def vit_large_patch16_384(pretrained=False, **kwargs):
683683
def vit_base_patch16_sam_224(pretrained=False, **kwargs):
684684
""" ViT-Base (ViT-B/16) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548
685685
"""
686-
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
686+
# NOTE original SAM weights releaes worked with representation_size=768
687+
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, representation_size=0, **kwargs)
687688
model = _create_vision_transformer('vit_base_patch16_sam_224', pretrained=pretrained, **model_kwargs)
688689
return model
689690

@@ -692,7 +693,8 @@ def vit_base_patch16_sam_224(pretrained=False, **kwargs):
692693
def vit_base_patch32_sam_224(pretrained=False, **kwargs):
693694
""" ViT-Base (ViT-B/32) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548
694695
"""
695-
model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
696+
# NOTE original SAM weights releaes worked with representation_size=768
697+
model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, representation_size=0, **kwargs)
696698
model = _create_vision_transformer('vit_base_patch32_sam_224', pretrained=pretrained, **model_kwargs)
697699
return model
698700

0 commit comments

Comments
 (0)