From 3c628d0c267bf83f9262fdf9d78256cecbe9bcde Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Thu, 12 Dec 2024 18:27:31 +0800 Subject: [PATCH] refactor: rename agent to agent strategy --- api/core/agent/plugin_entities.py | 18 +++++++++--------- api/core/agent/strategy/base.py | 4 ++-- api/core/agent/strategy/plugin.py | 4 ++-- api/core/plugin/backwards_invocation/model.py | 2 +- api/core/plugin/entities/plugin.py | 5 +++++ api/core/plugin/entities/request.py | 2 +- api/core/plugin/manager/agent.py | 14 +++++++------- api/core/workflow/nodes/agent/agent_node.py | 4 ++-- api/core/workflow/nodes/tool/tool_node.py | 2 +- api/factories/agent_factory.py | 2 +- api/services/agent_service.py | 4 ++-- 11 files changed, 33 insertions(+), 28 deletions(-) diff --git a/api/core/agent/plugin_entities.py b/api/core/agent/plugin_entities.py index 820bd8af8e..d24c5e8336 100644 --- a/api/core/agent/plugin_entities.py +++ b/api/core/agent/plugin_entities.py @@ -6,26 +6,26 @@ from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolIdentity, ToolParameter, ToolProviderIdentity -class AgentProviderIdentity(ToolProviderIdentity): +class AgentStrategyProviderIdentity(ToolProviderIdentity): pass -class AgentParameter(ToolParameter): +class AgentStrategyParameter(ToolParameter): pass -class AgentProviderEntity(BaseModel): - identity: AgentProviderIdentity +class AgentStrategyProviderEntity(BaseModel): + identity: AgentStrategyProviderIdentity plugin_id: Optional[str] = Field(None, description="The id of the plugin") -class AgentIdentity(ToolIdentity): +class AgentStrategyIdentity(ToolIdentity): pass class AgentStrategyEntity(BaseModel): - identity: AgentIdentity - parameters: list[AgentParameter] = Field(default_factory=list) + identity: AgentStrategyIdentity + parameters: list[AgentStrategyParameter] = Field(default_factory=list) description: I18nObject = Field(..., description="The description of the agent strategy") output_schema: Optional[dict] = None @@ -34,9 +34,9 @@ class AgentStrategyEntity(BaseModel): @field_validator("parameters", mode="before") @classmethod - def set_parameters(cls, v, validation_info: ValidationInfo) -> list[AgentParameter]: + def set_parameters(cls, v, validation_info: ValidationInfo) -> list[AgentStrategyParameter]: return v or [] -class AgentProviderEntityWithPlugin(AgentProviderEntity): +class AgentProviderEntityWithPlugin(AgentStrategyProviderEntity): strategies: list[AgentStrategyEntity] = Field(default_factory=list) diff --git a/api/core/agent/strategy/base.py b/api/core/agent/strategy/base.py index b53caa5684..a9b9326fc5 100644 --- a/api/core/agent/strategy/base.py +++ b/api/core/agent/strategy/base.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from typing import Any, Generator, Optional, Sequence from core.agent.entities import AgentInvokeMessage -from core.agent.plugin_entities import AgentParameter +from core.agent.plugin_entities import AgentStrategyParameter class BaseAgentStrategy(ABC): @@ -23,7 +23,7 @@ class BaseAgentStrategy(ABC): """ yield from self._invoke(params, user_id, conversation_id, app_id, message_id) - def get_parameters(self) -> Sequence[AgentParameter]: + def get_parameters(self) -> Sequence[AgentStrategyParameter]: """ Get the parameters for the agent strategy. """ diff --git a/api/core/agent/strategy/plugin.py b/api/core/agent/strategy/plugin.py index d70f0c34aa..785879ad01 100644 --- a/api/core/agent/strategy/plugin.py +++ b/api/core/agent/strategy/plugin.py @@ -1,7 +1,7 @@ from typing import Any, Generator, Optional, Sequence from core.agent.entities import AgentInvokeMessage -from core.agent.plugin_entities import AgentParameter, AgentStrategyEntity +from core.agent.plugin_entities import AgentStrategyParameter, AgentStrategyEntity from core.agent.strategy.base import BaseAgentStrategy from core.plugin.manager.agent import PluginAgentManager from core.tools.plugin_tool.tool import PluginTool @@ -21,7 +21,7 @@ class PluginAgentStrategy(BaseAgentStrategy): self.plugin_unique_identifier = plugin_unique_identifier self.declaration = declaration - def get_parameters(self) -> Sequence[AgentParameter]: + def get_parameters(self) -> Sequence[AgentStrategyParameter]: return self.declaration.parameters def _invoke( diff --git a/api/core/plugin/backwards_invocation/model.py b/api/core/plugin/backwards_invocation/model.py index 8894f8eef5..a1c97c447e 100644 --- a/api/core/plugin/backwards_invocation/model.py +++ b/api/core/plugin/backwards_invocation/model.py @@ -43,7 +43,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): # invoke model response = model_instance.invoke_llm( prompt_messages=payload.prompt_messages, - model_parameters=payload.model_parameters, + model_parameters=payload.completion_params, tools=payload.tools, stop=payload.stop, stream=payload.stream or True, diff --git a/api/core/plugin/entities/plugin.py b/api/core/plugin/entities/plugin.py index f9c7870e77..642a9304c5 100644 --- a/api/core/plugin/entities/plugin.py +++ b/api/core/plugin/entities/plugin.py @@ -6,6 +6,7 @@ from typing import Any, Optional from pydantic import BaseModel, Field, model_validator +from core.agent.plugin_entities import AgentStrategyProviderEntity from core.model_runtime.entities.provider_entities import ProviderEntity from core.plugin.entities.base import BasePluginEntity from core.plugin.entities.endpoint import EndpointProviderDeclaration @@ -59,6 +60,7 @@ class PluginCategory(enum.StrEnum): Tool = "tool" Model = "model" Extension = "extension" + AgentStrategy = "agent_strategy" class PluginDeclaration(BaseModel): @@ -82,6 +84,7 @@ class PluginDeclaration(BaseModel): tool: Optional[ToolProviderEntity] = None model: Optional[ProviderEntity] = None endpoint: Optional[EndpointProviderDeclaration] = None + agent_strategy: Optional[AgentStrategyProviderEntity] = None @model_validator(mode="before") @classmethod @@ -91,6 +94,8 @@ class PluginDeclaration(BaseModel): values["category"] = PluginCategory.Tool elif values.get("model"): values["category"] = PluginCategory.Model + elif values.get("agent_strategy"): + values["category"] = PluginCategory.AgentStrategy else: values["category"] = PluginCategory.Extension return values diff --git a/api/core/plugin/entities/request.py b/api/core/plugin/entities/request.py index 9a0e569d4d..837dcf59c4 100644 --- a/api/core/plugin/entities/request.py +++ b/api/core/plugin/entities/request.py @@ -53,7 +53,7 @@ class RequestInvokeLLM(BaseRequestInvokeModel): model_type: ModelType = ModelType.LLM mode: str - model_parameters: dict[str, Any] = Field(default_factory=dict) + completion_params: dict[str, Any] = Field(default_factory=dict) prompt_messages: list[PromptMessage] = Field(default_factory=list) tools: Optional[list[PromptMessageTool]] = Field(default_factory=list) stop: Optional[list[str]] = Field(default_factory=list) diff --git a/api/core/plugin/manager/agent.py b/api/core/plugin/manager/agent.py index fd83d4411c..0a2fb0e2d6 100644 --- a/api/core/plugin/manager/agent.py +++ b/api/core/plugin/manager/agent.py @@ -10,7 +10,7 @@ from core.plugin.manager.base import BasePluginManager class PluginAgentManager(BasePluginManager): - def fetch_agent_providers(self, tenant_id: str) -> list[PluginAgentProviderEntity]: + def fetch_agent_strategy_providers(self, tenant_id: str) -> list[PluginAgentProviderEntity]: """ Fetch agent providers for the given tenant. """ @@ -26,7 +26,7 @@ class PluginAgentManager(BasePluginManager): response = self._request_with_plugin_daemon_response( "GET", - f"plugin/{tenant_id}/management/agents", + f"plugin/{tenant_id}/management/agent_strategies", list[PluginAgentProviderEntity], params={"page": 1, "page_size": 256}, transformer=transformer, @@ -41,7 +41,7 @@ class PluginAgentManager(BasePluginManager): return response - def fetch_agent_provider(self, tenant_id: str, provider: str) -> PluginAgentProviderEntity: + def fetch_agent_strategy_provider(self, tenant_id: str, provider: str) -> PluginAgentProviderEntity: """ Fetch tool provider for the given tenant and plugin. """ @@ -55,7 +55,7 @@ class PluginAgentManager(BasePluginManager): response = self._request_with_plugin_daemon_response( "GET", - f"plugin/{tenant_id}/management/agent", + f"plugin/{tenant_id}/management/agent_strategy", PluginAgentProviderEntity, params={"provider": agent_provider_id.provider_name, "plugin_id": agent_provider_id.plugin_id}, transformer=transformer, @@ -96,9 +96,9 @@ class PluginAgentManager(BasePluginManager): "app_id": app_id, "message_id": message_id, "data": { - "provider": agent_provider_id.provider_name, - "strategy": agent_strategy, - "agent_params": agent_params, + "agent_strategy_provider": agent_provider_id.provider_name, + "agent_strategy": agent_strategy, + "agent_strategy_params": agent_params, }, }, headers={ diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index ea9d8c2906..597a5d905e 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -1,7 +1,7 @@ from collections.abc import Generator from typing import Any, Sequence, cast -from core.agent.plugin_entities import AgentParameter +from core.agent.plugin_entities import AgentStrategyParameter from core.plugin.manager.exc import PluginDaemonClientSideError from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool @@ -90,7 +90,7 @@ class AgentNode(ToolNode): def _generate_parameters( self, *, - agent_parameters: Sequence[AgentParameter], + agent_parameters: Sequence[AgentStrategyParameter], variable_pool: VariablePool, node_data: AgentNodeData, for_log: bool = False, diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 44eae0b4c6..d4c453a53f 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -246,7 +246,7 @@ class ToolNode(BaseNode[ToolNodeData]): ) elif message.type == ToolInvokeMessage.MessageType.TEXT: assert isinstance(message.message, ToolInvokeMessage.TextMessage) - text += message.message.text + "\n" + text += message.message.text yield RunStreamChunkEvent( chunk_content=message.message.text, from_variable_selector=[self.node_id, "text"] ) diff --git a/api/factories/agent_factory.py b/api/factories/agent_factory.py index 66266bdd9c..b22e5d6e5c 100644 --- a/api/factories/agent_factory.py +++ b/api/factories/agent_factory.py @@ -7,7 +7,7 @@ def get_plugin_agent_strategy( ) -> PluginAgentStrategy: # TODO: use contexts to cache the agent provider manager = PluginAgentManager() - agent_provider = manager.fetch_agent_provider(tenant_id, agent_strategy_provider_name) + agent_provider = manager.fetch_agent_strategy_provider(tenant_id, agent_strategy_provider_name) for agent_strategy in agent_provider.declaration.strategies: if agent_strategy.identity.name == agent_strategy_name: return PluginAgentStrategy(tenant_id, plugin_unique_identifier, agent_strategy) diff --git a/api/services/agent_service.py b/api/services/agent_service.py index 78573bfaae..762760b168 100644 --- a/api/services/agent_service.py +++ b/api/services/agent_service.py @@ -160,7 +160,7 @@ class AgentService: List agent providers """ manager = PluginAgentManager() - return manager.fetch_agent_providers(tenant_id) + return manager.fetch_agent_strategy_providers(tenant_id) @classmethod def get_agent_provider(cls, user_id: str, tenant_id: str, provider_name: str): @@ -168,4 +168,4 @@ class AgentService: Get agent provider """ manager = PluginAgentManager() - return manager.fetch_agent_provider(tenant_id, provider_name) + return manager.fetch_agent_strategy_provider(tenant_id, provider_name)