Skip to content

Commit 36efb58

Browse files
committed
feat: 大语言模型支持自定义参数入参
1 parent cbdf4b7 commit 36efb58

File tree

16 files changed

+40
-86
lines changed

16 files changed

+40
-86
lines changed

apps/setting/models_provider/base_model_provider.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,14 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
9393
def is_cache_model():
9494
return True
9595

96+
@staticmethod
97+
def filter_optional_params(model_kwargs):
98+
optional_params = {}
99+
for key, value in model_kwargs.items():
100+
if key not in ['model_id', 'use_local', 'streaming']:
101+
optional_params[key] = value
102+
return optional_params
103+
96104

97105
class BaseModelCredential(ABC):
98106

apps/setting/models_provider/impl/aws_bedrock_model_provider/model/llm.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,7 @@ def __init__(self, model_id: str, region_name: str, credentials_profile_name: st
4040
@classmethod
4141
def new_instance(cls, model_type: str, model_name: str, model_credential: Dict[str, str],
4242
**model_kwargs) -> 'BedrockModel':
43-
optional_params = {}
44-
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
45-
keyword = get_max_tokens_keyword(model_name)
46-
optional_params[keyword] = model_kwargs['max_tokens']
47-
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
48-
optional_params['temperature'] = model_kwargs['temperature']
43+
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
4944

5045
return cls(
5146
model_id=model_name,

apps/setting/models_provider/impl/azure_model_provider/model/azure_chat_model.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,7 @@ def is_cache_model():
2626

2727
@staticmethod
2828
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
29-
optional_params = {}
30-
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
31-
optional_params['max_tokens'] = model_kwargs['max_tokens']
32-
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
33-
optional_params['temperature'] = model_kwargs['temperature']
29+
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
3430

3531
return AzureChatModel(
3632
azure_endpoint=model_credential.get('api_base'),

apps/setting/models_provider/impl/deepseek_model_provider/model/llm.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,7 @@ def is_cache_model():
2020

2121
@staticmethod
2222
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
23-
optional_params = {}
24-
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
25-
optional_params['max_tokens'] = model_kwargs['max_tokens']
26-
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
27-
optional_params['temperature'] = model_kwargs['temperature']
23+
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
2824

2925
deepseek_chat_open_ai = DeepSeekChatModel(
3026
model=model_name,

apps/setting/models_provider/impl/gemini_model_provider/model/llm.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,7 @@ def is_cache_model():
3030

3131
@staticmethod
3232
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
33-
optional_params = {}
34-
if 'temperature' in model_kwargs:
35-
optional_params['temperature'] = model_kwargs['temperature']
36-
if 'max_tokens' in model_kwargs:
37-
optional_params['max_output_tokens'] = model_kwargs['max_tokens']
33+
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
3834

3935
gemini_chat = GeminiChatModel(
4036
model=model_name,

apps/setting/models_provider/impl/kimi_model_provider/model/llm.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,7 @@ def is_cache_model():
2020

2121
@staticmethod
2222
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
23-
optional_params = {}
24-
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
25-
optional_params['max_tokens'] = model_kwargs['max_tokens']
26-
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
27-
optional_params['temperature'] = model_kwargs['temperature']
23+
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
2824

2925
kimi_chat_open_ai = KimiChatModel(
3026
openai_api_base=model_credential['api_base'],

apps/setting/models_provider/impl/ollama_model_provider/model/llm.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,7 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
3434
api_base = model_credential.get('api_base', '')
3535
base_url = get_base_url(api_base)
3636
base_url = base_url if base_url.endswith('/v1') else (base_url + '/v1')
37-
optional_params = {}
38-
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
39-
optional_params['max_tokens'] = model_kwargs['max_tokens']
40-
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
41-
optional_params['temperature'] = model_kwargs['temperature']
37+
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
4238

4339
return OllamaChatModel(model=model_name, openai_api_base=base_url,
4440
openai_api_key=model_credential.get('api_key'),

apps/setting/models_provider/impl/openai_model_provider/model/llm.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,7 @@ def is_cache_model():
3030

3131
@staticmethod
3232
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
33-
optional_params = {}
34-
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
35-
optional_params['max_tokens'] = model_kwargs['max_tokens']
36-
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
37-
optional_params['temperature'] = model_kwargs['temperature']
33+
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
3834
azure_chat_open_ai = OpenAIChatModel(
3935
model=model_name,
4036
openai_api_base=model_credential.get('api_base'),

apps/setting/models_provider/impl/qwen_model_provider/model/llm.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,7 @@ def is_cache_model():
2727

2828
@staticmethod
2929
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
30-
optional_params = {}
31-
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
32-
optional_params['max_tokens'] = model_kwargs['max_tokens']
33-
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
34-
optional_params['temperature'] = model_kwargs['temperature']
30+
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
3531
chat_tong_yi = QwenChatModel(
3632
model_name=model_name,
3733
dashscope_api_key=model_credential.get('api_key'),

apps/setting/models_provider/impl/tencent_model_provider/model/llm.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,7 @@ def __init__(self, model_name: str, credentials: Dict[str, str], streaming: bool
1818
hunyuan_secret_id = credentials.get('hunyuan_secret_id')
1919
hunyuan_secret_key = credentials.get('hunyuan_secret_key')
2020

21-
optional_params = {}
22-
if 'temperature' in kwargs:
23-
optional_params['temperature'] = kwargs['temperature']
21+
optional_params = MaxKBBaseModel.filter_optional_params(kwargs)
2422

2523
if not all([hunyuan_app_id, hunyuan_secret_id, hunyuan_secret_key]):
2624
raise ValueError(

0 commit comments

Comments
 (0)