refactor: rename agent to agent strategy

This commit is contained in:
Yeuoly 2024-12-12 18:27:31 +08:00
parent c2983ecbb7
commit 3c628d0c26
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61
11 changed files with 33 additions and 28 deletions

View File

@ -6,26 +6,26 @@ from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolIdentity, ToolParameter, ToolProviderIdentity
class AgentProviderIdentity(ToolProviderIdentity):
class AgentStrategyProviderIdentity(ToolProviderIdentity):
pass
class AgentParameter(ToolParameter):
class AgentStrategyParameter(ToolParameter):
pass
class AgentProviderEntity(BaseModel):
identity: AgentProviderIdentity
class AgentStrategyProviderEntity(BaseModel):
identity: AgentStrategyProviderIdentity
plugin_id: Optional[str] = Field(None, description="The id of the plugin")
class AgentIdentity(ToolIdentity):
class AgentStrategyIdentity(ToolIdentity):
pass
class AgentStrategyEntity(BaseModel):
identity: AgentIdentity
parameters: list[AgentParameter] = Field(default_factory=list)
identity: AgentStrategyIdentity
parameters: list[AgentStrategyParameter] = Field(default_factory=list)
description: I18nObject = Field(..., description="The description of the agent strategy")
output_schema: Optional[dict] = None
@ -34,9 +34,9 @@ class AgentStrategyEntity(BaseModel):
@field_validator("parameters", mode="before")
@classmethod
def set_parameters(cls, v, validation_info: ValidationInfo) -> list[AgentParameter]:
def set_parameters(cls, v, validation_info: ValidationInfo) -> list[AgentStrategyParameter]:
return v or []
class AgentProviderEntityWithPlugin(AgentProviderEntity):
class AgentProviderEntityWithPlugin(AgentStrategyProviderEntity):
strategies: list[AgentStrategyEntity] = Field(default_factory=list)

View File

@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
from typing import Any, Generator, Optional, Sequence
from core.agent.entities import AgentInvokeMessage
from core.agent.plugin_entities import AgentParameter
from core.agent.plugin_entities import AgentStrategyParameter
class BaseAgentStrategy(ABC):
@ -23,7 +23,7 @@ class BaseAgentStrategy(ABC):
"""
yield from self._invoke(params, user_id, conversation_id, app_id, message_id)
def get_parameters(self) -> Sequence[AgentParameter]:
def get_parameters(self) -> Sequence[AgentStrategyParameter]:
"""
Get the parameters for the agent strategy.
"""

View File

@ -1,7 +1,7 @@
from typing import Any, Generator, Optional, Sequence
from core.agent.entities import AgentInvokeMessage
from core.agent.plugin_entities import AgentParameter, AgentStrategyEntity
from core.agent.plugin_entities import AgentStrategyParameter, AgentStrategyEntity
from core.agent.strategy.base import BaseAgentStrategy
from core.plugin.manager.agent import PluginAgentManager
from core.tools.plugin_tool.tool import PluginTool
@ -21,7 +21,7 @@ class PluginAgentStrategy(BaseAgentStrategy):
self.plugin_unique_identifier = plugin_unique_identifier
self.declaration = declaration
def get_parameters(self) -> Sequence[AgentParameter]:
def get_parameters(self) -> Sequence[AgentStrategyParameter]:
return self.declaration.parameters
def _invoke(

View File

@ -43,7 +43,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
# invoke model
response = model_instance.invoke_llm(
prompt_messages=payload.prompt_messages,
model_parameters=payload.model_parameters,
model_parameters=payload.completion_params,
tools=payload.tools,
stop=payload.stop,
stream=payload.stream or True,

View File

@ -6,6 +6,7 @@ from typing import Any, Optional
from pydantic import BaseModel, Field, model_validator
from core.agent.plugin_entities import AgentStrategyProviderEntity
from core.model_runtime.entities.provider_entities import ProviderEntity
from core.plugin.entities.base import BasePluginEntity
from core.plugin.entities.endpoint import EndpointProviderDeclaration
@ -59,6 +60,7 @@ class PluginCategory(enum.StrEnum):
Tool = "tool"
Model = "model"
Extension = "extension"
AgentStrategy = "agent_strategy"
class PluginDeclaration(BaseModel):
@ -82,6 +84,7 @@ class PluginDeclaration(BaseModel):
tool: Optional[ToolProviderEntity] = None
model: Optional[ProviderEntity] = None
endpoint: Optional[EndpointProviderDeclaration] = None
agent_strategy: Optional[AgentStrategyProviderEntity] = None
@model_validator(mode="before")
@classmethod
@ -91,6 +94,8 @@ class PluginDeclaration(BaseModel):
values["category"] = PluginCategory.Tool
elif values.get("model"):
values["category"] = PluginCategory.Model
elif values.get("agent_strategy"):
values["category"] = PluginCategory.AgentStrategy
else:
values["category"] = PluginCategory.Extension
return values

View File

@ -53,7 +53,7 @@ class RequestInvokeLLM(BaseRequestInvokeModel):
model_type: ModelType = ModelType.LLM
mode: str
model_parameters: dict[str, Any] = Field(default_factory=dict)
completion_params: dict[str, Any] = Field(default_factory=dict)
prompt_messages: list[PromptMessage] = Field(default_factory=list)
tools: Optional[list[PromptMessageTool]] = Field(default_factory=list)
stop: Optional[list[str]] = Field(default_factory=list)

View File

@ -10,7 +10,7 @@ from core.plugin.manager.base import BasePluginManager
class PluginAgentManager(BasePluginManager):
def fetch_agent_providers(self, tenant_id: str) -> list[PluginAgentProviderEntity]:
def fetch_agent_strategy_providers(self, tenant_id: str) -> list[PluginAgentProviderEntity]:
"""
Fetch agent providers for the given tenant.
"""
@ -26,7 +26,7 @@ class PluginAgentManager(BasePluginManager):
response = self._request_with_plugin_daemon_response(
"GET",
f"plugin/{tenant_id}/management/agents",
f"plugin/{tenant_id}/management/agent_strategies",
list[PluginAgentProviderEntity],
params={"page": 1, "page_size": 256},
transformer=transformer,
@ -41,7 +41,7 @@ class PluginAgentManager(BasePluginManager):
return response
def fetch_agent_provider(self, tenant_id: str, provider: str) -> PluginAgentProviderEntity:
def fetch_agent_strategy_provider(self, tenant_id: str, provider: str) -> PluginAgentProviderEntity:
"""
Fetch tool provider for the given tenant and plugin.
"""
@ -55,7 +55,7 @@ class PluginAgentManager(BasePluginManager):
response = self._request_with_plugin_daemon_response(
"GET",
f"plugin/{tenant_id}/management/agent",
f"plugin/{tenant_id}/management/agent_strategy",
PluginAgentProviderEntity,
params={"provider": agent_provider_id.provider_name, "plugin_id": agent_provider_id.plugin_id},
transformer=transformer,
@ -96,9 +96,9 @@ class PluginAgentManager(BasePluginManager):
"app_id": app_id,
"message_id": message_id,
"data": {
"provider": agent_provider_id.provider_name,
"strategy": agent_strategy,
"agent_params": agent_params,
"agent_strategy_provider": agent_provider_id.provider_name,
"agent_strategy": agent_strategy,
"agent_strategy_params": agent_params,
},
},
headers={

View File

@ -1,7 +1,7 @@
from collections.abc import Generator
from typing import Any, Sequence, cast
from core.agent.plugin_entities import AgentParameter
from core.agent.plugin_entities import AgentStrategyParameter
from core.plugin.manager.exc import PluginDaemonClientSideError
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
@ -90,7 +90,7 @@ class AgentNode(ToolNode):
def _generate_parameters(
self,
*,
agent_parameters: Sequence[AgentParameter],
agent_parameters: Sequence[AgentStrategyParameter],
variable_pool: VariablePool,
node_data: AgentNodeData,
for_log: bool = False,

View File

@ -246,7 +246,7 @@ class ToolNode(BaseNode[ToolNodeData]):
)
elif message.type == ToolInvokeMessage.MessageType.TEXT:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
text += message.message.text + "\n"
text += message.message.text
yield RunStreamChunkEvent(
chunk_content=message.message.text, from_variable_selector=[self.node_id, "text"]
)

View File

@ -7,7 +7,7 @@ def get_plugin_agent_strategy(
) -> PluginAgentStrategy:
# TODO: use contexts to cache the agent provider
manager = PluginAgentManager()
agent_provider = manager.fetch_agent_provider(tenant_id, agent_strategy_provider_name)
agent_provider = manager.fetch_agent_strategy_provider(tenant_id, agent_strategy_provider_name)
for agent_strategy in agent_provider.declaration.strategies:
if agent_strategy.identity.name == agent_strategy_name:
return PluginAgentStrategy(tenant_id, plugin_unique_identifier, agent_strategy)

View File

@ -160,7 +160,7 @@ class AgentService:
List agent providers
"""
manager = PluginAgentManager()
return manager.fetch_agent_providers(tenant_id)
return manager.fetch_agent_strategy_providers(tenant_id)
@classmethod
def get_agent_provider(cls, user_id: str, tenant_id: str, provider_name: str):
@ -168,4 +168,4 @@ class AgentService:
Get agent provider
"""
manager = PluginAgentManager()
return manager.fetch_agent_provider(tenant_id, provider_name)
return manager.fetch_agent_strategy_provider(tenant_id, provider_name)