feat: add agent strategy on node start (#12667)
Co-authored-by: Novice Lee <novicelee@NoviPro.local>
This commit is contained in:
parent
f0a3c14adb
commit
98b139c680
@ -241,6 +241,7 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||||||
predecessor_node_id=event.predecessor_node_id,
|
predecessor_node_id=event.predecessor_node_id,
|
||||||
in_iteration_id=event.in_iteration_id,
|
in_iteration_id=event.in_iteration_id,
|
||||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
parallel_mode_run_id=event.parallel_mode_run_id,
|
||||||
|
agent_strategy=event.agent_strategy,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif isinstance(event, NodeRunSucceededEvent):
|
elif isinstance(event, NodeRunSucceededEvent):
|
||||||
|
@ -6,7 +6,7 @@ from typing import Any, Optional
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunMetadataKey
|
||||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||||
from core.workflow.nodes import NodeType
|
from core.workflow.nodes import NodeType
|
||||||
from core.workflow.nodes.base import BaseNodeData
|
from core.workflow.nodes.base import BaseNodeData
|
||||||
@ -281,6 +281,7 @@ class QueueNodeStartedEvent(AppQueueEvent):
|
|||||||
start_at: datetime
|
start_at: datetime
|
||||||
parallel_mode_run_id: Optional[str] = None
|
parallel_mode_run_id: Optional[str] = None
|
||||||
"""iteratoin run in parallel mode run id"""
|
"""iteratoin run in parallel mode run id"""
|
||||||
|
agent_strategy: Optional[AgentNodeStrategyInit] = None
|
||||||
|
|
||||||
|
|
||||||
class QueueNodeSucceededEvent(AppQueueEvent):
|
class QueueNodeSucceededEvent(AppQueueEvent):
|
||||||
|
@ -6,6 +6,7 @@ from pydantic import BaseModel, ConfigDict
|
|||||||
|
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult
|
from core.model_runtime.entities.llm_entities import LLMResult
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
|
from core.workflow.entities.node_entities import AgentNodeStrategyInit
|
||||||
from models.workflow import WorkflowNodeExecutionStatus
|
from models.workflow import WorkflowNodeExecutionStatus
|
||||||
|
|
||||||
|
|
||||||
@ -248,6 +249,7 @@ class NodeStartStreamResponse(StreamResponse):
|
|||||||
parent_parallel_start_node_id: Optional[str] = None
|
parent_parallel_start_node_id: Optional[str] = None
|
||||||
iteration_id: Optional[str] = None
|
iteration_id: Optional[str] = None
|
||||||
parallel_run_id: Optional[str] = None
|
parallel_run_id: Optional[str] = None
|
||||||
|
agent_strategy: Optional[AgentNodeStrategyInit] = None
|
||||||
|
|
||||||
event: StreamEvent = StreamEvent.NODE_STARTED
|
event: StreamEvent = StreamEvent.NODE_STARTED
|
||||||
workflow_run_id: str
|
workflow_run_id: str
|
||||||
|
@ -541,6 +541,7 @@ class WorkflowCycleManage:
|
|||||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||||
iteration_id=event.in_iteration_id,
|
iteration_id=event.in_iteration_id,
|
||||||
parallel_run_id=event.parallel_mode_run_id,
|
parallel_run_id=event.parallel_mode_run_id,
|
||||||
|
agent_strategy=event.agent_strategy,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -49,3 +49,8 @@ class NodeRunResult(BaseModel):
|
|||||||
|
|
||||||
# single step node run retry
|
# single step node run retry
|
||||||
retry_index: int = 0
|
retry_index: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class AgentNodeStrategyInit(BaseModel):
|
||||||
|
name: str
|
||||||
|
icon: str | None = None
|
||||||
|
@ -4,6 +4,7 @@ from typing import Any, Optional
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from core.workflow.entities.node_entities import AgentNodeStrategyInit
|
||||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||||
from core.workflow.nodes import NodeType
|
from core.workflow.nodes import NodeType
|
||||||
from core.workflow.nodes.base import BaseNodeData
|
from core.workflow.nodes.base import BaseNodeData
|
||||||
@ -66,8 +67,10 @@ class BaseNodeEvent(GraphEngineEvent):
|
|||||||
|
|
||||||
class NodeRunStartedEvent(BaseNodeEvent):
|
class NodeRunStartedEvent(BaseNodeEvent):
|
||||||
predecessor_node_id: Optional[str] = None
|
predecessor_node_id: Optional[str] = None
|
||||||
parallel_mode_run_id: Optional[str] = None
|
|
||||||
"""predecessor node id"""
|
"""predecessor node id"""
|
||||||
|
parallel_mode_run_id: Optional[str] = None
|
||||||
|
"""iteration node parallel mode run id"""
|
||||||
|
agent_strategy: Optional[AgentNodeStrategyInit] = None
|
||||||
|
|
||||||
|
|
||||||
class NodeRunStreamChunkEvent(BaseNodeEvent):
|
class NodeRunStreamChunkEvent(BaseNodeEvent):
|
||||||
|
@ -14,7 +14,7 @@ from flask import Flask, current_app
|
|||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
|
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunMetadataKey, NodeRunResult
|
||||||
from core.workflow.entities.variable_pool import VariablePool, VariableValue
|
from core.workflow.entities.variable_pool import VariablePool, VariableValue
|
||||||
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
|
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
|
||||||
from core.workflow.graph_engine.entities.event import (
|
from core.workflow.graph_engine.entities.event import (
|
||||||
@ -40,6 +40,8 @@ from core.workflow.graph_engine.entities.graph_init_params import GraphInitParam
|
|||||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||||
from core.workflow.nodes import NodeType
|
from core.workflow.nodes import NodeType
|
||||||
|
from core.workflow.nodes.agent.agent_node import AgentNode
|
||||||
|
from core.workflow.nodes.agent.entities import AgentNodeData
|
||||||
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
|
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
|
||||||
from core.workflow.nodes.answer.base_stream_processor import StreamProcessor
|
from core.workflow.nodes.answer.base_stream_processor import StreamProcessor
|
||||||
from core.workflow.nodes.base import BaseNode
|
from core.workflow.nodes.base import BaseNode
|
||||||
@ -606,6 +608,14 @@ class GraphEngine:
|
|||||||
Run node
|
Run node
|
||||||
"""
|
"""
|
||||||
# trigger node run start event
|
# trigger node run start event
|
||||||
|
agent_strategy = (
|
||||||
|
AgentNodeStrategyInit(
|
||||||
|
name=cast(AgentNodeData, node_instance.node_data).agent_strategy_name,
|
||||||
|
icon=cast(AgentNode, node_instance).agent_strategy_icon,
|
||||||
|
)
|
||||||
|
if node_instance.node_type == NodeType.AGENT
|
||||||
|
else None
|
||||||
|
)
|
||||||
yield NodeRunStartedEvent(
|
yield NodeRunStartedEvent(
|
||||||
id=node_instance.id,
|
id=node_instance.id,
|
||||||
node_id=node_instance.node_id,
|
node_id=node_instance.node_id,
|
||||||
@ -617,6 +627,7 @@ class GraphEngine:
|
|||||||
parallel_start_node_id=parallel_start_node_id,
|
parallel_start_node_id=parallel_start_node_id,
|
||||||
parent_parallel_id=parent_parallel_id,
|
parent_parallel_id=parent_parallel_id,
|
||||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||||
|
agent_strategy=agent_strategy,
|
||||||
)
|
)
|
||||||
|
|
||||||
db.session.close()
|
db.session.close()
|
||||||
|
@ -90,18 +90,11 @@ class AgentNode(ToolNode):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# convert tool messages
|
# convert tool messages
|
||||||
manager = PluginInstallationManager()
|
|
||||||
plugins = manager.list_plugins(self.tenant_id)
|
|
||||||
current_plugin = next(
|
|
||||||
plugin
|
|
||||||
for plugin in plugins
|
|
||||||
if f"{plugin.plugin_id}/{plugin.name}"
|
|
||||||
== cast(AgentNodeData, self.node_data).agent_strategy_provider_name
|
|
||||||
)
|
|
||||||
yield from self._transform_message(
|
yield from self._transform_message(
|
||||||
message_stream,
|
message_stream,
|
||||||
{
|
{
|
||||||
"icon": current_plugin.declaration.icon,
|
"icon": self.agent_strategy_icon,
|
||||||
"agent_strategy": cast(AgentNodeData, self.node_data).agent_strategy_name,
|
"agent_strategy": cast(AgentNodeData, self.node_data).agent_strategy_name,
|
||||||
},
|
},
|
||||||
parameters_for_log,
|
parameters_for_log,
|
||||||
@ -229,16 +222,33 @@ class AgentNode(ToolNode):
|
|||||||
result: dict[str, Any] = {}
|
result: dict[str, Any] = {}
|
||||||
for parameter_name in node_data.agent_parameters:
|
for parameter_name in node_data.agent_parameters:
|
||||||
input = node_data.agent_parameters[parameter_name]
|
input = node_data.agent_parameters[parameter_name]
|
||||||
if input.type == "mixed":
|
if input.type in ["mixed", "constant"]:
|
||||||
assert isinstance(input.value, str)
|
selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors()
|
||||||
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
|
|
||||||
for selector in selectors:
|
for selector in selectors:
|
||||||
result[selector.variable] = selector.value_selector
|
result[selector.variable] = selector.value_selector
|
||||||
elif input.type == "variable":
|
elif input.type == "variable":
|
||||||
result[parameter_name] = input.value
|
result[parameter_name] = input.value
|
||||||
elif input.type == "constant":
|
|
||||||
pass
|
|
||||||
|
|
||||||
result = {node_id + "." + key: value for key, value in result.items()}
|
result = {node_id + "." + key: value for key, value in result.items()}
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@property
|
||||||
|
def agent_strategy_icon(self) -> str | None:
|
||||||
|
"""
|
||||||
|
Get agent strategy icon
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
manager = PluginInstallationManager()
|
||||||
|
plugins = manager.list_plugins(self.tenant_id)
|
||||||
|
try:
|
||||||
|
current_plugin = next(
|
||||||
|
plugin
|
||||||
|
for plugin in plugins
|
||||||
|
if f"{plugin.plugin_id}/{plugin.name}"
|
||||||
|
== cast(AgentNodeData, self.node_data).agent_strategy_provider_name
|
||||||
|
)
|
||||||
|
icon = current_plugin.declaration.icon
|
||||||
|
except StopIteration:
|
||||||
|
icon = None
|
||||||
|
return icon
|
||||||
|
Loading…
Reference in New Issue
Block a user