Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 14 additions & 16 deletions apps/common/chunk/impl/mark_chunk_handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,25 @@

from common.chunk.i_chunk_handle import IChunkHandle

split_chunk_pattern = "!|。|\n|;|;"
min_chunk_len = 20
max_chunk_len = 256
split_chunk_pattern = r'.{1,%d}[。| |\\.|!|;|;|!|\n]' % max_chunk_len
max_chunk_pattern = r'.{1,%d}' % max_chunk_len


class MarkChunkHandle(IChunkHandle):
def handle(self, chunk_list: List[str]):
result = []
for chunk in chunk_list:
base_chunk = re.split(split_chunk_pattern, chunk)
base_chunk = [chunk.strip() for chunk in base_chunk if len(chunk.strip()) > 0]
result_chunk = []
for c in base_chunk:
if len(result_chunk) == 0:
result_chunk.append(c)
else:
if len(result_chunk[-1]) < min_chunk_len:
result_chunk[-1] = result_chunk[-1] + c
chunk_result = re.findall(split_chunk_pattern, chunk, flags=re.DOTALL)
for c_r in chunk_result:
result.append(c_r)
other_chunk_list = re.split(split_chunk_pattern, chunk, flags=re.DOTALL)
for other_chunk in other_chunk_list:
if len(other_chunk) > 0:
if len(other_chunk) < max_chunk_len:
result.append(other_chunk)
else:
if len(c) < min_chunk_len:
result_chunk[-1] = result_chunk[-1] + c
else:
result_chunk.append(c)
result = [*result, *result_chunk]
max_chunk_list = re.findall(max_chunk_pattern, other_chunk, flags=re.DOTALL)
for m_c in max_chunk_list:
result.append(m_c)
return result
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@
from common.util.file_util import get_file_content
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \
ModelInfoManage
from setting.models_provider.impl.aliyun_bai_lian_model_provider.credential.embedding import \
AliyunBaiLianEmbeddingCredential
from setting.models_provider.impl.aliyun_bai_lian_model_provider.credential.reranker import \
AliyunBaiLianRerankerCredential
from setting.models_provider.impl.aliyun_bai_lian_model_provider.credential.stt import AliyunBaiLianSTTModelCredential
from setting.models_provider.impl.aliyun_bai_lian_model_provider.credential.tts import AliyunBaiLianTTSModelCredential
from setting.models_provider.impl.aliyun_bai_lian_model_provider.model.embedding import AliyunBaiLianEmbedding
from setting.models_provider.impl.aliyun_bai_lian_model_provider.model.reranker import AliyunBaiLianReranker
from setting.models_provider.impl.aliyun_bai_lian_model_provider.model.stt import AliyunBaiLianSpeechToText
from setting.models_provider.impl.aliyun_bai_lian_model_provider.model.tts import AliyunBaiLianTextToSpeech
Expand All @@ -23,6 +26,7 @@
aliyun_bai_lian_model_credential = AliyunBaiLianRerankerCredential()
aliyun_bai_lian_tts_model_credential = AliyunBaiLianTTSModelCredential()
aliyun_bai_lian_stt_model_credential = AliyunBaiLianSTTModelCredential()
aliyun_bai_lian_embedding_model_credential = AliyunBaiLianEmbeddingCredential()

model_info_list = [ModelInfo('gte-rerank',
'阿里巴巴通义实验室开发的GTE-Rerank文本排序系列模型,开发者可以通过LlamaIndex框架进行集成高质量文本检索、排序。',
Expand All @@ -33,10 +37,15 @@
ModelInfo('cosyvoice-v1',
'CosyVoice基于新一代生成式语音大模型,能根据上下文预测情绪、语调、韵律等,具有更好的拟人效果',
ModelTypeConst.TTS, aliyun_bai_lian_tts_model_credential, AliyunBaiLianTextToSpeech),
ModelInfo('text-embedding-v1',
'通用文本向量,是通义实验室基于LLM底座的多语言文本统一向量模型,面向全球多个主流语种,提供高水准的向量服务,帮助开发者将文本数据快速转换为高质量的向量数据。',
ModelTypeConst.EMBEDDING, aliyun_bai_lian_embedding_model_credential,
AliyunBaiLianEmbedding),
]

model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(
model_info_list[1]).append_default_model_info(model_info_list[2]).build()
model_info_list[1]).append_default_model_info(model_info_list[2]).append_default_model_info(
model_info_list[3]).build()


class AliyunBaiLianModelProvider(IModelProvider):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# coding=utf-8
"""
@project: MaxKB
@Author:虎
@file: embedding.py
@date:2024/10/16 17:01
@desc:
"""
from typing import Dict

from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from setting.models_provider.base_model_provider import ValidCode, BaseModelCredential
from setting.models_provider.impl.aliyun_bai_lian_model_provider.model.embedding import AliyunBaiLianEmbedding


class AliyunBaiLianEmbeddingCredential(BaseForm, BaseModelCredential):

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
for key in ['dashscope_api_key']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else:
return False
try:
model: AliyunBaiLianEmbedding = provider.get_model(model_type, model_name, model_credential)
model.embed_query('你好')
except Exception as e:
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
else:
return False
return True

def encryption_dict(self, model: Dict[str, object]):
return model

dashscope_api_key = forms.PasswordInputField('API Key', required=True)
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# coding=utf-8
"""
@project: MaxKB
@Author:虎
@file: embedding.py
@date:2024/10/16 16:34
@desc:
"""
from typing import Dict, List

from langchain_community.embeddings import DashScopeEmbeddings
from langchain_community.embeddings.dashscope import embed_with_retry

from setting.models_provider.base_model_provider import MaxKBBaseModel


class AliyunBaiLianEmbedding(MaxKBBaseModel, DashScopeEmbeddings):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
return AliyunBaiLianEmbedding(
model=model_name,
dashscope_api_key=model_credential.get('dashscope_api_key')
)

def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Call out to DashScope's embedding endpoint for embedding search docs.

Args:
texts: The list of texts to embed.
chunk_size: The chunk size of embeddings. If None, will use the chunk size
specified by the class.

Returns:
List of embeddings, one for each text.
"""
embeddings = embed_with_retry(
self, input=texts, text_type="document", model=self.model
)
embedding_list = [item["embedding"] for item in embeddings]
return embedding_list

def embed_query(self, text: str) -> List[float]:
"""Call out to DashScope's embedding endpoint for embedding query text.

Args:
text: The text to embed.

Returns:
Embedding for the text.
"""
embedding = embed_with_retry(
self, input=[text], text_type="document", model=self.model
)[0]["embedding"]
return embedding
Loading