Skip to content

Commit 989ecd2

Browse files
authored
[Misc] Gemma3ForConditionalGeneration supports LoRA (#14797)
Signed-off-by: Jee Jee Li <[email protected]>
1 parent 54cc46f commit 989ecd2

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

vllm/model_executor/models/gemma3_mm.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from vllm.logger import init_logger
1313
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
1414
from vllm.model_executor.layers.sampler import SamplerOutput
15+
from vllm.model_executor.models.module_mapping import MultiModelKeys
1516
from vllm.model_executor.sampling_metadata import SamplingMetadata
1617
from vllm.multimodal import MULTIMODAL_REGISTRY
1718
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
@@ -23,7 +24,8 @@
2324
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
2425
from vllm.sequence import IntermediateTensors
2526

26-
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
27+
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
28+
SupportsMultiModal, SupportsPP)
2729
from .siglip import SiglipVisionModel
2830
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
2931
maybe_prefix, merge_multimodal_embeddings)
@@ -371,8 +373,8 @@ def forward(self, vision_outputs: torch.Tensor):
371373
@MULTIMODAL_REGISTRY.register_processor(Gemma3MultiModalProcessor,
372374
info=Gemma3ProcessingInfo,
373375
dummy_inputs=Gemma3DummyInputsBuilder)
374-
class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal,
375-
SupportsPP):
376+
class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
377+
SupportsLoRA):
376378
packed_modules_mapping = {
377379
"qkv_proj": [
378380
"q_proj",
@@ -614,3 +616,12 @@ def load_weights(self, weights: Iterable[Tuple[str,
614616
torch.Tensor]]) -> Set[str]:
615617
loader = AutoWeightsLoader(self)
616618
return loader.load_weights(weights)
619+
620+
def get_mm_mapping(self) -> MultiModelKeys:
621+
"""
622+
Get the module prefix in multimodal models
623+
"""
624+
return MultiModelKeys.from_string_field(
625+
language_model="language_model",
626+
connector="multi_modal_projector",
627+
tower_model="vision_tower")

0 commit comments

Comments
 (0)