diff --git a/apps/application/flow/step_node/__init__.py b/apps/application/flow/step_node/__init__.py index cd8b08a974a..535560b5fcd 100644 --- a/apps/application/flow/step_node/__init__.py +++ b/apps/application/flow/step_node/__init__.py @@ -18,6 +18,7 @@ from .document_extract_node import * from .image_understand_step_node import * +from .image_generate_step_node import * from .search_dataset_node import * from .start_node import * @@ -25,7 +26,7 @@ node_list = [BaseStartStepNode, BaseChatNode, BaseSearchDatasetNode, BaseQuestionNode, BaseConditionNode, BaseReplyNode, BaseFunctionNodeNode, BaseFunctionLibNodeNode, BaseRerankerNode, BaseApplicationNode, BaseDocumentExtractNode, - BaseImageUnderstandNode, BaseFormNode] + BaseImageUnderstandNode, BaseImageGenerateNode, BaseFormNode] def get_node(node_type): diff --git a/apps/application/flow/step_node/image_generate_step_node/__init__.py b/apps/application/flow/step_node/image_generate_step_node/__init__.py new file mode 100644 index 00000000000..f3feecc9ce2 --- /dev/null +++ b/apps/application/flow/step_node/image_generate_step_node/__init__.py @@ -0,0 +1,3 @@ +# coding=utf-8 + +from .impl import * diff --git a/apps/application/flow/step_node/image_generate_step_node/i_image_generate_node.py b/apps/application/flow/step_node/image_generate_step_node/i_image_generate_node.py new file mode 100644 index 00000000000..1ce1af46c1b --- /dev/null +++ b/apps/application/flow/step_node/image_generate_step_node/i_image_generate_node.py @@ -0,0 +1,40 @@ +# coding=utf-8 + +from typing import Type + +from rest_framework import serializers + +from application.flow.i_step_node import INode, NodeResult +from common.util.field_message import ErrMessage + + +class ImageGenerateNodeSerializer(serializers.Serializer): + model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型id")) + + prompt = serializers.CharField(required=True, error_messages=ErrMessage.char("提示词(正向)")) + + negative_prompt = serializers.CharField(required=False, default='', error_messages=ErrMessage.char("提示词(负向)")) + # 多轮对话数量 + dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量")) + + dialogue_type = serializers.CharField(required=True, error_messages=ErrMessage.char("对话存储类型")) + + is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容')) + + model_params_setting = serializers.JSONField(required=False, default=dict, error_messages=ErrMessage.json("模型参数设置")) + + +class IImageGenerateNode(INode): + type = 'image-generate-node' + + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: + return ImageGenerateNodeSerializer + + def _run(self): + return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data) + + def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_type, history_chat_record, chat_id, + model_params_setting, + chat_record_id, + **kwargs) -> NodeResult: + pass diff --git a/apps/application/flow/step_node/image_generate_step_node/impl/__init__.py b/apps/application/flow/step_node/image_generate_step_node/impl/__init__.py new file mode 100644 index 00000000000..14a21a9159c --- /dev/null +++ b/apps/application/flow/step_node/image_generate_step_node/impl/__init__.py @@ -0,0 +1,3 @@ +# coding=utf-8 + +from .base_image_generate_node import BaseImageGenerateNode diff --git a/apps/application/flow/step_node/image_generate_step_node/impl/base_image_generate_node.py b/apps/application/flow/step_node/image_generate_step_node/impl/base_image_generate_node.py new file mode 100644 index 00000000000..77231b70d24 --- /dev/null +++ b/apps/application/flow/step_node/image_generate_step_node/impl/base_image_generate_node.py @@ -0,0 +1,117 @@ +# coding=utf-8 +from functools import reduce +from typing import List + +import requests +from langchain_core.messages import BaseMessage, HumanMessage, AIMessage + +from application.flow.i_step_node import NodeResult +from application.flow.step_node.image_generate_step_node.i_image_generate_node import IImageGenerateNode +from common.util.common import bytes_to_uploaded_file +from dataset.serializers.file_serializers import FileSerializer +from setting.models_provider.tools import get_model_instance_by_model_user_id + + +class BaseImageGenerateNode(IImageGenerateNode): + def save_context(self, details, workflow_manage): + self.context['answer'] = details.get('answer') + self.context['question'] = details.get('question') + self.answer_text = details.get('answer') + + def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_type, history_chat_record, chat_id, + model_params_setting, + chat_record_id, + **kwargs) -> NodeResult: + print(model_params_setting) + application = self.workflow_manage.work_flow_post_handler.chat_info.application + tti_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'), **model_params_setting) + history_message = self.get_history_message(history_chat_record, dialogue_number) + self.context['history_message'] = history_message + question = self.generate_prompt_question(prompt) + self.context['question'] = question + message_list = self.generate_message_list(question, history_message) + self.context['message_list'] = message_list + self.context['dialogue_type'] = dialogue_type + print(message_list) + image_urls = tti_model.generate_image(question, negative_prompt) + # 保存图片 + file_urls = [] + for image_url in image_urls: + file_name = 'generated_image.png' + file = bytes_to_uploaded_file(requests.get(image_url).content, file_name) + meta = { + 'debug': False if application.id else True, + 'chat_id': chat_id, + 'application_id': str(application.id) if application.id else None, + } + file_url = FileSerializer(data={'file': file, 'meta': meta}).upload() + file_urls.append(file_url) + self.context['image_list'] = file_urls + answer = '\n'.join([f"![Image]({path})" for path in file_urls]) + return NodeResult({'answer': answer, 'chat_model': tti_model, 'message_list': message_list, + 'image': [{'file_id': path.split('/')[-1], 'file_url': path} for path in file_urls], + 'history_message': history_message, 'question': question}, {}) + + def generate_history_ai_message(self, chat_record): + for val in chat_record.details.values(): + if self.node.id == val['node_id'] and 'image_list' in val: + if val['dialogue_type'] == 'WORKFLOW': + return chat_record.get_ai_message() + return AIMessage(content=val['answer']) + return chat_record.get_ai_message() + + def get_history_message(self, history_chat_record, dialogue_number): + start_index = len(history_chat_record) - dialogue_number + history_message = reduce(lambda x, y: [*x, *y], [ + [self.generate_history_human_message(history_chat_record[index]), + self.generate_history_ai_message(history_chat_record[index])] + for index in + range(start_index if start_index > 0 else 0, len(history_chat_record))], []) + return history_message + + def generate_history_human_message(self, chat_record): + + for data in chat_record.details.values(): + if self.node.id == data['node_id'] and 'image_list' in data: + image_list = data['image_list'] + if len(image_list) == 0 or data['dialogue_type'] == 'WORKFLOW': + return HumanMessage(content=chat_record.problem_text) + return HumanMessage(content=data['question']) + return HumanMessage(content=chat_record.problem_text) + + def generate_prompt_question(self, prompt): + return self.workflow_manage.generate_prompt(prompt) + + def generate_message_list(self, question: str, history_message): + return [ + *history_message, + question + ] + + @staticmethod + def reset_message_list(message_list: List[BaseMessage], answer_text): + result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for + message + in + message_list] + result.append({'role': 'ai', 'content': answer_text}) + return result + + def get_details(self, index: int, **kwargs): + return { + 'name': self.node.properties.get('stepName'), + "index": index, + 'run_time': self.context.get('run_time'), + 'history_message': [{'content': message.content, 'role': message.type} for message in + (self.context.get('history_message') if self.context.get( + 'history_message') is not None else [])], + 'question': self.context.get('question'), + 'answer': self.context.get('answer'), + 'type': self.node.type, + 'message_tokens': self.context.get('message_tokens'), + 'answer_tokens': self.context.get('answer_tokens'), + 'status': self.status, + 'err_message': self.err_message, + 'image_list': self.context.get('image_list'), + 'dialogue_type': self.context.get('dialogue_type') + } diff --git a/apps/application/flow/step_node/image_understand_step_node/i_image_understand_node.py b/apps/application/flow/step_node/image_understand_step_node/i_image_understand_node.py index 31aaa020527..8fce0b1761a 100644 --- a/apps/application/flow/step_node/image_understand_step_node/i_image_understand_node.py +++ b/apps/application/flow/step_node/image_understand_step_node/i_image_understand_node.py @@ -22,6 +22,9 @@ class ImageUnderstandNodeSerializer(serializers.Serializer): image_list = serializers.ListField(required=False, error_messages=ErrMessage.list("图片")) + model_params_setting = serializers.JSONField(required=False, default=dict, error_messages=ErrMessage.json("模型参数设置")) + + class IImageUnderstandNode(INode): type = 'image-understand-node' @@ -35,6 +38,7 @@ def _run(self): return self.execute(image=res, **self.node_params_serializer.data, **self.flow_params_serializer.data) def execute(self, model_id, system, prompt, dialogue_number, dialogue_type, history_chat_record, stream, chat_id, + model_params_setting, chat_record_id, image, **kwargs) -> NodeResult: diff --git a/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py b/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py index 1c2536e0c86..3b96f15cd6f 100644 --- a/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py +++ b/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py @@ -12,6 +12,7 @@ from application.flow.step_node.image_understand_step_node.i_image_understand_node import IImageUnderstandNode from dataset.models import File from setting.models_provider.tools import get_model_instance_by_model_user_id +from imghdr import what def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str): @@ -59,8 +60,9 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor def file_id_to_base64(file_id: str): file = QuerySet(File).filter(id=file_id).first() - base64_image = base64.b64encode(file.get_byte()).decode("utf-8") - return base64_image + file_bytes = file.get_byte() + base64_image = base64.b64encode(file_bytes).decode("utf-8") + return [base64_image, what(None, file_bytes.tobytes())] class BaseImageUnderstandNode(IImageUnderstandNode): @@ -70,14 +72,15 @@ def save_context(self, details, workflow_manage): self.answer_text = details.get('answer') def execute(self, model_id, system, prompt, dialogue_number, dialogue_type, history_chat_record, stream, chat_id, + model_params_setting, chat_record_id, image, **kwargs) -> NodeResult: # 处理不正确的参数 if image is None or not isinstance(image, list): image = [] - - image_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id')) + print(model_params_setting) + image_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'), **model_params_setting) # 执行详情中的历史消息不需要图片内容 history_message = self.get_history_message_for_details(history_chat_record, dialogue_number) self.context['history_message'] = history_message @@ -151,7 +154,7 @@ def generate_history_human_message(self, chat_record): return HumanMessage( content=[ {'type': 'text', 'text': data['question']}, - *[{'type': 'image_url', 'image_url': {'url': f'data:image/jpeg;base64,{base64_image}'}} for + *[{'type': 'image_url', 'image_url': {'url': f'data:image/{base64_image[1]};base64,{base64_image[0]}'}} for base64_image in image_base64_list] ]) return HumanMessage(content=chat_record.problem_text) @@ -166,8 +169,10 @@ def generate_message_list(self, image_model, system: str, prompt: str, history_m for img in image: file_id = img['file_id'] file = QuerySet(File).filter(id=file_id).first() - base64_image = base64.b64encode(file.get_byte()).decode("utf-8") - images.append({'type': 'image_url', 'image_url': {'url': f'data:image/jpeg;base64,{base64_image}'}}) + image_bytes = file.get_byte() + base64_image = base64.b64encode(image_bytes).decode("utf-8") + image_format = what(None, image_bytes.tobytes()) + images.append({'type': 'image_url', 'image_url': {'url': f'data:image/{image_format};base64,{base64_image}'}}) messages = [HumanMessage( content=[ {'type': 'text', 'text': self.workflow_manage.generate_prompt(prompt)}, diff --git a/apps/application/flow/workflow_manage.py b/apps/application/flow/workflow_manage.py index c62719b6e3f..4a8e0b92279 100644 --- a/apps/application/flow/workflow_manage.py +++ b/apps/application/flow/workflow_manage.py @@ -54,7 +54,7 @@ def __init__(self, _id: str, _type: str, x: int, y: int, properties: dict, **kwa end_nodes = ['ai-chat-node', 'reply-node', 'function-node', 'function-lib-node', 'application-node', - 'image-understand-node'] + 'image-understand-node', 'image-generate-node'] class Flow: diff --git a/apps/common/forms/text_input_field.py b/apps/common/forms/text_input_field.py index 28a821e1570..2b8b2ce04a5 100644 --- a/apps/common/forms/text_input_field.py +++ b/apps/common/forms/text_input_field.py @@ -8,6 +8,7 @@ """ from typing import Dict +from common.forms import BaseLabel from common.forms.base_field import BaseField, TriggerType @@ -16,7 +17,7 @@ class TextInputField(BaseField): 文本输入框 """ - def __init__(self, label: str, + def __init__(self, label: str or BaseLabel, required: bool = False, default_value=None, relation_show_field_dict: Dict = None, diff --git a/apps/common/util/common.py b/apps/common/util/common.py index 8571c91e33c..230727622a7 100644 --- a/apps/common/util/common.py +++ b/apps/common/util/common.py @@ -8,9 +8,12 @@ """ import hashlib import importlib +import mimetypes +import io from functools import reduce from typing import Dict, List +from django.core.files.uploadedfile import InMemoryUploadedFile from django.db.models import QuerySet from ..exception.app_exception import AppApiException @@ -111,3 +114,25 @@ def bulk_create_in_batches(model, data, batch_size=1000): batch = data[i:i + batch_size] model.objects.bulk_create(batch) + +def bytes_to_uploaded_file(file_bytes, file_name="file.txt"): + content_type, _ = mimetypes.guess_type(file_name) + if content_type is None: + # 如果未能识别,设置为默认的二进制文件类型 + content_type = "application/octet-stream" + # 创建一个内存中的字节流对象 + file_stream = io.BytesIO(file_bytes) + + # 获取文件大小 + file_size = len(file_bytes) + + # 创建 InMemoryUploadedFile 对象 + uploaded_file = InMemoryUploadedFile( + file=file_stream, + field_name=None, + name=file_name, + content_type=content_type, + size=file_size, + charset=None, + ) + return uploaded_file diff --git a/apps/setting/models_provider/base_model_provider.py b/apps/setting/models_provider/base_model_provider.py index c86068c68a8..39a759a6548 100644 --- a/apps/setting/models_provider/base_model_provider.py +++ b/apps/setting/models_provider/base_model_provider.py @@ -150,6 +150,7 @@ class ModelTypeConst(Enum): STT = {'code': 'STT', 'message': '语音识别'} TTS = {'code': 'TTS', 'message': '语音合成'} IMAGE = {'code': 'IMAGE', 'message': '图片理解'} + TTI = {'code': 'TTI', 'message': '图片生成'} RERANKER = {'code': 'RERANKER', 'message': '重排模型'} diff --git a/apps/setting/models_provider/impl/base_tti.py b/apps/setting/models_provider/impl/base_tti.py new file mode 100644 index 00000000000..5e34d12cd11 --- /dev/null +++ b/apps/setting/models_provider/impl/base_tti.py @@ -0,0 +1,14 @@ +# coding=utf-8 +from abc import abstractmethod + +from pydantic import BaseModel + + +class BaseTextToImage(BaseModel): + @abstractmethod + def check_auth(self): + pass + + @abstractmethod + def generate_image(self, prompt: str, negative_prompt: str = None): + pass diff --git a/apps/setting/models_provider/impl/openai_model_provider/credential/image.py b/apps/setting/models_provider/impl/openai_model_provider/credential/image.py index e6063695a25..83c1e70d195 100644 --- a/apps/setting/models_provider/impl/openai_model_provider/credential/image.py +++ b/apps/setting/models_provider/impl/openai_model_provider/credential/image.py @@ -7,9 +7,26 @@ from common import forms from common.exception.app_exception import AppApiException -from common.forms import BaseForm +from common.forms import BaseForm, TooltipLabel from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode +class OpenAIImageModelParams(BaseForm): + temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=0.7, + _min=0.1, + _max=1.0, + _step=0.01, + precision=2) + + max_tokens = forms.SliderField( + TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'), + required=True, default_value=800, + _min=1, + _max=100000, + _step=1, + precision=0) + + class OpenAIImageModelCredential(BaseForm, BaseModelCredential): api_base = forms.TextInputField('API 域名', required=True) @@ -45,4 +62,4 @@ def encryption_dict(self, model: Dict[str, object]): return {**model, 'api_key': super().encryption(model.get('api_key', ''))} def get_model_params_setting_form(self, model_name): - pass + return OpenAIImageModelParams() diff --git a/apps/setting/models_provider/impl/openai_model_provider/credential/tti.py b/apps/setting/models_provider/impl/openai_model_provider/credential/tti.py new file mode 100644 index 00000000000..668ebca8a9c --- /dev/null +++ b/apps/setting/models_provider/impl/openai_model_provider/credential/tti.py @@ -0,0 +1,82 @@ +# coding=utf-8 +import base64 +import os +from typing import Dict + +from langchain_core.messages import HumanMessage + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm, TooltipLabel +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class OpenAITTIModelParams(BaseForm): + size = forms.SingleSelect( + TooltipLabel('图片尺寸', '指定生成图片的尺寸, 如: 1024x1024'), + required=True, + default_value='1024x1024', + option_list=[ + {'value': '1024x1024', 'label': '1024x1024'}, + {'value': '1024x1792', 'label': '1024x1792'}, + {'value': '1792x1024', 'label': '1792x1024'}, + ], + text_field='label', + value_field='value' + ) + + quality = forms.SingleSelect( + TooltipLabel('图片质量', ''), + required=True, + default_value='standard', + option_list=[ + {'value': 'standard', 'label': 'standard'}, + {'value': 'hd', 'label': 'hd'}, + ], + text_field='label', + value_field='value' + ) + + n = forms.SliderField( + TooltipLabel('图片数量', '指定生成图片的数量'), + required=True, default_value=1, + _min=1, + _max=10, + _step=1, + precision=0) + + +class OpenAITextToImageModelCredential(BaseForm, BaseModelCredential): + api_base = forms.TextInputField('API 域名', required=True) + api_key = forms.PasswordInputField('API Key', required=True) + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + + for key in ['api_base', 'api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + res = model.check_auth() + print(res) + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + def get_model_params_setting_form(self, model_name): + return OpenAITTIModelParams() diff --git a/apps/setting/models_provider/impl/openai_model_provider/model/tti.py b/apps/setting/models_provider/impl/openai_model_provider/model/tti.py new file mode 100644 index 00000000000..942afcf9f0d --- /dev/null +++ b/apps/setting/models_provider/impl/openai_model_provider/model/tti.py @@ -0,0 +1,58 @@ +from typing import Dict + +from openai import OpenAI + +from common.config.tokenizer_manage_config import TokenizerManage +from setting.models_provider.base_model_provider import MaxKBBaseModel +from setting.models_provider.impl.base_tti import BaseTextToImage + + +def custom_get_token_ids(text: str): + tokenizer = TokenizerManage.get_tokenizer() + return tokenizer.encode(text) + + +class OpenAITextToImage(MaxKBBaseModel, BaseTextToImage): + api_base: str + api_key: str + model: str + params: dict + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.api_key = kwargs.get('api_key') + self.api_base = kwargs.get('api_base') + self.model = kwargs.get('model') + self.params = kwargs.get('params') + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = {'params': {'size': '1024x1024', 'quality': 'standard', 'n': 1}} + for key, value in model_kwargs.items(): + if key not in ['model_id', 'use_local', 'streaming']: + optional_params['params'][key] = value + return OpenAITextToImage( + model=model_name, + api_base=model_credential.get('api_base'), + api_key=model_credential.get('api_key'), + **optional_params, + ) + + def is_cache_model(self): + return False + + def check_auth(self): + chat = OpenAI(api_key=self.api_key, base_url=self.api_base) + response_list = chat.models.with_raw_response.list() + + # self.generate_image('生成一个小猫图片') + + def generate_image(self, prompt: str, negative_prompt: str = None): + chat = OpenAI(api_key=self.api_key, base_url=self.api_base) + res = chat.images.generate(model=self.model, prompt=prompt, **self.params) + file_urls = [] + for content in res.data: + url = content.url + file_urls.append(url) + + return file_urls diff --git a/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py b/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py index 974599cc89c..be659291efc 100644 --- a/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py +++ b/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py @@ -15,11 +15,13 @@ from setting.models_provider.impl.openai_model_provider.credential.image import OpenAIImageModelCredential from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential from setting.models_provider.impl.openai_model_provider.credential.stt import OpenAISTTModelCredential +from setting.models_provider.impl.openai_model_provider.credential.tti import OpenAITextToImageModelCredential from setting.models_provider.impl.openai_model_provider.credential.tts import OpenAITTSModelCredential from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel from setting.models_provider.impl.openai_model_provider.model.image import OpenAIImage from setting.models_provider.impl.openai_model_provider.model.llm import OpenAIChatModel from setting.models_provider.impl.openai_model_provider.model.stt import OpenAISpeechToText +from setting.models_provider.impl.openai_model_provider.model.tti import OpenAITextToImage from setting.models_provider.impl.openai_model_provider.model.tts import OpenAITextToSpeech from smartdoc.conf import PROJECT_DIR @@ -27,6 +29,7 @@ openai_stt_model_credential = OpenAISTTModelCredential() openai_tts_model_credential = OpenAITTSModelCredential() openai_image_model_credential = OpenAIImageModelCredential() +openai_tti_model_credential = OpenAITextToImageModelCredential() model_info_list = [ ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo,随OpenAI调整而更新', ModelTypeConst.LLM, openai_llm_model_credential, OpenAIChatModel @@ -37,8 +40,8 @@ ModelTypeConst.LLM, openai_llm_model_credential, OpenAIChatModel), ModelInfo('gpt-4o-mini', '最新的gpt-4o-mini,比gpt-4o更便宜、更快,随OpenAI调整而更新', - ModelTypeConst.LLM, openai_llm_model_credential, - OpenAIChatModel), + ModelTypeConst.LLM, openai_llm_model_credential, + OpenAIChatModel), ModelInfo('gpt-4-turbo', '最新的gpt-4-turbo,随OpenAI调整而更新', ModelTypeConst.LLM, openai_llm_model_credential, OpenAIChatModel), @@ -100,11 +103,27 @@ OpenAIImage), ] -model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info( - ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo,随OpenAI调整而更新', ModelTypeConst.LLM, - openai_llm_model_credential, OpenAIChatModel - )).append_model_info_list(model_info_embedding_list).append_default_model_info( - model_info_embedding_list[0]).append_model_info_list(model_info_image_list).build() +model_info_tti_list = [ + ModelInfo('dall-e-2', '', + ModelTypeConst.TTI, openai_tti_model_credential, + OpenAITextToImage), + ModelInfo('dall-e-3', '', + ModelTypeConst.TTI, openai_tti_model_credential, + OpenAITextToImage), +] + +model_info_manage = ( + ModelInfoManage.builder() + .append_model_info_list(model_info_list) + .append_default_model_info(ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo,随OpenAI调整而更新', ModelTypeConst.LLM, + openai_llm_model_credential, OpenAIChatModel + )) + .append_model_info_list(model_info_embedding_list) + .append_default_model_info(model_info_embedding_list[0]) + .append_model_info_list(model_info_image_list) + .append_model_info_list(model_info_tti_list) + .build() +) class OpenAIModelProvider(IModelProvider): diff --git a/apps/setting/models_provider/impl/qwen_model_provider/credential/tti.py b/apps/setting/models_provider/impl/qwen_model_provider/credential/tti.py new file mode 100644 index 00000000000..395d94db9e1 --- /dev/null +++ b/apps/setting/models_provider/impl/qwen_model_provider/credential/tti.py @@ -0,0 +1,94 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: llm.py + @date:2024/7/11 18:41 + @desc: +""" +import base64 +import os +from typing import Dict + +from langchain_core.messages import HumanMessage + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm, TooltipLabel +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class QwenModelParams(BaseForm): + size = forms.SingleSelect( + TooltipLabel('图片尺寸', '指定生成图片的尺寸, 如: 1024x1024'), + required=True, + default_value='1024*1024', + option_list=[ + {'value': '1024*1024', 'label': '1024*1024'}, + {'value': '720*1280', 'label': '720*1280'}, + {'value': '768*1152', 'label': '768*1152'}, + {'value': '1280*720', 'label': '1280*720'}, + ], + text_field='label', + value_field='value') + n = forms.SliderField( + TooltipLabel('图片数量', '指定生成图片的数量'), + required=True, default_value=1, + _min=1, + _max=4, + _step=1, + precision=0) + style = forms.SingleSelect( + TooltipLabel('风格', '指定生成图片的风格'), + required=True, + default_value='', + option_list=[ + {'value': '', 'label': '默认值,由模型随机输出图像风格'}, + {'value': '', 'label': '摄影'}, + {'value': '', 'label': '人像写真'}, + {'value': '<3d cartoon>', 'label': '3D卡通'}, + {'value': '', 'label': '动画'}, + {'value': '', 'label': '油画'}, + {'value': '', 'label': '水彩'}, + {'value': '', 'label': '素描'}, + {'value': '', 'label': '中国画'}, + {'value': '', 'label': '扁平插画'}, + ], + text_field='label', + value_field='value' + ) + + +class QwenTextToImageModelCredential(BaseForm, BaseModelCredential): + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + for key in ['api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + res = model.check_auth() + print(res) + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + api_key = forms.PasswordInputField('API Key', required=True) + + def get_model_params_setting_form(self, model_name): + return QwenModelParams() diff --git a/apps/setting/models_provider/impl/qwen_model_provider/model/tti.py b/apps/setting/models_provider/impl/qwen_model_provider/model/tti.py new file mode 100644 index 00000000000..c2fd32877e4 --- /dev/null +++ b/apps/setting/models_provider/impl/qwen_model_provider/model/tti.py @@ -0,0 +1,58 @@ +# coding=utf-8 +from http import HTTPStatus +from typing import Dict + +from dashscope import ImageSynthesis +from langchain_community.chat_models import ChatTongyi +from langchain_core.messages import HumanMessage + +from setting.models_provider.base_model_provider import MaxKBBaseModel +from setting.models_provider.impl.base_tti import BaseTextToImage + + +class QwenTextToImageModel(MaxKBBaseModel, BaseTextToImage): + api_key: str + model_name: str + params: dict + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.api_key = kwargs.get('api_key') + self.model_name = kwargs.get('model_name') + self.params = kwargs.get('params') + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = {'params': {'size': '1024*1024', 'style': '', 'n': 1}} + for key, value in model_kwargs.items(): + if key not in ['model_id', 'use_local', 'streaming']: + optional_params['params'][key] = value + chat_tong_yi = QwenTextToImageModel( + model_name=model_name, + api_key=model_credential.get('api_key'), + **optional_params, + ) + return chat_tong_yi + + def is_cache_model(self): + return False + + def check_auth(self): + chat = ChatTongyi(api_key=self.api_key, model_name='qwen-max') + chat.invoke([HumanMessage([{"type": "text", "text": "你好"}])]) + + def generate_image(self, prompt: str, negative_prompt: str = None): + # api_base='https://dashscope.aliyuncs.com/compatible-mode/v1', + rsp = ImageSynthesis.call(api_key=self.api_key, + model=self.model_name, + prompt=prompt, + negative_prompt=negative_prompt, + **self.params) + file_urls = [] + if rsp.status_code == HTTPStatus.OK: + for result in rsp.output.results: + file_urls.append(result.url) + else: + print('sync_call Failed, status_code: %s, code: %s, message: %s' % + (rsp.status_code, rsp.code, rsp.message)) + return file_urls diff --git a/apps/setting/models_provider/impl/qwen_model_provider/qwen_model_provider.py b/apps/setting/models_provider/impl/qwen_model_provider/qwen_model_provider.py index 0a24ca35ce8..fc2506a9c59 100644 --- a/apps/setting/models_provider/impl/qwen_model_provider/qwen_model_provider.py +++ b/apps/setting/models_provider/impl/qwen_model_provider/qwen_model_provider.py @@ -13,13 +13,16 @@ ModelInfoManage from setting.models_provider.impl.qwen_model_provider.credential.image import QwenVLModelCredential from setting.models_provider.impl.qwen_model_provider.credential.llm import OpenAILLMModelCredential +from setting.models_provider.impl.qwen_model_provider.credential.tti import QwenTextToImageModelCredential from setting.models_provider.impl.qwen_model_provider.model.image import QwenVLChatModel from setting.models_provider.impl.qwen_model_provider.model.llm import QwenChatModel +from setting.models_provider.impl.qwen_model_provider.model.tti import QwenTextToImageModel from smartdoc.conf import PROJECT_DIR qwen_model_credential = OpenAILLMModelCredential() qwenvl_model_credential = QwenVLModelCredential() +qwentti_model_credential = QwenTextToImageModelCredential() module_info_list = [ ModelInfo('qwen-turbo', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel), @@ -31,13 +34,21 @@ ModelInfo('qwen-vl-max-0809', '', ModelTypeConst.IMAGE, qwenvl_model_credential, QwenVLChatModel), ModelInfo('qwen-vl-plus-0809', '', ModelTypeConst.IMAGE, qwenvl_model_credential, QwenVLChatModel), ] +module_info_tti_list = [ + ModelInfo('wanx-v1', + '通义万相-文本生成图像大模型,支持中英文双语输入,支持输入参考图片进行参考内容或者参考风格迁移,重点风格包括但不限于水彩、油画、中国画、素描、扁平插画、二次元、3D卡通。', + ModelTypeConst.TTI, qwentti_model_credential, QwenTextToImageModel), +] -model_info_manage = (ModelInfoManage.builder() - .append_model_info_list(module_info_list) - .append_default_model_info( - ModelInfo('qwen-turbo', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel)) - .append_model_info_list(module_info_vl_list) - .build()) +model_info_manage = ( + ModelInfoManage.builder() + .append_model_info_list(module_info_list) + .append_default_model_info( + ModelInfo('qwen-turbo', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel)) + .append_model_info_list(module_info_vl_list) + .append_model_info_list(module_info_tti_list) + .build() +) class QwenModelProvider(IModelProvider): diff --git a/apps/setting/models_provider/impl/tencent_model_provider/credential/tti.py b/apps/setting/models_provider/impl/tencent_model_provider/credential/tti.py new file mode 100644 index 00000000000..1b6183e8a11 --- /dev/null +++ b/apps/setting/models_provider/impl/tencent_model_provider/credential/tti.py @@ -0,0 +1,108 @@ +# coding=utf-8 +from langchain_core.messages import HumanMessage + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm, TooltipLabel +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class TencentTTIModelParams(BaseForm): + Style = forms.SingleSelect( + TooltipLabel('绘画风格', '不传默认使用201(日系动漫风格)'), + required=True, + default_value='201', + option_list=[ + {'value': '000', 'label': '不限定风格'}, + {'value': '101', 'label': '水墨画'}, + {'value': '102', 'label': '概念艺术'}, + {'value': '103', 'label': '油画1'}, + {'value': '118', 'label': '油画2(梵高)'}, + {'value': '104', 'label': '水彩画'}, + {'value': '105', 'label': '像素画'}, + {'value': '106', 'label': '厚涂风格'}, + {'value': '107', 'label': '插图'}, + {'value': '108', 'label': '剪纸风格'}, + {'value': '109', 'label': '印象派1(莫奈)'}, + {'value': '119', 'label': '印象派2'}, + {'value': '110', 'label': '2.5D'}, + {'value': '111', 'label': '古典肖像画'}, + {'value': '112', 'label': '黑白素描画'}, + {'value': '113', 'label': '赛博朋克'}, + {'value': '114', 'label': '科幻风格'}, + {'value': '115', 'label': '暗黑风格'}, + {'value': '116', 'label': '3D'}, + {'value': '117', 'label': '蒸汽波'}, + {'value': '201', 'label': '日系动漫'}, + {'value': '202', 'label': '怪兽风格'}, + {'value': '203', 'label': '唯美古风'}, + {'value': '204', 'label': '复古动漫'}, + {'value': '301', 'label': '游戏卡通手绘'}, + {'value': '401', 'label': '通用写实风格'}, + ], + value_field='value', + text_field='label' + ) + + Resolution = forms.SingleSelect( + TooltipLabel('生成图分辨率', '不传默认使用768:768。'), + required=True, + default_value='768:768', + option_list=[ + {'value': '768:768', 'label': '768:768(1:1)'}, + {'value': '768:1024', 'label': '768:1024(3:4)'}, + {'value': '1024:768', 'label': '1024:768(4:3)'}, + {'value': '1024:1024', 'label': '1024:1024(1:1)'}, + {'value': '720:1280', 'label': '720:1280(9:16)'}, + {'value': '1280:720', 'label': '1280:720(16:9)'}, + {'value': '768:1280', 'label': '768:1280(3:5)'}, + {'value': '1280:768', 'label': '1280:768(5:3)'}, + {'value': '1080:1920', 'label': '1080:1920(9:16)'}, + {'value': '1920:1080', 'label': '1920:1080(16:9)'}, + ], + value_field='value', + text_field='label' + ) + + +class TencentTTIModelCredential(BaseForm, BaseModelCredential): + REQUIRED_FIELDS = ['hunyuan_secret_id', 'hunyuan_secret_key'] + + @classmethod + def _validate_model_type(cls, model_type, provider, raise_exception=False): + if not any(mt['value'] == model_type for mt in provider.get_model_type_list()): + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + return False + return True + + @classmethod + def _validate_credential_fields(cls, model_credential, raise_exception=False): + missing_keys = [key for key in cls.REQUIRED_FIELDS if key not in model_credential] + if missing_keys: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{", ".join(missing_keys)} 字段为必填字段') + return False + return True + + def is_valid(self, model_type, model_name, model_credential, provider, raise_exception=False): + if not (self._validate_model_type(model_type, provider, raise_exception) and + self._validate_credential_fields(model_credential, raise_exception)): + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + model.check_auth() + except Exception as e: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + return False + return True + + def encryption_dict(self, model): + return {**model, 'hunyuan_secret_key': super().encryption(model.get('hunyuan_secret_key', ''))} + + hunyuan_secret_id = forms.PasswordInputField('SecretId', required=True) + hunyuan_secret_key = forms.PasswordInputField('SecretKey', required=True) + + def get_model_params_setting_form(self, model_name): + return TencentTTIModelParams() diff --git a/apps/setting/models_provider/impl/tencent_model_provider/model/tti.py b/apps/setting/models_provider/impl/tencent_model_provider/model/tti.py new file mode 100644 index 00000000000..e8d57dc13c4 --- /dev/null +++ b/apps/setting/models_provider/impl/tencent_model_provider/model/tti.py @@ -0,0 +1,91 @@ +# coding=utf-8 + +import json +from typing import Dict + +from tencentcloud.common import credential +from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException +from tencentcloud.common.profile.client_profile import ClientProfile +from tencentcloud.common.profile.http_profile import HttpProfile +from tencentcloud.hunyuan.v20230901 import hunyuan_client, models + +from setting.models_provider.base_model_provider import MaxKBBaseModel +from setting.models_provider.impl.base_tti import BaseTextToImage +from setting.models_provider.impl.tencent_model_provider.model.hunyuan import ChatHunyuan + + +class TencentTextToImageModel(MaxKBBaseModel, BaseTextToImage): + hunyuan_secret_id: str + hunyuan_secret_key: str + model: str + params: dict + + @staticmethod + def is_cache_model(): + return False + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.hunyuan_secret_id = kwargs.get('hunyuan_secret_id') + self.hunyuan_secret_key = kwargs.get('hunyuan_secret_key') + self.model = kwargs.get('model_name') + self.params = kwargs.get('params') + + @staticmethod + def new_instance(model_type: str, model_name: str, model_credential: Dict[str, object], + **model_kwargs) -> 'TencentTextToImageModel': + optional_params = {'params': {'Style': '201', 'Resolution': '768:768'}} + for key, value in model_kwargs.items(): + if key not in ['model_id', 'use_local', 'streaming']: + optional_params['params'][key] = value + return TencentTextToImageModel( + model=model_name, + hunyuan_secret_id=model_credential.get('hunyuan_secret_id'), + hunyuan_secret_key=model_credential.get('hunyuan_secret_key'), + **optional_params + ) + + def check_auth(self): + chat = ChatHunyuan(hunyuan_app_id='111111', + hunyuan_secret_id=self.hunyuan_secret_id, + hunyuan_secret_key=self.hunyuan_secret_key, + model="hunyuan-standard") + res = chat.invoke('你好') + # print(res) + + def generate_image(self, prompt: str, negative_prompt: str = None): + try: + # 实例化一个认证对象,入参需要传入腾讯云账户 SecretId 和 SecretKey,此处还需注意密钥对的保密 + # 代码泄露可能会导致 SecretId 和 SecretKey 泄露,并威胁账号下所有资源的安全性。以下代码示例仅供参考,建议采用更安全的方式来使用密钥,请参见:https://cloud.tencent.com/document/product/1278/85305 + # 密钥可前往官网控制台 https://console.cloud.tencent.com/cam/capi 进行获取 + cred = credential.Credential(self.hunyuan_secret_id, self.hunyuan_secret_key) + # 实例化一个http选项,可选的,没有特殊需求可以跳过 + httpProfile = HttpProfile() + httpProfile.endpoint = "hunyuan.tencentcloudapi.com" + + # 实例化一个client选项,可选的,没有特殊需求可以跳过 + clientProfile = ClientProfile() + clientProfile.httpProfile = httpProfile + # 实例化要请求产品的client对象,clientProfile是可选的 + client = hunyuan_client.HunyuanClient(cred, "ap-guangzhou", clientProfile) + + # 实例化一个请求对象,每个接口都会对应一个request对象 + req = models.TextToImageLiteRequest() + params = { + "Prompt": prompt, + "NegativePrompt": negative_prompt, + "RspImgType": "url", + **self.params + } + req.from_json_string(json.dumps(params)) + + # 返回的resp是一个TextToImageLiteResponse的实例,与请求对象对应 + resp = client.TextToImageLite(req) + # 输出json格式的字符串回包 + print(resp.to_json_string()) + file_urls = [] + + file_urls.append(resp.ResultImage) + return file_urls + except TencentCloudSDKException as err: + print(err) diff --git a/apps/setting/models_provider/impl/tencent_model_provider/tencent_model_provider.py b/apps/setting/models_provider/impl/tencent_model_provider/tencent_model_provider.py index b37809eb582..553d5e3834b 100644 --- a/apps/setting/models_provider/impl/tencent_model_provider/tencent_model_provider.py +++ b/apps/setting/models_provider/impl/tencent_model_provider/tencent_model_provider.py @@ -9,9 +9,11 @@ from setting.models_provider.impl.tencent_model_provider.credential.embedding import TencentEmbeddingCredential from setting.models_provider.impl.tencent_model_provider.credential.image import TencentVisionModelCredential from setting.models_provider.impl.tencent_model_provider.credential.llm import TencentLLMModelCredential +from setting.models_provider.impl.tencent_model_provider.credential.tti import TencentTTIModelCredential from setting.models_provider.impl.tencent_model_provider.model.embedding import TencentEmbeddingModel from setting.models_provider.impl.tencent_model_provider.model.image import TencentVision from setting.models_provider.impl.tencent_model_provider.model.llm import TencentModel +from setting.models_provider.impl.tencent_model_provider.model.tti import TencentTextToImageModel from smartdoc.conf import PROJECT_DIR @@ -87,11 +89,19 @@ def _initialize_model_info(): TencentVisionModelCredential, TencentVision)] + model_info_tti_list = [_create_model_info( + 'hunyuan-dit', + '混元生图模型', + ModelTypeConst.TTI, + TencentTTIModelCredential, + TencentTextToImageModel)] + model_info_manage = ModelInfoManage.builder() \ .append_model_info_list(model_info_list) \ .append_model_info_list(model_info_embedding_list) \ .append_model_info_list(model_info_vision_list) \ + .append_model_info_list(model_info_tti_list) \ .append_default_model_info(model_info_list[0]) \ .build() diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/image.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/image.py new file mode 100644 index 00000000000..ff31b5ef06c --- /dev/null +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/image.py @@ -0,0 +1,63 @@ +# coding=utf-8 +import base64 +import os +from typing import Dict + +from langchain_core.messages import HumanMessage + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm, TooltipLabel +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + +class VolcanicEngineImageModelParams(BaseForm): + temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=0.95, + _min=0.1, + _max=1.0, + _step=0.01, + precision=2) + + max_tokens = forms.SliderField( + TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'), + required=True, default_value=1024, + _min=1, + _max=100000, + _step=1, + precision=0) + +class VolcanicEngineImageModelCredential(BaseForm, BaseModelCredential): + api_key = forms.PasswordInputField('API Key', required=True) + api_base = forms.TextInputField('API 域名', required=True) + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + + for key in ['api_key', 'api_base']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + res = model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])]) + for chunk in res: + print(chunk) + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + def get_model_params_setting_form(self, model_name): + return VolcanicEngineImageModelParams() diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/tti.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/tti.py new file mode 100644 index 00000000000..3c980778d77 --- /dev/null +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/tti.py @@ -0,0 +1,62 @@ +# coding=utf-8 + +from typing import Dict + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm, TooltipLabel +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class VolcanicEngineTTIModelGeneralParams(BaseForm): + size = forms.SingleSelect( + TooltipLabel('图片尺寸', + '宽、高与512差距过大,则出图效果不佳、延迟过长概率显著增加。超分前建议比例及对应宽高:width*height'), + required=True, + default_value='512*512', + option_list=[ + {'value': '512*512', 'label': '512*512'}, + {'value': '512*384', 'label': '512*384'}, + {'value': '384*512', 'label': '384*512'}, + {'value': '512*341', 'label': '512*341'}, + {'value': '341*512', 'label': '341*512'}, + {'value': '512*288', 'label': '512*288'}, + {'value': '288*512', 'label': '288*512'}, + ], + text_field='label', + value_field='value') + + +class VolcanicEngineTTIModelCredential(BaseForm, BaseModelCredential): + access_key = forms.PasswordInputField('Access Key', required=True) + secret_key = forms.PasswordInputField('Secret Key', required=True) + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + + for key in ['access_key', 'secret_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + model.check_auth() + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'secret_key': super().encryption(model.get('secret_key', ''))} + + def get_model_params_setting_form(self, model_name): + return VolcanicEngineTTIModelGeneralParams() diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/image.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/image.py new file mode 100644 index 00000000000..3cc467611a8 --- /dev/null +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/image.py @@ -0,0 +1,26 @@ +from typing import Dict + +from langchain_openai.chat_models import ChatOpenAI + +from common.config.tokenizer_manage_config import TokenizerManage +from setting.models_provider.base_model_provider import MaxKBBaseModel + + +def custom_get_token_ids(text: str): + tokenizer = TokenizerManage.get_tokenizer() + return tokenizer.encode(text) + + +class VolcanicEngineImage(MaxKBBaseModel, ChatOpenAI): + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) + return VolcanicEngineImage( + model_name=model_name, + openai_api_key=model_credential.get('api_key'), + openai_api_base=model_credential.get('api_base'), + # stream_options={"include_usage": True}, + streaming=True, + **optional_params, + ) diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/tti.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/tti.py new file mode 100644 index 00000000000..eccfda2592f --- /dev/null +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/tti.py @@ -0,0 +1,173 @@ +# coding=utf-8 + +''' +requires Python 3.6 or later + +pip install asyncio +pip install websockets + +''' + +import datetime +import hashlib +import hmac +import json +import sys +from typing import Dict + +import requests +from langchain_openai import ChatOpenAI + +from setting.models_provider.base_model_provider import MaxKBBaseModel +from setting.models_provider.impl.base_tti import BaseTextToImage + +method = 'POST' +host = 'visual.volcengineapi.com' +region = 'cn-north-1' +endpoint = 'https://visual.volcengineapi.com' +service = 'cv' + +req_key_dict = { + 'general_v1.4': 'high_aes_general_v14', + 'general_v2.0': 'high_aes_general_v20', + 'general_v2.0_L': 'high_aes_general_v20_L', + 'anime_v1.3': 'high_aes', + 'anime_v1.3.1': 'high_aes', +} + + +def sign(key, msg): + return hmac.new(key, msg.encode('utf-8'), hashlib.sha256).digest() + + +def getSignatureKey(key, dateStamp, regionName, serviceName): + kDate = sign(key.encode('utf-8'), dateStamp) + kRegion = sign(kDate, regionName) + kService = sign(kRegion, serviceName) + kSigning = sign(kService, 'request') + return kSigning + + +def formatQuery(parameters): + request_parameters_init = '' + for key in sorted(parameters): + request_parameters_init += key + '=' + parameters[key] + '&' + request_parameters = request_parameters_init[:-1] + return request_parameters + + +def signV4Request(access_key, secret_key, service, req_query, req_body): + if access_key is None or secret_key is None: + print('No access key is available.') + sys.exit() + + t = datetime.datetime.utcnow() + current_date = t.strftime('%Y%m%dT%H%M%SZ') + # current_date = '20210818T095729Z' + datestamp = t.strftime('%Y%m%d') # Date w/o time, used in credential scope + canonical_uri = '/' + canonical_querystring = req_query + signed_headers = 'content-type;host;x-content-sha256;x-date' + payload_hash = hashlib.sha256(req_body.encode('utf-8')).hexdigest() + content_type = 'application/json' + canonical_headers = 'content-type:' + content_type + '\n' + 'host:' + host + \ + '\n' + 'x-content-sha256:' + payload_hash + \ + '\n' + 'x-date:' + current_date + '\n' + canonical_request = method + '\n' + canonical_uri + '\n' + canonical_querystring + \ + '\n' + canonical_headers + '\n' + signed_headers + '\n' + payload_hash + # print(canonical_request) + algorithm = 'HMAC-SHA256' + credential_scope = datestamp + '/' + region + '/' + service + '/' + 'request' + string_to_sign = algorithm + '\n' + current_date + '\n' + credential_scope + '\n' + hashlib.sha256( + canonical_request.encode('utf-8')).hexdigest() + # print(string_to_sign) + signing_key = getSignatureKey(secret_key, datestamp, region, service) + # print(signing_key) + signature = hmac.new(signing_key, (string_to_sign).encode( + 'utf-8'), hashlib.sha256).hexdigest() + # print(signature) + + authorization_header = algorithm + ' ' + 'Credential=' + access_key + '/' + \ + credential_scope + ', ' + 'SignedHeaders=' + \ + signed_headers + ', ' + 'Signature=' + signature + # print(authorization_header) + headers = {'X-Date': current_date, + 'Authorization': authorization_header, + 'X-Content-Sha256': payload_hash, + 'Content-Type': content_type + } + # print(headers) + + # ************* SEND THE REQUEST ************* + request_url = endpoint + '?' + canonical_querystring + + print('\nBEGIN REQUEST++++++++++++++++++++++++++++++++++++') + print('Request URL = ' + request_url) + try: + r = requests.post(request_url, headers=headers, data=req_body) + except Exception as err: + print(f'error occurred: {err}') + raise + else: + print('\nRESPONSE++++++++++++++++++++++++++++++++++++') + print(f'Response code: {r.status_code}\n') + # 使用 replace 方法将 \u0026 替换为 & + resp_str = r.text.replace("\\u0026", "&") + if r.status_code != 200: + raise Exception(f'Error: {resp_str}') + print(f'Response body: {resp_str}\n') + return json.loads(resp_str)['data']['image_urls'] + + +class VolcanicEngineTextToImage(MaxKBBaseModel, BaseTextToImage): + access_key: str + secret_key: str + model_version: str + params: dict + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.access_key = kwargs.get('access_key') + self.secret_key = kwargs.get('secret_key') + self.model_version = kwargs.get('model_version') + self.params = kwargs.get('params') + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = {'params': {}} + for key, value in model_kwargs.items(): + if key not in ['model_id', 'use_local', 'streaming']: + optional_params['params'][key] = value + return VolcanicEngineTextToImage( + model_version=model_name, + access_key=model_credential.get('access_key'), + secret_key=model_credential.get('secret_key'), + **optional_params + ) + + def check_auth(self): + res = self.generate_image('生成一张小猫图片') + print(res) + + def generate_image(self, prompt: str, negative_prompt: str = None): + # 请求Query,按照接口文档中填入即可 + query_params = { + 'Action': 'CVProcess', + 'Version': '2022-08-31', + } + formatted_query = formatQuery(query_params) + size = self.params.pop('size', '512*512').split('*') + body_params = { + "req_key": req_key_dict[self.model_version], + "prompt": prompt, + "model_version": self.model_version, + "return_url": True, + "width": int(size[0]), + "height": int(size[1]), + **self.params + } + formatted_body = json.dumps(body_params) + return signV4Request(self.access_key, self.secret_key, service, formatted_query, formatted_body) + + def is_cache_model(self): + return False diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py index 1a0e17d8bb3..2223d1ccbe4 100644 --- a/apps/setting/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py @@ -14,10 +14,15 @@ from setting.models_provider.impl.openai_model_provider.credential.embedding import OpenAIEmbeddingCredential from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel +from setting.models_provider.impl.volcanic_engine_model_provider.credential.image import \ + VolcanicEngineImageModelCredential +from setting.models_provider.impl.volcanic_engine_model_provider.credential.tti import VolcanicEngineTTIModelCredential from setting.models_provider.impl.volcanic_engine_model_provider.credential.tts import VolcanicEngineTTSModelCredential +from setting.models_provider.impl.volcanic_engine_model_provider.model.image import VolcanicEngineImage from setting.models_provider.impl.volcanic_engine_model_provider.model.llm import VolcanicEngineChatModel from setting.models_provider.impl.volcanic_engine_model_provider.credential.stt import VolcanicEngineSTTModelCredential from setting.models_provider.impl.volcanic_engine_model_provider.model.stt import VolcanicEngineSpeechToText +from setting.models_provider.impl.volcanic_engine_model_provider.model.tti import VolcanicEngineTextToImage from setting.models_provider.impl.volcanic_engine_model_provider.model.tts import VolcanicEngineTextToSpeech from smartdoc.conf import PROJECT_DIR @@ -25,6 +30,8 @@ volcanic_engine_llm_model_credential = OpenAILLMModelCredential() volcanic_engine_stt_model_credential = VolcanicEngineSTTModelCredential() volcanic_engine_tts_model_credential = VolcanicEngineTTSModelCredential() +volcanic_engine_image_model_credential = VolcanicEngineImageModelCredential() +volcanic_engine_tti_model_credential = VolcanicEngineTTIModelCredential() model_info_list = [ ModelInfo('ep-xxxxxxxxxx-yyyy', @@ -32,6 +39,11 @@ ModelTypeConst.LLM, volcanic_engine_llm_model_credential, VolcanicEngineChatModel ), + ModelInfo('ep-xxxxxxxxxx-yyyy', + '用户前往火山方舟的模型推理页面创建推理接入点,这里需要输入ep-xxxxxxxxxx-yyyy进行调用', + ModelTypeConst.IMAGE, + volcanic_engine_image_model_credential, VolcanicEngineImage + ), ModelInfo('asr', '', ModelTypeConst.STT, @@ -42,6 +54,31 @@ ModelTypeConst.TTS, volcanic_engine_tts_model_credential, VolcanicEngineTextToSpeech ), + ModelInfo('general_v2.0', + '通用2.0-文生图', + ModelTypeConst.TTI, + volcanic_engine_tti_model_credential, VolcanicEngineTextToImage + ), + ModelInfo('general_v2.0_L', + '通用2.0Pro-文生图', + ModelTypeConst.TTI, + volcanic_engine_tti_model_credential, VolcanicEngineTextToImage + ), + ModelInfo('general_v1.4', + '通用1.4-文生图', + ModelTypeConst.TTI, + volcanic_engine_tti_model_credential, VolcanicEngineTextToImage + ), + ModelInfo('anime_v1.3', + '动漫1.3.0-文生图', + ModelTypeConst.TTI, + volcanic_engine_tti_model_credential, VolcanicEngineTextToImage + ), + ModelInfo('anime_v1.3.1', + '动漫1.3.1-文生图', + ModelTypeConst.TTI, + volcanic_engine_tti_model_credential, VolcanicEngineTextToImage + ), ] open_ai_embedding_credential = OpenAIEmbeddingCredential() @@ -51,8 +88,13 @@ ModelTypeConst.EMBEDDING, open_ai_embedding_credential, OpenAIEmbeddingModel)] -model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info( - model_info_list[0]).build() +model_info_manage = ( + ModelInfoManage.builder() + .append_model_info_list(model_info_list) + .append_default_model_info(model_info_list[0]) + .append_default_model_info(model_info_list[1]) + .build() +) class VolcanicEngineModelProvider(IModelProvider): diff --git a/apps/setting/models_provider/impl/xinference_model_provider/credential/image.py b/apps/setting/models_provider/impl/xinference_model_provider/credential/image.py new file mode 100644 index 00000000000..e2cbbb1948c --- /dev/null +++ b/apps/setting/models_provider/impl/xinference_model_provider/credential/image.py @@ -0,0 +1,65 @@ +# coding=utf-8 +import base64 +import os +from typing import Dict + +from langchain_core.messages import HumanMessage + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm, TooltipLabel +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + +class XinferenceImageModelParams(BaseForm): + temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=0.7, + _min=0.1, + _max=1.0, + _step=0.01, + precision=2) + + max_tokens = forms.SliderField( + TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'), + required=True, default_value=800, + _min=1, + _max=100000, + _step=1, + precision=0) + + + +class XinferenceImageModelCredential(BaseForm, BaseModelCredential): + api_base = forms.TextInputField('API 域名', required=True) + api_key = forms.PasswordInputField('API Key', required=True) + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + + for key in ['api_base', 'api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + res = model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])]) + for chunk in res: + print(chunk) + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + def get_model_params_setting_form(self, model_name): + return XinferenceImageModelParams() diff --git a/apps/setting/models_provider/impl/xinference_model_provider/credential/tti.py b/apps/setting/models_provider/impl/xinference_model_provider/credential/tti.py new file mode 100644 index 00000000000..eba50d022d4 --- /dev/null +++ b/apps/setting/models_provider/impl/xinference_model_provider/credential/tti.py @@ -0,0 +1,82 @@ +# coding=utf-8 +import base64 +import os +from typing import Dict + +from langchain_core.messages import HumanMessage + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm, TooltipLabel +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class XinferenceTTIModelParams(BaseForm): + size = forms.SingleSelect( + TooltipLabel('图片尺寸', '指定生成图片的尺寸, 如: 1024x1024'), + required=True, + default_value='1024x1024', + option_list=[ + {'value': '1024x1024', 'label': '1024x1024'}, + {'value': '1024x1792', 'label': '1024x1792'}, + {'value': '1792x1024', 'label': '1792x1024'}, + ], + text_field='label', + value_field='value' + ) + + quality = forms.SingleSelect( + TooltipLabel('图片质量', ''), + required=True, + default_value='standard', + option_list=[ + {'value': 'standard', 'label': 'standard'}, + {'value': 'hd', 'label': 'hd'}, + ], + text_field='label', + value_field='value' + ) + + n = forms.SliderField( + TooltipLabel('图片数量', '指定生成图片的数量'), + required=True, default_value=1, + _min=1, + _max=10, + _step=1, + precision=0) + + +class XinferenceTextToImageModelCredential(BaseForm, BaseModelCredential): + api_base = forms.TextInputField('API 域名', required=True) + api_key = forms.PasswordInputField('API Key', required=True) + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + + for key in ['api_base', 'api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + res = model.check_auth() + print(res) + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + def get_model_params_setting_form(self, model_name): + return XinferenceTTIModelParams() diff --git a/apps/setting/models_provider/impl/xinference_model_provider/model/image.py b/apps/setting/models_provider/impl/xinference_model_provider/model/image.py new file mode 100644 index 00000000000..1b696b8cf3b --- /dev/null +++ b/apps/setting/models_provider/impl/xinference_model_provider/model/image.py @@ -0,0 +1,26 @@ +from typing import Dict + +from langchain_openai.chat_models import ChatOpenAI + +from common.config.tokenizer_manage_config import TokenizerManage +from setting.models_provider.base_model_provider import MaxKBBaseModel + + +def custom_get_token_ids(text: str): + tokenizer = TokenizerManage.get_tokenizer() + return tokenizer.encode(text) + + +class XinferenceImage(MaxKBBaseModel, ChatOpenAI): + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) + return XinferenceImage( + model_name=model_name, + openai_api_base=model_credential.get('api_base'), + openai_api_key=model_credential.get('api_key'), + # stream_options={"include_usage": True}, + streaming=True, + **optional_params, + ) diff --git a/apps/setting/models_provider/impl/xinference_model_provider/model/tti.py b/apps/setting/models_provider/impl/xinference_model_provider/model/tti.py new file mode 100644 index 00000000000..ee5b655f4bd --- /dev/null +++ b/apps/setting/models_provider/impl/xinference_model_provider/model/tti.py @@ -0,0 +1,66 @@ +import base64 +from typing import Dict + +from openai import OpenAI + +from common.config.tokenizer_manage_config import TokenizerManage +from common.util.common import bytes_to_uploaded_file +from dataset.serializers.file_serializers import FileSerializer +from setting.models_provider.base_model_provider import MaxKBBaseModel +from setting.models_provider.impl.base_tti import BaseTextToImage + + +def custom_get_token_ids(text: str): + tokenizer = TokenizerManage.get_tokenizer() + return tokenizer.encode(text) + + +class XinferenceTextToImage(MaxKBBaseModel, BaseTextToImage): + api_base: str + api_key: str + model: str + params: dict + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.api_key = kwargs.get('api_key') + self.api_base = kwargs.get('api_base') + self.model = kwargs.get('model') + self.params = kwargs.get('params') + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = {'params': {'size': '1024x1024', 'quality': 'standard', 'n': 1}} + for key, value in model_kwargs.items(): + if key not in ['model_id', 'use_local', 'streaming']: + optional_params['params'][key] = value + return XinferenceTextToImage( + model=model_name, + api_base=model_credential.get('api_base'), + api_key=model_credential.get('api_key'), + **optional_params, + ) + + def is_cache_model(self): + return False + + def check_auth(self): + chat = OpenAI(api_key=self.api_key, base_url=self.api_base) + response_list = chat.models.with_raw_response.list() + + # self.generate_image('生成一个小猫图片') + + def generate_image(self, prompt: str, negative_prompt: str = None): + chat = OpenAI(api_key=self.api_key, base_url=self.api_base) + res = chat.images.generate(model=self.model, prompt=prompt, response_format='b64_json', **self.params) + file_urls = [] + # 临时文件 + for img in res.data: + file = bytes_to_uploaded_file(base64.b64decode(img.b64_json), 'file_name.jpg') + meta = { + 'debug': True, + } + file_url = FileSerializer(data={'file': file, 'meta': meta}).upload() + file_urls.append(f'http://localhost:8080{file_url}') + + return file_urls diff --git a/apps/setting/models_provider/impl/xinference_model_provider/xinference_model_provider.py b/apps/setting/models_provider/impl/xinference_model_provider/xinference_model_provider.py index 0da07f6d3c7..4c7b9c0c7db 100644 --- a/apps/setting/models_provider/impl/xinference_model_provider/xinference_model_provider.py +++ b/apps/setting/models_provider/impl/xinference_model_provider/xinference_model_provider.py @@ -9,20 +9,26 @@ ModelInfoManage from setting.models_provider.impl.xinference_model_provider.credential.embedding import \ XinferenceEmbeddingModelCredential +from setting.models_provider.impl.xinference_model_provider.credential.image import XinferenceImageModelCredential from setting.models_provider.impl.xinference_model_provider.credential.llm import XinferenceLLMModelCredential from setting.models_provider.impl.xinference_model_provider.credential.reranker import XInferenceRerankerModelCredential from setting.models_provider.impl.xinference_model_provider.credential.stt import XInferenceSTTModelCredential +from setting.models_provider.impl.xinference_model_provider.credential.tti import XinferenceTextToImageModelCredential from setting.models_provider.impl.xinference_model_provider.credential.tts import XInferenceTTSModelCredential from setting.models_provider.impl.xinference_model_provider.model.embedding import XinferenceEmbedding +from setting.models_provider.impl.xinference_model_provider.model.image import XinferenceImage from setting.models_provider.impl.xinference_model_provider.model.llm import XinferenceChatModel from setting.models_provider.impl.xinference_model_provider.model.reranker import XInferenceReranker from setting.models_provider.impl.xinference_model_provider.model.stt import XInferenceSpeechToText +from setting.models_provider.impl.xinference_model_provider.model.tti import XinferenceTextToImage from setting.models_provider.impl.xinference_model_provider.model.tts import XInferenceTextToSpeech from smartdoc.conf import PROJECT_DIR xinference_llm_model_credential = XinferenceLLMModelCredential() xinference_stt_model_credential = XInferenceSTTModelCredential() xinference_tts_model_credential = XInferenceTTSModelCredential() +xinference_image_model_credential = XinferenceImageModelCredential() +xinference_tti_model_credential = XinferenceTextToImageModelCredential() model_info_list = [ ModelInfo( @@ -296,6 +302,159 @@ ), ] +image_model_info = [ + ModelInfo( + 'qwen-vl-chat', + '', + ModelTypeConst.IMAGE, + xinference_image_model_credential, + XinferenceImage + ), + ModelInfo( + 'deepseek-vl-chat', + '', + ModelTypeConst.IMAGE, + xinference_image_model_credential, + XinferenceImage + ), + ModelInfo( + 'yi-vl-chat', + '', + ModelTypeConst.IMAGE, + xinference_image_model_credential, + XinferenceImage + ), + ModelInfo( + 'omnilmm', + '', + ModelTypeConst.IMAGE, + xinference_image_model_credential, + XinferenceImage + ), + ModelInfo( + 'internvl-chat', + '', + ModelTypeConst.IMAGE, + xinference_image_model_credential, + XinferenceImage + ), + ModelInfo( + 'cogvlm2', + '', + ModelTypeConst.IMAGE, + xinference_image_model_credential, + XinferenceImage + ), + ModelInfo( + 'MiniCPM-Llama3-V-2_5', + '', + ModelTypeConst.IMAGE, + xinference_image_model_credential, + XinferenceImage + ), + ModelInfo( + 'GLM-4V', + '', + ModelTypeConst.IMAGE, + xinference_image_model_credential, + XinferenceImage + ), + ModelInfo( + 'MiniCPM-V-2.6', + '', + ModelTypeConst.IMAGE, + xinference_image_model_credential, + XinferenceImage + ), + ModelInfo( + 'internvl2', + '', + ModelTypeConst.IMAGE, + xinference_image_model_credential, + XinferenceImage + ), + ModelInfo( + 'qwen2-vl-instruct', + '', + ModelTypeConst.IMAGE, + xinference_image_model_credential, + XinferenceImage + ), + ModelInfo( + 'llama-3.2-vision', + '', + ModelTypeConst.IMAGE, + xinference_image_model_credential, + XinferenceImage + ), + ModelInfo( + 'llama-3.2-vision-instruct', + '', + ModelTypeConst.IMAGE, + xinference_image_model_credential, + XinferenceImage + ), + ModelInfo( + 'glm-edge-v', + '', + ModelTypeConst.IMAGE, + xinference_image_model_credential, + XinferenceImage + ), +] + +tti_model_info = [ + ModelInfo( + 'sd-turbo', + '', + ModelTypeConst.TTI, + xinference_tti_model_credential, + XinferenceTextToImage + ), + ModelInfo( + 'sdxl-turbo', + '', + ModelTypeConst.TTI, + xinference_tti_model_credential, + XinferenceTextToImage + ), + ModelInfo( + 'stable-diffusion-v1.5', + '', + ModelTypeConst.TTI, + xinference_tti_model_credential, + XinferenceTextToImage + ), + ModelInfo( + 'stable-diffusion-xl-base-1.0', + '', + ModelTypeConst.TTI, + xinference_tti_model_credential, + XinferenceTextToImage + ), + ModelInfo( + 'sd3-medium', + '', + ModelTypeConst.TTI, + xinference_tti_model_credential, + XinferenceTextToImage + ), + ModelInfo( + 'FLUX.1-schnell', + '', + ModelTypeConst.TTI, + xinference_tti_model_credential, + XinferenceTextToImage + ), + ModelInfo( + 'FLUX.1-dev', + '', + ModelTypeConst.TTI, + xinference_tti_model_credential, + XinferenceTextToImage + ), +] + xinference_embedding_model_credential = XinferenceEmbeddingModelCredential() # 生成embedding_model_info列表 @@ -377,6 +536,8 @@ ModelTypeConst.EMBEDDING, xinference_embedding_model_credential, XinferenceEmbedding)) .append_model_info_list(rerank_list) + .append_model_info_list(image_model_info) + .append_model_info_list(tti_model_info) .append_default_model_info(rerank_list[0]) .build()) diff --git a/apps/setting/models_provider/impl/zhipu_model_provider/credential/image.py b/apps/setting/models_provider/impl/zhipu_model_provider/credential/image.py index 0eb05bb91af..54bd19e14bc 100644 --- a/apps/setting/models_provider/impl/zhipu_model_provider/credential/image.py +++ b/apps/setting/models_provider/impl/zhipu_model_provider/credential/image.py @@ -7,9 +7,24 @@ from common import forms from common.exception.app_exception import AppApiException -from common.forms import BaseForm +from common.forms import BaseForm, TooltipLabel from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode +class ZhiPuImageModelParams(BaseForm): + temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=0.95, + _min=0.1, + _max=1.0, + _step=0.01, + precision=2) + + max_tokens = forms.SliderField( + TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'), + required=True, default_value=1024, + _min=1, + _max=100000, + _step=1, + precision=0) class ZhiPuImageModelCredential(BaseForm, BaseModelCredential): api_key = forms.PasswordInputField('API Key', required=True) @@ -44,4 +59,4 @@ def encryption_dict(self, model: Dict[str, object]): return {**model, 'api_key': super().encryption(model.get('api_key', ''))} def get_model_params_setting_form(self, model_name): - pass + return ZhiPuImageModelParams() diff --git a/apps/setting/models_provider/impl/zhipu_model_provider/credential/tti.py b/apps/setting/models_provider/impl/zhipu_model_provider/credential/tti.py new file mode 100644 index 00000000000..f951efd9ecb --- /dev/null +++ b/apps/setting/models_provider/impl/zhipu_model_provider/credential/tti.py @@ -0,0 +1,61 @@ +# coding=utf-8 +from typing import Dict + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm, TooltipLabel +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class ZhiPuTTIModelParams(BaseForm): + size = forms.SingleSelect( + TooltipLabel('图片尺寸', + '图片尺寸,仅 cogview-3-plus 支持该参数。可选范围:[1024x1024,768x1344,864x1152,1344x768,1152x864,1440x720,720x1440],默认是1024x1024。'), + required=True, + default_value='1024x1024', + option_list=[ + {'value': '1024x1024', 'label': '1024x1024'}, + {'value': '768x1344', 'label': '768x1344'}, + {'value': '864x1152', 'label': '864x1152'}, + {'value': '1344x768', 'label': '1344x768'}, + {'value': '1152x864', 'label': '1152x864'}, + {'value': '1440x720', 'label': '1440x720'}, + {'value': '720x1440', 'label': '720x1440'}, + ], + text_field='label', + value_field='value') + + +class ZhiPuTextToImageModelCredential(BaseForm, BaseModelCredential): + api_key = forms.PasswordInputField('API Key', required=True) + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + + for key in ['api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + res = model.check_auth() + print(res) + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + def get_model_params_setting_form(self, model_name): + return ZhiPuTTIModelParams() diff --git a/apps/setting/models_provider/impl/zhipu_model_provider/model/tti.py b/apps/setting/models_provider/impl/zhipu_model_provider/model/tti.py new file mode 100644 index 00000000000..e2c59b85aa5 --- /dev/null +++ b/apps/setting/models_provider/impl/zhipu_model_provider/model/tti.py @@ -0,0 +1,68 @@ +from typing import Dict + +from langchain_community.chat_models import ChatZhipuAI +from langchain_core.messages import HumanMessage +from zhipuai import ZhipuAI + +from common.config.tokenizer_manage_config import TokenizerManage +from setting.models_provider.base_model_provider import MaxKBBaseModel +from setting.models_provider.impl.base_tti import BaseTextToImage + + +def custom_get_token_ids(text: str): + tokenizer = TokenizerManage.get_tokenizer() + return tokenizer.encode(text) + + +class ZhiPuTextToImage(MaxKBBaseModel, BaseTextToImage): + api_key: str + model: str + params: dict + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.api_key = kwargs.get('api_key') + self.model = kwargs.get('model') + self.params = kwargs.get('params') + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = {'params': {'size': '1024x1024'}} + for key, value in model_kwargs.items(): + if key not in ['model_id', 'use_local', 'streaming']: + optional_params['params'][key] = value + return ZhiPuTextToImage( + model=model_name, + api_key=model_credential.get('api_key'), + **optional_params, + ) + + def is_cache_model(self): + return False + + def check_auth(self): + chat = ChatZhipuAI( + zhipuai_api_key=self.api_key, + model_name=self.model, + ) + chat.invoke([HumanMessage([{"type": "text", "text": "你好"}])]) + + # self.generate_image('生成一个小猫图片') + + def generate_image(self, prompt: str, negative_prompt: str = None): + # chat = ChatZhipuAI( + # zhipuai_api_key=self.api_key, + # model_name=self.model, + # ) + chat = ZhipuAI(api_key=self.api_key) + response = chat.images.generations( + model=self.model, # 填写需要调用的模型编码 + prompt=prompt, # 填写需要生成图片的文本 + **self.params # 填写额外参数 + ) + file_urls = [] + for content in response.data: + url = content.url + file_urls.append(url) + + return file_urls diff --git a/apps/setting/models_provider/impl/zhipu_model_provider/zhipu_model_provider.py b/apps/setting/models_provider/impl/zhipu_model_provider/zhipu_model_provider.py index b24c8dd0d86..0fd0b3f2524 100644 --- a/apps/setting/models_provider/impl/zhipu_model_provider/zhipu_model_provider.py +++ b/apps/setting/models_provider/impl/zhipu_model_provider/zhipu_model_provider.py @@ -13,12 +13,15 @@ ModelInfoManage from setting.models_provider.impl.zhipu_model_provider.credential.image import ZhiPuImageModelCredential from setting.models_provider.impl.zhipu_model_provider.credential.llm import ZhiPuLLMModelCredential +from setting.models_provider.impl.zhipu_model_provider.credential.tti import ZhiPuTextToImageModelCredential from setting.models_provider.impl.zhipu_model_provider.model.image import ZhiPuImage from setting.models_provider.impl.zhipu_model_provider.model.llm import ZhipuChatModel +from setting.models_provider.impl.zhipu_model_provider.model.tti import ZhiPuTextToImage from smartdoc.conf import PROJECT_DIR qwen_model_credential = ZhiPuLLMModelCredential() zhipu_image_model_credential = ZhiPuImageModelCredential() +zhipu_tti_model_credential = ZhiPuTextToImageModelCredential() model_info_list = [ ModelInfo('glm-4', '', ModelTypeConst.LLM, qwen_model_credential, ZhipuChatModel), @@ -38,11 +41,21 @@ ZhiPuImage), ] +model_info_tti_list = [ + ModelInfo('cogview-3', '根据用户文字描述快速、精准生成图像。分辨率支持1024x1024', + ModelTypeConst.TTI, zhipu_tti_model_credential, + ZhiPuTextToImage), + ModelInfo('cogview-3-plus', '根据用户文字描述生成高质量图像,支持多图片尺寸', + ModelTypeConst.TTI, zhipu_tti_model_credential, + ZhiPuTextToImage), +] + model_info_manage = ( ModelInfoManage.builder() .append_model_info_list(model_info_list) .append_default_model_info(ModelInfo('glm-4', '', ModelTypeConst.LLM, qwen_model_credential, ZhipuChatModel)) .append_model_info_list(model_info_image_list) + .append_model_info_list(model_info_tti_list) .build() ) diff --git a/apps/setting/serializers/provider_serializers.py b/apps/setting/serializers/provider_serializers.py index e732087b0ba..b8f314127b4 100644 --- a/apps/setting/serializers/provider_serializers.py +++ b/apps/setting/serializers/provider_serializers.py @@ -28,6 +28,13 @@ from setting.models_provider.base_model_provider import ValidCode, DownModelChunkStatus from setting.models_provider.constants.model_provider_constants import ModelProvideConstants +def get_default_model_params_setting(provider, model_type, model_name): + credential = get_model_credential(provider, model_type, model_name) + setting_form = credential.get_model_params_setting_form(model_name) + if setting_form is not None: + return setting_form.to_form_list() + return [] + class ModelPullManage: @@ -173,6 +180,8 @@ class Create(serializers.Serializer): model_name = serializers.CharField(required=True, error_messages=ErrMessage.char("基础模型")) + model_params_form = serializers.ListField(required=False, default=list, error_messages=ErrMessage.char("参数配置")) + credential = serializers.DictField(required=True, error_messages=ErrMessage.dict("认证信息")) def is_valid(self, *, raise_exception=False): @@ -202,10 +211,12 @@ def insert(self, user_id, with_valid=False): model_type = self.data.get('model_type') model_name = self.data.get('model_name') permission_type = self.data.get('permission_type') + model_params_form = self.data.get('model_params_form') model_credential_str = json.dumps(credential) model = Model(id=uuid.uuid1(), status=status, user_id=user_id, name=name, credential=rsa_long_encrypt(model_credential_str), provider=provider, model_type=model_type, model_name=model_name, + model_params_form=model_params_form, permission_type=permission_type) model.save() if status == Status.DOWNLOAD: diff --git a/apps/setting/urls.py b/apps/setting/urls.py index 42e80592cc1..73fe9ba12db 100644 --- a/apps/setting/urls.py +++ b/apps/setting/urls.py @@ -14,6 +14,8 @@ path('provider/model_type_list', views.Provide.ModelTypeList.as_view(), name="provider/model_type_list"), path('provider/model_list', views.Provide.ModelList.as_view(), name="provider/model_name_list"), + path('provider/model_params_form', views.Provide.ModelParamsForm.as_view(), + name="provider/model_params_form"), path('provider/model_form', views.Provide.ModelForm.as_view(), name="provider/model_form"), path('model', views.Model.as_view(), name='model'), diff --git a/apps/setting/views/model.py b/apps/setting/views/model.py index b5abf919668..965f68b1bde 100644 --- a/apps/setting/views/model.py +++ b/apps/setting/views/model.py @@ -16,7 +16,7 @@ from common.response import result from common.util.common import query_params_to_single_dict from setting.models_provider.constants.model_provider_constants import ModelProvideConstants -from setting.serializers.provider_serializers import ProviderSerializer, ModelSerializer +from setting.serializers.provider_serializers import ProviderSerializer, ModelSerializer, get_default_model_params_setting from setting.swagger_api.provide_api import ProvideApi, ModelCreateApi, ModelQueryApi, ModelEditApi @@ -207,6 +207,24 @@ def get(self, request: Request): ModelProvideConstants[provider].value.get_model_list( model_type)) + class ModelParamsForm(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取模型默认参数", + operation_id="获取模型创建表单", + manual_parameters=ProvideApi.ModelList.get_request_params_api(), + responses=result.get_api_array_response(ProvideApi.ModelList.get_response_body_api()) + , tags=["模型"] + ) + @has_permissions(PermissionConstants.MODEL_READ) + def get(self, request: Request): + provider = request.query_params.get('provider') + model_type = request.query_params.get('model_type') + model_name = request.query_params.get('model_name') + + return result.success(get_default_model_params_setting(provider, model_type, model_name)) + class ModelForm(APIView): authentication_classes = [TokenAuth] diff --git a/ui/src/api/application.ts b/ui/src/api/application.ts index bf384ad49ff..404303aebd9 100644 --- a/ui/src/api/application.ts +++ b/ui/src/api/application.ts @@ -293,6 +293,13 @@ const getApplicationImageModel: ( return get(`${prefix}/${application_id}/model`, { model_type: 'IMAGE' }, loading) } +const getApplicationTTIModel: ( + application_id: string, + loading?: Ref +) => Promise>> = (application_id, loading) => { + return get(`${prefix}/${application_id}/model`, { model_type: 'TTI' }, loading) +} + /** * 发布应用 @@ -523,6 +530,7 @@ export default { getApplicationSTTModel, getApplicationTTSModel, getApplicationImageModel, + getApplicationTTIModel, postSpeechToText, postTextToSpeech, getPlatformStatus, diff --git a/ui/src/api/model.ts b/ui/src/api/model.ts index 6519f1bc0f3..5129dd05572 100644 --- a/ui/src/api/model.ts +++ b/ui/src/api/model.ts @@ -98,6 +98,15 @@ const listBaseModel: ( return get(`${prefix_provider}/model_list`, { provider, model_type }, loading) } +const listBaseModelParamsForm: ( + provider: string, + model_type: string, + model_name: string, + loading?: Ref +) => Promise>> = (provider, model_type, model_name, loading) => { + return get(`${prefix_provider}/model_params_form`, { provider, model_type, model_name}, loading) +} + /** * 创建模型 * @param request 请求对象 @@ -187,6 +196,7 @@ export default { getModelCreateForm, listModelType, listBaseModel, + listBaseModelParamsForm, createModel, updateModel, deleteModel, diff --git a/ui/src/components/ai-chat/ExecutionDetailDialog.vue b/ui/src/components/ai-chat/ExecutionDetailDialog.vue index 5c55f1bf379..39a4f2617cc 100644 --- a/ui/src/components/ai-chat/ExecutionDetailDialog.vue +++ b/ui/src/components/ai-chat/ExecutionDetailDialog.vue @@ -32,6 +32,7 @@ item.type === WorkflowType.Question || item.type === WorkflowType.AiChat || item.type === WorkflowType.ImageUnderstandNode || + item.type === WorkflowType.ImageGenerateNode || item.type === WorkflowType.Application " >{{ item?.message_tokens + item?.answer_tokens }} tokens + + - - - diff --git a/ui/src/workflow/nodes/image-generate/index.ts b/ui/src/workflow/nodes/image-generate/index.ts new file mode 100644 index 00000000000..5afc2e57194 --- /dev/null +++ b/ui/src/workflow/nodes/image-generate/index.ts @@ -0,0 +1,14 @@ +import ImageGenerateNodeVue from './index.vue' +import { AppNode, AppNodeModel } from '@/workflow/common/app-node' + +class RerankerNode extends AppNode { + constructor(props: any) { + super(props, ImageGenerateNodeVue) + } +} + +export default { + type: 'image-generate-node', + model: AppNodeModel, + view: RerankerNode +} diff --git a/ui/src/workflow/nodes/image-generate/index.vue b/ui/src/workflow/nodes/image-generate/index.vue new file mode 100644 index 00000000000..ae4817204f1 --- /dev/null +++ b/ui/src/workflow/nodes/image-generate/index.vue @@ -0,0 +1,323 @@ + + + + + diff --git a/ui/src/workflow/nodes/image-understand/index.vue b/ui/src/workflow/nodes/image-understand/index.vue index ecde494346d..40f23847c3a 100644 --- a/ui/src/workflow/nodes/image-understand/index.vue +++ b/ui/src/workflow/nodes/image-understand/index.vue @@ -25,6 +25,15 @@
图片理解模型*
+ + {{ $t('views.application.applicationForm.form.paramSetting') }} + + @@ -197,6 +207,7 @@ import { app } from '@/main' import useStore from '@/stores' import NodeCascader from '@/workflow/common/NodeCascader.vue' import type { FormInstance } from 'element-plus' +import AIModeParamSettingDialog from '@/views/application/component/AIModeParamSettingDialog.vue' const { model } = useStore() @@ -207,6 +218,7 @@ const { const props = defineProps<{ nodeModel: any }>() const modelOptions = ref(null) const providerOptions = ref>([]) +const AIModeParamSettingDialogRef = ref>() const aiChatNodeFormRef = ref() const validate = () => { @@ -281,6 +293,16 @@ function submitDialog(val: string) { set(props.nodeModel.properties.node_data, 'prompt', val) } +const openAIParamSettingDialog = (modelId: string) => { + if (modelId) { + AIModeParamSettingDialogRef.value?.open(modelId, id, form_data.value.model_params_setting) + } +} + +function refreshParam(data: any) { + set(props.nodeModel.properties.node_data, 'model_params_setting', data) +} + onMounted(() => { getModel() getProvider()