@@ -443,27 +443,26 @@ def __init__(
443443 self .config = config
444444 embed_dim = config .hidden_size
445445
446- if (num_hidden_layers_override is None
447- or num_hidden_layers_override == config .num_hidden_layers ):
448- self .need_post_layernorm = True
449- elif num_hidden_layers_override > config .num_hidden_layers :
450- raise ValueError (
451- "num_hidden_layers_override cannot be greater than "
452- "num_hidden_layers" )
453- else :
454- self .need_post_layernorm = False
455-
456446 self .embeddings = SiglipVisionEmbeddings (config )
457447 self .encoder = SiglipEncoder (
458448 config ,
459449 quant_config = quant_config ,
460450 num_hidden_layers_override = num_hidden_layers_override ,
461451 )
462- if self .need_post_layernorm :
452+
453+ if len (self .encoder .layers ) > config .num_hidden_layers :
454+ raise ValueError (
455+ f"The original encoder only has { config .num_hidden_layers } "
456+ f"layers, but you requested { len (self .encoder .layers )} layers."
457+ )
458+ elif len (self .encoder .layers ) == config .num_hidden_layers :
463459 self .post_layernorm = nn .LayerNorm (embed_dim ,
464460 eps = config .layer_norm_eps )
465461 else :
466- self .post_layernorm = nn .Identity ()
462+ # post_layernorm is unused when we extract intermediate features
463+ # In this case, we can skip it to conserve memory
464+ self .post_layernorm = None
465+
467466 self .use_head = (True if not hasattr (config , "vision_use_head" ) else
468467 config .vision_use_head )
469468 if self .use_head :
@@ -482,6 +481,9 @@ def forward(
482481
483482 encoder_outputs = self .encoder (inputs_embeds = hidden_states )
484483
484+ if self .post_layernorm is None :
485+ return encoder_outputs
486+
485487 last_hidden_state = self .post_layernorm (encoder_outputs )
486488 # TODO: add this back when pooled_output is used in inference
487489 # if self.use_head:
@@ -512,8 +514,8 @@ def __init__(
512514 )
513515
514516 @property
515- def need_post_layernorm (self ):
516- return self .vision_model .need_post_layernorm
517+ def _require_post_layernorm (self ) -> bool :
518+ return self .vision_model .post_layernorm is not None
517519
518520 def get_input_embeddings (self ) -> nn .Module :
519521 return self .vision_model .embeddings .patch_embedding
@@ -541,7 +543,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
541543 for name , loaded_weight in weights :
542544 # post_layernorm is optional in SiglipVisionModel
543545 if ("vision_model.post_layernorm" in name
544- and not self .need_post_layernorm ):
546+ and not self ._require_post_layernorm ):
545547 continue
546548
547549 # omit layers when num_hidden_layers_override is set
0 commit comments