|
1 | | -import itertools |
2 | 1 | from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, |
3 | 2 | TypedDict, Union) |
4 | 3 |
|
|
30 | 29 | from .siglip import (SiglipVisionModel, dummy_image_for_siglip, |
31 | 30 | dummy_seq_data_for_siglip, get_siglip_image_feature_size, |
32 | 31 | get_siglip_patch_grid_length, input_processor_for_siglip) |
33 | | -from .utils import (filter_weights, flatten_bn, init_vllm_registered_model, |
34 | | - merge_multimodal_embeddings) |
| 32 | +from .utils import (flatten_bn, group_weights_with_prefix, |
| 33 | + init_vllm_registered_model, merge_multimodal_embeddings) |
35 | 34 |
|
36 | 35 | logger = init_logger(__name__) |
37 | 36 |
|
@@ -637,31 +636,26 @@ def sample( |
637 | 636 |
|
638 | 637 | def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): |
639 | 638 | # prepare weight iterators for components |
640 | | - vit_weights, mlp_weights, newline_weights, llm_weights = itertools.tee( |
641 | | - weights, 4) |
| 639 | + weights_group = group_weights_with_prefix(weights) |
642 | 640 |
|
643 | 641 | # load vision encoder |
644 | | - vit_weights = filter_weights(vit_weights, "vision_tower") |
645 | | - self.vision_tower.load_weights(vit_weights) |
| 642 | + self.vision_tower.load_weights(weights_group["vision_tower"]) |
646 | 643 |
|
647 | 644 | # load mlp projector |
648 | | - mlp_weights = filter_weights(mlp_weights, "multi_modal_projector") |
649 | 645 | mlp_params_dict = dict(self.multi_modal_projector.named_parameters()) |
650 | | - for name, loaded_weight in mlp_weights: |
| 646 | + for name, loaded_weight in weights_group["multi_modal_projector"]: |
651 | 647 | param = mlp_params_dict[name] |
652 | 648 | weight_loader = getattr(param, "weight_loader", |
653 | 649 | default_weight_loader) |
654 | 650 | weight_loader(param, loaded_weight) |
655 | 651 |
|
656 | 652 | # load newline |
657 | | - newline_weights = filter_weights(newline_weights, "image_newline") |
658 | | - for name, loaded_weight in newline_weights: |
| 653 | + for name, loaded_weight in weights_group["image_newline"]: |
659 | 654 | assert name == "" |
660 | 655 | param = self.image_newline |
661 | 656 | weight_loader = getattr(param, "weight_loader", |
662 | 657 | default_weight_loader) |
663 | 658 | weight_loader(param, loaded_weight) |
664 | 659 |
|
665 | 660 | # load llm backbone |
666 | | - llm_weights = filter_weights(llm_weights, "language_model") |
667 | | - self.language_model.load_weights(llm_weights) |
| 661 | + self.language_model.load_weights(weights_group["language_model"]) |
0 commit comments