diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index 9b4a918ead..dabd4a0bd7 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -241,6 +241,7 @@ class WorkflowBasedAppRunner(AppRunner): predecessor_node_id=event.predecessor_node_id, in_iteration_id=event.in_iteration_id, parallel_mode_run_id=event.parallel_mode_run_id, + agent_strategy=event.agent_strategy, ) ) elif isinstance(event, NodeRunSucceededEvent): diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index c0d3d9d88a..17b1797c4f 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -6,7 +6,7 @@ from typing import Any, Optional from pydantic import BaseModel from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk -from core.workflow.entities.node_entities import NodeRunMetadataKey +from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunMetadataKey from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes import NodeType from core.workflow.nodes.base import BaseNodeData @@ -281,6 +281,7 @@ class QueueNodeStartedEvent(AppQueueEvent): start_at: datetime parallel_mode_run_id: Optional[str] = None """iteratoin run in parallel mode run id""" + agent_strategy: Optional[AgentNodeStrategyInit] = None class QueueNodeSucceededEvent(AppQueueEvent): diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 7182c36fe2..3c055232f8 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -6,6 +6,7 @@ from pydantic import BaseModel, ConfigDict from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.utils.encoders import jsonable_encoder +from core.workflow.entities.node_entities import AgentNodeStrategyInit from models.workflow import WorkflowNodeExecutionStatus @@ -248,6 +249,7 @@ class NodeStartStreamResponse(StreamResponse): parent_parallel_start_node_id: Optional[str] = None iteration_id: Optional[str] = None parallel_run_id: Optional[str] = None + agent_strategy: Optional[AgentNodeStrategyInit] = None event: StreamEvent = StreamEvent.NODE_STARTED workflow_run_id: str diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index 309d5f1422..8b4e700563 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -541,6 +541,7 @@ class WorkflowCycleManage: parent_parallel_start_node_id=event.parent_parallel_start_node_id, iteration_id=event.in_iteration_id, parallel_run_id=event.parallel_mode_run_id, + agent_strategy=event.agent_strategy, ), ) diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index ed26889614..27c0e6702a 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -49,3 +49,8 @@ class NodeRunResult(BaseModel): # single step node run retry retry_index: int = 0 + + +class AgentNodeStrategyInit(BaseModel): + name: str + icon: str | None = None diff --git a/api/core/workflow/graph_engine/entities/event.py b/api/core/workflow/graph_engine/entities/event.py index f1c486d761..439f768c5e 100644 --- a/api/core/workflow/graph_engine/entities/event.py +++ b/api/core/workflow/graph_engine/entities/event.py @@ -4,6 +4,7 @@ from typing import Any, Optional from pydantic import BaseModel, Field +from core.workflow.entities.node_entities import AgentNodeStrategyInit from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState from core.workflow.nodes import NodeType from core.workflow.nodes.base import BaseNodeData @@ -66,8 +67,10 @@ class BaseNodeEvent(GraphEngineEvent): class NodeRunStartedEvent(BaseNodeEvent): predecessor_node_id: Optional[str] = None - parallel_mode_run_id: Optional[str] = None """predecessor node id""" + parallel_mode_run_id: Optional[str] = None + """iteration node parallel mode run id""" + agent_strategy: Optional[AgentNodeStrategyInit] = None class NodeRunStreamChunkEvent(BaseNodeEvent): diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 463525b9f4..a05cc30cab 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -14,7 +14,7 @@ from flask import Flask, current_app from configs import dify_config from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult +from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunMetadataKey, NodeRunResult from core.workflow.entities.variable_pool import VariablePool, VariableValue from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager from core.workflow.graph_engine.entities.event import ( @@ -40,6 +40,8 @@ from core.workflow.graph_engine.entities.graph_init_params import GraphInitParam from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState from core.workflow.nodes import NodeType +from core.workflow.nodes.agent.agent_node import AgentNode +from core.workflow.nodes.agent.entities import AgentNodeData from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor from core.workflow.nodes.answer.base_stream_processor import StreamProcessor from core.workflow.nodes.base import BaseNode @@ -606,6 +608,14 @@ class GraphEngine: Run node """ # trigger node run start event + agent_strategy = ( + AgentNodeStrategyInit( + name=cast(AgentNodeData, node_instance.node_data).agent_strategy_name, + icon=cast(AgentNode, node_instance).agent_strategy_icon, + ) + if node_instance.node_type == NodeType.AGENT + else None + ) yield NodeRunStartedEvent( id=node_instance.id, node_id=node_instance.node_id, @@ -617,6 +627,7 @@ class GraphEngine: parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, + agent_strategy=agent_strategy, ) db.session.close() diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 8c29ce5176..db84624961 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -90,18 +90,11 @@ class AgentNode(ToolNode): try: # convert tool messages - manager = PluginInstallationManager() - plugins = manager.list_plugins(self.tenant_id) - current_plugin = next( - plugin - for plugin in plugins - if f"{plugin.plugin_id}/{plugin.name}" - == cast(AgentNodeData, self.node_data).agent_strategy_provider_name - ) + yield from self._transform_message( message_stream, { - "icon": current_plugin.declaration.icon, + "icon": self.agent_strategy_icon, "agent_strategy": cast(AgentNodeData, self.node_data).agent_strategy_name, }, parameters_for_log, @@ -229,16 +222,33 @@ class AgentNode(ToolNode): result: dict[str, Any] = {} for parameter_name in node_data.agent_parameters: input = node_data.agent_parameters[parameter_name] - if input.type == "mixed": - assert isinstance(input.value, str) - selectors = VariableTemplateParser(input.value).extract_variable_selectors() + if input.type in ["mixed", "constant"]: + selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors() for selector in selectors: result[selector.variable] = selector.value_selector elif input.type == "variable": result[parameter_name] = input.value - elif input.type == "constant": - pass result = {node_id + "." + key: value for key, value in result.items()} return result + + @property + def agent_strategy_icon(self) -> str | None: + """ + Get agent strategy icon + :return: + """ + manager = PluginInstallationManager() + plugins = manager.list_plugins(self.tenant_id) + try: + current_plugin = next( + plugin + for plugin in plugins + if f"{plugin.plugin_id}/{plugin.name}" + == cast(AgentNodeData, self.node_data).agent_strategy_provider_name + ) + icon = current_plugin.declaration.icon + except StopIteration: + icon = None + return icon