Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions vllm/model_executor/models/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,15 +186,21 @@ def __init__(
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
raise NotImplementedError

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
def load_weights(
self,
weights: Iterable[tuple[str, torch.Tensor]],
load_lm_head: bool = False,
):
# TODO: Support uninitialized params tracking

# We have deleted this attribute, so don't load it
weights = (
(name, data)
for name, data in weights
if not name.startswith("lm_head.")
)
# For most pooling models: We have deleted this attribute, so don't load it.
# For converting an LLM into a seq cls model, we need the lm_head.
if not load_lm_head:
weights = (
(name, data)
for name, data in weights
if not name.startswith("lm_head.")
)

# If `*ForCausalLM` defines `load_weights` on the inner model
# and there are no other inner modules with parameters,
Expand Down Expand Up @@ -431,8 +437,12 @@ def load_weights_using_from_2_way_softmax(
)
model.lm_head = model.lm_head.tie_weights(embed_tokens)

# Skip ModelForSequenceClassification in MRO to avoid infinite recursion
loaded_weights = type(model).__mro__[1].load_weights(model, weights)
# ModelForPooling is dynamically defined inside the _create_pooling_model_cls
# function, so we need use this hacky method to obtain it.
pooling_model_cls = [
x for x in type(model).__mro__ if x.__name__ == "ModelForPooling"
][0]
loaded_weights = pooling_model_cls.load_weights(model, weights, load_lm_head=True)

from vllm.transformers_utils.tokenizer import get_tokenizer

Expand Down