feat(api/workflow): Add Conversation.dialogue_count
(#7275)
This commit is contained in:
parent
8f5d8397f9
commit
32dc963556
@ -1,3 +1,7 @@
|
||||
from contextvars import ContextVar
|
||||
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
tenant_id: ContextVar[str] = ContextVar('tenant_id')
|
||||
|
||||
workflow_variable_pool: ContextVar[VariablePool] = ContextVar('workflow_variable_pool')
|
||||
|
@ -8,6 +8,8 @@ from typing import Union
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import contexts
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
@ -18,15 +20,20 @@ from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGe
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AdvancedChatAppGenerateEntity,
|
||||
InvokeFrom,
|
||||
)
|
||||
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
|
||||
from core.file.message_file_parser import MessageFileParser
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariable
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import App, Conversation, EndUser, Message
|
||||
from models.workflow import Workflow
|
||||
from models.workflow import ConversationVariable, Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -209,7 +216,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
# update conversation features
|
||||
conversation.override_model_configs = workflow.features
|
||||
db.session.commit()
|
||||
db.session.refresh(conversation)
|
||||
# db.session.refresh(conversation)
|
||||
|
||||
# init queue manager
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
@ -221,15 +228,69 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
message_id=message.id
|
||||
)
|
||||
|
||||
# Init conversation variables
|
||||
stmt = select(ConversationVariable).where(
|
||||
ConversationVariable.app_id == conversation.app_id, ConversationVariable.conversation_id == conversation.id
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
conversation_variables = session.scalars(stmt).all()
|
||||
if not conversation_variables:
|
||||
# Create conversation variables if they don't exist.
|
||||
conversation_variables = [
|
||||
ConversationVariable.from_variable(
|
||||
app_id=conversation.app_id, conversation_id=conversation.id, variable=variable
|
||||
)
|
||||
for variable in workflow.conversation_variables
|
||||
]
|
||||
session.add_all(conversation_variables)
|
||||
# Convert database entities to variables.
|
||||
conversation_variables = [item.to_variable() for item in conversation_variables]
|
||||
|
||||
session.commit()
|
||||
|
||||
# Increment dialogue count.
|
||||
conversation.dialogue_count += 1
|
||||
|
||||
conversation_id = conversation.id
|
||||
conversation_dialogue_count = conversation.dialogue_count
|
||||
db.session.commit()
|
||||
db.session.refresh(conversation)
|
||||
|
||||
inputs = application_generate_entity.inputs
|
||||
query = application_generate_entity.query
|
||||
files = application_generate_entity.files
|
||||
|
||||
user_id = None
|
||||
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
|
||||
if end_user:
|
||||
user_id = end_user.session_id
|
||||
else:
|
||||
user_id = application_generate_entity.user_id
|
||||
|
||||
# Create a variable pool.
|
||||
system_inputs = {
|
||||
SystemVariable.QUERY: query,
|
||||
SystemVariable.FILES: files,
|
||||
SystemVariable.CONVERSATION_ID: conversation_id,
|
||||
SystemVariable.USER_ID: user_id,
|
||||
SystemVariable.DIALOGUE_COUNT: conversation_dialogue_count,
|
||||
}
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
user_inputs=inputs,
|
||||
environment_variables=workflow.environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
)
|
||||
contexts.workflow_variable_pool.set(variable_pool)
|
||||
|
||||
# new thread
|
||||
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'application_generate_entity': application_generate_entity,
|
||||
'queue_manager': queue_manager,
|
||||
'conversation_id': conversation.id,
|
||||
'message_id': message.id,
|
||||
'user': user,
|
||||
'context': contextvars.copy_context()
|
||||
'context': contextvars.copy_context(),
|
||||
})
|
||||
|
||||
worker_thread.start()
|
||||
@ -242,7 +303,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
user=user,
|
||||
stream=stream
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
return AdvancedChatAppGenerateResponseConverter.convert(
|
||||
@ -253,9 +314,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
def _generate_worker(self, flask_app: Flask,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation_id: str,
|
||||
message_id: str,
|
||||
user: Account,
|
||||
context: contextvars.Context) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
@ -282,8 +341,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
user_id=application_generate_entity.user_id
|
||||
)
|
||||
else:
|
||||
# get conversation and message
|
||||
conversation = self._get_conversation(conversation_id)
|
||||
# get message
|
||||
message = self._get_message(message_id)
|
||||
|
||||
# chatbot app
|
||||
@ -291,7 +349,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
runner.run(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation=conversation,
|
||||
message=message
|
||||
)
|
||||
except GenerateTaskStoppedException:
|
||||
@ -314,14 +371,17 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
finally:
|
||||
db.session.close()
|
||||
|
||||
def _handle_advanced_chat_response(self, application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
def _handle_advanced_chat_response(
|
||||
self,
|
||||
*,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool = False) \
|
||||
-> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
||||
stream: bool = False,
|
||||
) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
||||
"""
|
||||
Handle response.
|
||||
:param application_generate_entity: application generate entity
|
||||
@ -341,7 +401,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
user=user,
|
||||
stream=stream
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
try:
|
||||
|
@ -4,9 +4,6 @@ import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
||||
from core.app.apps.advanced_chat.workflow_event_trigger_callback import WorkflowEventTriggerCallback
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
@ -19,13 +16,10 @@ from core.app.entities.app_invoke_entities import (
|
||||
from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent
|
||||
from core.moderation.base import ModerationException
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.node_entities import SystemVariable
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, Conversation, EndUser, Message
|
||||
from models.workflow import ConversationVariable, Workflow
|
||||
from models import App, Message, Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -39,7 +33,6 @@ class AdvancedChatAppRunner(AppRunner):
|
||||
self,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
) -> None:
|
||||
"""
|
||||
@ -63,15 +56,6 @@ class AdvancedChatAppRunner(AppRunner):
|
||||
|
||||
inputs = application_generate_entity.inputs
|
||||
query = application_generate_entity.query
|
||||
files = application_generate_entity.files
|
||||
|
||||
user_id = None
|
||||
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
|
||||
if end_user:
|
||||
user_id = end_user.session_id
|
||||
else:
|
||||
user_id = application_generate_entity.user_id
|
||||
|
||||
# moderation
|
||||
if self.handle_input_moderation(
|
||||
@ -103,38 +87,6 @@ class AdvancedChatAppRunner(AppRunner):
|
||||
if bool(os.environ.get('DEBUG', 'False').lower() == 'true'):
|
||||
workflow_callbacks.append(WorkflowLoggingCallback())
|
||||
|
||||
# Init conversation variables
|
||||
stmt = select(ConversationVariable).where(
|
||||
ConversationVariable.app_id == conversation.app_id, ConversationVariable.conversation_id == conversation.id
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
conversation_variables = session.scalars(stmt).all()
|
||||
if not conversation_variables:
|
||||
conversation_variables = [
|
||||
ConversationVariable.from_variable(
|
||||
app_id=conversation.app_id, conversation_id=conversation.id, variable=variable
|
||||
)
|
||||
for variable in workflow.conversation_variables
|
||||
]
|
||||
session.add_all(conversation_variables)
|
||||
session.commit()
|
||||
# Convert database entities to variables
|
||||
conversation_variables = [item.to_variable() for item in conversation_variables]
|
||||
|
||||
# Create a variable pool.
|
||||
system_inputs = {
|
||||
SystemVariable.QUERY: query,
|
||||
SystemVariable.FILES: files,
|
||||
SystemVariable.CONVERSATION_ID: conversation.id,
|
||||
SystemVariable.USER_ID: user_id,
|
||||
}
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
user_inputs=inputs,
|
||||
environment_variables=workflow.environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
)
|
||||
|
||||
# RUN WORKFLOW
|
||||
workflow_engine_manager = WorkflowEngineManager()
|
||||
workflow_engine_manager.run_workflow(
|
||||
@ -146,7 +98,6 @@ class AdvancedChatAppRunner(AppRunner):
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
callbacks=workflow_callbacks,
|
||||
call_depth=application_generate_entity.call_depth,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
def single_iteration_run(
|
||||
@ -155,7 +106,7 @@ class AdvancedChatAppRunner(AppRunner):
|
||||
"""
|
||||
Single iteration run
|
||||
"""
|
||||
app_record: App = db.session.query(App).filter(App.id == app_id).first()
|
||||
app_record = db.session.query(App).filter(App.id == app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError('App not found')
|
||||
|
||||
|
@ -4,6 +4,7 @@ import time
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
import contexts
|
||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
@ -47,7 +48,8 @@ from core.file.file_obj import FileVar
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.node_entities import NodeType, SystemVariable
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.enums import SystemVariable
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
from core.workflow.nodes.answer.entities import TextGenerateRouteChunk, VarGenerateRouteChunk
|
||||
from events.message_event import message_was_created
|
||||
@ -71,6 +73,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
_application_generate_entity: AdvancedChatAppGenerateEntity
|
||||
_workflow: Workflow
|
||||
_user: Union[Account, EndUser]
|
||||
# Deprecated
|
||||
_workflow_system_variables: dict[SystemVariable, Any]
|
||||
_iteration_nested_relations: dict[str, list[str]]
|
||||
|
||||
@ -81,7 +84,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool
|
||||
stream: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize AdvancedChatAppGenerateTaskPipeline.
|
||||
@ -103,11 +106,12 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
self._workflow = workflow
|
||||
self._conversation = conversation
|
||||
self._message = message
|
||||
# Deprecated
|
||||
self._workflow_system_variables = {
|
||||
SystemVariable.QUERY: message.query,
|
||||
SystemVariable.FILES: application_generate_entity.files,
|
||||
SystemVariable.CONVERSATION_ID: conversation.id,
|
||||
SystemVariable.USER_ID: user_id
|
||||
SystemVariable.USER_ID: user_id,
|
||||
}
|
||||
|
||||
self._task_state = AdvancedChatTaskState(
|
||||
@ -613,7 +617,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
|
||||
if route_chunk_node_id == 'sys':
|
||||
# system variable
|
||||
value = self._workflow_system_variables.get(SystemVariable.value_of(value_selector[1]))
|
||||
value = contexts.workflow_variable_pool.get().get(value_selector)
|
||||
if value:
|
||||
value = value.text
|
||||
elif route_chunk_node_id in self._iteration_nested_relations:
|
||||
# it's a iteration variable
|
||||
if not self._iteration_state or route_chunk_node_id not in self._iteration_state.current_iterations:
|
||||
|
@ -258,7 +258,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
|
||||
return introduction
|
||||
|
||||
def _get_conversation(self, conversation_id: str) -> Conversation:
|
||||
def _get_conversation(self, conversation_id: str):
|
||||
"""
|
||||
Get conversation by conversation id
|
||||
:param conversation_id: conversation id
|
||||
@ -270,6 +270,9 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
.first()
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
raise ConversationNotExistsError()
|
||||
|
||||
return conversation
|
||||
|
||||
def _get_message(self, message_id: str) -> Message:
|
||||
|
@ -11,8 +11,8 @@ from core.app.entities.app_invoke_entities import (
|
||||
WorkflowAppGenerateEntity,
|
||||
)
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.node_entities import SystemVariable
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariable
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||
from extensions.ext_database import db
|
||||
|
@ -42,7 +42,8 @@ from core.app.entities.task_entities import (
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.node_entities import NodeType, SystemVariable
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.enums import SystemVariable
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
@ -530,4 +531,3 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id
|
||||
] for iteration_id in iteration_ids
|
||||
}
|
||||
|
@ -2,7 +2,6 @@ from .segment_group import SegmentGroup
|
||||
from .segments import (
|
||||
ArrayAnySegment,
|
||||
ArraySegment,
|
||||
FileSegment,
|
||||
FloatSegment,
|
||||
IntegerSegment,
|
||||
NoneSegment,
|
||||
@ -13,11 +12,9 @@ from .segments import (
|
||||
from .types import SegmentType
|
||||
from .variables import (
|
||||
ArrayAnyVariable,
|
||||
ArrayFileVariable,
|
||||
ArrayNumberVariable,
|
||||
ArrayObjectVariable,
|
||||
ArrayStringVariable,
|
||||
FileVariable,
|
||||
FloatVariable,
|
||||
IntegerVariable,
|
||||
NoneVariable,
|
||||
@ -32,7 +29,6 @@ __all__ = [
|
||||
'FloatVariable',
|
||||
'ObjectVariable',
|
||||
'SecretVariable',
|
||||
'FileVariable',
|
||||
'StringVariable',
|
||||
'ArrayAnyVariable',
|
||||
'Variable',
|
||||
@ -45,11 +41,9 @@ __all__ = [
|
||||
'FloatSegment',
|
||||
'ObjectSegment',
|
||||
'ArrayAnySegment',
|
||||
'FileSegment',
|
||||
'StringSegment',
|
||||
'ArrayStringVariable',
|
||||
'ArrayNumberVariable',
|
||||
'ArrayObjectVariable',
|
||||
'ArrayFileVariable',
|
||||
'ArraySegment',
|
||||
]
|
||||
|
@ -2,12 +2,10 @@ from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from configs import dify_config
|
||||
from core.file.file_obj import FileVar
|
||||
|
||||
from .exc import VariableError
|
||||
from .segments import (
|
||||
ArrayAnySegment,
|
||||
FileSegment,
|
||||
FloatSegment,
|
||||
IntegerSegment,
|
||||
NoneSegment,
|
||||
@ -17,11 +15,9 @@ from .segments import (
|
||||
)
|
||||
from .types import SegmentType
|
||||
from .variables import (
|
||||
ArrayFileVariable,
|
||||
ArrayNumberVariable,
|
||||
ArrayObjectVariable,
|
||||
ArrayStringVariable,
|
||||
FileVariable,
|
||||
FloatVariable,
|
||||
IntegerVariable,
|
||||
ObjectVariable,
|
||||
@ -49,8 +45,6 @@ def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
|
||||
result = FloatVariable.model_validate(mapping)
|
||||
case SegmentType.NUMBER if not isinstance(value, float | int):
|
||||
raise VariableError(f'invalid number value {value}')
|
||||
case SegmentType.FILE:
|
||||
result = FileVariable.model_validate(mapping)
|
||||
case SegmentType.OBJECT if isinstance(value, dict):
|
||||
result = ObjectVariable.model_validate(mapping)
|
||||
case SegmentType.ARRAY_STRING if isinstance(value, list):
|
||||
@ -59,10 +53,6 @@ def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
|
||||
result = ArrayNumberVariable.model_validate(mapping)
|
||||
case SegmentType.ARRAY_OBJECT if isinstance(value, list):
|
||||
result = ArrayObjectVariable.model_validate(mapping)
|
||||
case SegmentType.ARRAY_FILE if isinstance(value, list):
|
||||
mapping = dict(mapping)
|
||||
mapping['value'] = [{'value': v} for v in value]
|
||||
result = ArrayFileVariable.model_validate(mapping)
|
||||
case _:
|
||||
raise VariableError(f'not supported value type {value_type}')
|
||||
if result.size > dify_config.MAX_VARIABLE_SIZE:
|
||||
@ -83,6 +73,4 @@ def build_segment(value: Any, /) -> Segment:
|
||||
return ObjectSegment(value=value)
|
||||
if isinstance(value, list):
|
||||
return ArrayAnySegment(value=value)
|
||||
if isinstance(value, FileVar):
|
||||
return FileSegment(value=value)
|
||||
raise ValueError(f'not supported value {value}')
|
||||
|
@ -5,8 +5,6 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, field_validator
|
||||
|
||||
from core.file.file_obj import FileVar
|
||||
|
||||
from .types import SegmentType
|
||||
|
||||
|
||||
@ -78,14 +76,7 @@ class IntegerSegment(Segment):
|
||||
value: int
|
||||
|
||||
|
||||
class FileSegment(Segment):
|
||||
value_type: SegmentType = SegmentType.FILE
|
||||
# TODO: embed FileVar in this model.
|
||||
value: FileVar
|
||||
|
||||
@property
|
||||
def markdown(self) -> str:
|
||||
return self.value.to_markdown()
|
||||
|
||||
|
||||
class ObjectSegment(Segment):
|
||||
@ -130,7 +121,3 @@ class ArrayObjectSegment(ArraySegment):
|
||||
value_type: SegmentType = SegmentType.ARRAY_OBJECT
|
||||
value: Sequence[Mapping[str, Any]]
|
||||
|
||||
|
||||
class ArrayFileSegment(ArraySegment):
|
||||
value_type: SegmentType = SegmentType.ARRAY_FILE
|
||||
value: Sequence[FileSegment]
|
||||
|
@ -10,8 +10,6 @@ class SegmentType(str, Enum):
|
||||
ARRAY_STRING = 'array[string]'
|
||||
ARRAY_NUMBER = 'array[number]'
|
||||
ARRAY_OBJECT = 'array[object]'
|
||||
ARRAY_FILE = 'array[file]'
|
||||
OBJECT = 'object'
|
||||
FILE = 'file'
|
||||
|
||||
GROUP = 'group'
|
||||
|
@ -4,11 +4,9 @@ from core.helper import encrypter
|
||||
|
||||
from .segments import (
|
||||
ArrayAnySegment,
|
||||
ArrayFileSegment,
|
||||
ArrayNumberSegment,
|
||||
ArrayObjectSegment,
|
||||
ArrayStringSegment,
|
||||
FileSegment,
|
||||
FloatSegment,
|
||||
IntegerSegment,
|
||||
NoneSegment,
|
||||
@ -44,10 +42,6 @@ class IntegerVariable(IntegerSegment, Variable):
|
||||
pass
|
||||
|
||||
|
||||
class FileVariable(FileSegment, Variable):
|
||||
pass
|
||||
|
||||
|
||||
class ObjectVariable(ObjectSegment, Variable):
|
||||
pass
|
||||
|
||||
@ -68,9 +62,6 @@ class ArrayObjectVariable(ArrayObjectSegment, Variable):
|
||||
pass
|
||||
|
||||
|
||||
class ArrayFileVariable(ArrayFileSegment, Variable):
|
||||
pass
|
||||
|
||||
|
||||
class SecretVariable(StringVariable):
|
||||
value_type: SegmentType = SegmentType.SECRET
|
||||
|
@ -2,7 +2,7 @@ from typing import Any, Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
|
||||
from core.app.entities.task_entities import AdvancedChatTaskState, WorkflowTaskState
|
||||
from core.workflow.entities.node_entities import SystemVariable
|
||||
from core.workflow.enums import SystemVariable
|
||||
from models.account import Account
|
||||
from models.model import EndUser
|
||||
from models.workflow import Workflow
|
||||
|
@ -4,13 +4,14 @@ from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
from models import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class NodeType(Enum):
|
||||
"""
|
||||
Node Types.
|
||||
"""
|
||||
|
||||
START = 'start'
|
||||
END = 'end'
|
||||
ANSWER = 'answer'
|
||||
@ -44,33 +45,11 @@ class NodeType(Enum):
|
||||
raise ValueError(f'invalid node type value {value}')
|
||||
|
||||
|
||||
class SystemVariable(Enum):
|
||||
"""
|
||||
System Variables.
|
||||
"""
|
||||
QUERY = 'query'
|
||||
FILES = 'files'
|
||||
CONVERSATION_ID = 'conversation_id'
|
||||
USER_ID = 'user_id'
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> 'SystemVariable':
|
||||
"""
|
||||
Get value of given system variable.
|
||||
|
||||
:param value: system variable value
|
||||
:return: system variable
|
||||
"""
|
||||
for system_variable in cls:
|
||||
if system_variable.value == value:
|
||||
return system_variable
|
||||
raise ValueError(f'invalid system variable value {value}')
|
||||
|
||||
|
||||
class NodeRunMetadataKey(Enum):
|
||||
"""
|
||||
Node Run Metadata Key.
|
||||
"""
|
||||
|
||||
TOTAL_TOKENS = 'total_tokens'
|
||||
TOTAL_PRICE = 'total_price'
|
||||
CURRENCY = 'currency'
|
||||
@ -83,6 +62,7 @@ class NodeRunResult(BaseModel):
|
||||
"""
|
||||
Node Run Result.
|
||||
"""
|
||||
|
||||
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING
|
||||
|
||||
inputs: Optional[Mapping[str, Any]] = None # node inputs
|
||||
|
@ -6,7 +6,7 @@ from typing_extensions import deprecated
|
||||
|
||||
from core.app.segments import Segment, Variable, factory
|
||||
from core.file.file_obj import FileVar
|
||||
from core.workflow.entities.node_entities import SystemVariable
|
||||
from core.workflow.enums import SystemVariable
|
||||
|
||||
VariableValue = Union[str, int, float, dict, list, FileVar]
|
||||
|
||||
|
25
api/core/workflow/enums.py
Normal file
25
api/core/workflow/enums.py
Normal file
@ -0,0 +1,25 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class SystemVariable(str, Enum):
|
||||
"""
|
||||
System Variables.
|
||||
"""
|
||||
QUERY = 'query'
|
||||
FILES = 'files'
|
||||
CONVERSATION_ID = 'conversation_id'
|
||||
USER_ID = 'user_id'
|
||||
DIALOGUE_COUNT = 'dialogue_count'
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str):
|
||||
"""
|
||||
Get value of given system variable.
|
||||
|
||||
:param value: system variable value
|
||||
:return: system variable
|
||||
"""
|
||||
for system_variable in cls:
|
||||
if system_variable.value == value:
|
||||
return system_variable
|
||||
raise ValueError(f'invalid system variable value {value}')
|
@ -23,8 +23,9 @@ from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariable
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.llm.entities import (
|
||||
LLMNodeChatModelMessage,
|
||||
|
@ -2,19 +2,20 @@ from collections.abc import Mapping, Sequence
|
||||
from os import path
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.segments import parser
|
||||
from core.app.segments import ArrayAnyVariable, parser
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
from core.file.file_obj import FileTransferMethod, FileType, FileVar
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariable
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
from models import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class ToolNode(BaseNode):
|
||||
@ -140,9 +141,9 @@ class ToolNode(BaseNode):
|
||||
return result
|
||||
|
||||
def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]:
|
||||
# FIXME: ensure this is a ArrayVariable contains FileVariable.
|
||||
variable = variable_pool.get(['sys', SystemVariable.FILES.value])
|
||||
return [file_var.value for file_var in variable.value] if variable else []
|
||||
assert isinstance(variable, ArrayAnyVariable)
|
||||
return list(variable.value) if variable else []
|
||||
|
||||
def _convert_tool_messages(self, messages: list[ToolInvokeMessage]):
|
||||
"""
|
||||
|
@ -3,6 +3,7 @@ import time
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import contexts
|
||||
from configs import dify_config
|
||||
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@ -97,7 +98,7 @@ class WorkflowEngineManager:
|
||||
invoke_from: InvokeFrom,
|
||||
callbacks: Sequence[WorkflowCallback],
|
||||
call_depth: int = 0,
|
||||
variable_pool: VariablePool,
|
||||
variable_pool: VariablePool | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
:param workflow: Workflow instance
|
||||
@ -128,6 +129,8 @@ class WorkflowEngineManager:
|
||||
raise ValueError('Max workflow call depth {} reached.'.format(workflow_call_max_depth))
|
||||
|
||||
# init workflow run state
|
||||
if not variable_pool:
|
||||
variable_pool = contexts.workflow_variable_pool.get()
|
||||
workflow_run_state = WorkflowRunState(
|
||||
workflow=workflow,
|
||||
start_at=time.perf_counter(),
|
||||
|
@ -0,0 +1,33 @@
|
||||
"""add conversations.dialogue_count
|
||||
|
||||
Revision ID: 8782057ff0dc
|
||||
Revises: 63a83fcf12ba
|
||||
Create Date: 2024-08-14 13:54:25.161324
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
import models as models
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '8782057ff0dc'
|
||||
down_revision = '63a83fcf12ba'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('conversations', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('dialogue_count', sa.Integer(), server_default='0', nullable=False))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('conversations', schema=None) as batch_op:
|
||||
batch_op.drop_column('dialogue_count')
|
||||
|
||||
# ### end Alembic commands ###
|
@ -1,10 +1,10 @@
|
||||
from enum import Enum
|
||||
|
||||
from .model import AppMode
|
||||
from .model import App, AppMode, Message
|
||||
from .types import StringUUID
|
||||
from .workflow import ConversationVariable, WorkflowNodeExecutionStatus
|
||||
from .workflow import ConversationVariable, Workflow, WorkflowNodeExecutionStatus
|
||||
|
||||
__all__ = ['ConversationVariable', 'StringUUID', 'AppMode', 'WorkflowNodeExecutionStatus']
|
||||
__all__ = ['ConversationVariable', 'StringUUID', 'AppMode', 'WorkflowNodeExecutionStatus', 'Workflow', 'App', 'Message']
|
||||
|
||||
|
||||
class CreatedByRole(Enum):
|
||||
|
@ -7,6 +7,7 @@ from typing import Optional
|
||||
from flask import request
|
||||
from flask_login import UserMixin
|
||||
from sqlalchemy import Float, func, text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from configs import dify_config
|
||||
from core.file.tool_file_parser import ToolFileParser
|
||||
@ -512,12 +513,12 @@ class Conversation(db.Model):
|
||||
from_account_id = db.Column(StringUUID)
|
||||
read_at = db.Column(db.DateTime)
|
||||
read_account_id = db.Column(StringUUID)
|
||||
dialogue_count: Mapped[int] = mapped_column(default=0)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
|
||||
messages = db.relationship("Message", backref="conversation", lazy='select', passive_deletes="all")
|
||||
message_annotations = db.relationship("MessageAnnotation", backref="conversation", lazy='select',
|
||||
passive_deletes="all")
|
||||
message_annotations = db.relationship("MessageAnnotation", backref="conversation", lazy='select', passive_deletes="all")
|
||||
|
||||
is_deleted = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
|
||||
|
||||
|
@ -10,8 +10,8 @@ from core.entities.provider_entities import CustomConfiguration, CustomProviderC
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers import ModelProviderFactory
|
||||
from core.workflow.entities.node_entities import SystemVariable
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariable
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from core.workflow.nodes.llm.llm_node import LLMNode
|
||||
from extensions.ext_database import db
|
||||
|
@ -12,8 +12,8 @@ from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.workflow.entities.node_entities import SystemVariable
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariable
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||
from extensions.ext_database import db
|
||||
|
@ -3,12 +3,9 @@ from uuid import uuid4
|
||||
import pytest
|
||||
|
||||
from core.app.segments import (
|
||||
ArrayFileVariable,
|
||||
ArrayNumberVariable,
|
||||
ArrayObjectVariable,
|
||||
ArrayStringVariable,
|
||||
FileSegment,
|
||||
FileVariable,
|
||||
FloatVariable,
|
||||
IntegerVariable,
|
||||
ObjectSegment,
|
||||
@ -149,83 +146,6 @@ def test_array_object_variable():
|
||||
assert isinstance(variable.value[1]['key2'], int)
|
||||
|
||||
|
||||
def test_file_variable():
|
||||
mapping = {
|
||||
'id': str(uuid4()),
|
||||
'value_type': 'file',
|
||||
'name': 'test_file',
|
||||
'description': 'Description of the variable.',
|
||||
'value': {
|
||||
'id': str(uuid4()),
|
||||
'tenant_id': 'tenant_id',
|
||||
'type': 'image',
|
||||
'transfer_method': 'local_file',
|
||||
'url': 'url',
|
||||
'related_id': 'related_id',
|
||||
'extra_config': {
|
||||
'image_config': {
|
||||
'width': 100,
|
||||
'height': 100,
|
||||
},
|
||||
},
|
||||
'filename': 'filename',
|
||||
'extension': 'extension',
|
||||
'mime_type': 'mime_type',
|
||||
},
|
||||
}
|
||||
variable = factory.build_variable_from_mapping(mapping)
|
||||
assert isinstance(variable, FileVariable)
|
||||
|
||||
|
||||
def test_array_file_variable():
|
||||
mapping = {
|
||||
'id': str(uuid4()),
|
||||
'value_type': 'array[file]',
|
||||
'name': 'test_array_file',
|
||||
'description': 'Description of the variable.',
|
||||
'value': [
|
||||
{
|
||||
'id': str(uuid4()),
|
||||
'tenant_id': 'tenant_id',
|
||||
'type': 'image',
|
||||
'transfer_method': 'local_file',
|
||||
'url': 'url',
|
||||
'related_id': 'related_id',
|
||||
'extra_config': {
|
||||
'image_config': {
|
||||
'width': 100,
|
||||
'height': 100,
|
||||
},
|
||||
},
|
||||
'filename': 'filename',
|
||||
'extension': 'extension',
|
||||
'mime_type': 'mime_type',
|
||||
},
|
||||
{
|
||||
'id': str(uuid4()),
|
||||
'tenant_id': 'tenant_id',
|
||||
'type': 'image',
|
||||
'transfer_method': 'local_file',
|
||||
'url': 'url',
|
||||
'related_id': 'related_id',
|
||||
'extra_config': {
|
||||
'image_config': {
|
||||
'width': 100,
|
||||
'height': 100,
|
||||
},
|
||||
},
|
||||
'filename': 'filename',
|
||||
'extension': 'extension',
|
||||
'mime_type': 'mime_type',
|
||||
},
|
||||
],
|
||||
}
|
||||
variable = factory.build_variable_from_mapping(mapping)
|
||||
assert isinstance(variable, ArrayFileVariable)
|
||||
assert isinstance(variable.value[0], FileSegment)
|
||||
assert isinstance(variable.value[1], FileSegment)
|
||||
|
||||
|
||||
def test_variable_cannot_large_than_5_kb():
|
||||
with pytest.raises(VariableError):
|
||||
factory.build_variable_from_mapping(
|
||||
|
@ -1,7 +1,7 @@
|
||||
from core.app.segments import SecretVariable, StringSegment, parser
|
||||
from core.helper import encrypter
|
||||
from core.workflow.entities.node_entities import SystemVariable
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariable
|
||||
|
||||
|
||||
def test_segment_group_to_text():
|
||||
|
@ -1,8 +1,8 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import SystemVariable
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariable
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from extensions.ext_database import db
|
||||
|
@ -1,8 +1,8 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import SystemVariable
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariable
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from core.workflow.nodes.if_else.if_else_node import IfElseNode
|
||||
from extensions.ext_database import db
|
||||
|
@ -3,8 +3,8 @@ from uuid import uuid4
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.segments import ArrayStringVariable, StringVariable
|
||||
from core.workflow.entities.node_entities import SystemVariable
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariable
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from core.workflow.nodes.variable_assigner import VariableAssignerNode, WriteMode
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user