Skip to content

Commit 209702c

Browse files
committed
feat: ai对话,问题优化,指定回复节点支持返回结果开关
1 parent 7b48599 commit 209702c

File tree

9 files changed

+128
-239
lines changed

9 files changed

+128
-239
lines changed

apps/application/flow/i_step_node.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from abc import abstractmethod
1111
from typing import Type, Dict, List
1212

13+
from django.core import cache
1314
from django.db.models import QuerySet
1415
from rest_framework import serializers
1516

@@ -18,7 +19,6 @@
1819
from common.constants.authentication_type import AuthenticationType
1920
from common.field.common import InstanceField
2021
from common.util.field_message import ErrMessage
21-
from django.core import cache
2222

2323
chat_cache = cache.caches['chat_cache']
2424

@@ -27,6 +27,9 @@ def write_context(step_variable: Dict, global_variable: Dict, node, workflow):
2727
if step_variable is not None:
2828
for key in step_variable:
2929
node.context[key] = step_variable[key]
30+
if workflow.is_result() and 'answer' in step_variable:
31+
yield step_variable['answer']
32+
workflow.answer += step_variable['answer']
3033
if global_variable is not None:
3134
for key in global_variable:
3235
workflow.context[key] = global_variable[key]
@@ -70,18 +73,14 @@ def handler(self, chat_id,
7073

7174

7275
class NodeResult:
73-
def __init__(self, node_variable: Dict, workflow_variable: Dict, _to_response=None, _write_context=write_context):
76+
def __init__(self, node_variable: Dict, workflow_variable: Dict,
77+
_write_context=write_context):
7478
self._write_context = _write_context
7579
self.node_variable = node_variable
7680
self.workflow_variable = workflow_variable
77-
self._to_response = _to_response
7881

7982
def write_context(self, node, workflow):
80-
self._write_context(self.node_variable, self.workflow_variable, node, workflow)
81-
82-
def to_response(self, chat_id, chat_record_id, node, workflow, post_handler: WorkFlowPostHandler):
83-
return self._to_response(chat_id, chat_record_id, self.node_variable, self.workflow_variable, node, workflow,
84-
post_handler)
83+
return self._write_context(self.node_variable, self.workflow_variable, node, workflow)
8584

8685
def is_assertion_result(self):
8786
return 'branch_id' in self.node_variable

apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ class ChatNodeSerializer(serializers.Serializer):
2222
# 多轮对话数量
2323
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量"))
2424

25+
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容'))
26+
2527

2628
class IChatNode(INode):
2729
type = 'ai-chat-node'

apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py

Lines changed: 18 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,30 @@
66
@date:2024/6/4 14:30
77
@desc:
88
"""
9-
import time
109
from functools import reduce
1110
from typing import List, Dict
1211

1312
from langchain.schema import HumanMessage, SystemMessage
1413
from langchain_core.messages import BaseMessage
1514

16-
from application.flow import tools
1715
from application.flow.i_step_node import NodeResult, INode
1816
from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode
1917
from setting.models_provider.tools import get_model_instance_by_model_user_id
2018

2119

20+
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str):
21+
chat_model = node_variable.get('chat_model')
22+
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
23+
answer_tokens = chat_model.get_num_tokens(answer)
24+
node.context['message_tokens'] = message_tokens
25+
node.context['answer_tokens'] = answer_tokens
26+
node.context['answer'] = answer
27+
node.context['history_message'] = node_variable['history_message']
28+
node.context['question'] = node_variable['question']
29+
if workflow.is_result():
30+
workflow.answer += answer
31+
32+
2233
def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
2334
"""
2435
写入上下文数据 (流式)
@@ -31,15 +42,8 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo
3142
answer = ''
3243
for chunk in response:
3344
answer += chunk.content
34-
chat_model = node_variable.get('chat_model')
35-
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
36-
answer_tokens = chat_model.get_num_tokens(answer)
37-
node.context['message_tokens'] = message_tokens
38-
node.context['answer_tokens'] = answer_tokens
39-
node.context['answer'] = answer
40-
node.context['history_message'] = node_variable['history_message']
41-
node.context['question'] = node_variable['question']
42-
node.context['run_time'] = time.time() - node.context['start_time']
45+
yield answer
46+
_write_context(node_variable, workflow_variable, node, workflow, answer)
4347

4448

4549
def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
@@ -51,71 +55,8 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor
5155
@param workflow: 工作流管理器
5256
"""
5357
response = node_variable.get('result')
54-
chat_model = node_variable.get('chat_model')
5558
answer = response.content
56-
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
57-
answer_tokens = chat_model.get_num_tokens(answer)
58-
node.context['message_tokens'] = message_tokens
59-
node.context['answer_tokens'] = answer_tokens
60-
node.context['answer'] = answer
61-
node.context['history_message'] = node_variable['history_message']
62-
node.context['question'] = node_variable['question']
63-
64-
65-
def get_to_response_write_context(node_variable: Dict, node: INode):
66-
def _write_context(answer, status=200):
67-
chat_model = node_variable.get('chat_model')
68-
69-
if status == 200:
70-
answer_tokens = chat_model.get_num_tokens(answer)
71-
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
72-
else:
73-
answer_tokens = 0
74-
message_tokens = 0
75-
node.err_message = answer
76-
node.status = status
77-
node.context['message_tokens'] = message_tokens
78-
node.context['answer_tokens'] = answer_tokens
79-
node.context['answer'] = answer
80-
node.context['run_time'] = time.time() - node.context['start_time']
81-
82-
return _write_context
83-
84-
85-
def to_stream_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow,
86-
post_handler):
87-
"""
88-
将流式数据 转换为 流式响应
89-
@param chat_id: 会话id
90-
@param chat_record_id: 对话记录id
91-
@param node_variable: 节点数据
92-
@param workflow_variable: 工作流数据
93-
@param node: 节点
94-
@param workflow: 工作流管理器
95-
@param post_handler: 后置处理器 输出结果后执行
96-
@return: 流式响应
97-
"""
98-
response = node_variable.get('result')
99-
_write_context = get_to_response_write_context(node_variable, node)
100-
return tools.to_stream_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler)
101-
102-
103-
def to_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow,
104-
post_handler):
105-
"""
106-
将结果转换
107-
@param chat_id: 会话id
108-
@param chat_record_id: 对话记录id
109-
@param node_variable: 节点数据
110-
@param workflow_variable: 工作流数据
111-
@param node: 节点
112-
@param workflow: 工作流管理器
113-
@param post_handler: 后置处理器
114-
@return: 响应
115-
"""
116-
response = node_variable.get('result')
117-
_write_context = get_to_response_write_context(node_variable, node)
118-
return tools.to_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler)
59+
_write_context(node_variable, workflow_variable, node, workflow, answer)
11960

12061

12162
class BaseChatNode(IChatNode):
@@ -132,13 +73,12 @@ def execute(self, model_id, system, prompt, dialogue_number, history_chat_record
13273
r = chat_model.stream(message_list)
13374
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
13475
'history_message': history_message, 'question': question.content}, {},
135-
_write_context=write_context_stream,
136-
_to_response=to_stream_response)
76+
_write_context=write_context_stream)
13777
else:
13878
r = chat_model.invoke(message_list)
13979
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
14080
'history_message': history_message, 'question': question.content}, {},
141-
_write_context=write_context, _to_response=to_response)
81+
_write_context=write_context)
14282

14383
@staticmethod
14484
def get_history_message(history_chat_record, dialogue_number):

apps/application/flow/step_node/direct_reply_node/i_reply_node.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class ReplyNodeParamsSerializer(serializers.Serializer):
2020
fields = serializers.ListField(required=False, error_messages=ErrMessage.list("引用字段"))
2121
content = serializers.CharField(required=False, allow_blank=True, allow_null=True,
2222
error_messages=ErrMessage.char("直接回答内容"))
23+
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容'))
2324

2425
def is_valid(self, *, raise_exception=False):
2526
super().is_valid(raise_exception=True)

apps/application/flow/step_node/direct_reply_node/impl/base_reply_node.py

Lines changed: 3 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -6,69 +6,19 @@
66
@date:2024/6/11 17:25
77
@desc:
88
"""
9-
from typing import List, Dict
9+
from typing import List
1010

11-
from langchain_core.messages import AIMessage, AIMessageChunk
12-
13-
from application.flow import tools
14-
from application.flow.i_step_node import NodeResult, INode
11+
from application.flow.i_step_node import NodeResult
1512
from application.flow.step_node.direct_reply_node.i_reply_node import IReplyNode
1613

1714

18-
def get_to_response_write_context(node_variable: Dict, node: INode):
19-
def _write_context(answer, status=200):
20-
node.context['answer'] = answer
21-
22-
return _write_context
23-
24-
25-
def to_stream_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow,
26-
post_handler):
27-
"""
28-
将流式数据 转换为 流式响应
29-
@param chat_id: 会话id
30-
@param chat_record_id: 对话记录id
31-
@param node_variable: 节点数据
32-
@param workflow_variable: 工作流数据
33-
@param node: 节点
34-
@param workflow: 工作流管理器
35-
@param post_handler: 后置处理器 输出结果后执行
36-
@return: 流式响应
37-
"""
38-
response = node_variable.get('result')
39-
_write_context = get_to_response_write_context(node_variable, node)
40-
return tools.to_stream_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler)
41-
42-
43-
def to_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow,
44-
post_handler):
45-
"""
46-
将结果转换
47-
@param chat_id: 会话id
48-
@param chat_record_id: 对话记录id
49-
@param node_variable: 节点数据
50-
@param workflow_variable: 工作流数据
51-
@param node: 节点
52-
@param workflow: 工作流管理器
53-
@param post_handler: 后置处理器
54-
@return: 响应
55-
"""
56-
response = node_variable.get('result')
57-
_write_context = get_to_response_write_context(node_variable, node)
58-
return tools.to_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler)
59-
60-
6115
class BaseReplyNode(IReplyNode):
6216
def execute(self, reply_type, stream, fields=None, content=None, **kwargs) -> NodeResult:
6317
if reply_type == 'referencing':
6418
result = self.get_reference_content(fields)
6519
else:
6620
result = self.generate_reply_content(content)
67-
if stream:
68-
return NodeResult({'result': iter([AIMessageChunk(content=result)]), 'answer': result}, {},
69-
_to_response=to_stream_response)
70-
else:
71-
return NodeResult({'result': AIMessage(content=result), 'answer': result}, {}, _to_response=to_response)
21+
return NodeResult({'answer': result}, {})
7222

7323
def generate_reply_content(self, prompt):
7424
return self.workflow_manage.generate_prompt(prompt)

apps/application/flow/step_node/question_node/i_question_node.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ class QuestionNodeSerializer(serializers.Serializer):
2222
# 多轮对话数量
2323
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量"))
2424

25+
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容'))
26+
2527

2628
class IQuestionNode(INode):
2729
type = 'question-node'

0 commit comments

Comments
 (0)