fix: dialogue_count
incorrect in chatflow when there's... (#11175)
This commit is contained in:
parent
02572e8cca
commit
c4fad66f2a
@ -23,6 +23,7 @@ from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity,
|
|||||||
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
|
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
|
||||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
|
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from factories import file_factory
|
from factories import file_factory
|
||||||
from models.account import Account
|
from models.account import Account
|
||||||
@ -33,6 +34,8 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||||
|
_dialogue_count: int
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
app_model: App,
|
app_model: App,
|
||||||
@ -211,6 +214,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
db.session.commit()
|
db.session.commit()
|
||||||
db.session.refresh(conversation)
|
db.session.refresh(conversation)
|
||||||
|
|
||||||
|
# get conversation dialogue count
|
||||||
|
self._dialogue_count = get_thread_messages_length(conversation.id)
|
||||||
|
|
||||||
# init queue manager
|
# init queue manager
|
||||||
queue_manager = MessageBasedAppQueueManager(
|
queue_manager = MessageBasedAppQueueManager(
|
||||||
task_id=application_generate_entity.task_id,
|
task_id=application_generate_entity.task_id,
|
||||||
@ -281,6 +287,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
queue_manager=queue_manager,
|
queue_manager=queue_manager,
|
||||||
conversation=conversation,
|
conversation=conversation,
|
||||||
message=message,
|
message=message,
|
||||||
|
dialogue_count=self._dialogue_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
runner.run()
|
runner.run()
|
||||||
@ -334,6 +341,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
message=message,
|
message=message,
|
||||||
user=user,
|
user=user,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
|
dialogue_count=self._dialogue_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -39,12 +39,14 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||||||
queue_manager: AppQueueManager,
|
queue_manager: AppQueueManager,
|
||||||
conversation: Conversation,
|
conversation: Conversation,
|
||||||
message: Message,
|
message: Message,
|
||||||
|
dialogue_count: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(queue_manager)
|
super().__init__(queue_manager)
|
||||||
|
|
||||||
self.application_generate_entity = application_generate_entity
|
self.application_generate_entity = application_generate_entity
|
||||||
self.conversation = conversation
|
self.conversation = conversation
|
||||||
self.message = message
|
self.message = message
|
||||||
|
self._dialogue_count = dialogue_count
|
||||||
|
|
||||||
def run(self) -> None:
|
def run(self) -> None:
|
||||||
app_config = self.application_generate_entity.app_config
|
app_config = self.application_generate_entity.app_config
|
||||||
@ -122,19 +124,13 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||||||
|
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
# Increment dialogue count.
|
|
||||||
self.conversation.dialogue_count += 1
|
|
||||||
|
|
||||||
conversation_dialogue_count = self.conversation.dialogue_count
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
# Create a variable pool.
|
# Create a variable pool.
|
||||||
system_inputs = {
|
system_inputs = {
|
||||||
SystemVariableKey.QUERY: query,
|
SystemVariableKey.QUERY: query,
|
||||||
SystemVariableKey.FILES: files,
|
SystemVariableKey.FILES: files,
|
||||||
SystemVariableKey.CONVERSATION_ID: self.conversation.id,
|
SystemVariableKey.CONVERSATION_ID: self.conversation.id,
|
||||||
SystemVariableKey.USER_ID: user_id,
|
SystemVariableKey.USER_ID: user_id,
|
||||||
SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count,
|
SystemVariableKey.DIALOGUE_COUNT: self._dialogue_count,
|
||||||
SystemVariableKey.APP_ID: app_config.app_id,
|
SystemVariableKey.APP_ID: app_config.app_id,
|
||||||
SystemVariableKey.WORKFLOW_ID: app_config.workflow_id,
|
SystemVariableKey.WORKFLOW_ID: app_config.workflow_id,
|
||||||
SystemVariableKey.WORKFLOW_RUN_ID: self.application_generate_entity.workflow_run_id,
|
SystemVariableKey.WORKFLOW_RUN_ID: self.application_generate_entity.workflow_run_id,
|
||||||
|
@ -88,6 +88,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
message: Message,
|
message: Message,
|
||||||
user: Union[Account, EndUser],
|
user: Union[Account, EndUser],
|
||||||
stream: bool,
|
stream: bool,
|
||||||
|
dialogue_count: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize AdvancedChatAppGenerateTaskPipeline.
|
Initialize AdvancedChatAppGenerateTaskPipeline.
|
||||||
@ -98,6 +99,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
:param message: message
|
:param message: message
|
||||||
:param user: user
|
:param user: user
|
||||||
:param stream: stream
|
:param stream: stream
|
||||||
|
:param dialogue_count: dialogue count
|
||||||
"""
|
"""
|
||||||
super().__init__(application_generate_entity, queue_manager, user, stream)
|
super().__init__(application_generate_entity, queue_manager, user, stream)
|
||||||
|
|
||||||
@ -114,7 +116,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
SystemVariableKey.FILES: application_generate_entity.files,
|
SystemVariableKey.FILES: application_generate_entity.files,
|
||||||
SystemVariableKey.CONVERSATION_ID: conversation.id,
|
SystemVariableKey.CONVERSATION_ID: conversation.id,
|
||||||
SystemVariableKey.USER_ID: user_id,
|
SystemVariableKey.USER_ID: user_id,
|
||||||
SystemVariableKey.DIALOGUE_COUNT: conversation.dialogue_count,
|
SystemVariableKey.DIALOGUE_COUNT: dialogue_count,
|
||||||
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
|
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
|
||||||
SystemVariableKey.WORKFLOW_ID: workflow.id,
|
SystemVariableKey.WORKFLOW_ID: workflow.id,
|
||||||
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
|
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
|
||||||
|
32
api/core/prompt/utils/get_thread_messages_length.py
Normal file
32
api/core/prompt/utils/get_thread_messages_length.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models.model import Message
|
||||||
|
|
||||||
|
|
||||||
|
def get_thread_messages_length(conversation_id: str) -> int:
|
||||||
|
"""
|
||||||
|
Get the number of thread messages based on the parent message id.
|
||||||
|
"""
|
||||||
|
# Fetch all messages related to the conversation
|
||||||
|
query = (
|
||||||
|
db.session.query(
|
||||||
|
Message.id,
|
||||||
|
Message.parent_message_id,
|
||||||
|
Message.answer,
|
||||||
|
)
|
||||||
|
.filter(
|
||||||
|
Message.conversation_id == conversation_id,
|
||||||
|
)
|
||||||
|
.order_by(Message.created_at.desc())
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = query.all()
|
||||||
|
|
||||||
|
# Extract thread messages
|
||||||
|
thread_messages = extract_thread_messages(messages)
|
||||||
|
|
||||||
|
# Exclude the newly created message with an empty answer
|
||||||
|
if thread_messages and not thread_messages[0].answer:
|
||||||
|
thread_messages.pop(0)
|
||||||
|
|
||||||
|
return len(thread_messages)
|
Loading…
Reference in New Issue
Block a user