Skip to content

Commit c98874e

Browse files
committed
feat: Support Azure image tts stt model
1 parent 82f263b commit c98874e

File tree

9 files changed

+520
-0
lines changed

9 files changed

+520
-0
lines changed

apps/setting/models_provider/impl/azure_model_provider/azure_model_provider.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,25 @@
1212
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, \
1313
ModelTypeConst, ModelInfoManage
1414
from setting.models_provider.impl.azure_model_provider.credential.embedding import AzureOpenAIEmbeddingCredential
15+
from setting.models_provider.impl.azure_model_provider.credential.image import AzureOpenAIImageModelCredential
1516
from setting.models_provider.impl.azure_model_provider.credential.llm import AzureLLMModelCredential
17+
from setting.models_provider.impl.azure_model_provider.credential.stt import AzureOpenAISTTModelCredential
18+
from setting.models_provider.impl.azure_model_provider.credential.tti import AzureOpenAITextToImageModelCredential
19+
from setting.models_provider.impl.azure_model_provider.credential.tts import AzureOpenAITTSModelCredential
1620
from setting.models_provider.impl.azure_model_provider.model.azure_chat_model import AzureChatModel
1721
from setting.models_provider.impl.azure_model_provider.model.embedding import AzureOpenAIEmbeddingModel
22+
from setting.models_provider.impl.azure_model_provider.model.image import AzureOpenAIImage
23+
from setting.models_provider.impl.azure_model_provider.model.stt import AzureOpenAISpeechToText
24+
from setting.models_provider.impl.azure_model_provider.model.tti import AzureOpenAITextToImage
25+
from setting.models_provider.impl.azure_model_provider.model.tts import AzureOpenAITextToSpeech
1826
from smartdoc.conf import PROJECT_DIR
1927

2028
base_azure_llm_model_credential = AzureLLMModelCredential()
2129
base_azure_embedding_model_credential = AzureOpenAIEmbeddingCredential()
30+
base_azure_image_model_credential = AzureOpenAIImageModelCredential()
31+
base_azure_tti_model_credential = AzureOpenAITextToImageModelCredential()
32+
base_azure_tts_model_credential = AzureOpenAITTSModelCredential()
33+
base_azure_stt_model_credential = AzureOpenAISTTModelCredential()
2234

2335
default_model_info = ModelInfo('Azure OpenAI', '具体的基础模型由部署名决定', ModelTypeConst.LLM,
2436
base_azure_llm_model_credential, AzureChatModel, api_version='2024-02-15-preview'
@@ -36,12 +48,47 @@
3648
),
3749
]
3850

51+
image_model_info = [
52+
ModelInfo('gpt-4o', '', ModelTypeConst.IMAGE,
53+
base_azure_image_model_credential, AzureOpenAIImage, api_version='2023-05-15'
54+
),
55+
ModelInfo('gpt-4o-mini', '', ModelTypeConst.IMAGE,
56+
base_azure_image_model_credential, AzureOpenAIImage, api_version='2023-05-15'
57+
),
58+
]
59+
60+
tti_model_info = [
61+
ModelInfo('dall-e-3', '', ModelTypeConst.TTI,
62+
base_azure_tti_model_credential, AzureOpenAITextToImage, api_version='2023-05-15'
63+
),
64+
]
65+
66+
tts_model_info = [
67+
ModelInfo('tts', '', ModelTypeConst.TTS,
68+
base_azure_tts_model_credential, AzureOpenAITextToSpeech, api_version='2023-05-15'
69+
),
70+
]
71+
72+
stt_model_info = [
73+
ModelInfo('whisper', '', ModelTypeConst.STT,
74+
base_azure_stt_model_credential, AzureOpenAISpeechToText, api_version='2023-05-15'
75+
),
76+
]
77+
3978
model_info_manage = (
4079
ModelInfoManage.builder()
4180
.append_default_model_info(default_model_info)
4281
.append_model_info(default_model_info)
4382
.append_model_info_list(embedding_model_info)
4483
.append_default_model_info(embedding_model_info[0])
84+
.append_model_info_list(image_model_info)
85+
.append_default_model_info(image_model_info[0])
86+
.append_model_info_list(stt_model_info)
87+
.append_default_model_info(stt_model_info[0])
88+
.append_model_info_list(tts_model_info)
89+
.append_default_model_info(tts_model_info[0])
90+
.append_model_info_list(tti_model_info)
91+
.append_default_model_info(tti_model_info[0])
4592
.build()
4693
)
4794

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# coding=utf-8
2+
import base64
3+
import os
4+
from typing import Dict
5+
6+
from langchain_core.messages import HumanMessage
7+
8+
from common import forms
9+
from common.exception.app_exception import AppApiException
10+
from common.forms import BaseForm, TooltipLabel
11+
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
12+
13+
class AzureOpenAIImageModelParams(BaseForm):
14+
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
15+
required=True, default_value=0.7,
16+
_min=0.1,
17+
_max=1.0,
18+
_step=0.01,
19+
precision=2)
20+
21+
max_tokens = forms.SliderField(
22+
TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'),
23+
required=True, default_value=800,
24+
_min=1,
25+
_max=100000,
26+
_step=1,
27+
precision=0)
28+
29+
30+
31+
class AzureOpenAIImageModelCredential(BaseForm, BaseModelCredential):
32+
api_version = forms.TextInputField("API 版本 (api_version)", required=True)
33+
api_base = forms.TextInputField('API 域名 (azure_endpoint)', required=True)
34+
api_key = forms.PasswordInputField("API Key (api_key)", required=True)
35+
36+
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
37+
raise_exception=False):
38+
model_type_list = provider.get_model_type_list()
39+
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
40+
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
41+
42+
for key in ['api_base', 'api_key', 'api_version']:
43+
if key not in model_credential:
44+
if raise_exception:
45+
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
46+
else:
47+
return False
48+
try:
49+
model = provider.get_model(model_type, model_name, model_credential)
50+
res = model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])])
51+
for chunk in res:
52+
print(chunk)
53+
except Exception as e:
54+
if isinstance(e, AppApiException):
55+
raise e
56+
if raise_exception:
57+
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
58+
else:
59+
return False
60+
return True
61+
62+
def encryption_dict(self, model: Dict[str, object]):
63+
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
64+
65+
def get_model_params_setting_form(self, model_name):
66+
return AzureOpenAIImageModelParams()
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# coding=utf-8
2+
from typing import Dict
3+
4+
from common import forms
5+
from common.exception.app_exception import AppApiException
6+
from common.forms import BaseForm
7+
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
8+
9+
10+
class AzureOpenAISTTModelCredential(BaseForm, BaseModelCredential):
11+
api_version = forms.TextInputField("API 版本 (api_version)", required=True)
12+
api_base = forms.TextInputField('API 域名 (azure_endpoint)', required=True)
13+
api_key = forms.PasswordInputField("API Key (api_key)", required=True)
14+
15+
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
16+
raise_exception=False):
17+
model_type_list = provider.get_model_type_list()
18+
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
19+
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
20+
21+
for key in ['api_base', 'api_key', 'api_version']:
22+
if key not in model_credential:
23+
if raise_exception:
24+
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
25+
else:
26+
return False
27+
try:
28+
model = provider.get_model(model_type, model_name, model_credential)
29+
model.check_auth()
30+
except Exception as e:
31+
if isinstance(e, AppApiException):
32+
raise e
33+
if raise_exception:
34+
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
35+
else:
36+
return False
37+
return True
38+
39+
def encryption_dict(self, model: Dict[str, object]):
40+
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
41+
42+
def get_model_params_setting_form(self, model_name):
43+
pass
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# coding=utf-8
2+
import base64
3+
import os
4+
from typing import Dict
5+
6+
from langchain_core.messages import HumanMessage
7+
8+
from common import forms
9+
from common.exception.app_exception import AppApiException
10+
from common.forms import BaseForm, TooltipLabel
11+
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
12+
13+
14+
class AzureOpenAITTIModelParams(BaseForm):
15+
size = forms.SingleSelect(
16+
TooltipLabel('图片尺寸', '指定生成图片的尺寸, 如: 1024x1024'),
17+
required=True,
18+
default_value='1024x1024',
19+
option_list=[
20+
{'value': '1024x1024', 'label': '1024x1024'},
21+
{'value': '1024x1792', 'label': '1024x1792'},
22+
{'value': '1792x1024', 'label': '1792x1024'},
23+
],
24+
text_field='label',
25+
value_field='value'
26+
)
27+
28+
quality = forms.SingleSelect(
29+
TooltipLabel('图片质量', ''),
30+
required=True,
31+
default_value='standard',
32+
option_list=[
33+
{'value': 'standard', 'label': 'standard'},
34+
{'value': 'hd', 'label': 'hd'},
35+
],
36+
text_field='label',
37+
value_field='value'
38+
)
39+
40+
n = forms.SliderField(
41+
TooltipLabel('图片数量', '指定生成图片的数量'),
42+
required=True, default_value=1,
43+
_min=1,
44+
_max=10,
45+
_step=1,
46+
precision=0)
47+
48+
49+
class AzureOpenAITextToImageModelCredential(BaseForm, BaseModelCredential):
50+
api_version = forms.TextInputField("API 版本 (api_version)", required=True)
51+
api_base = forms.TextInputField('API 域名 (azure_endpoint)', required=True)
52+
api_key = forms.PasswordInputField("API Key (api_key)", required=True)
53+
54+
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
55+
raise_exception=False):
56+
model_type_list = provider.get_model_type_list()
57+
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
58+
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
59+
60+
for key in ['api_base', 'api_key', 'api_version']:
61+
if key not in model_credential:
62+
if raise_exception:
63+
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
64+
else:
65+
return False
66+
try:
67+
model = provider.get_model(model_type, model_name, model_credential)
68+
res = model.check_auth()
69+
print(res)
70+
except Exception as e:
71+
if isinstance(e, AppApiException):
72+
raise e
73+
if raise_exception:
74+
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
75+
else:
76+
return False
77+
return True
78+
79+
def encryption_dict(self, model: Dict[str, object]):
80+
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
81+
82+
def get_model_params_setting_form(self, model_name):
83+
return AzureOpenAITTIModelParams()
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# coding=utf-8
2+
from typing import Dict
3+
4+
from common import forms
5+
from common.exception.app_exception import AppApiException
6+
from common.forms import BaseForm, TooltipLabel
7+
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
8+
9+
class AzureOpenAITTSModelGeneralParams(BaseForm):
10+
# alloy, echo, fable, onyx, nova, shimmer
11+
voice = forms.SingleSelect(
12+
TooltipLabel('Voice', '尝试不同的声音(合金、回声、寓言、缟玛瑙、新星和闪光),找到一种适合您所需的音调和听众的声音。当前的语音针对英语进行了优化。'),
13+
required=True, default_value='alloy',
14+
text_field='value',
15+
value_field='value',
16+
option_list=[
17+
{'text': 'alloy', 'value': 'alloy'},
18+
{'text': 'echo', 'value': 'echo'},
19+
{'text': 'fable', 'value': 'fable'},
20+
{'text': 'onyx', 'value': 'onyx'},
21+
{'text': 'nova', 'value': 'nova'},
22+
{'text': 'shimmer', 'value': 'shimmer'},
23+
])
24+
25+
26+
class AzureOpenAITTSModelCredential(BaseForm, BaseModelCredential):
27+
api_version = forms.TextInputField("API 版本 (api_version)", required=True)
28+
api_base = forms.TextInputField('API 域名 (azure_endpoint)', required=True)
29+
api_key = forms.PasswordInputField("API Key (api_key)", required=True)
30+
31+
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
32+
raise_exception=False):
33+
model_type_list = provider.get_model_type_list()
34+
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
35+
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
36+
37+
for key in ['api_base', 'api_key', 'api_version']:
38+
if key not in model_credential:
39+
if raise_exception:
40+
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
41+
else:
42+
return False
43+
try:
44+
model = provider.get_model(model_type, model_name, model_credential)
45+
model.check_auth()
46+
except Exception as e:
47+
if isinstance(e, AppApiException):
48+
raise e
49+
if raise_exception:
50+
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
51+
else:
52+
return False
53+
return True
54+
55+
def encryption_dict(self, model: Dict[str, object]):
56+
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
57+
58+
def get_model_params_setting_form(self, model_name):
59+
return AzureOpenAITTSModelGeneralParams()
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from typing import Dict
2+
3+
from langchain_openai import AzureChatOpenAI
4+
from langchain_openai.chat_models import ChatOpenAI
5+
6+
from common.config.tokenizer_manage_config import TokenizerManage
7+
from setting.models_provider.base_model_provider import MaxKBBaseModel
8+
9+
10+
def custom_get_token_ids(text: str):
11+
tokenizer = TokenizerManage.get_tokenizer()
12+
return tokenizer.encode(text)
13+
14+
15+
class AzureOpenAIImage(MaxKBBaseModel, AzureChatOpenAI):
16+
17+
@staticmethod
18+
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
19+
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
20+
return AzureOpenAIImage(
21+
model_name=model_name,
22+
openai_api_key=model_credential.get('api_key'),
23+
azure_endpoint=model_credential.get('api_base'),
24+
openai_api_version=model_credential.get('api_version'),
25+
openai_api_type="azure",
26+
streaming=True,
27+
**optional_params,
28+
)

0 commit comments

Comments
 (0)