@@ -683,7 +683,8 @@ def vit_large_patch16_384(pretrained=False, **kwargs):
683683def vit_base_patch16_sam_224 (pretrained = False , ** kwargs ):
684684 """ ViT-Base (ViT-B/16) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548
685685 """
686- model_kwargs = dict (patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , representation_size = 768 , ** kwargs )
686+ # NOTE original SAM weights releaes worked with representation_size=768
687+ model_kwargs = dict (patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , representation_size = 0 , ** kwargs )
687688 model = _create_vision_transformer ('vit_base_patch16_sam_224' , pretrained = pretrained , ** model_kwargs )
688689 return model
689690
@@ -692,7 +693,8 @@ def vit_base_patch16_sam_224(pretrained=False, **kwargs):
692693def vit_base_patch32_sam_224 (pretrained = False , ** kwargs ):
693694 """ ViT-Base (ViT-B/32) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548
694695 """
695- model_kwargs = dict (patch_size = 32 , embed_dim = 768 , depth = 12 , num_heads = 12 , representation_size = 768 , ** kwargs )
696+ # NOTE original SAM weights releaes worked with representation_size=768
697+ model_kwargs = dict (patch_size = 32 , embed_dim = 768 , depth = 12 , num_heads = 12 , representation_size = 0 , ** kwargs )
696698 model = _create_vision_transformer ('vit_base_patch32_sam_224' , pretrained = pretrained , ** model_kwargs )
697699 return model
698700
0 commit comments