Skip to content

Commit 54381ff

Browse files
authored
feat: Document vectorization supports processing based on status (#1984)
1 parent 9a310bf commit 54381ff

File tree

9 files changed

+140
-45
lines changed

9 files changed

+140
-45
lines changed

apps/common/event/listener_manage.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,26 +6,22 @@
66
@date:2023/10/20 14:01
77
@desc:
88
"""
9-
import datetime
109
import logging
1110
import os
1211
import threading
13-
import time
1412
import traceback
1513
from typing import List
1614

1715
import django.db.models
18-
from django.db import models, transaction
1916
from django.db.models import QuerySet
2017
from django.db.models.functions import Substr, Reverse
2118
from langchain_core.embeddings import Embeddings
2219

2320
from common.config.embedding_config import VectorStore
2421
from common.db.search import native_search, get_dynamics_model, native_update
25-
from common.db.sql_execute import sql_execute, update_execute
2622
from common.util.file_util import get_file_content
2723
from common.util.lock import try_lock, un_lock
28-
from common.util.page_utils import page
24+
from common.util.page_utils import page_desc
2925
from dataset.models import Paragraph, Status, Document, ProblemParagraphMapping, TaskType, State
3026
from embedding.models import SourceType, SearchMode
3127
from smartdoc.conf import PROJECT_DIR
@@ -162,7 +158,7 @@ def embedding_paragraph_apply(paragraph_list):
162158
if is_the_task_interrupted():
163159
break
164160
ListenerManagement.embedding_by_paragraph(str(paragraph.get('id')), embedding_model)
165-
post_apply()
161+
post_apply()
166162

167163
return embedding_paragraph_apply
168164

@@ -241,13 +237,16 @@ def update_status(query_set: QuerySet, taskType: TaskType, state: State):
241237
lock.release()
242238

243239
@staticmethod
244-
def embedding_by_document(document_id, embedding_model: Embeddings):
240+
def embedding_by_document(document_id, embedding_model: Embeddings, state_list=None):
245241
"""
246242
向量化文档
243+
@param state_list:
247244
@param document_id: 文档id
248245
@param embedding_model 向量模型
249246
:return: None
250247
"""
248+
if state_list is None:
249+
state_list = [State.PENDING, State.SUCCESS, State.FAILURE, State.REVOKE, State.REVOKED]
251250
if not try_lock('embedding' + str(document_id)):
252251
return
253252
try:
@@ -268,11 +267,17 @@ def is_the_task_interrupted():
268267
VectorStore.get_embedding_vector().delete_by_document_id(document_id)
269268

270269
# 根据段落进行向量化处理
271-
page(QuerySet(Paragraph).filter(document_id=document_id).values('id'), 5,
272-
ListenerManagement.get_embedding_paragraph_apply(embedding_model, is_the_task_interrupted,
273-
ListenerManagement.get_aggregation_document_status(
274-
document_id)),
275-
is_the_task_interrupted)
270+
page_desc(QuerySet(Paragraph)
271+
.annotate(
272+
reversed_status=Reverse('status'),
273+
task_type_status=Substr('reversed_status', TaskType.EMBEDDING.value,
274+
1),
275+
).filter(task_type_status__in=state_list, document_id=document_id)
276+
.values('id'), 5,
277+
ListenerManagement.get_embedding_paragraph_apply(embedding_model, is_the_task_interrupted,
278+
ListenerManagement.get_aggregation_document_status(
279+
document_id)),
280+
is_the_task_interrupted)
276281
except Exception as e:
277282
max_kb_error.error(f'向量化文档:{document_id}出现错误{str(e)}{traceback.format_exc()}')
278283
finally:

apps/common/util/page_utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,22 @@ def page(query_set, page_size, handler, is_the_task_interrupted=lambda: False):
2626
offset = i * page_size
2727
paragraph_list = query.all()[offset: offset + page_size]
2828
handler(paragraph_list)
29+
30+
31+
def page_desc(query_set, page_size, handler, is_the_task_interrupted=lambda: False):
32+
"""
33+
34+
@param query_set: 查询query_set
35+
@param page_size: 每次查询大小
36+
@param handler: 数据处理器
37+
@param is_the_task_interrupted: 任务是否被中断
38+
@return:
39+
"""
40+
query = query_set.order_by("id")
41+
count = query_set.count()
42+
for i in sorted(range(0, ceil(count / page_size)), reverse=True):
43+
if is_the_task_interrupted():
44+
return
45+
offset = i * page_size
46+
paragraph_list = query.all()[offset: offset + page_size]
47+
handler(paragraph_list)

apps/dataset/serializers/document_serializers.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -700,20 +700,24 @@ def edit(self, instance: Dict, with_valid=False):
700700
_document.save()
701701
return self.one()
702702

703-
@transaction.atomic
704-
def refresh(self, with_valid=True):
703+
def refresh(self, state_list, with_valid=True):
705704
if with_valid:
706705
self.is_valid(raise_exception=True)
707706
document_id = self.data.get("document_id")
708707
ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), TaskType.EMBEDDING,
709708
State.PENDING)
710-
ListenerManagement.update_status(QuerySet(Paragraph).filter(document_id=document_id),
709+
ListenerManagement.update_status(QuerySet(Paragraph).annotate(
710+
reversed_status=Reverse('status'),
711+
task_type_status=Substr('reversed_status', TaskType.EMBEDDING.value,
712+
1),
713+
).filter(task_type_status__in=state_list, document_id=document_id)
714+
.values('id'),
711715
TaskType.EMBEDDING,
712716
State.PENDING)
713717
ListenerManagement.get_aggregation_document_status(document_id)()
714718
embedding_model_id = get_embedding_model_id_by_dataset_id(dataset_id=self.data.get('dataset_id'))
715719
try:
716-
embedding_by_document.delay(document_id, embedding_model_id)
720+
embedding_by_document.delay(document_id, embedding_model_id, state_list)
717721
except AlreadyQueued as e:
718722
raise AppApiException(500, "任务正在执行中,请勿重复下发")
719723

@@ -1122,14 +1126,14 @@ def batch_refresh(self, instance: Dict, with_valid=True):
11221126
if with_valid:
11231127
self.is_valid(raise_exception=True)
11241128
document_id_list = instance.get("id_list")
1125-
with transaction.atomic():
1126-
dataset_id = self.data.get('dataset_id')
1127-
for document_id in document_id_list:
1128-
try:
1129-
DocumentSerializers.Operate(
1130-
data={'dataset_id': dataset_id, 'document_id': document_id}).refresh()
1131-
except AlreadyQueued as e:
1132-
pass
1129+
state_list = instance.get("state_list")
1130+
dataset_id = self.data.get('dataset_id')
1131+
for document_id in document_id_list:
1132+
try:
1133+
DocumentSerializers.Operate(
1134+
data={'dataset_id': dataset_id, 'document_id': document_id}).refresh(state_list)
1135+
except AlreadyQueued as e:
1136+
pass
11331137

11341138
class GenerateRelated(ApiMixin, serializers.Serializer):
11351139
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id"))

apps/dataset/swagger_api/document_api.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,16 @@ def get_request_body_api():
5151
description="1|2|3 1:向量化|2:生成问题|3:同步文档", default=1)
5252
}
5353
)
54+
55+
class EmbeddingState(ApiMixin):
56+
@staticmethod
57+
def get_request_body_api():
58+
return openapi.Schema(
59+
type=openapi.TYPE_OBJECT,
60+
properties={
61+
'state_list': openapi.Schema(type=openapi.TYPE_ARRAY,
62+
items=openapi.Schema(type=openapi.TYPE_STRING),
63+
title="状态列表",
64+
description="状态列表")
65+
}
66+
)

apps/dataset/views/document.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ class Refresh(APIView):
262262
@action(methods=['PUT'], detail=False)
263263
@swagger_auto_schema(operation_summary="刷新文档向量库",
264264
operation_id="刷新文档向量库",
265+
request_body=DocumentApi.EmbeddingState.get_request_body_api(),
265266
manual_parameters=DocumentSerializers.Operate.get_request_params_api(),
266267
responses=result.get_default_response(),
267268
tags=["知识库/文档"]
@@ -272,6 +273,7 @@ class Refresh(APIView):
272273
def put(self, request: Request, dataset_id: str, document_id: str):
273274
return result.success(
274275
DocumentSerializers.Operate(data={'document_id': document_id, 'dataset_id': dataset_id}).refresh(
276+
request.data.get('state_list')
275277
))
276278

277279
class BatchRefresh(APIView):

apps/embedding/task/embedding.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,22 +56,28 @@ def embedding_by_paragraph_list(paragraph_id_list, model_id):
5656

5757

5858
@celery_app.task(base=QueueOnce, once={'keys': ['document_id']}, name='celery:embedding_by_document')
59-
def embedding_by_document(document_id, model_id):
59+
def embedding_by_document(document_id, model_id, state_list=None):
6060
"""
6161
向量化文档
62+
@param state_list:
6263
@param document_id: 文档id
6364
@param model_id 向量模型
6465
:return: None
6566
"""
6667

68+
if state_list is None:
69+
state_list = [State.PENDING.value, State.STARTED.value, State.SUCCESS.value, State.FAILURE.value,
70+
State.REVOKE.value,
71+
State.REVOKED.value, State.IGNORED.value]
72+
6773
def exception_handler(e):
6874
ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), TaskType.EMBEDDING,
6975
State.FAILURE)
7076
max_kb_error.error(
7177
f'获取向量模型失败:{str(e)}{traceback.format_exc()}')
7278

7379
embedding_model = get_embedding_model(model_id, exception_handler)
74-
ListenerManagement.embedding_by_document(document_id, embedding_model)
80+
ListenerManagement.embedding_by_document(document_id, embedding_model, state_list)
7581

7682

7783
@celery_app.task(name='celery:embedding_by_document_list')

ui/src/api/document.ts

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,11 +129,12 @@ const delMulDocument: (
129129
const batchRefresh: (
130130
dataset_id: string,
131131
data: any,
132+
stateList: Array<string>,
132133
loading?: Ref<boolean>
133-
) => Promise<Result<boolean>> = (dataset_id, data, loading) => {
134+
) => Promise<Result<boolean>> = (dataset_id, data, stateList, loading) => {
134135
return put(
135136
`${prefix}/${dataset_id}/document/batch_refresh`,
136-
{ id_list: data },
137+
{ id_list: data, state_list: stateList },
137138
undefined,
138139
loading
139140
)
@@ -157,11 +158,12 @@ const getDocumentDetail: (dataset_id: string, document_id: string) => Promise<Re
157158
const putDocumentRefresh: (
158159
dataset_id: string,
159160
document_id: string,
161+
state_list: Array<string>,
160162
loading?: Ref<boolean>
161-
) => Promise<Result<any>> = (dataset_id, document_id, loading) => {
163+
) => Promise<Result<any>> = (dataset_id, document_id, state_list, loading) => {
162164
return put(
163165
`${prefix}/${dataset_id}/document/${document_id}/refresh`,
164-
undefined,
166+
{ state_list },
165167
undefined,
166168
loading
167169
)
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
<template>
2+
<el-dialog v-model="dialogVisible" title="选择向量化内容" width="500" :before-close="close">
3+
<el-radio-group v-model="state">
4+
<el-radio value="error" size="large">向量化未成功的分段</el-radio>
5+
<el-radio value="all" size="large">全部分段</el-radio>
6+
</el-radio-group>
7+
<template #footer>
8+
<div class="dialog-footer">
9+
<el-button @click="close">取消</el-button>
10+
<el-button type="primary" @click="submit"> 提交 </el-button>
11+
</div>
12+
</template>
13+
</el-dialog>
14+
</template>
15+
<script setup lang="ts">
16+
import { ref } from 'vue'
17+
const dialogVisible = ref<boolean>(false)
18+
const state = ref<'all' | 'error'>('error')
19+
const stateMap = {
20+
all: ['0', '1', '2', '3', '4', '5', 'n'],
21+
error: ['0', '1', '3', '4', '5', 'n']
22+
}
23+
const submit_handle = ref<(stateList: Array<string>) => void>()
24+
const submit = () => {
25+
if (submit_handle.value) {
26+
submit_handle.value(stateMap[state.value])
27+
}
28+
close()
29+
}
30+
31+
const open = (handle: (stateList: Array<string>) => void) => {
32+
submit_handle.value = handle
33+
dialogVisible.value = true
34+
}
35+
const close = () => {
36+
submit_handle.value = undefined
37+
dialogVisible.value = false
38+
}
39+
defineExpose({ open, close })
40+
</script>
41+
<style lang="scss" scoped></style>

ui/src/views/document/index.vue

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,7 @@
422422
</el-text>
423423
<el-button class="ml-16" type="primary" link @click="clearSelection"> 清空 </el-button>
424424
</div>
425+
<EmbeddingContentDialog ref="embeddingContentDialogRef"></EmbeddingContentDialog>
425426
</LayoutContainer>
426427
</template>
427428
<script setup lang="ts">
@@ -439,6 +440,7 @@ import { MsgSuccess, MsgConfirm, MsgError } from '@/utils/message'
439440
import useStore from '@/stores'
440441
import StatusVlue from '@/views/document/component/Status.vue'
441442
import GenerateRelatedDialog from '@/components/generate-related-dialog/index.vue'
443+
import EmbeddingContentDialog from '@/views/document/component/EmbeddingContentDialog.vue'
442444
import { TaskType, State } from '@/utils/status'
443445
const router = useRouter()
444446
const route = useRoute()
@@ -469,7 +471,7 @@ onBeforeRouteLeave((to: any) => {
469471
})
470472
const beforePagination = computed(() => common.paginationConfig[storeKey])
471473
const beforeSearch = computed(() => common.search[storeKey])
472-
474+
const embeddingContentDialogRef = ref<InstanceType<typeof EmbeddingContentDialog>>()
473475
const SyncWebDialogRef = ref()
474476
const loading = ref(false)
475477
let interval: any
@@ -621,10 +623,14 @@ function syncDocument(row: any) {
621623
.catch(() => {})
622624
}
623625
}
626+
624627
function refreshDocument(row: any) {
625-
documentApi.putDocumentRefresh(row.dataset_id, row.id).then(() => {
626-
getList()
627-
})
628+
const embeddingDocument = (stateList: Array<string>) => {
629+
return documentApi.putDocumentRefresh(row.dataset_id, row.id, stateList).then(() => {
630+
getList()
631+
})
632+
}
633+
embeddingContentDialogRef.value?.open(embeddingDocument)
628634
}
629635
630636
function rowClickHandle(row: any, column: any) {
@@ -691,19 +697,16 @@ function deleteMulDocument() {
691697
}
692698
693699
function batchRefresh() {
694-
const arr: string[] = []
695-
multipleSelection.value.map((v) => {
696-
if (v) {
697-
arr.push(v.id)
698-
}
699-
})
700-
documentApi.batchRefresh(id, arr, loading).then(() => {
701-
MsgSuccess('批量向量化成功')
702-
multipleTableRef.value?.clearSelection()
703-
})
700+
const arr: string[] = multipleSelection.value.map((v) => v.id)
701+
const embeddingBatchDocument = (stateList: Array<string>) => {
702+
documentApi.batchRefresh(id, arr, stateList, loading).then(() => {
703+
MsgSuccess('批量向量化成功')
704+
multipleTableRef.value?.clearSelection()
705+
})
706+
}
707+
embeddingContentDialogRef.value?.open(embeddingBatchDocument)
704708
}
705709
706-
707710
function deleteDocument(row: any) {
708711
MsgConfirm(
709712
`是否删除文档:${row.name} ?`,

0 commit comments

Comments
 (0)