1212
1313from django .db .models import QuerySet
1414from langchain .schema import HumanMessage , SystemMessage
15- from langchain_core .messages import BaseMessage
15+ from langchain_core .messages import BaseMessage , AIMessage
1616
1717from application .flow .i_step_node import NodeResult , INode
1818from application .flow .step_node .ai_chat_step_node .i_chat_node import IChatNode
@@ -72,6 +72,22 @@ def get_default_model_params_setting(model_id):
7272 return model_params_setting
7373
7474
75+ def get_node_message (chat_record , runtime_node_id ):
76+ node_details = chat_record .get_node_details_runtime_node_id (runtime_node_id )
77+ if node_details is None :
78+ return []
79+ return [HumanMessage (node_details .get ('question' )), AIMessage (node_details .get ('answer' ))]
80+
81+
82+ def get_workflow_message (chat_record ):
83+ return [chat_record .get_human_message (), chat_record .get_ai_message ()]
84+
85+
86+ def get_message (chat_record , dialogue_type , runtime_node_id ):
87+ return get_node_message (chat_record , runtime_node_id ) if dialogue_type == 'NODE' else get_workflow_message (
88+ chat_record )
89+
90+
7591class BaseChatNode (IChatNode ):
7692 def save_context (self , details , workflow_manage ):
7793 self .context ['answer' ] = details .get ('answer' )
@@ -80,12 +96,17 @@ def save_context(self, details, workflow_manage):
8096
8197 def execute (self , model_id , system , prompt , dialogue_number , history_chat_record , stream , chat_id , chat_record_id ,
8298 model_params_setting = None ,
99+ dialogue_type = None ,
83100 ** kwargs ) -> NodeResult :
101+ if dialogue_type is None :
102+ dialogue_type = 'WORKFLOW'
103+
84104 if model_params_setting is None :
85105 model_params_setting = get_default_model_params_setting (model_id )
86106 chat_model = get_model_instance_by_model_user_id (model_id , self .flow_params_serializer .data .get ('user_id' ),
87107 ** model_params_setting )
88- history_message = self .get_history_message (history_chat_record , dialogue_number )
108+ history_message = self .get_history_message (history_chat_record , dialogue_number , dialogue_type ,
109+ self .runtime_node_id )
89110 self .context ['history_message' ] = history_message
90111 question = self .generate_prompt_question (prompt )
91112 self .context ['question' ] = question .content
@@ -103,10 +124,10 @@ def execute(self, model_id, system, prompt, dialogue_number, history_chat_record
103124 _write_context = write_context )
104125
105126 @staticmethod
106- def get_history_message (history_chat_record , dialogue_number ):
127+ def get_history_message (history_chat_record , dialogue_number , dialogue_type , runtime_node_id ):
107128 start_index = len (history_chat_record ) - dialogue_number
108129 history_message = reduce (lambda x , y : [* x , * y ], [
109- [ history_chat_record [index ]. get_human_message (), history_chat_record [ index ]. get_ai_message ()]
130+ get_message ( history_chat_record [index ], dialogue_type , runtime_node_id )
110131 for index in
111132 range (start_index if start_index > 0 else 0 , len (history_chat_record ))], [])
112133 return history_message
0 commit comments