feat(api/workflow): Add Conversation.dialogue_count (#7275)

This commit is contained in:
-LAN- 2024-08-15 10:53:05 +08:00 committed by GitHub
parent 8f5d8397f9
commit 32dc963556
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
29 changed files with 205 additions and 259 deletions

View File

@ -1,3 +1,7 @@
from contextvars import ContextVar
from core.workflow.entities.variable_pool import VariablePool
tenant_id: ContextVar[str] = ContextVar('tenant_id')
workflow_variable_pool: ContextVar[VariablePool] = ContextVar('workflow_variable_pool')

View File

@ -8,6 +8,8 @@ from typing import Union
from flask import Flask, current_app
from pydantic import ValidationError
from sqlalchemy import select
from sqlalchemy.orm import Session
import contexts
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
@ -18,15 +20,20 @@ from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGe
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
InvokeFrom,
)
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
from core.file.message_file_parser import MessageFileParser
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from extensions.ext_database import db
from models.account import Account
from models.model import App, Conversation, EndUser, Message
from models.workflow import Workflow
from models.workflow import ConversationVariable, Workflow
logger = logging.getLogger(__name__)
@ -209,7 +216,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# update conversation features
conversation.override_model_configs = workflow.features
db.session.commit()
db.session.refresh(conversation)
# db.session.refresh(conversation)
# init queue manager
queue_manager = MessageBasedAppQueueManager(
@ -221,15 +228,69 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
message_id=message.id
)
# Init conversation variables
stmt = select(ConversationVariable).where(
ConversationVariable.app_id == conversation.app_id, ConversationVariable.conversation_id == conversation.id
)
with Session(db.engine) as session:
conversation_variables = session.scalars(stmt).all()
if not conversation_variables:
# Create conversation variables if they don't exist.
conversation_variables = [
ConversationVariable.from_variable(
app_id=conversation.app_id, conversation_id=conversation.id, variable=variable
)
for variable in workflow.conversation_variables
]
session.add_all(conversation_variables)
# Convert database entities to variables.
conversation_variables = [item.to_variable() for item in conversation_variables]
session.commit()
# Increment dialogue count.
conversation.dialogue_count += 1
conversation_id = conversation.id
conversation_dialogue_count = conversation.dialogue_count
db.session.commit()
db.session.refresh(conversation)
inputs = application_generate_entity.inputs
query = application_generate_entity.query
files = application_generate_entity.files
user_id = None
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = application_generate_entity.user_id
# Create a variable pool.
system_inputs = {
SystemVariable.QUERY: query,
SystemVariable.FILES: files,
SystemVariable.CONVERSATION_ID: conversation_id,
SystemVariable.USER_ID: user_id,
SystemVariable.DIALOGUE_COUNT: conversation_dialogue_count,
}
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
conversation_variables=conversation_variables,
)
contexts.workflow_variable_pool.set(variable_pool)
# new thread
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
'flask_app': current_app._get_current_object(),
'application_generate_entity': application_generate_entity,
'queue_manager': queue_manager,
'conversation_id': conversation.id,
'message_id': message.id,
'user': user,
'context': contextvars.copy_context()
'context': contextvars.copy_context(),
})
worker_thread.start()
@ -242,7 +303,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
conversation=conversation,
message=message,
user=user,
stream=stream
stream=stream,
)
return AdvancedChatAppGenerateResponseConverter.convert(
@ -253,9 +314,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
def _generate_worker(self, flask_app: Flask,
application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation_id: str,
message_id: str,
user: Account,
context: contextvars.Context) -> None:
"""
Generate worker in a new thread.
@ -282,8 +341,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
user_id=application_generate_entity.user_id
)
else:
# get conversation and message
conversation = self._get_conversation(conversation_id)
# get message
message = self._get_message(message_id)
# chatbot app
@ -291,7 +349,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
runner.run(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message
)
except GenerateTaskStoppedException:
@ -314,14 +371,17 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
finally:
db.session.close()
def _handle_advanced_chat_response(self, application_generate_entity: AdvancedChatAppGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
stream: bool = False) \
-> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
def _handle_advanced_chat_response(
self,
*,
application_generate_entity: AdvancedChatAppGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
stream: bool = False,
) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
"""
Handle response.
:param application_generate_entity: application generate entity
@ -341,7 +401,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
conversation=conversation,
message=message,
user=user,
stream=stream
stream=stream,
)
try:

View File

@ -4,9 +4,6 @@ import time
from collections.abc import Mapping
from typing import Any, Optional, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.advanced_chat.workflow_event_trigger_callback import WorkflowEventTriggerCallback
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@ -19,13 +16,10 @@ from core.app.entities.app_invoke_entities import (
from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent
from core.moderation.base import ModerationException
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import UserFrom
from core.workflow.workflow_engine_manager import WorkflowEngineManager
from extensions.ext_database import db
from models.model import App, Conversation, EndUser, Message
from models.workflow import ConversationVariable, Workflow
from models import App, Message, Workflow
logger = logging.getLogger(__name__)
@ -39,7 +33,6 @@ class AdvancedChatAppRunner(AppRunner):
self,
application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
) -> None:
"""
@ -63,15 +56,6 @@ class AdvancedChatAppRunner(AppRunner):
inputs = application_generate_entity.inputs
query = application_generate_entity.query
files = application_generate_entity.files
user_id = None
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = application_generate_entity.user_id
# moderation
if self.handle_input_moderation(
@ -103,38 +87,6 @@ class AdvancedChatAppRunner(AppRunner):
if bool(os.environ.get('DEBUG', 'False').lower() == 'true'):
workflow_callbacks.append(WorkflowLoggingCallback())
# Init conversation variables
stmt = select(ConversationVariable).where(
ConversationVariable.app_id == conversation.app_id, ConversationVariable.conversation_id == conversation.id
)
with Session(db.engine) as session:
conversation_variables = session.scalars(stmt).all()
if not conversation_variables:
conversation_variables = [
ConversationVariable.from_variable(
app_id=conversation.app_id, conversation_id=conversation.id, variable=variable
)
for variable in workflow.conversation_variables
]
session.add_all(conversation_variables)
session.commit()
# Convert database entities to variables
conversation_variables = [item.to_variable() for item in conversation_variables]
# Create a variable pool.
system_inputs = {
SystemVariable.QUERY: query,
SystemVariable.FILES: files,
SystemVariable.CONVERSATION_ID: conversation.id,
SystemVariable.USER_ID: user_id,
}
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
conversation_variables=conversation_variables,
)
# RUN WORKFLOW
workflow_engine_manager = WorkflowEngineManager()
workflow_engine_manager.run_workflow(
@ -146,7 +98,6 @@ class AdvancedChatAppRunner(AppRunner):
invoke_from=application_generate_entity.invoke_from,
callbacks=workflow_callbacks,
call_depth=application_generate_entity.call_depth,
variable_pool=variable_pool,
)
def single_iteration_run(
@ -155,7 +106,7 @@ class AdvancedChatAppRunner(AppRunner):
"""
Single iteration run
"""
app_record: App = db.session.query(App).filter(App.id == app_id).first()
app_record = db.session.query(App).filter(App.id == app_id).first()
if not app_record:
raise ValueError('App not found')

View File

@ -4,6 +4,7 @@ import time
from collections.abc import Generator
from typing import Any, Optional, Union, cast
import contexts
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@ -47,7 +48,8 @@ from core.file.file_obj import FileVar
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.node_entities import NodeType, SystemVariable
from core.workflow.entities.node_entities import NodeType
from core.workflow.enums import SystemVariable
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.answer.entities import TextGenerateRouteChunk, VarGenerateRouteChunk
from events.message_event import message_was_created
@ -71,6 +73,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
_application_generate_entity: AdvancedChatAppGenerateEntity
_workflow: Workflow
_user: Union[Account, EndUser]
# Deprecated
_workflow_system_variables: dict[SystemVariable, Any]
_iteration_nested_relations: dict[str, list[str]]
@ -81,7 +84,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
stream: bool
stream: bool,
) -> None:
"""
Initialize AdvancedChatAppGenerateTaskPipeline.
@ -103,11 +106,12 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._workflow = workflow
self._conversation = conversation
self._message = message
# Deprecated
self._workflow_system_variables = {
SystemVariable.QUERY: message.query,
SystemVariable.FILES: application_generate_entity.files,
SystemVariable.CONVERSATION_ID: conversation.id,
SystemVariable.USER_ID: user_id
SystemVariable.USER_ID: user_id,
}
self._task_state = AdvancedChatTaskState(
@ -613,7 +617,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if route_chunk_node_id == 'sys':
# system variable
value = self._workflow_system_variables.get(SystemVariable.value_of(value_selector[1]))
value = contexts.workflow_variable_pool.get().get(value_selector)
if value:
value = value.text
elif route_chunk_node_id in self._iteration_nested_relations:
# it's a iteration variable
if not self._iteration_state or route_chunk_node_id not in self._iteration_state.current_iterations:

View File

@ -258,7 +258,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
return introduction
def _get_conversation(self, conversation_id: str) -> Conversation:
def _get_conversation(self, conversation_id: str):
"""
Get conversation by conversation id
:param conversation_id: conversation id
@ -270,6 +270,9 @@ class MessageBasedAppGenerator(BaseAppGenerator):
.first()
)
if not conversation:
raise ConversationNotExistsError()
return conversation
def _get_message(self, message_id: str) -> Message:

View File

@ -11,8 +11,8 @@ from core.app.entities.app_invoke_entities import (
WorkflowAppGenerateEntity,
)
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.nodes.base_node import UserFrom
from core.workflow.workflow_engine_manager import WorkflowEngineManager
from extensions.ext_database import db

View File

@ -42,7 +42,8 @@ from core.app.entities.task_entities import (
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.node_entities import NodeType, SystemVariable
from core.workflow.entities.node_entities import NodeType
from core.workflow.enums import SystemVariable
from core.workflow.nodes.end.end_node import EndNode
from extensions.ext_database import db
from models.account import Account
@ -530,4 +531,3 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id
] for iteration_id in iteration_ids
}

View File

@ -2,7 +2,6 @@ from .segment_group import SegmentGroup
from .segments import (
ArrayAnySegment,
ArraySegment,
FileSegment,
FloatSegment,
IntegerSegment,
NoneSegment,
@ -13,11 +12,9 @@ from .segments import (
from .types import SegmentType
from .variables import (
ArrayAnyVariable,
ArrayFileVariable,
ArrayNumberVariable,
ArrayObjectVariable,
ArrayStringVariable,
FileVariable,
FloatVariable,
IntegerVariable,
NoneVariable,
@ -32,7 +29,6 @@ __all__ = [
'FloatVariable',
'ObjectVariable',
'SecretVariable',
'FileVariable',
'StringVariable',
'ArrayAnyVariable',
'Variable',
@ -45,11 +41,9 @@ __all__ = [
'FloatSegment',
'ObjectSegment',
'ArrayAnySegment',
'FileSegment',
'StringSegment',
'ArrayStringVariable',
'ArrayNumberVariable',
'ArrayObjectVariable',
'ArrayFileVariable',
'ArraySegment',
]

View File

@ -2,12 +2,10 @@ from collections.abc import Mapping
from typing import Any
from configs import dify_config
from core.file.file_obj import FileVar
from .exc import VariableError
from .segments import (
ArrayAnySegment,
FileSegment,
FloatSegment,
IntegerSegment,
NoneSegment,
@ -17,11 +15,9 @@ from .segments import (
)
from .types import SegmentType
from .variables import (
ArrayFileVariable,
ArrayNumberVariable,
ArrayObjectVariable,
ArrayStringVariable,
FileVariable,
FloatVariable,
IntegerVariable,
ObjectVariable,
@ -49,8 +45,6 @@ def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
result = FloatVariable.model_validate(mapping)
case SegmentType.NUMBER if not isinstance(value, float | int):
raise VariableError(f'invalid number value {value}')
case SegmentType.FILE:
result = FileVariable.model_validate(mapping)
case SegmentType.OBJECT if isinstance(value, dict):
result = ObjectVariable.model_validate(mapping)
case SegmentType.ARRAY_STRING if isinstance(value, list):
@ -59,10 +53,6 @@ def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
result = ArrayNumberVariable.model_validate(mapping)
case SegmentType.ARRAY_OBJECT if isinstance(value, list):
result = ArrayObjectVariable.model_validate(mapping)
case SegmentType.ARRAY_FILE if isinstance(value, list):
mapping = dict(mapping)
mapping['value'] = [{'value': v} for v in value]
result = ArrayFileVariable.model_validate(mapping)
case _:
raise VariableError(f'not supported value type {value_type}')
if result.size > dify_config.MAX_VARIABLE_SIZE:
@ -83,6 +73,4 @@ def build_segment(value: Any, /) -> Segment:
return ObjectSegment(value=value)
if isinstance(value, list):
return ArrayAnySegment(value=value)
if isinstance(value, FileVar):
return FileSegment(value=value)
raise ValueError(f'not supported value {value}')

View File

@ -5,8 +5,6 @@ from typing import Any
from pydantic import BaseModel, ConfigDict, field_validator
from core.file.file_obj import FileVar
from .types import SegmentType
@ -78,14 +76,7 @@ class IntegerSegment(Segment):
value: int
class FileSegment(Segment):
value_type: SegmentType = SegmentType.FILE
# TODO: embed FileVar in this model.
value: FileVar
@property
def markdown(self) -> str:
return self.value.to_markdown()
class ObjectSegment(Segment):
@ -130,7 +121,3 @@ class ArrayObjectSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_OBJECT
value: Sequence[Mapping[str, Any]]
class ArrayFileSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_FILE
value: Sequence[FileSegment]

View File

@ -10,8 +10,6 @@ class SegmentType(str, Enum):
ARRAY_STRING = 'array[string]'
ARRAY_NUMBER = 'array[number]'
ARRAY_OBJECT = 'array[object]'
ARRAY_FILE = 'array[file]'
OBJECT = 'object'
FILE = 'file'
GROUP = 'group'

View File

@ -4,11 +4,9 @@ from core.helper import encrypter
from .segments import (
ArrayAnySegment,
ArrayFileSegment,
ArrayNumberSegment,
ArrayObjectSegment,
ArrayStringSegment,
FileSegment,
FloatSegment,
IntegerSegment,
NoneSegment,
@ -44,10 +42,6 @@ class IntegerVariable(IntegerSegment, Variable):
pass
class FileVariable(FileSegment, Variable):
pass
class ObjectVariable(ObjectSegment, Variable):
pass
@ -68,9 +62,6 @@ class ArrayObjectVariable(ArrayObjectSegment, Variable):
pass
class ArrayFileVariable(ArrayFileSegment, Variable):
pass
class SecretVariable(StringVariable):
value_type: SegmentType = SegmentType.SECRET

View File

@ -2,7 +2,7 @@ from typing import Any, Union
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
from core.app.entities.task_entities import AdvancedChatTaskState, WorkflowTaskState
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.enums import SystemVariable
from models.account import Account
from models.model import EndUser
from models.workflow import Workflow

View File

@ -4,13 +4,14 @@ from typing import Any, Optional
from pydantic import BaseModel
from models.workflow import WorkflowNodeExecutionStatus
from models import WorkflowNodeExecutionStatus
class NodeType(Enum):
"""
Node Types.
"""
START = 'start'
END = 'end'
ANSWER = 'answer'
@ -44,33 +45,11 @@ class NodeType(Enum):
raise ValueError(f'invalid node type value {value}')
class SystemVariable(Enum):
"""
System Variables.
"""
QUERY = 'query'
FILES = 'files'
CONVERSATION_ID = 'conversation_id'
USER_ID = 'user_id'
@classmethod
def value_of(cls, value: str) -> 'SystemVariable':
"""
Get value of given system variable.
:param value: system variable value
:return: system variable
"""
for system_variable in cls:
if system_variable.value == value:
return system_variable
raise ValueError(f'invalid system variable value {value}')
class NodeRunMetadataKey(Enum):
"""
Node Run Metadata Key.
"""
TOTAL_TOKENS = 'total_tokens'
TOTAL_PRICE = 'total_price'
CURRENCY = 'currency'
@ -83,6 +62,7 @@ class NodeRunResult(BaseModel):
"""
Node Run Result.
"""
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING
inputs: Optional[Mapping[str, Any]] = None # node inputs

View File

@ -6,7 +6,7 @@ from typing_extensions import deprecated
from core.app.segments import Segment, Variable, factory
from core.file.file_obj import FileVar
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.enums import SystemVariable
VariableValue = Union[str, int, float, dict, list, FileVar]

View File

@ -0,0 +1,25 @@
from enum import Enum
class SystemVariable(str, Enum):
"""
System Variables.
"""
QUERY = 'query'
FILES = 'files'
CONVERSATION_ID = 'conversation_id'
USER_ID = 'user_id'
DIALOGUE_COUNT = 'dialogue_count'
@classmethod
def value_of(cls, value: str):
"""
Get value of given system variable.
:param value: system variable value
:return: system variable
"""
for system_variable in cls:
if system_variable.value == value:
return system_variable
raise ValueError(f'invalid system variable value {value}')

View File

@ -23,8 +23,9 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.llm.entities import (
LLMNodeChatModelMessage,

View File

@ -2,19 +2,20 @@ from collections.abc import Mapping, Sequence
from os import path
from typing import Any, cast
from core.app.segments import parser
from core.app.segments import ArrayAnyVariable, parser
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file.file_obj import FileTransferMethod, FileType, FileVar
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.tool_engine import ToolEngine
from core.tools.tool_manager import ToolManager
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from models.workflow import WorkflowNodeExecutionStatus
from models import WorkflowNodeExecutionStatus
class ToolNode(BaseNode):
@ -140,9 +141,9 @@ class ToolNode(BaseNode):
return result
def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]:
# FIXME: ensure this is a ArrayVariable contains FileVariable.
variable = variable_pool.get(['sys', SystemVariable.FILES.value])
return [file_var.value for file_var in variable.value] if variable else []
assert isinstance(variable, ArrayAnyVariable)
return list(variable.value) if variable else []
def _convert_tool_messages(self, messages: list[ToolInvokeMessage]):
"""

View File

@ -3,6 +3,7 @@ import time
from collections.abc import Mapping, Sequence
from typing import Any, Optional, cast
import contexts
from configs import dify_config
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
from core.app.entities.app_invoke_entities import InvokeFrom
@ -97,7 +98,7 @@ class WorkflowEngineManager:
invoke_from: InvokeFrom,
callbacks: Sequence[WorkflowCallback],
call_depth: int = 0,
variable_pool: VariablePool,
variable_pool: VariablePool | None = None,
) -> None:
"""
:param workflow: Workflow instance
@ -128,6 +129,8 @@ class WorkflowEngineManager:
raise ValueError('Max workflow call depth {} reached.'.format(workflow_call_max_depth))
# init workflow run state
if not variable_pool:
variable_pool = contexts.workflow_variable_pool.get()
workflow_run_state = WorkflowRunState(
workflow=workflow,
start_at=time.perf_counter(),

View File

@ -0,0 +1,33 @@
"""add conversations.dialogue_count
Revision ID: 8782057ff0dc
Revises: 63a83fcf12ba
Create Date: 2024-08-14 13:54:25.161324
"""
import sqlalchemy as sa
from alembic import op
import models as models
# revision identifiers, used by Alembic.
revision = '8782057ff0dc'
down_revision = '63a83fcf12ba'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('conversations', schema=None) as batch_op:
batch_op.add_column(sa.Column('dialogue_count', sa.Integer(), server_default='0', nullable=False))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('conversations', schema=None) as batch_op:
batch_op.drop_column('dialogue_count')
# ### end Alembic commands ###

View File

@ -1,10 +1,10 @@
from enum import Enum
from .model import AppMode
from .model import App, AppMode, Message
from .types import StringUUID
from .workflow import ConversationVariable, WorkflowNodeExecutionStatus
from .workflow import ConversationVariable, Workflow, WorkflowNodeExecutionStatus
__all__ = ['ConversationVariable', 'StringUUID', 'AppMode', 'WorkflowNodeExecutionStatus']
__all__ = ['ConversationVariable', 'StringUUID', 'AppMode', 'WorkflowNodeExecutionStatus', 'Workflow', 'App', 'Message']
class CreatedByRole(Enum):

View File

@ -7,6 +7,7 @@ from typing import Optional
from flask import request
from flask_login import UserMixin
from sqlalchemy import Float, func, text
from sqlalchemy.orm import Mapped, mapped_column
from configs import dify_config
from core.file.tool_file_parser import ToolFileParser
@ -512,12 +513,12 @@ class Conversation(db.Model):
from_account_id = db.Column(StringUUID)
read_at = db.Column(db.DateTime)
read_account_id = db.Column(StringUUID)
dialogue_count: Mapped[int] = mapped_column(default=0)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
messages = db.relationship("Message", backref="conversation", lazy='select', passive_deletes="all")
message_annotations = db.relationship("MessageAnnotation", backref="conversation", lazy='select',
passive_deletes="all")
message_annotations = db.relationship("MessageAnnotation", backref="conversation", lazy='select', passive_deletes="all")
is_deleted = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))

View File

@ -10,8 +10,8 @@ from core.entities.provider_entities import CustomConfiguration, CustomProviderC
from core.model_manager import ModelInstance
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers import ModelProviderFactory
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.nodes.base_node import UserFrom
from core.workflow.nodes.llm.llm_node import LLMNode
from extensions.ext_database import db

View File

@ -12,8 +12,8 @@ from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.nodes.base_node import UserFrom
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from extensions.ext_database import db

View File

@ -3,12 +3,9 @@ from uuid import uuid4
import pytest
from core.app.segments import (
ArrayFileVariable,
ArrayNumberVariable,
ArrayObjectVariable,
ArrayStringVariable,
FileSegment,
FileVariable,
FloatVariable,
IntegerVariable,
ObjectSegment,
@ -149,83 +146,6 @@ def test_array_object_variable():
assert isinstance(variable.value[1]['key2'], int)
def test_file_variable():
mapping = {
'id': str(uuid4()),
'value_type': 'file',
'name': 'test_file',
'description': 'Description of the variable.',
'value': {
'id': str(uuid4()),
'tenant_id': 'tenant_id',
'type': 'image',
'transfer_method': 'local_file',
'url': 'url',
'related_id': 'related_id',
'extra_config': {
'image_config': {
'width': 100,
'height': 100,
},
},
'filename': 'filename',
'extension': 'extension',
'mime_type': 'mime_type',
},
}
variable = factory.build_variable_from_mapping(mapping)
assert isinstance(variable, FileVariable)
def test_array_file_variable():
mapping = {
'id': str(uuid4()),
'value_type': 'array[file]',
'name': 'test_array_file',
'description': 'Description of the variable.',
'value': [
{
'id': str(uuid4()),
'tenant_id': 'tenant_id',
'type': 'image',
'transfer_method': 'local_file',
'url': 'url',
'related_id': 'related_id',
'extra_config': {
'image_config': {
'width': 100,
'height': 100,
},
},
'filename': 'filename',
'extension': 'extension',
'mime_type': 'mime_type',
},
{
'id': str(uuid4()),
'tenant_id': 'tenant_id',
'type': 'image',
'transfer_method': 'local_file',
'url': 'url',
'related_id': 'related_id',
'extra_config': {
'image_config': {
'width': 100,
'height': 100,
},
},
'filename': 'filename',
'extension': 'extension',
'mime_type': 'mime_type',
},
],
}
variable = factory.build_variable_from_mapping(mapping)
assert isinstance(variable, ArrayFileVariable)
assert isinstance(variable.value[0], FileSegment)
assert isinstance(variable.value[1], FileSegment)
def test_variable_cannot_large_than_5_kb():
with pytest.raises(VariableError):
factory.build_variable_from_mapping(

View File

@ -1,7 +1,7 @@
from core.app.segments import SecretVariable, StringSegment, parser
from core.helper import encrypter
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
def test_segment_group_to_text():

View File

@ -1,8 +1,8 @@
from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.base_node import UserFrom
from extensions.ext_database import db

View File

@ -1,8 +1,8 @@
from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.nodes.base_node import UserFrom
from core.workflow.nodes.if_else.if_else_node import IfElseNode
from extensions.ext_database import db

View File

@ -3,8 +3,8 @@ from uuid import uuid4
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.segments import ArrayStringVariable, StringVariable
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.nodes.base_node import UserFrom
from core.workflow.nodes.variable_assigner import VariableAssignerNode, WriteMode