|
94 | 94 | _EMBEDDING_MODELS = { |
95 | 95 | # [Text-only] |
96 | 96 | "BertModel": ("bert", "BertEmbeddingModel"), |
| 97 | + "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"), |
97 | 98 | "Gemma2Model": ("gemma2", "Gemma2EmbeddingModel"), |
98 | 99 | "LlamaModel": ("llama", "LlamaEmbeddingModel"), |
| 100 | + **{ |
| 101 | + # Multiple models share the same architecture, so we include them all |
| 102 | + k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items() |
| 103 | + if arch == "LlamaForCausalLM" |
| 104 | + }, |
99 | 105 | "MistralModel": ("llama", "LlamaEmbeddingModel"), |
100 | | - "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"), |
101 | | - "Qwen2ForSequenceClassification": ( |
102 | | - "qwen2_cls", "Qwen2ForSequenceClassification"), |
103 | | - "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), |
104 | 106 | "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), |
105 | | - "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"), |
| 107 | + "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"), |
| 108 | + "Qwen2ForSequenceClassification": ("qwen2_cls", "Qwen2ForSequenceClassification"), # noqa: E501 |
106 | 109 | # [Multimodal] |
107 | 110 | "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501 |
108 | 111 | "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), |
109 | 112 | } |
110 | 113 |
|
111 | | -def add_embedding_models(base_models, embedding_models): |
112 | | - with_pooler_method_models = {} |
113 | | - embedding_models_name = embedding_models.keys() |
114 | | - for name, (path, arch) in base_models.items(): |
115 | | - if arch in embedding_models_name: |
116 | | - with_pooler_method_models[name] = (path, arch) |
117 | | - return with_pooler_method_models |
118 | | - |
119 | | -_EMBEDDING_MODELS = { |
120 | | - **add_embedding_models(_TEXT_GENERATION_MODELS, _EMBEDDING_MODELS), |
121 | | - **_EMBEDDING_MODELS, |
122 | | -} |
123 | | - |
124 | 114 | _MULTIMODAL_MODELS = { |
125 | 115 | # [Decoder-only] |
126 | 116 | "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"), |
|
0 commit comments