From c4fad66f2a04ea8cd7f67b6baf18b3316093ec77 Mon Sep 17 00:00:00 2001 From: Hash Brown Date: Mon, 2 Dec 2024 16:09:26 +0800 Subject: [PATCH] fix: `dialogue_count` incorrect in chatflow when there's... (#11175) --- .../app/apps/advanced_chat/app_generator.py | 8 +++++ api/core/app/apps/advanced_chat/app_runner.py | 10 ++---- .../advanced_chat/generate_task_pipeline.py | 4 ++- .../utils/get_thread_messages_length.py | 32 +++++++++++++++++++ 4 files changed, 46 insertions(+), 8 deletions(-) create mode 100644 api/core/prompt/utils/get_thread_messages_length.py diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 4a5f43eedd..bd4fd9cd3b 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -23,6 +23,7 @@ from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.ops.ops_trace_manager import TraceQueueManager +from core.prompt.utils.get_thread_messages_length import get_thread_messages_length from extensions.ext_database import db from factories import file_factory from models.account import Account @@ -33,6 +34,8 @@ logger = logging.getLogger(__name__) class AdvancedChatAppGenerator(MessageBasedAppGenerator): + _dialogue_count: int + def generate( self, app_model: App, @@ -211,6 +214,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): db.session.commit() db.session.refresh(conversation) + # get conversation dialogue count + self._dialogue_count = get_thread_messages_length(conversation.id) + # init queue manager queue_manager = MessageBasedAppQueueManager( task_id=application_generate_entity.task_id, @@ -281,6 +287,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): queue_manager=queue_manager, conversation=conversation, message=message, + dialogue_count=self._dialogue_count, ) runner.run() @@ -334,6 +341,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): message=message, user=user, stream=stream, + dialogue_count=self._dialogue_count, ) try: diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 65d744eddf..cf0c9d7593 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -39,12 +39,14 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): queue_manager: AppQueueManager, conversation: Conversation, message: Message, + dialogue_count: int, ) -> None: super().__init__(queue_manager) self.application_generate_entity = application_generate_entity self.conversation = conversation self.message = message + self._dialogue_count = dialogue_count def run(self) -> None: app_config = self.application_generate_entity.app_config @@ -122,19 +124,13 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): session.commit() - # Increment dialogue count. - self.conversation.dialogue_count += 1 - - conversation_dialogue_count = self.conversation.dialogue_count - db.session.commit() - # Create a variable pool. system_inputs = { SystemVariableKey.QUERY: query, SystemVariableKey.FILES: files, SystemVariableKey.CONVERSATION_ID: self.conversation.id, SystemVariableKey.USER_ID: user_id, - SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count, + SystemVariableKey.DIALOGUE_COUNT: self._dialogue_count, SystemVariableKey.APP_ID: app_config.app_id, SystemVariableKey.WORKFLOW_ID: app_config.workflow_id, SystemVariableKey.WORKFLOW_RUN_ID: self.application_generate_entity.workflow_run_id, diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 0dd5813391..cd12690e28 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -88,6 +88,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc message: Message, user: Union[Account, EndUser], stream: bool, + dialogue_count: int, ) -> None: """ Initialize AdvancedChatAppGenerateTaskPipeline. @@ -98,6 +99,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc :param message: message :param user: user :param stream: stream + :param dialogue_count: dialogue count """ super().__init__(application_generate_entity, queue_manager, user, stream) @@ -114,7 +116,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc SystemVariableKey.FILES: application_generate_entity.files, SystemVariableKey.CONVERSATION_ID: conversation.id, SystemVariableKey.USER_ID: user_id, - SystemVariableKey.DIALOGUE_COUNT: conversation.dialogue_count, + SystemVariableKey.DIALOGUE_COUNT: dialogue_count, SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, SystemVariableKey.WORKFLOW_ID: workflow.id, SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id, diff --git a/api/core/prompt/utils/get_thread_messages_length.py b/api/core/prompt/utils/get_thread_messages_length.py new file mode 100644 index 0000000000..f49466db6d --- /dev/null +++ b/api/core/prompt/utils/get_thread_messages_length.py @@ -0,0 +1,32 @@ +from core.prompt.utils.extract_thread_messages import extract_thread_messages +from extensions.ext_database import db +from models.model import Message + + +def get_thread_messages_length(conversation_id: str) -> int: + """ + Get the number of thread messages based on the parent message id. + """ + # Fetch all messages related to the conversation + query = ( + db.session.query( + Message.id, + Message.parent_message_id, + Message.answer, + ) + .filter( + Message.conversation_id == conversation_id, + ) + .order_by(Message.created_at.desc()) + ) + + messages = query.all() + + # Extract thread messages + thread_messages = extract_thread_messages(messages) + + # Exclude the newly created message with an empty answer + if thread_messages and not thread_messages[0].answer: + thread_messages.pop(0) + + return len(thread_messages)