diff --git a/apps/setting/models_provider/base_model_provider.py b/apps/setting/models_provider/base_model_provider.py index 9171e036cfc..c4722c9f59d 100644 --- a/apps/setting/models_provider/base_model_provider.py +++ b/apps/setting/models_provider/base_model_provider.py @@ -93,6 +93,14 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** def is_cache_model(): return True + @staticmethod + def filter_optional_params(model_kwargs): + optional_params = {} + for key, value in model_kwargs.items(): + if key not in ['model_id', 'use_local', 'streaming']: + optional_params[key] = value + return optional_params + class BaseModelCredential(ABC): diff --git a/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/llm.py b/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/llm.py index bfaf9f17b8d..950cd2b3f3c 100644 --- a/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/llm.py @@ -40,12 +40,7 @@ def __init__(self, model_id: str, region_name: str, credentials_profile_name: st @classmethod def new_instance(cls, model_type: str, model_name: str, model_credential: Dict[str, str], **model_kwargs) -> 'BedrockModel': - optional_params = {} - if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: - keyword = get_max_tokens_keyword(model_name) - optional_params[keyword] = model_kwargs['max_tokens'] - if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: - optional_params['temperature'] = model_kwargs['temperature'] + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) return cls( model_id=model_name, diff --git a/apps/setting/models_provider/impl/azure_model_provider/model/azure_chat_model.py b/apps/setting/models_provider/impl/azure_model_provider/model/azure_chat_model.py index 110d72da890..0996c328909 100644 --- a/apps/setting/models_provider/impl/azure_model_provider/model/azure_chat_model.py +++ b/apps/setting/models_provider/impl/azure_model_provider/model/azure_chat_model.py @@ -26,11 +26,7 @@ def is_cache_model(): @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - optional_params = {} - if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: - optional_params['max_tokens'] = model_kwargs['max_tokens'] - if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: - optional_params['temperature'] = model_kwargs['temperature'] + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) return AzureChatModel( azure_endpoint=model_credential.get('api_base'), diff --git a/apps/setting/models_provider/impl/deepseek_model_provider/model/llm.py b/apps/setting/models_provider/impl/deepseek_model_provider/model/llm.py index 94c7d4899d3..ac8dff4bd63 100644 --- a/apps/setting/models_provider/impl/deepseek_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/deepseek_model_provider/model/llm.py @@ -20,11 +20,7 @@ def is_cache_model(): @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - optional_params = {} - if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: - optional_params['max_tokens'] = model_kwargs['max_tokens'] - if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: - optional_params['temperature'] = model_kwargs['temperature'] + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) deepseek_chat_open_ai = DeepSeekChatModel( model=model_name, diff --git a/apps/setting/models_provider/impl/gemini_model_provider/model/llm.py b/apps/setting/models_provider/impl/gemini_model_provider/model/llm.py index e9174d3a6df..68d5e112859 100644 --- a/apps/setting/models_provider/impl/gemini_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/gemini_model_provider/model/llm.py @@ -30,11 +30,7 @@ def is_cache_model(): @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - optional_params = {} - if 'temperature' in model_kwargs: - optional_params['temperature'] = model_kwargs['temperature'] - if 'max_tokens' in model_kwargs: - optional_params['max_output_tokens'] = model_kwargs['max_tokens'] + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) gemini_chat = GeminiChatModel( model=model_name, diff --git a/apps/setting/models_provider/impl/kimi_model_provider/model/llm.py b/apps/setting/models_provider/impl/kimi_model_provider/model/llm.py index 2dd117fa5c3..c5f7b62b649 100644 --- a/apps/setting/models_provider/impl/kimi_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/kimi_model_provider/model/llm.py @@ -20,11 +20,7 @@ def is_cache_model(): @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - optional_params = {} - if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: - optional_params['max_tokens'] = model_kwargs['max_tokens'] - if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: - optional_params['temperature'] = model_kwargs['temperature'] + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) kimi_chat_open_ai = KimiChatModel( openai_api_base=model_credential['api_base'], diff --git a/apps/setting/models_provider/impl/ollama_model_provider/model/llm.py b/apps/setting/models_provider/impl/ollama_model_provider/model/llm.py index abb5fcb8006..7c98f7e5cef 100644 --- a/apps/setting/models_provider/impl/ollama_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/ollama_model_provider/model/llm.py @@ -34,11 +34,7 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** api_base = model_credential.get('api_base', '') base_url = get_base_url(api_base) base_url = base_url if base_url.endswith('/v1') else (base_url + '/v1') - optional_params = {} - if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: - optional_params['max_tokens'] = model_kwargs['max_tokens'] - if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: - optional_params['temperature'] = model_kwargs['temperature'] + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) return OllamaChatModel(model=model_name, openai_api_base=base_url, openai_api_key=model_credential.get('api_key'), diff --git a/apps/setting/models_provider/impl/openai_model_provider/model/llm.py b/apps/setting/models_provider/impl/openai_model_provider/model/llm.py index a78ecc0c007..c5b5694e294 100644 --- a/apps/setting/models_provider/impl/openai_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/openai_model_provider/model/llm.py @@ -30,11 +30,7 @@ def is_cache_model(): @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - optional_params = {} - if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: - optional_params['max_tokens'] = model_kwargs['max_tokens'] - if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: - optional_params['temperature'] = model_kwargs['temperature'] + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) azure_chat_open_ai = OpenAIChatModel( model=model_name, openai_api_base=model_credential.get('api_base'), diff --git a/apps/setting/models_provider/impl/qwen_model_provider/model/llm.py b/apps/setting/models_provider/impl/qwen_model_provider/model/llm.py index 8ac8347aad9..1336cb05b91 100644 --- a/apps/setting/models_provider/impl/qwen_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/qwen_model_provider/model/llm.py @@ -27,11 +27,7 @@ def is_cache_model(): @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - optional_params = {} - if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: - optional_params['max_tokens'] = model_kwargs['max_tokens'] - if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: - optional_params['temperature'] = model_kwargs['temperature'] + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) chat_tong_yi = QwenChatModel( model_name=model_name, dashscope_api_key=model_credential.get('api_key'), diff --git a/apps/setting/models_provider/impl/tencent_model_provider/model/llm.py b/apps/setting/models_provider/impl/tencent_model_provider/model/llm.py index cfe673b5096..17023f32eea 100644 --- a/apps/setting/models_provider/impl/tencent_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/tencent_model_provider/model/llm.py @@ -18,9 +18,7 @@ def __init__(self, model_name: str, credentials: Dict[str, str], streaming: bool hunyuan_secret_id = credentials.get('hunyuan_secret_id') hunyuan_secret_key = credentials.get('hunyuan_secret_key') - optional_params = {} - if 'temperature' in kwargs: - optional_params['temperature'] = kwargs['temperature'] + optional_params = MaxKBBaseModel.filter_optional_params(kwargs) if not all([hunyuan_app_id, hunyuan_secret_id, hunyuan_secret_key]): raise ValueError( diff --git a/apps/setting/models_provider/impl/vllm_model_provider/model/llm.py b/apps/setting/models_provider/impl/vllm_model_provider/model/llm.py index 5c498f8e079..d03eb722935 100644 --- a/apps/setting/models_provider/impl/vllm_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/vllm_model_provider/model/llm.py @@ -22,11 +22,7 @@ def is_cache_model(): @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - optional_params = {} - if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: - optional_params['max_tokens'] = model_kwargs['max_tokens'] - if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: - optional_params['temperature'] = model_kwargs['temperature'] + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) vllm_chat_open_ai = VllmChatModel( model=model_name, openai_api_base=model_credential.get('api_base'), diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/llm.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/llm.py index c549710e5db..181ad2971db 100644 --- a/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/llm.py @@ -12,11 +12,7 @@ def is_cache_model(): @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - optional_params = {} - if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: - optional_params['max_tokens'] = model_kwargs['max_tokens'] - if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: - optional_params['temperature'] = model_kwargs['temperature'] + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) return VolcanicEngineChatModel( model=model_name, openai_api_base=model_credential.get('api_base'), diff --git a/apps/setting/models_provider/impl/wenxin_model_provider/model/llm.py b/apps/setting/models_provider/impl/wenxin_model_provider/model/llm.py index 8c634c6d03c..e9b69d7814f 100644 --- a/apps/setting/models_provider/impl/wenxin_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/wenxin_model_provider/model/llm.py @@ -26,11 +26,7 @@ def is_cache_model(): @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - optional_params = {} - if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: - optional_params['max_output_tokens'] = model_kwargs['max_tokens'] - if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: - optional_params['temperature'] = model_kwargs['temperature'] + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) return QianfanChatModel(model=model_name, qianfan_ak=model_credential.get('api_key'), qianfan_sk=model_credential.get('secret_key'), diff --git a/apps/setting/models_provider/impl/xf_model_provider/model/llm.py b/apps/setting/models_provider/impl/xf_model_provider/model/llm.py index 598af07acce..7c3b39d316a 100644 --- a/apps/setting/models_provider/impl/xf_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/xf_model_provider/model/llm.py @@ -6,11 +6,10 @@ @date:2024/04/19 15:55 @desc: """ -import json from typing import List, Optional, Any, Iterator, Dict -from langchain_community.chat_models.sparkllm import _convert_message_to_dict, _convert_delta_to_message_chunk, \ - ChatSparkLLM +from langchain_community.chat_models.sparkllm import \ + ChatSparkLLM, _convert_message_to_dict, _convert_delta_to_message_chunk from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.messages import BaseMessage, AIMessageChunk from langchain_core.outputs import ChatGenerationChunk @@ -25,11 +24,7 @@ def is_cache_model(): @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - optional_params = {} - if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: - optional_params['max_tokens'] = model_kwargs['max_tokens'] - if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: - optional_params['temperature'] = model_kwargs['temperature'] + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) return XFChatSparkLLM( spark_app_id=model_credential.get('spark_app_id'), spark_api_key=model_credential.get('spark_api_key'), diff --git a/apps/setting/models_provider/impl/xinference_model_provider/model/llm.py b/apps/setting/models_provider/impl/xinference_model_provider/model/llm.py index b0fb4d16e0c..16996b90780 100644 --- a/apps/setting/models_provider/impl/xinference_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/xinference_model_provider/model/llm.py @@ -1,8 +1,12 @@ # coding=utf-8 -from typing import Dict +from typing import Dict, Optional, List, Any, Iterator from urllib.parse import urlparse, ParseResult +from langchain_core.language_models import LanguageModelInput +from langchain_core.messages import BaseMessageChunk +from langchain_core.runnables import RunnableConfig + from setting.models_provider.base_model_provider import MaxKBBaseModel from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI @@ -26,11 +30,7 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** api_base = model_credential.get('api_base', '') base_url = get_base_url(api_base) base_url = base_url if base_url.endswith('/v1') else (base_url + '/v1') - optional_params = {} - if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: - optional_params['max_tokens'] = model_kwargs['max_tokens'] - if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: - optional_params['temperature'] = model_kwargs['temperature'] + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) return XinferenceChatModel( model=model_name, openai_api_base=base_url, diff --git a/apps/setting/models_provider/impl/zhipu_model_provider/model/llm.py b/apps/setting/models_provider/impl/zhipu_model_provider/model/llm.py index c86c2e3a3d8..03699321c82 100644 --- a/apps/setting/models_provider/impl/zhipu_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/zhipu_model_provider/model/llm.py @@ -7,43 +7,41 @@ @desc: """ -from langchain_community.chat_models import ChatZhipuAI -from langchain_community.chat_models.zhipuai import _truncate_params, _get_jwt_token, connect_sse, \ - _convert_delta_to_message_chunk -from setting.models_provider.base_model_provider import MaxKBBaseModel import json from collections.abc import Iterator from typing import Any, Dict, List, Optional +from langchain_community.chat_models import ChatZhipuAI +from langchain_community.chat_models.zhipuai import _truncate_params, _get_jwt_token, connect_sse, \ + _convert_delta_to_message_chunk from langchain_core.callbacks import ( CallbackManagerForLLMRun, ) - from langchain_core.messages import ( AIMessageChunk, BaseMessage ) from langchain_core.outputs import ChatGenerationChunk +from setting.models_provider.base_model_provider import MaxKBBaseModel + class ZhipuChatModel(MaxKBBaseModel, ChatZhipuAI): + optional_params: dict + @staticmethod def is_cache_model(): return False @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - optional_params = {} - if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: - optional_params['max_tokens'] = model_kwargs['max_tokens'] - if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: - optional_params['temperature'] = model_kwargs['temperature'] - + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) zhipuai_chat = ZhipuChatModel( api_key=model_credential.get('api_key'), model=model_name, streaming=model_kwargs.get('streaming', False), - **optional_params + optional_params=optional_params, + **optional_params, ) return zhipuai_chat @@ -71,7 +69,7 @@ def _stream( if self.zhipuai_api_base is None: raise ValueError("Did not find zhipu_api_base.") message_dicts, params = self._create_message_dicts(messages, stop) - payload = {**params, **kwargs, "messages": message_dicts, "stream": True} + payload = {**params, **kwargs, **self.optional_params, "messages": message_dicts, "stream": True} _truncate_params(payload) headers = { "Authorization": _get_jwt_token(self.zhipuai_api_key),