Skip to content

Commit da1a844

Browse files
[Bugfix] Fix missing post_layernorm in CLIP (#8155)
1 parent a1d8742 commit da1a844

File tree

2 files changed

+42
-19
lines changed

2 files changed

+42
-19
lines changed

vllm/model_executor/models/clip.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,19 @@ def __init__(self,
355355
quant_config=quant_config,
356356
num_hidden_layers_override=num_hidden_layers_override)
357357

358+
if len(self.encoder.layers) > config.num_hidden_layers:
359+
raise ValueError(
360+
f"The original encoder only has {config.num_hidden_layers} "
361+
f"layers, but you requested {len(self.encoder.layers)} layers."
362+
)
363+
elif len(self.encoder.layers) == config.num_hidden_layers:
364+
self.post_layernorm = nn.LayerNorm(embed_dim,
365+
eps=config.layer_norm_eps)
366+
else:
367+
# post_layernorm is unused when we extract intermediate features
368+
# In this case, we can skip it to conserve memory
369+
self.post_layernorm = None
370+
358371
def forward(
359372
self,
360373
pixel_values: torch.Tensor,
@@ -364,7 +377,10 @@ def forward(
364377
hidden_states = self.pre_layrnorm(hidden_states)
365378
hidden_states = self.encoder(inputs_embeds=hidden_states)
366379

367-
return hidden_states
380+
if self.post_layernorm is None:
381+
return hidden_states
382+
383+
return self.post_layernorm(hidden_states)
368384

369385

370386
class CLIPVisionModel(nn.Module):
@@ -386,9 +402,12 @@ def __init__(self,
386402
quant_config=quant_config,
387403
num_hidden_layers_override=num_hidden_layers_override)
388404

389-
def forward(self, pixel_values: Optional[torch.Tensor] = None):
405+
@property
406+
def _require_post_layernorm(self) -> bool:
407+
return self.vision_model.post_layernorm is not None
390408

391-
return self.vision_model(pixel_values=pixel_values)
409+
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
410+
return self.vision_model(pixel_values)
392411

393412
@property
394413
def device(self):
@@ -408,8 +427,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
408427

409428
for name, loaded_weight in weights:
410429
# post_layernorm is not needed in CLIPVisionModel
411-
if "vision_model.post_layernorm" in name:
430+
if ("vision_model.post_layernorm" in name
431+
and not self._require_post_layernorm):
412432
continue
433+
413434
# omit layers when num_hidden_layers_override is set
414435
if "vision_model.encoder.layers." in name:
415436
layer_idx = int(name.split(".")[3])

vllm/model_executor/models/siglip.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -443,27 +443,26 @@ def __init__(
443443
self.config = config
444444
embed_dim = config.hidden_size
445445

446-
if (num_hidden_layers_override is None
447-
or num_hidden_layers_override == config.num_hidden_layers):
448-
self.need_post_layernorm = True
449-
elif num_hidden_layers_override > config.num_hidden_layers:
450-
raise ValueError(
451-
"num_hidden_layers_override cannot be greater than "
452-
"num_hidden_layers")
453-
else:
454-
self.need_post_layernorm = False
455-
456446
self.embeddings = SiglipVisionEmbeddings(config)
457447
self.encoder = SiglipEncoder(
458448
config,
459449
quant_config=quant_config,
460450
num_hidden_layers_override=num_hidden_layers_override,
461451
)
462-
if self.need_post_layernorm:
452+
453+
if len(self.encoder.layers) > config.num_hidden_layers:
454+
raise ValueError(
455+
f"The original encoder only has {config.num_hidden_layers} "
456+
f"layers, but you requested {len(self.encoder.layers)} layers."
457+
)
458+
elif len(self.encoder.layers) == config.num_hidden_layers:
463459
self.post_layernorm = nn.LayerNorm(embed_dim,
464460
eps=config.layer_norm_eps)
465461
else:
466-
self.post_layernorm = nn.Identity()
462+
# post_layernorm is unused when we extract intermediate features
463+
# In this case, we can skip it to conserve memory
464+
self.post_layernorm = None
465+
467466
self.use_head = (True if not hasattr(config, "vision_use_head") else
468467
config.vision_use_head)
469468
if self.use_head:
@@ -482,6 +481,9 @@ def forward(
482481

483482
encoder_outputs = self.encoder(inputs_embeds=hidden_states)
484483

484+
if self.post_layernorm is None:
485+
return encoder_outputs
486+
485487
last_hidden_state = self.post_layernorm(encoder_outputs)
486488
# TODO: add this back when pooled_output is used in inference
487489
# if self.use_head:
@@ -512,8 +514,8 @@ def __init__(
512514
)
513515

514516
@property
515-
def need_post_layernorm(self):
516-
return self.vision_model.need_post_layernorm
517+
def _require_post_layernorm(self) -> bool:
518+
return self.vision_model.post_layernorm is not None
517519

518520
def get_input_embeddings(self) -> nn.Module:
519521
return self.vision_model.embeddings.patch_embedding
@@ -541,7 +543,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
541543
for name, loaded_weight in weights:
542544
# post_layernorm is optional in SiglipVisionModel
543545
if ("vision_model.post_layernorm" in name
544-
and not self.need_post_layernorm):
546+
and not self._require_post_layernorm):
545547
continue
546548

547549
# omit layers when num_hidden_layers_override is set

0 commit comments

Comments
 (0)