feat(api/workflow): Add Conversation.dialogue_count
(#7275)
This commit is contained in:
parent
8f5d8397f9
commit
32dc963556
@ -1,3 +1,7 @@
|
|||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
|
|
||||||
tenant_id: ContextVar[str] = ContextVar('tenant_id')
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
|
|
||||||
|
tenant_id: ContextVar[str] = ContextVar('tenant_id')
|
||||||
|
|
||||||
|
workflow_variable_pool: ContextVar[VariablePool] = ContextVar('workflow_variable_pool')
|
||||||
|
@ -8,6 +8,8 @@ from typing import Union
|
|||||||
|
|
||||||
from flask import Flask, current_app
|
from flask import Flask, current_app
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
import contexts
|
import contexts
|
||||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||||
@ -18,15 +20,20 @@ from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGe
|
|||||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
|
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
|
||||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
from core.app.entities.app_invoke_entities import (
|
||||||
|
AdvancedChatAppGenerateEntity,
|
||||||
|
InvokeFrom,
|
||||||
|
)
|
||||||
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
|
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
|
||||||
from core.file.message_file_parser import MessageFileParser
|
from core.file.message_file_parser import MessageFileParser
|
||||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
|
from core.workflow.enums import SystemVariable
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.account import Account
|
from models.account import Account
|
||||||
from models.model import App, Conversation, EndUser, Message
|
from models.model import App, Conversation, EndUser, Message
|
||||||
from models.workflow import Workflow
|
from models.workflow import ConversationVariable, Workflow
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -120,7 +127,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
conversation=conversation,
|
conversation=conversation,
|
||||||
stream=stream
|
stream=stream
|
||||||
)
|
)
|
||||||
|
|
||||||
def single_iteration_generate(self, app_model: App,
|
def single_iteration_generate(self, app_model: App,
|
||||||
workflow: Workflow,
|
workflow: Workflow,
|
||||||
node_id: str,
|
node_id: str,
|
||||||
@ -140,10 +147,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
"""
|
"""
|
||||||
if not node_id:
|
if not node_id:
|
||||||
raise ValueError('node_id is required')
|
raise ValueError('node_id is required')
|
||||||
|
|
||||||
if args.get('inputs') is None:
|
if args.get('inputs') is None:
|
||||||
raise ValueError('inputs is required')
|
raise ValueError('inputs is required')
|
||||||
|
|
||||||
extras = {
|
extras = {
|
||||||
"auto_generate_conversation_name": False
|
"auto_generate_conversation_name": False
|
||||||
}
|
}
|
||||||
@ -209,7 +216,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
# update conversation features
|
# update conversation features
|
||||||
conversation.override_model_configs = workflow.features
|
conversation.override_model_configs = workflow.features
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
db.session.refresh(conversation)
|
# db.session.refresh(conversation)
|
||||||
|
|
||||||
# init queue manager
|
# init queue manager
|
||||||
queue_manager = MessageBasedAppQueueManager(
|
queue_manager = MessageBasedAppQueueManager(
|
||||||
@ -221,15 +228,69 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
message_id=message.id
|
message_id=message.id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Init conversation variables
|
||||||
|
stmt = select(ConversationVariable).where(
|
||||||
|
ConversationVariable.app_id == conversation.app_id, ConversationVariable.conversation_id == conversation.id
|
||||||
|
)
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
conversation_variables = session.scalars(stmt).all()
|
||||||
|
if not conversation_variables:
|
||||||
|
# Create conversation variables if they don't exist.
|
||||||
|
conversation_variables = [
|
||||||
|
ConversationVariable.from_variable(
|
||||||
|
app_id=conversation.app_id, conversation_id=conversation.id, variable=variable
|
||||||
|
)
|
||||||
|
for variable in workflow.conversation_variables
|
||||||
|
]
|
||||||
|
session.add_all(conversation_variables)
|
||||||
|
# Convert database entities to variables.
|
||||||
|
conversation_variables = [item.to_variable() for item in conversation_variables]
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
# Increment dialogue count.
|
||||||
|
conversation.dialogue_count += 1
|
||||||
|
|
||||||
|
conversation_id = conversation.id
|
||||||
|
conversation_dialogue_count = conversation.dialogue_count
|
||||||
|
db.session.commit()
|
||||||
|
db.session.refresh(conversation)
|
||||||
|
|
||||||
|
inputs = application_generate_entity.inputs
|
||||||
|
query = application_generate_entity.query
|
||||||
|
files = application_generate_entity.files
|
||||||
|
|
||||||
|
user_id = None
|
||||||
|
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
|
||||||
|
end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
|
||||||
|
if end_user:
|
||||||
|
user_id = end_user.session_id
|
||||||
|
else:
|
||||||
|
user_id = application_generate_entity.user_id
|
||||||
|
|
||||||
|
# Create a variable pool.
|
||||||
|
system_inputs = {
|
||||||
|
SystemVariable.QUERY: query,
|
||||||
|
SystemVariable.FILES: files,
|
||||||
|
SystemVariable.CONVERSATION_ID: conversation_id,
|
||||||
|
SystemVariable.USER_ID: user_id,
|
||||||
|
SystemVariable.DIALOGUE_COUNT: conversation_dialogue_count,
|
||||||
|
}
|
||||||
|
variable_pool = VariablePool(
|
||||||
|
system_variables=system_inputs,
|
||||||
|
user_inputs=inputs,
|
||||||
|
environment_variables=workflow.environment_variables,
|
||||||
|
conversation_variables=conversation_variables,
|
||||||
|
)
|
||||||
|
contexts.workflow_variable_pool.set(variable_pool)
|
||||||
|
|
||||||
# new thread
|
# new thread
|
||||||
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
|
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
|
||||||
'flask_app': current_app._get_current_object(),
|
'flask_app': current_app._get_current_object(),
|
||||||
'application_generate_entity': application_generate_entity,
|
'application_generate_entity': application_generate_entity,
|
||||||
'queue_manager': queue_manager,
|
'queue_manager': queue_manager,
|
||||||
'conversation_id': conversation.id,
|
|
||||||
'message_id': message.id,
|
'message_id': message.id,
|
||||||
'user': user,
|
'context': contextvars.copy_context(),
|
||||||
'context': contextvars.copy_context()
|
|
||||||
})
|
})
|
||||||
|
|
||||||
worker_thread.start()
|
worker_thread.start()
|
||||||
@ -242,7 +303,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
conversation=conversation,
|
conversation=conversation,
|
||||||
message=message,
|
message=message,
|
||||||
user=user,
|
user=user,
|
||||||
stream=stream
|
stream=stream,
|
||||||
)
|
)
|
||||||
|
|
||||||
return AdvancedChatAppGenerateResponseConverter.convert(
|
return AdvancedChatAppGenerateResponseConverter.convert(
|
||||||
@ -253,9 +314,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
def _generate_worker(self, flask_app: Flask,
|
def _generate_worker(self, flask_app: Flask,
|
||||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||||
queue_manager: AppQueueManager,
|
queue_manager: AppQueueManager,
|
||||||
conversation_id: str,
|
|
||||||
message_id: str,
|
message_id: str,
|
||||||
user: Account,
|
|
||||||
context: contextvars.Context) -> None:
|
context: contextvars.Context) -> None:
|
||||||
"""
|
"""
|
||||||
Generate worker in a new thread.
|
Generate worker in a new thread.
|
||||||
@ -282,8 +341,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
user_id=application_generate_entity.user_id
|
user_id=application_generate_entity.user_id
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# get conversation and message
|
# get message
|
||||||
conversation = self._get_conversation(conversation_id)
|
|
||||||
message = self._get_message(message_id)
|
message = self._get_message(message_id)
|
||||||
|
|
||||||
# chatbot app
|
# chatbot app
|
||||||
@ -291,7 +349,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
runner.run(
|
runner.run(
|
||||||
application_generate_entity=application_generate_entity,
|
application_generate_entity=application_generate_entity,
|
||||||
queue_manager=queue_manager,
|
queue_manager=queue_manager,
|
||||||
conversation=conversation,
|
|
||||||
message=message
|
message=message
|
||||||
)
|
)
|
||||||
except GenerateTaskStoppedException:
|
except GenerateTaskStoppedException:
|
||||||
@ -314,14 +371,17 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
finally:
|
finally:
|
||||||
db.session.close()
|
db.session.close()
|
||||||
|
|
||||||
def _handle_advanced_chat_response(self, application_generate_entity: AdvancedChatAppGenerateEntity,
|
def _handle_advanced_chat_response(
|
||||||
workflow: Workflow,
|
self,
|
||||||
queue_manager: AppQueueManager,
|
*,
|
||||||
conversation: Conversation,
|
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||||
message: Message,
|
workflow: Workflow,
|
||||||
user: Union[Account, EndUser],
|
queue_manager: AppQueueManager,
|
||||||
stream: bool = False) \
|
conversation: Conversation,
|
||||||
-> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
message: Message,
|
||||||
|
user: Union[Account, EndUser],
|
||||||
|
stream: bool = False,
|
||||||
|
) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
||||||
"""
|
"""
|
||||||
Handle response.
|
Handle response.
|
||||||
:param application_generate_entity: application generate entity
|
:param application_generate_entity: application generate entity
|
||||||
@ -341,7 +401,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
conversation=conversation,
|
conversation=conversation,
|
||||||
message=message,
|
message=message,
|
||||||
user=user,
|
user=user,
|
||||||
stream=stream
|
stream=stream,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -4,9 +4,6 @@ import time
|
|||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import Any, Optional, cast
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
from sqlalchemy import select
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
||||||
from core.app.apps.advanced_chat.workflow_event_trigger_callback import WorkflowEventTriggerCallback
|
from core.app.apps.advanced_chat.workflow_event_trigger_callback import WorkflowEventTriggerCallback
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||||
@ -19,13 +16,10 @@ from core.app.entities.app_invoke_entities import (
|
|||||||
from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent
|
from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent
|
||||||
from core.moderation.base import ModerationException
|
from core.moderation.base import ModerationException
|
||||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||||
from core.workflow.entities.node_entities import SystemVariable
|
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
|
||||||
from core.workflow.nodes.base_node import UserFrom
|
from core.workflow.nodes.base_node import UserFrom
|
||||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.model import App, Conversation, EndUser, Message
|
from models import App, Message, Workflow
|
||||||
from models.workflow import ConversationVariable, Workflow
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -39,7 +33,6 @@ class AdvancedChatAppRunner(AppRunner):
|
|||||||
self,
|
self,
|
||||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||||
queue_manager: AppQueueManager,
|
queue_manager: AppQueueManager,
|
||||||
conversation: Conversation,
|
|
||||||
message: Message,
|
message: Message,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
@ -63,15 +56,6 @@ class AdvancedChatAppRunner(AppRunner):
|
|||||||
|
|
||||||
inputs = application_generate_entity.inputs
|
inputs = application_generate_entity.inputs
|
||||||
query = application_generate_entity.query
|
query = application_generate_entity.query
|
||||||
files = application_generate_entity.files
|
|
||||||
|
|
||||||
user_id = None
|
|
||||||
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
|
|
||||||
end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
|
|
||||||
if end_user:
|
|
||||||
user_id = end_user.session_id
|
|
||||||
else:
|
|
||||||
user_id = application_generate_entity.user_id
|
|
||||||
|
|
||||||
# moderation
|
# moderation
|
||||||
if self.handle_input_moderation(
|
if self.handle_input_moderation(
|
||||||
@ -103,38 +87,6 @@ class AdvancedChatAppRunner(AppRunner):
|
|||||||
if bool(os.environ.get('DEBUG', 'False').lower() == 'true'):
|
if bool(os.environ.get('DEBUG', 'False').lower() == 'true'):
|
||||||
workflow_callbacks.append(WorkflowLoggingCallback())
|
workflow_callbacks.append(WorkflowLoggingCallback())
|
||||||
|
|
||||||
# Init conversation variables
|
|
||||||
stmt = select(ConversationVariable).where(
|
|
||||||
ConversationVariable.app_id == conversation.app_id, ConversationVariable.conversation_id == conversation.id
|
|
||||||
)
|
|
||||||
with Session(db.engine) as session:
|
|
||||||
conversation_variables = session.scalars(stmt).all()
|
|
||||||
if not conversation_variables:
|
|
||||||
conversation_variables = [
|
|
||||||
ConversationVariable.from_variable(
|
|
||||||
app_id=conversation.app_id, conversation_id=conversation.id, variable=variable
|
|
||||||
)
|
|
||||||
for variable in workflow.conversation_variables
|
|
||||||
]
|
|
||||||
session.add_all(conversation_variables)
|
|
||||||
session.commit()
|
|
||||||
# Convert database entities to variables
|
|
||||||
conversation_variables = [item.to_variable() for item in conversation_variables]
|
|
||||||
|
|
||||||
# Create a variable pool.
|
|
||||||
system_inputs = {
|
|
||||||
SystemVariable.QUERY: query,
|
|
||||||
SystemVariable.FILES: files,
|
|
||||||
SystemVariable.CONVERSATION_ID: conversation.id,
|
|
||||||
SystemVariable.USER_ID: user_id,
|
|
||||||
}
|
|
||||||
variable_pool = VariablePool(
|
|
||||||
system_variables=system_inputs,
|
|
||||||
user_inputs=inputs,
|
|
||||||
environment_variables=workflow.environment_variables,
|
|
||||||
conversation_variables=conversation_variables,
|
|
||||||
)
|
|
||||||
|
|
||||||
# RUN WORKFLOW
|
# RUN WORKFLOW
|
||||||
workflow_engine_manager = WorkflowEngineManager()
|
workflow_engine_manager = WorkflowEngineManager()
|
||||||
workflow_engine_manager.run_workflow(
|
workflow_engine_manager.run_workflow(
|
||||||
@ -146,7 +98,6 @@ class AdvancedChatAppRunner(AppRunner):
|
|||||||
invoke_from=application_generate_entity.invoke_from,
|
invoke_from=application_generate_entity.invoke_from,
|
||||||
callbacks=workflow_callbacks,
|
callbacks=workflow_callbacks,
|
||||||
call_depth=application_generate_entity.call_depth,
|
call_depth=application_generate_entity.call_depth,
|
||||||
variable_pool=variable_pool,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def single_iteration_run(
|
def single_iteration_run(
|
||||||
@ -155,7 +106,7 @@ class AdvancedChatAppRunner(AppRunner):
|
|||||||
"""
|
"""
|
||||||
Single iteration run
|
Single iteration run
|
||||||
"""
|
"""
|
||||||
app_record: App = db.session.query(App).filter(App.id == app_id).first()
|
app_record = db.session.query(App).filter(App.id == app_id).first()
|
||||||
if not app_record:
|
if not app_record:
|
||||||
raise ValueError('App not found')
|
raise ValueError('App not found')
|
||||||
|
|
||||||
|
@ -4,6 +4,7 @@ import time
|
|||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import Any, Optional, Union, cast
|
from typing import Any, Optional, Union, cast
|
||||||
|
|
||||||
|
import contexts
|
||||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||||
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
|
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||||
@ -47,7 +48,8 @@ from core.file.file_obj import FileVar
|
|||||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
from core.workflow.entities.node_entities import NodeType, SystemVariable
|
from core.workflow.entities.node_entities import NodeType
|
||||||
|
from core.workflow.enums import SystemVariable
|
||||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||||
from core.workflow.nodes.answer.entities import TextGenerateRouteChunk, VarGenerateRouteChunk
|
from core.workflow.nodes.answer.entities import TextGenerateRouteChunk, VarGenerateRouteChunk
|
||||||
from events.message_event import message_was_created
|
from events.message_event import message_was_created
|
||||||
@ -71,6 +73,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
_application_generate_entity: AdvancedChatAppGenerateEntity
|
_application_generate_entity: AdvancedChatAppGenerateEntity
|
||||||
_workflow: Workflow
|
_workflow: Workflow
|
||||||
_user: Union[Account, EndUser]
|
_user: Union[Account, EndUser]
|
||||||
|
# Deprecated
|
||||||
_workflow_system_variables: dict[SystemVariable, Any]
|
_workflow_system_variables: dict[SystemVariable, Any]
|
||||||
_iteration_nested_relations: dict[str, list[str]]
|
_iteration_nested_relations: dict[str, list[str]]
|
||||||
|
|
||||||
@ -81,7 +84,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
conversation: Conversation,
|
conversation: Conversation,
|
||||||
message: Message,
|
message: Message,
|
||||||
user: Union[Account, EndUser],
|
user: Union[Account, EndUser],
|
||||||
stream: bool
|
stream: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize AdvancedChatAppGenerateTaskPipeline.
|
Initialize AdvancedChatAppGenerateTaskPipeline.
|
||||||
@ -103,11 +106,12 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
self._workflow = workflow
|
self._workflow = workflow
|
||||||
self._conversation = conversation
|
self._conversation = conversation
|
||||||
self._message = message
|
self._message = message
|
||||||
|
# Deprecated
|
||||||
self._workflow_system_variables = {
|
self._workflow_system_variables = {
|
||||||
SystemVariable.QUERY: message.query,
|
SystemVariable.QUERY: message.query,
|
||||||
SystemVariable.FILES: application_generate_entity.files,
|
SystemVariable.FILES: application_generate_entity.files,
|
||||||
SystemVariable.CONVERSATION_ID: conversation.id,
|
SystemVariable.CONVERSATION_ID: conversation.id,
|
||||||
SystemVariable.USER_ID: user_id
|
SystemVariable.USER_ID: user_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
self._task_state = AdvancedChatTaskState(
|
self._task_state = AdvancedChatTaskState(
|
||||||
@ -613,7 +617,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
|
|
||||||
if route_chunk_node_id == 'sys':
|
if route_chunk_node_id == 'sys':
|
||||||
# system variable
|
# system variable
|
||||||
value = self._workflow_system_variables.get(SystemVariable.value_of(value_selector[1]))
|
value = contexts.workflow_variable_pool.get().get(value_selector)
|
||||||
|
if value:
|
||||||
|
value = value.text
|
||||||
elif route_chunk_node_id in self._iteration_nested_relations:
|
elif route_chunk_node_id in self._iteration_nested_relations:
|
||||||
# it's a iteration variable
|
# it's a iteration variable
|
||||||
if not self._iteration_state or route_chunk_node_id not in self._iteration_state.current_iterations:
|
if not self._iteration_state or route_chunk_node_id not in self._iteration_state.current_iterations:
|
||||||
|
@ -258,7 +258,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
|||||||
|
|
||||||
return introduction
|
return introduction
|
||||||
|
|
||||||
def _get_conversation(self, conversation_id: str) -> Conversation:
|
def _get_conversation(self, conversation_id: str):
|
||||||
"""
|
"""
|
||||||
Get conversation by conversation id
|
Get conversation by conversation id
|
||||||
:param conversation_id: conversation id
|
:param conversation_id: conversation id
|
||||||
@ -270,6 +270,9 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
|||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not conversation:
|
||||||
|
raise ConversationNotExistsError()
|
||||||
|
|
||||||
return conversation
|
return conversation
|
||||||
|
|
||||||
def _get_message(self, message_id: str) -> Message:
|
def _get_message(self, message_id: str) -> Message:
|
||||||
|
@ -11,8 +11,8 @@ from core.app.entities.app_invoke_entities import (
|
|||||||
WorkflowAppGenerateEntity,
|
WorkflowAppGenerateEntity,
|
||||||
)
|
)
|
||||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||||
from core.workflow.entities.node_entities import SystemVariable
|
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
|
from core.workflow.enums import SystemVariable
|
||||||
from core.workflow.nodes.base_node import UserFrom
|
from core.workflow.nodes.base_node import UserFrom
|
||||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
@ -42,7 +42,8 @@ from core.app.entities.task_entities import (
|
|||||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||||
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
|
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
from core.workflow.entities.node_entities import NodeType, SystemVariable
|
from core.workflow.entities.node_entities import NodeType
|
||||||
|
from core.workflow.enums import SystemVariable
|
||||||
from core.workflow.nodes.end.end_node import EndNode
|
from core.workflow.nodes.end.end_node import EndNode
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.account import Account
|
from models.account import Account
|
||||||
@ -519,7 +520,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
"""
|
"""
|
||||||
nodes = graph.get('nodes')
|
nodes = graph.get('nodes')
|
||||||
|
|
||||||
iteration_ids = [node.get('id') for node in nodes
|
iteration_ids = [node.get('id') for node in nodes
|
||||||
if node.get('data', {}).get('type') in [
|
if node.get('data', {}).get('type') in [
|
||||||
NodeType.ITERATION.value,
|
NodeType.ITERATION.value,
|
||||||
NodeType.LOOP.value,
|
NodeType.LOOP.value,
|
||||||
@ -530,4 +531,3 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id
|
node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id
|
||||||
] for iteration_id in iteration_ids
|
] for iteration_id in iteration_ids
|
||||||
}
|
}
|
||||||
|
|
@ -2,7 +2,6 @@ from .segment_group import SegmentGroup
|
|||||||
from .segments import (
|
from .segments import (
|
||||||
ArrayAnySegment,
|
ArrayAnySegment,
|
||||||
ArraySegment,
|
ArraySegment,
|
||||||
FileSegment,
|
|
||||||
FloatSegment,
|
FloatSegment,
|
||||||
IntegerSegment,
|
IntegerSegment,
|
||||||
NoneSegment,
|
NoneSegment,
|
||||||
@ -13,11 +12,9 @@ from .segments import (
|
|||||||
from .types import SegmentType
|
from .types import SegmentType
|
||||||
from .variables import (
|
from .variables import (
|
||||||
ArrayAnyVariable,
|
ArrayAnyVariable,
|
||||||
ArrayFileVariable,
|
|
||||||
ArrayNumberVariable,
|
ArrayNumberVariable,
|
||||||
ArrayObjectVariable,
|
ArrayObjectVariable,
|
||||||
ArrayStringVariable,
|
ArrayStringVariable,
|
||||||
FileVariable,
|
|
||||||
FloatVariable,
|
FloatVariable,
|
||||||
IntegerVariable,
|
IntegerVariable,
|
||||||
NoneVariable,
|
NoneVariable,
|
||||||
@ -32,7 +29,6 @@ __all__ = [
|
|||||||
'FloatVariable',
|
'FloatVariable',
|
||||||
'ObjectVariable',
|
'ObjectVariable',
|
||||||
'SecretVariable',
|
'SecretVariable',
|
||||||
'FileVariable',
|
|
||||||
'StringVariable',
|
'StringVariable',
|
||||||
'ArrayAnyVariable',
|
'ArrayAnyVariable',
|
||||||
'Variable',
|
'Variable',
|
||||||
@ -45,11 +41,9 @@ __all__ = [
|
|||||||
'FloatSegment',
|
'FloatSegment',
|
||||||
'ObjectSegment',
|
'ObjectSegment',
|
||||||
'ArrayAnySegment',
|
'ArrayAnySegment',
|
||||||
'FileSegment',
|
|
||||||
'StringSegment',
|
'StringSegment',
|
||||||
'ArrayStringVariable',
|
'ArrayStringVariable',
|
||||||
'ArrayNumberVariable',
|
'ArrayNumberVariable',
|
||||||
'ArrayObjectVariable',
|
'ArrayObjectVariable',
|
||||||
'ArrayFileVariable',
|
|
||||||
'ArraySegment',
|
'ArraySegment',
|
||||||
]
|
]
|
||||||
|
@ -2,12 +2,10 @@ from collections.abc import Mapping
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.file.file_obj import FileVar
|
|
||||||
|
|
||||||
from .exc import VariableError
|
from .exc import VariableError
|
||||||
from .segments import (
|
from .segments import (
|
||||||
ArrayAnySegment,
|
ArrayAnySegment,
|
||||||
FileSegment,
|
|
||||||
FloatSegment,
|
FloatSegment,
|
||||||
IntegerSegment,
|
IntegerSegment,
|
||||||
NoneSegment,
|
NoneSegment,
|
||||||
@ -17,11 +15,9 @@ from .segments import (
|
|||||||
)
|
)
|
||||||
from .types import SegmentType
|
from .types import SegmentType
|
||||||
from .variables import (
|
from .variables import (
|
||||||
ArrayFileVariable,
|
|
||||||
ArrayNumberVariable,
|
ArrayNumberVariable,
|
||||||
ArrayObjectVariable,
|
ArrayObjectVariable,
|
||||||
ArrayStringVariable,
|
ArrayStringVariable,
|
||||||
FileVariable,
|
|
||||||
FloatVariable,
|
FloatVariable,
|
||||||
IntegerVariable,
|
IntegerVariable,
|
||||||
ObjectVariable,
|
ObjectVariable,
|
||||||
@ -49,8 +45,6 @@ def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
|
|||||||
result = FloatVariable.model_validate(mapping)
|
result = FloatVariable.model_validate(mapping)
|
||||||
case SegmentType.NUMBER if not isinstance(value, float | int):
|
case SegmentType.NUMBER if not isinstance(value, float | int):
|
||||||
raise VariableError(f'invalid number value {value}')
|
raise VariableError(f'invalid number value {value}')
|
||||||
case SegmentType.FILE:
|
|
||||||
result = FileVariable.model_validate(mapping)
|
|
||||||
case SegmentType.OBJECT if isinstance(value, dict):
|
case SegmentType.OBJECT if isinstance(value, dict):
|
||||||
result = ObjectVariable.model_validate(mapping)
|
result = ObjectVariable.model_validate(mapping)
|
||||||
case SegmentType.ARRAY_STRING if isinstance(value, list):
|
case SegmentType.ARRAY_STRING if isinstance(value, list):
|
||||||
@ -59,10 +53,6 @@ def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
|
|||||||
result = ArrayNumberVariable.model_validate(mapping)
|
result = ArrayNumberVariable.model_validate(mapping)
|
||||||
case SegmentType.ARRAY_OBJECT if isinstance(value, list):
|
case SegmentType.ARRAY_OBJECT if isinstance(value, list):
|
||||||
result = ArrayObjectVariable.model_validate(mapping)
|
result = ArrayObjectVariable.model_validate(mapping)
|
||||||
case SegmentType.ARRAY_FILE if isinstance(value, list):
|
|
||||||
mapping = dict(mapping)
|
|
||||||
mapping['value'] = [{'value': v} for v in value]
|
|
||||||
result = ArrayFileVariable.model_validate(mapping)
|
|
||||||
case _:
|
case _:
|
||||||
raise VariableError(f'not supported value type {value_type}')
|
raise VariableError(f'not supported value type {value_type}')
|
||||||
if result.size > dify_config.MAX_VARIABLE_SIZE:
|
if result.size > dify_config.MAX_VARIABLE_SIZE:
|
||||||
@ -83,6 +73,4 @@ def build_segment(value: Any, /) -> Segment:
|
|||||||
return ObjectSegment(value=value)
|
return ObjectSegment(value=value)
|
||||||
if isinstance(value, list):
|
if isinstance(value, list):
|
||||||
return ArrayAnySegment(value=value)
|
return ArrayAnySegment(value=value)
|
||||||
if isinstance(value, FileVar):
|
|
||||||
return FileSegment(value=value)
|
|
||||||
raise ValueError(f'not supported value {value}')
|
raise ValueError(f'not supported value {value}')
|
||||||
|
@ -5,8 +5,6 @@ from typing import Any
|
|||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, field_validator
|
from pydantic import BaseModel, ConfigDict, field_validator
|
||||||
|
|
||||||
from core.file.file_obj import FileVar
|
|
||||||
|
|
||||||
from .types import SegmentType
|
from .types import SegmentType
|
||||||
|
|
||||||
|
|
||||||
@ -78,14 +76,7 @@ class IntegerSegment(Segment):
|
|||||||
value: int
|
value: int
|
||||||
|
|
||||||
|
|
||||||
class FileSegment(Segment):
|
|
||||||
value_type: SegmentType = SegmentType.FILE
|
|
||||||
# TODO: embed FileVar in this model.
|
|
||||||
value: FileVar
|
|
||||||
|
|
||||||
@property
|
|
||||||
def markdown(self) -> str:
|
|
||||||
return self.value.to_markdown()
|
|
||||||
|
|
||||||
|
|
||||||
class ObjectSegment(Segment):
|
class ObjectSegment(Segment):
|
||||||
@ -130,7 +121,3 @@ class ArrayObjectSegment(ArraySegment):
|
|||||||
value_type: SegmentType = SegmentType.ARRAY_OBJECT
|
value_type: SegmentType = SegmentType.ARRAY_OBJECT
|
||||||
value: Sequence[Mapping[str, Any]]
|
value: Sequence[Mapping[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
class ArrayFileSegment(ArraySegment):
|
|
||||||
value_type: SegmentType = SegmentType.ARRAY_FILE
|
|
||||||
value: Sequence[FileSegment]
|
|
||||||
|
@ -10,8 +10,6 @@ class SegmentType(str, Enum):
|
|||||||
ARRAY_STRING = 'array[string]'
|
ARRAY_STRING = 'array[string]'
|
||||||
ARRAY_NUMBER = 'array[number]'
|
ARRAY_NUMBER = 'array[number]'
|
||||||
ARRAY_OBJECT = 'array[object]'
|
ARRAY_OBJECT = 'array[object]'
|
||||||
ARRAY_FILE = 'array[file]'
|
|
||||||
OBJECT = 'object'
|
OBJECT = 'object'
|
||||||
FILE = 'file'
|
|
||||||
|
|
||||||
GROUP = 'group'
|
GROUP = 'group'
|
||||||
|
@ -4,11 +4,9 @@ from core.helper import encrypter
|
|||||||
|
|
||||||
from .segments import (
|
from .segments import (
|
||||||
ArrayAnySegment,
|
ArrayAnySegment,
|
||||||
ArrayFileSegment,
|
|
||||||
ArrayNumberSegment,
|
ArrayNumberSegment,
|
||||||
ArrayObjectSegment,
|
ArrayObjectSegment,
|
||||||
ArrayStringSegment,
|
ArrayStringSegment,
|
||||||
FileSegment,
|
|
||||||
FloatSegment,
|
FloatSegment,
|
||||||
IntegerSegment,
|
IntegerSegment,
|
||||||
NoneSegment,
|
NoneSegment,
|
||||||
@ -44,10 +42,6 @@ class IntegerVariable(IntegerSegment, Variable):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class FileVariable(FileSegment, Variable):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ObjectVariable(ObjectSegment, Variable):
|
class ObjectVariable(ObjectSegment, Variable):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -68,9 +62,6 @@ class ArrayObjectVariable(ArrayObjectSegment, Variable):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ArrayFileVariable(ArrayFileSegment, Variable):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class SecretVariable(StringVariable):
|
class SecretVariable(StringVariable):
|
||||||
value_type: SegmentType = SegmentType.SECRET
|
value_type: SegmentType = SegmentType.SECRET
|
||||||
|
@ -2,7 +2,7 @@ from typing import Any, Union
|
|||||||
|
|
||||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
|
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
|
||||||
from core.app.entities.task_entities import AdvancedChatTaskState, WorkflowTaskState
|
from core.app.entities.task_entities import AdvancedChatTaskState, WorkflowTaskState
|
||||||
from core.workflow.entities.node_entities import SystemVariable
|
from core.workflow.enums import SystemVariable
|
||||||
from models.account import Account
|
from models.account import Account
|
||||||
from models.model import EndUser
|
from models.model import EndUser
|
||||||
from models.workflow import Workflow
|
from models.workflow import Workflow
|
||||||
@ -13,4 +13,4 @@ class WorkflowCycleStateManager:
|
|||||||
_workflow: Workflow
|
_workflow: Workflow
|
||||||
_user: Union[Account, EndUser]
|
_user: Union[Account, EndUser]
|
||||||
_task_state: Union[AdvancedChatTaskState, WorkflowTaskState]
|
_task_state: Union[AdvancedChatTaskState, WorkflowTaskState]
|
||||||
_workflow_system_variables: dict[SystemVariable, Any]
|
_workflow_system_variables: dict[SystemVariable, Any]
|
||||||
|
@ -4,13 +4,14 @@ from typing import Any, Optional
|
|||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from models.workflow import WorkflowNodeExecutionStatus
|
from models import WorkflowNodeExecutionStatus
|
||||||
|
|
||||||
|
|
||||||
class NodeType(Enum):
|
class NodeType(Enum):
|
||||||
"""
|
"""
|
||||||
Node Types.
|
Node Types.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
START = 'start'
|
START = 'start'
|
||||||
END = 'end'
|
END = 'end'
|
||||||
ANSWER = 'answer'
|
ANSWER = 'answer'
|
||||||
@ -44,33 +45,11 @@ class NodeType(Enum):
|
|||||||
raise ValueError(f'invalid node type value {value}')
|
raise ValueError(f'invalid node type value {value}')
|
||||||
|
|
||||||
|
|
||||||
class SystemVariable(Enum):
|
|
||||||
"""
|
|
||||||
System Variables.
|
|
||||||
"""
|
|
||||||
QUERY = 'query'
|
|
||||||
FILES = 'files'
|
|
||||||
CONVERSATION_ID = 'conversation_id'
|
|
||||||
USER_ID = 'user_id'
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def value_of(cls, value: str) -> 'SystemVariable':
|
|
||||||
"""
|
|
||||||
Get value of given system variable.
|
|
||||||
|
|
||||||
:param value: system variable value
|
|
||||||
:return: system variable
|
|
||||||
"""
|
|
||||||
for system_variable in cls:
|
|
||||||
if system_variable.value == value:
|
|
||||||
return system_variable
|
|
||||||
raise ValueError(f'invalid system variable value {value}')
|
|
||||||
|
|
||||||
|
|
||||||
class NodeRunMetadataKey(Enum):
|
class NodeRunMetadataKey(Enum):
|
||||||
"""
|
"""
|
||||||
Node Run Metadata Key.
|
Node Run Metadata Key.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
TOTAL_TOKENS = 'total_tokens'
|
TOTAL_TOKENS = 'total_tokens'
|
||||||
TOTAL_PRICE = 'total_price'
|
TOTAL_PRICE = 'total_price'
|
||||||
CURRENCY = 'currency'
|
CURRENCY = 'currency'
|
||||||
@ -83,6 +62,7 @@ class NodeRunResult(BaseModel):
|
|||||||
"""
|
"""
|
||||||
Node Run Result.
|
Node Run Result.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING
|
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING
|
||||||
|
|
||||||
inputs: Optional[Mapping[str, Any]] = None # node inputs
|
inputs: Optional[Mapping[str, Any]] = None # node inputs
|
||||||
|
@ -6,7 +6,7 @@ from typing_extensions import deprecated
|
|||||||
|
|
||||||
from core.app.segments import Segment, Variable, factory
|
from core.app.segments import Segment, Variable, factory
|
||||||
from core.file.file_obj import FileVar
|
from core.file.file_obj import FileVar
|
||||||
from core.workflow.entities.node_entities import SystemVariable
|
from core.workflow.enums import SystemVariable
|
||||||
|
|
||||||
VariableValue = Union[str, int, float, dict, list, FileVar]
|
VariableValue = Union[str, int, float, dict, list, FileVar]
|
||||||
|
|
||||||
|
25
api/core/workflow/enums.py
Normal file
25
api/core/workflow/enums.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class SystemVariable(str, Enum):
|
||||||
|
"""
|
||||||
|
System Variables.
|
||||||
|
"""
|
||||||
|
QUERY = 'query'
|
||||||
|
FILES = 'files'
|
||||||
|
CONVERSATION_ID = 'conversation_id'
|
||||||
|
USER_ID = 'user_id'
|
||||||
|
DIALOGUE_COUNT = 'dialogue_count'
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def value_of(cls, value: str):
|
||||||
|
"""
|
||||||
|
Get value of given system variable.
|
||||||
|
|
||||||
|
:param value: system variable value
|
||||||
|
:return: system variable
|
||||||
|
"""
|
||||||
|
for system_variable in cls:
|
||||||
|
if system_variable.value == value:
|
||||||
|
return system_variable
|
||||||
|
raise ValueError(f'invalid system variable value {value}')
|
@ -23,8 +23,9 @@ from core.model_runtime.utils.encoders import jsonable_encoder
|
|||||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||||
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
||||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
|
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
|
from core.workflow.enums import SystemVariable
|
||||||
from core.workflow.nodes.base_node import BaseNode
|
from core.workflow.nodes.base_node import BaseNode
|
||||||
from core.workflow.nodes.llm.entities import (
|
from core.workflow.nodes.llm.entities import (
|
||||||
LLMNodeChatModelMessage,
|
LLMNodeChatModelMessage,
|
||||||
@ -201,8 +202,8 @@ class LLMNode(BaseNode):
|
|||||||
usage = LLMUsage.empty_usage()
|
usage = LLMUsage.empty_usage()
|
||||||
|
|
||||||
return full_text, usage
|
return full_text, usage
|
||||||
|
|
||||||
def _transform_chat_messages(self,
|
def _transform_chat_messages(self,
|
||||||
messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
|
messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
|
||||||
) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
|
) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
|
||||||
"""
|
"""
|
||||||
@ -249,13 +250,13 @@ class LLMNode(BaseNode):
|
|||||||
# check if it's a context structure
|
# check if it's a context structure
|
||||||
if 'metadata' in d and '_source' in d['metadata'] and 'content' in d:
|
if 'metadata' in d and '_source' in d['metadata'] and 'content' in d:
|
||||||
return d['content']
|
return d['content']
|
||||||
|
|
||||||
# else, parse the dict
|
# else, parse the dict
|
||||||
try:
|
try:
|
||||||
return json.dumps(d, ensure_ascii=False)
|
return json.dumps(d, ensure_ascii=False)
|
||||||
except Exception:
|
except Exception:
|
||||||
return str(d)
|
return str(d)
|
||||||
|
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
value = value
|
value = value
|
||||||
elif isinstance(value, list):
|
elif isinstance(value, list):
|
||||||
|
@ -2,19 +2,20 @@ from collections.abc import Mapping, Sequence
|
|||||||
from os import path
|
from os import path
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from core.app.segments import parser
|
from core.app.segments import ArrayAnyVariable, parser
|
||||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||||
from core.file.file_obj import FileTransferMethod, FileType, FileVar
|
from core.file.file_obj import FileTransferMethod, FileType, FileVar
|
||||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||||
from core.tools.tool_engine import ToolEngine
|
from core.tools.tool_engine import ToolEngine
|
||||||
from core.tools.tool_manager import ToolManager
|
from core.tools.tool_manager import ToolManager
|
||||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
|
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
|
from core.workflow.enums import SystemVariable
|
||||||
from core.workflow.nodes.base_node import BaseNode
|
from core.workflow.nodes.base_node import BaseNode
|
||||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||||
from models.workflow import WorkflowNodeExecutionStatus
|
from models import WorkflowNodeExecutionStatus
|
||||||
|
|
||||||
|
|
||||||
class ToolNode(BaseNode):
|
class ToolNode(BaseNode):
|
||||||
@ -140,9 +141,9 @@ class ToolNode(BaseNode):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]:
|
def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]:
|
||||||
# FIXME: ensure this is a ArrayVariable contains FileVariable.
|
|
||||||
variable = variable_pool.get(['sys', SystemVariable.FILES.value])
|
variable = variable_pool.get(['sys', SystemVariable.FILES.value])
|
||||||
return [file_var.value for file_var in variable.value] if variable else []
|
assert isinstance(variable, ArrayAnyVariable)
|
||||||
|
return list(variable.value) if variable else []
|
||||||
|
|
||||||
def _convert_tool_messages(self, messages: list[ToolInvokeMessage]):
|
def _convert_tool_messages(self, messages: list[ToolInvokeMessage]):
|
||||||
"""
|
"""
|
||||||
|
@ -3,6 +3,7 @@ import time
|
|||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from typing import Any, Optional, cast
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
|
import contexts
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
|
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
@ -97,7 +98,7 @@ class WorkflowEngineManager:
|
|||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
callbacks: Sequence[WorkflowCallback],
|
callbacks: Sequence[WorkflowCallback],
|
||||||
call_depth: int = 0,
|
call_depth: int = 0,
|
||||||
variable_pool: VariablePool,
|
variable_pool: VariablePool | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
:param workflow: Workflow instance
|
:param workflow: Workflow instance
|
||||||
@ -128,6 +129,8 @@ class WorkflowEngineManager:
|
|||||||
raise ValueError('Max workflow call depth {} reached.'.format(workflow_call_max_depth))
|
raise ValueError('Max workflow call depth {} reached.'.format(workflow_call_max_depth))
|
||||||
|
|
||||||
# init workflow run state
|
# init workflow run state
|
||||||
|
if not variable_pool:
|
||||||
|
variable_pool = contexts.workflow_variable_pool.get()
|
||||||
workflow_run_state = WorkflowRunState(
|
workflow_run_state = WorkflowRunState(
|
||||||
workflow=workflow,
|
workflow=workflow,
|
||||||
start_at=time.perf_counter(),
|
start_at=time.perf_counter(),
|
||||||
|
@ -0,0 +1,33 @@
|
|||||||
|
"""add conversations.dialogue_count
|
||||||
|
|
||||||
|
Revision ID: 8782057ff0dc
|
||||||
|
Revises: 63a83fcf12ba
|
||||||
|
Create Date: 2024-08-14 13:54:25.161324
|
||||||
|
|
||||||
|
"""
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
import models as models
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = '8782057ff0dc'
|
||||||
|
down_revision = '63a83fcf12ba'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('conversations', schema=None) as batch_op:
|
||||||
|
batch_op.add_column(sa.Column('dialogue_count', sa.Integer(), server_default='0', nullable=False))
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('conversations', schema=None) as batch_op:
|
||||||
|
batch_op.drop_column('dialogue_count')
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
@ -1,10 +1,10 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
from .model import AppMode
|
from .model import App, AppMode, Message
|
||||||
from .types import StringUUID
|
from .types import StringUUID
|
||||||
from .workflow import ConversationVariable, WorkflowNodeExecutionStatus
|
from .workflow import ConversationVariable, Workflow, WorkflowNodeExecutionStatus
|
||||||
|
|
||||||
__all__ = ['ConversationVariable', 'StringUUID', 'AppMode', 'WorkflowNodeExecutionStatus']
|
__all__ = ['ConversationVariable', 'StringUUID', 'AppMode', 'WorkflowNodeExecutionStatus', 'Workflow', 'App', 'Message']
|
||||||
|
|
||||||
|
|
||||||
class CreatedByRole(Enum):
|
class CreatedByRole(Enum):
|
||||||
|
@ -7,6 +7,7 @@ from typing import Optional
|
|||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import UserMixin
|
from flask_login import UserMixin
|
||||||
from sqlalchemy import Float, func, text
|
from sqlalchemy import Float, func, text
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.file.tool_file_parser import ToolFileParser
|
from core.file.tool_file_parser import ToolFileParser
|
||||||
@ -512,12 +513,12 @@ class Conversation(db.Model):
|
|||||||
from_account_id = db.Column(StringUUID)
|
from_account_id = db.Column(StringUUID)
|
||||||
read_at = db.Column(db.DateTime)
|
read_at = db.Column(db.DateTime)
|
||||||
read_account_id = db.Column(StringUUID)
|
read_account_id = db.Column(StringUUID)
|
||||||
|
dialogue_count: Mapped[int] = mapped_column(default=0)
|
||||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||||
|
|
||||||
messages = db.relationship("Message", backref="conversation", lazy='select', passive_deletes="all")
|
messages = db.relationship("Message", backref="conversation", lazy='select', passive_deletes="all")
|
||||||
message_annotations = db.relationship("MessageAnnotation", backref="conversation", lazy='select',
|
message_annotations = db.relationship("MessageAnnotation", backref="conversation", lazy='select', passive_deletes="all")
|
||||||
passive_deletes="all")
|
|
||||||
|
|
||||||
is_deleted = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
|
is_deleted = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
|
||||||
|
|
||||||
|
@ -10,8 +10,8 @@ from core.entities.provider_entities import CustomConfiguration, CustomProviderC
|
|||||||
from core.model_manager import ModelInstance
|
from core.model_manager import ModelInstance
|
||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
from core.model_runtime.model_providers import ModelProviderFactory
|
from core.model_runtime.model_providers import ModelProviderFactory
|
||||||
from core.workflow.entities.node_entities import SystemVariable
|
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
|
from core.workflow.enums import SystemVariable
|
||||||
from core.workflow.nodes.base_node import UserFrom
|
from core.workflow.nodes.base_node import UserFrom
|
||||||
from core.workflow.nodes.llm.llm_node import LLMNode
|
from core.workflow.nodes.llm.llm_node import LLMNode
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
@ -236,4 +236,4 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock):
|
|||||||
|
|
||||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||||
assert 'sunny' in json.dumps(result.process_data)
|
assert 'sunny' in json.dumps(result.process_data)
|
||||||
assert 'what\'s the weather today?' in json.dumps(result.process_data)
|
assert 'what\'s the weather today?' in json.dumps(result.process_data)
|
||||||
|
@ -12,8 +12,8 @@ from core.memory.token_buffer_memory import TokenBufferMemory
|
|||||||
from core.model_manager import ModelInstance
|
from core.model_manager import ModelInstance
|
||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||||
from core.workflow.entities.node_entities import SystemVariable
|
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
|
from core.workflow.enums import SystemVariable
|
||||||
from core.workflow.nodes.base_node import UserFrom
|
from core.workflow.nodes.base_node import UserFrom
|
||||||
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
@ -363,7 +363,7 @@ def test_extract_json_response():
|
|||||||
{
|
{
|
||||||
"location": "kawaii"
|
"location": "kawaii"
|
||||||
}
|
}
|
||||||
hello world.
|
hello world.
|
||||||
""")
|
""")
|
||||||
|
|
||||||
assert result['location'] == 'kawaii'
|
assert result['location'] == 'kawaii'
|
||||||
@ -445,4 +445,4 @@ def test_chat_parameter_extractor_with_memory(setup_anthropic_mock):
|
|||||||
assert latest_role != prompt.get('role')
|
assert latest_role != prompt.get('role')
|
||||||
|
|
||||||
if prompt.get('role') in ['user', 'assistant']:
|
if prompt.get('role') in ['user', 'assistant']:
|
||||||
latest_role = prompt.get('role')
|
latest_role = prompt.get('role')
|
||||||
|
@ -3,12 +3,9 @@ from uuid import uuid4
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from core.app.segments import (
|
from core.app.segments import (
|
||||||
ArrayFileVariable,
|
|
||||||
ArrayNumberVariable,
|
ArrayNumberVariable,
|
||||||
ArrayObjectVariable,
|
ArrayObjectVariable,
|
||||||
ArrayStringVariable,
|
ArrayStringVariable,
|
||||||
FileSegment,
|
|
||||||
FileVariable,
|
|
||||||
FloatVariable,
|
FloatVariable,
|
||||||
IntegerVariable,
|
IntegerVariable,
|
||||||
ObjectSegment,
|
ObjectSegment,
|
||||||
@ -149,83 +146,6 @@ def test_array_object_variable():
|
|||||||
assert isinstance(variable.value[1]['key2'], int)
|
assert isinstance(variable.value[1]['key2'], int)
|
||||||
|
|
||||||
|
|
||||||
def test_file_variable():
|
|
||||||
mapping = {
|
|
||||||
'id': str(uuid4()),
|
|
||||||
'value_type': 'file',
|
|
||||||
'name': 'test_file',
|
|
||||||
'description': 'Description of the variable.',
|
|
||||||
'value': {
|
|
||||||
'id': str(uuid4()),
|
|
||||||
'tenant_id': 'tenant_id',
|
|
||||||
'type': 'image',
|
|
||||||
'transfer_method': 'local_file',
|
|
||||||
'url': 'url',
|
|
||||||
'related_id': 'related_id',
|
|
||||||
'extra_config': {
|
|
||||||
'image_config': {
|
|
||||||
'width': 100,
|
|
||||||
'height': 100,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
'filename': 'filename',
|
|
||||||
'extension': 'extension',
|
|
||||||
'mime_type': 'mime_type',
|
|
||||||
},
|
|
||||||
}
|
|
||||||
variable = factory.build_variable_from_mapping(mapping)
|
|
||||||
assert isinstance(variable, FileVariable)
|
|
||||||
|
|
||||||
|
|
||||||
def test_array_file_variable():
|
|
||||||
mapping = {
|
|
||||||
'id': str(uuid4()),
|
|
||||||
'value_type': 'array[file]',
|
|
||||||
'name': 'test_array_file',
|
|
||||||
'description': 'Description of the variable.',
|
|
||||||
'value': [
|
|
||||||
{
|
|
||||||
'id': str(uuid4()),
|
|
||||||
'tenant_id': 'tenant_id',
|
|
||||||
'type': 'image',
|
|
||||||
'transfer_method': 'local_file',
|
|
||||||
'url': 'url',
|
|
||||||
'related_id': 'related_id',
|
|
||||||
'extra_config': {
|
|
||||||
'image_config': {
|
|
||||||
'width': 100,
|
|
||||||
'height': 100,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
'filename': 'filename',
|
|
||||||
'extension': 'extension',
|
|
||||||
'mime_type': 'mime_type',
|
|
||||||
},
|
|
||||||
{
|
|
||||||
'id': str(uuid4()),
|
|
||||||
'tenant_id': 'tenant_id',
|
|
||||||
'type': 'image',
|
|
||||||
'transfer_method': 'local_file',
|
|
||||||
'url': 'url',
|
|
||||||
'related_id': 'related_id',
|
|
||||||
'extra_config': {
|
|
||||||
'image_config': {
|
|
||||||
'width': 100,
|
|
||||||
'height': 100,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
'filename': 'filename',
|
|
||||||
'extension': 'extension',
|
|
||||||
'mime_type': 'mime_type',
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
variable = factory.build_variable_from_mapping(mapping)
|
|
||||||
assert isinstance(variable, ArrayFileVariable)
|
|
||||||
assert isinstance(variable.value[0], FileSegment)
|
|
||||||
assert isinstance(variable.value[1], FileSegment)
|
|
||||||
|
|
||||||
|
|
||||||
def test_variable_cannot_large_than_5_kb():
|
def test_variable_cannot_large_than_5_kb():
|
||||||
with pytest.raises(VariableError):
|
with pytest.raises(VariableError):
|
||||||
factory.build_variable_from_mapping(
|
factory.build_variable_from_mapping(
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from core.app.segments import SecretVariable, StringSegment, parser
|
from core.app.segments import SecretVariable, StringSegment, parser
|
||||||
from core.helper import encrypter
|
from core.helper import encrypter
|
||||||
from core.workflow.entities.node_entities import SystemVariable
|
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
|
from core.workflow.enums import SystemVariable
|
||||||
|
|
||||||
|
|
||||||
def test_segment_group_to_text():
|
def test_segment_group_to_text():
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.workflow.entities.node_entities import SystemVariable
|
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
|
from core.workflow.enums import SystemVariable
|
||||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||||
from core.workflow.nodes.base_node import UserFrom
|
from core.workflow.nodes.base_node import UserFrom
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.workflow.entities.node_entities import SystemVariable
|
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
|
from core.workflow.enums import SystemVariable
|
||||||
from core.workflow.nodes.base_node import UserFrom
|
from core.workflow.nodes.base_node import UserFrom
|
||||||
from core.workflow.nodes.if_else.if_else_node import IfElseNode
|
from core.workflow.nodes.if_else.if_else_node import IfElseNode
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
@ -3,8 +3,8 @@ from uuid import uuid4
|
|||||||
|
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.app.segments import ArrayStringVariable, StringVariable
|
from core.app.segments import ArrayStringVariable, StringVariable
|
||||||
from core.workflow.entities.node_entities import SystemVariable
|
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
|
from core.workflow.enums import SystemVariable
|
||||||
from core.workflow.nodes.base_node import UserFrom
|
from core.workflow.nodes.base_node import UserFrom
|
||||||
from core.workflow.nodes.variable_assigner import VariableAssignerNode, WriteMode
|
from core.workflow.nodes.variable_assigner import VariableAssignerNode, WriteMode
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user