refactor: rename agent to agent strategy
This commit is contained in:
parent
c2983ecbb7
commit
3c628d0c26
@ -6,26 +6,26 @@ from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolIdentity, ToolParameter, ToolProviderIdentity
|
||||
|
||||
|
||||
class AgentProviderIdentity(ToolProviderIdentity):
|
||||
class AgentStrategyProviderIdentity(ToolProviderIdentity):
|
||||
pass
|
||||
|
||||
|
||||
class AgentParameter(ToolParameter):
|
||||
class AgentStrategyParameter(ToolParameter):
|
||||
pass
|
||||
|
||||
|
||||
class AgentProviderEntity(BaseModel):
|
||||
identity: AgentProviderIdentity
|
||||
class AgentStrategyProviderEntity(BaseModel):
|
||||
identity: AgentStrategyProviderIdentity
|
||||
plugin_id: Optional[str] = Field(None, description="The id of the plugin")
|
||||
|
||||
|
||||
class AgentIdentity(ToolIdentity):
|
||||
class AgentStrategyIdentity(ToolIdentity):
|
||||
pass
|
||||
|
||||
|
||||
class AgentStrategyEntity(BaseModel):
|
||||
identity: AgentIdentity
|
||||
parameters: list[AgentParameter] = Field(default_factory=list)
|
||||
identity: AgentStrategyIdentity
|
||||
parameters: list[AgentStrategyParameter] = Field(default_factory=list)
|
||||
description: I18nObject = Field(..., description="The description of the agent strategy")
|
||||
output_schema: Optional[dict] = None
|
||||
|
||||
@ -34,9 +34,9 @@ class AgentStrategyEntity(BaseModel):
|
||||
|
||||
@field_validator("parameters", mode="before")
|
||||
@classmethod
|
||||
def set_parameters(cls, v, validation_info: ValidationInfo) -> list[AgentParameter]:
|
||||
def set_parameters(cls, v, validation_info: ValidationInfo) -> list[AgentStrategyParameter]:
|
||||
return v or []
|
||||
|
||||
|
||||
class AgentProviderEntityWithPlugin(AgentProviderEntity):
|
||||
class AgentProviderEntityWithPlugin(AgentStrategyProviderEntity):
|
||||
strategies: list[AgentStrategyEntity] = Field(default_factory=list)
|
||||
|
@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
|
||||
from typing import Any, Generator, Optional, Sequence
|
||||
|
||||
from core.agent.entities import AgentInvokeMessage
|
||||
from core.agent.plugin_entities import AgentParameter
|
||||
from core.agent.plugin_entities import AgentStrategyParameter
|
||||
|
||||
|
||||
class BaseAgentStrategy(ABC):
|
||||
@ -23,7 +23,7 @@ class BaseAgentStrategy(ABC):
|
||||
"""
|
||||
yield from self._invoke(params, user_id, conversation_id, app_id, message_id)
|
||||
|
||||
def get_parameters(self) -> Sequence[AgentParameter]:
|
||||
def get_parameters(self) -> Sequence[AgentStrategyParameter]:
|
||||
"""
|
||||
Get the parameters for the agent strategy.
|
||||
"""
|
||||
|
@ -1,7 +1,7 @@
|
||||
from typing import Any, Generator, Optional, Sequence
|
||||
|
||||
from core.agent.entities import AgentInvokeMessage
|
||||
from core.agent.plugin_entities import AgentParameter, AgentStrategyEntity
|
||||
from core.agent.plugin_entities import AgentStrategyParameter, AgentStrategyEntity
|
||||
from core.agent.strategy.base import BaseAgentStrategy
|
||||
from core.plugin.manager.agent import PluginAgentManager
|
||||
from core.tools.plugin_tool.tool import PluginTool
|
||||
@ -21,7 +21,7 @@ class PluginAgentStrategy(BaseAgentStrategy):
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
self.declaration = declaration
|
||||
|
||||
def get_parameters(self) -> Sequence[AgentParameter]:
|
||||
def get_parameters(self) -> Sequence[AgentStrategyParameter]:
|
||||
return self.declaration.parameters
|
||||
|
||||
def _invoke(
|
||||
|
@ -43,7 +43,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
# invoke model
|
||||
response = model_instance.invoke_llm(
|
||||
prompt_messages=payload.prompt_messages,
|
||||
model_parameters=payload.model_parameters,
|
||||
model_parameters=payload.completion_params,
|
||||
tools=payload.tools,
|
||||
stop=payload.stop,
|
||||
stream=payload.stream or True,
|
||||
|
@ -6,6 +6,7 @@ from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from core.agent.plugin_entities import AgentStrategyProviderEntity
|
||||
from core.model_runtime.entities.provider_entities import ProviderEntity
|
||||
from core.plugin.entities.base import BasePluginEntity
|
||||
from core.plugin.entities.endpoint import EndpointProviderDeclaration
|
||||
@ -59,6 +60,7 @@ class PluginCategory(enum.StrEnum):
|
||||
Tool = "tool"
|
||||
Model = "model"
|
||||
Extension = "extension"
|
||||
AgentStrategy = "agent_strategy"
|
||||
|
||||
|
||||
class PluginDeclaration(BaseModel):
|
||||
@ -82,6 +84,7 @@ class PluginDeclaration(BaseModel):
|
||||
tool: Optional[ToolProviderEntity] = None
|
||||
model: Optional[ProviderEntity] = None
|
||||
endpoint: Optional[EndpointProviderDeclaration] = None
|
||||
agent_strategy: Optional[AgentStrategyProviderEntity] = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
@ -91,6 +94,8 @@ class PluginDeclaration(BaseModel):
|
||||
values["category"] = PluginCategory.Tool
|
||||
elif values.get("model"):
|
||||
values["category"] = PluginCategory.Model
|
||||
elif values.get("agent_strategy"):
|
||||
values["category"] = PluginCategory.AgentStrategy
|
||||
else:
|
||||
values["category"] = PluginCategory.Extension
|
||||
return values
|
||||
|
@ -53,7 +53,7 @@ class RequestInvokeLLM(BaseRequestInvokeModel):
|
||||
|
||||
model_type: ModelType = ModelType.LLM
|
||||
mode: str
|
||||
model_parameters: dict[str, Any] = Field(default_factory=dict)
|
||||
completion_params: dict[str, Any] = Field(default_factory=dict)
|
||||
prompt_messages: list[PromptMessage] = Field(default_factory=list)
|
||||
tools: Optional[list[PromptMessageTool]] = Field(default_factory=list)
|
||||
stop: Optional[list[str]] = Field(default_factory=list)
|
||||
|
@ -10,7 +10,7 @@ from core.plugin.manager.base import BasePluginManager
|
||||
|
||||
|
||||
class PluginAgentManager(BasePluginManager):
|
||||
def fetch_agent_providers(self, tenant_id: str) -> list[PluginAgentProviderEntity]:
|
||||
def fetch_agent_strategy_providers(self, tenant_id: str) -> list[PluginAgentProviderEntity]:
|
||||
"""
|
||||
Fetch agent providers for the given tenant.
|
||||
"""
|
||||
@ -26,7 +26,7 @@ class PluginAgentManager(BasePluginManager):
|
||||
|
||||
response = self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/agents",
|
||||
f"plugin/{tenant_id}/management/agent_strategies",
|
||||
list[PluginAgentProviderEntity],
|
||||
params={"page": 1, "page_size": 256},
|
||||
transformer=transformer,
|
||||
@ -41,7 +41,7 @@ class PluginAgentManager(BasePluginManager):
|
||||
|
||||
return response
|
||||
|
||||
def fetch_agent_provider(self, tenant_id: str, provider: str) -> PluginAgentProviderEntity:
|
||||
def fetch_agent_strategy_provider(self, tenant_id: str, provider: str) -> PluginAgentProviderEntity:
|
||||
"""
|
||||
Fetch tool provider for the given tenant and plugin.
|
||||
"""
|
||||
@ -55,7 +55,7 @@ class PluginAgentManager(BasePluginManager):
|
||||
|
||||
response = self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/agent",
|
||||
f"plugin/{tenant_id}/management/agent_strategy",
|
||||
PluginAgentProviderEntity,
|
||||
params={"provider": agent_provider_id.provider_name, "plugin_id": agent_provider_id.plugin_id},
|
||||
transformer=transformer,
|
||||
@ -96,9 +96,9 @@ class PluginAgentManager(BasePluginManager):
|
||||
"app_id": app_id,
|
||||
"message_id": message_id,
|
||||
"data": {
|
||||
"provider": agent_provider_id.provider_name,
|
||||
"strategy": agent_strategy,
|
||||
"agent_params": agent_params,
|
||||
"agent_strategy_provider": agent_provider_id.provider_name,
|
||||
"agent_strategy": agent_strategy,
|
||||
"agent_strategy_params": agent_params,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
|
@ -1,7 +1,7 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Sequence, cast
|
||||
|
||||
from core.agent.plugin_entities import AgentParameter
|
||||
from core.agent.plugin_entities import AgentStrategyParameter
|
||||
from core.plugin.manager.exc import PluginDaemonClientSideError
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
@ -90,7 +90,7 @@ class AgentNode(ToolNode):
|
||||
def _generate_parameters(
|
||||
self,
|
||||
*,
|
||||
agent_parameters: Sequence[AgentParameter],
|
||||
agent_parameters: Sequence[AgentStrategyParameter],
|
||||
variable_pool: VariablePool,
|
||||
node_data: AgentNodeData,
|
||||
for_log: bool = False,
|
||||
|
@ -246,7 +246,7 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
text += message.message.text + "\n"
|
||||
text += message.message.text
|
||||
yield RunStreamChunkEvent(
|
||||
chunk_content=message.message.text, from_variable_selector=[self.node_id, "text"]
|
||||
)
|
||||
|
@ -7,7 +7,7 @@ def get_plugin_agent_strategy(
|
||||
) -> PluginAgentStrategy:
|
||||
# TODO: use contexts to cache the agent provider
|
||||
manager = PluginAgentManager()
|
||||
agent_provider = manager.fetch_agent_provider(tenant_id, agent_strategy_provider_name)
|
||||
agent_provider = manager.fetch_agent_strategy_provider(tenant_id, agent_strategy_provider_name)
|
||||
for agent_strategy in agent_provider.declaration.strategies:
|
||||
if agent_strategy.identity.name == agent_strategy_name:
|
||||
return PluginAgentStrategy(tenant_id, plugin_unique_identifier, agent_strategy)
|
||||
|
@ -160,7 +160,7 @@ class AgentService:
|
||||
List agent providers
|
||||
"""
|
||||
manager = PluginAgentManager()
|
||||
return manager.fetch_agent_providers(tenant_id)
|
||||
return manager.fetch_agent_strategy_providers(tenant_id)
|
||||
|
||||
@classmethod
|
||||
def get_agent_provider(cls, user_id: str, tenant_id: str, provider_name: str):
|
||||
@ -168,4 +168,4 @@ class AgentService:
|
||||
Get agent provider
|
||||
"""
|
||||
manager = PluginAgentManager()
|
||||
return manager.fetch_agent_provider(tenant_id, provider_name)
|
||||
return manager.fetch_agent_strategy_provider(tenant_id, provider_name)
|
||||
|
Loading…
Reference in New Issue
Block a user