Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions apps/common/event/listener_manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,16 @@ def embedding_by_paragraph_list(paragraph_id_list, embedding_model: Embeddings):
@embedding_poxy
def embedding_by_paragraph_data_list(data_list, paragraph_id_list, embedding_model: Embeddings):
max_kb.info(f'开始--->向量化段落:{paragraph_id_list}')
status = Status.success
try:
# 删除段落
VectorStore.get_embedding_vector().delete_by_paragraph_ids(paragraph_id_list)

def is_save_function():
return QuerySet(Paragraph).filter(id__in=paragraph_id_list).exists()

# 批量向量化
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model)
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, is_save_function)
except Exception as e:
max_kb_error.error(f'向量化段落:{paragraph_id_list}出现错误{str(e)}{traceback.format_exc()}')
status = Status.error
Expand All @@ -141,8 +146,12 @@ def embedding_by_paragraph(paragraph_id, embedding_model: Embeddings):
os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql')))
# 删除段落
VectorStore.get_embedding_vector().delete_by_paragraph_id(paragraph_id)

def is_save_function():
return QuerySet(Paragraph).filter(id=paragraph_id).exists()

# 批量向量化
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model)
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, is_save_function)
except Exception as e:
max_kb_error.error(f'向量化段落:{paragraph_id}出现错误{str(e)}{traceback.format_exc()}')
status = Status.error
Expand Down Expand Up @@ -175,8 +184,12 @@ def embedding_by_document(document_id, embedding_model: Embeddings):
os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql')))
# 删除文档向量数据
VectorStore.get_embedding_vector().delete_by_document_id(document_id)

def is_save_function():
return QuerySet(Document).filter(id=document_id).exists()

# 批量向量化
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model)
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, is_save_function)
except Exception as e:
max_kb_error.error(f'向量化文档:{document_id}出现错误{str(e)}{traceback.format_exc()}')
status = Status.error
Expand Down
11 changes: 7 additions & 4 deletions apps/embedding/vector/base_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,9 @@ def save(self, text, source_type: SourceType, dataset_id: str, document_id: str,
chunk_list = chunk_data(data)
result = sub_array(chunk_list)
for child_array in result:
self._batch_save(child_array, embedding)
self._batch_save(child_array, embedding, lambda: True)

def batch_save(self, data_list: List[Dict], embedding: Embeddings):
def batch_save(self, data_list: List[Dict], embedding: Embeddings, is_save_function):
# 获取锁
lock.acquire()
try:
Expand All @@ -100,7 +100,10 @@ def batch_save(self, data_list: List[Dict], embedding: Embeddings):
chunk_list = chunk_data_list(data_list)
result = sub_array(chunk_list)
for child_array in result:
self._batch_save(child_array, embedding)
if is_save_function():
self._batch_save(child_array, embedding, is_save_function)
else:
break
finally:
# 释放锁
lock.release()
Expand All @@ -113,7 +116,7 @@ def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str
pass

@abstractmethod
def _batch_save(self, text_list: List[Dict], embedding: Embeddings):
def _batch_save(self, text_list: List[Dict], embedding: Embeddings, is_save_function):
pass

def search(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str],
Expand Down
5 changes: 3 additions & 2 deletions apps/embedding/vector/pg_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str
embedding.save()
return True

def _batch_save(self, text_list: List[Dict], embedding: Embeddings):
def _batch_save(self, text_list: List[Dict], embedding: Embeddings, is_save_function):
texts = [row.get('text') for row in text_list]
embeddings = embedding.embed_documents(texts)
embedding_list = [Embedding(id=uuid.uuid1(),
Expand All @@ -68,7 +68,8 @@ def _batch_save(self, text_list: List[Dict], embedding: Embeddings):
embedding=embeddings[index],
search_vector=to_ts_vector(text_list[index]['text'])) for index in
range(0, len(text_list))]
QuerySet(Embedding).bulk_create(embedding_list) if len(embedding_list) > 0 else None
if is_save_function():
QuerySet(Embedding).bulk_create(embedding_list) if len(embedding_list) > 0 else None
return True

def hit_test(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str], top_number: int,
Expand Down