refactor: credentials schemas to array
This commit is contained in:
parent
c9f80b46a1
commit
6dfc31a542
@ -159,3 +159,6 @@ class ProviderConfig(BasicProviderConfig):
|
||||
help: Optional[I18nObject] = None
|
||||
url: Optional[str] = None
|
||||
placeholder: Optional[I18nObject] = None
|
||||
|
||||
def to_basic_provider_config(self) -> BasicProviderConfig:
|
||||
return BasicProviderConfig(type=self.type, name=self.name)
|
||||
|
@ -1,6 +1,3 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.plugin.entities.request import RequestInvokeEncrypt
|
||||
from core.tools.utils.configuration import ProviderConfigEncrypter
|
||||
from models.account import Tenant
|
||||
@ -11,7 +8,7 @@ class PluginEncrypter:
|
||||
def invoke_encrypt(cls, tenant: Tenant, payload: RequestInvokeEncrypt) -> dict:
|
||||
encrypter = ProviderConfigEncrypter(
|
||||
tenant_id=tenant.id,
|
||||
config=payload.data,
|
||||
config=payload.config,
|
||||
provider_type=payload.namespace,
|
||||
provider_identity=payload.identity,
|
||||
)
|
||||
|
@ -1,4 +1,3 @@
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
@ -12,7 +11,7 @@ class EndpointDeclaration(BaseModel):
|
||||
declaration of an endpoint
|
||||
"""
|
||||
|
||||
settings: Mapping[str, ProviderConfig] = Field(default_factory=Mapping)
|
||||
settings: list[ProviderConfig] = Field(default_factory=list)
|
||||
|
||||
|
||||
class EndpointEntity(BasePluginEntity):
|
||||
|
@ -1,4 +1,3 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
@ -181,4 +180,4 @@ class RequestInvokeEncrypt(BaseModel):
|
||||
namespace: Literal["endpoint"]
|
||||
identity: str
|
||||
data: dict = Field(default_factory=dict)
|
||||
config: Mapping[str, BasicProviderConfig] = Field(default_factory=Mapping)
|
||||
config: list[BasicProviderConfig] = Field(default_factory=list)
|
||||
|
@ -1,4 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
@ -16,13 +17,13 @@ class ToolProviderController(ABC):
|
||||
def __init__(self, entity: ToolProviderEntity) -> None:
|
||||
self.entity = entity
|
||||
|
||||
def get_credentials_schema(self) -> dict[str, ProviderConfig]:
|
||||
def get_credentials_schema(self) -> list[ProviderConfig]:
|
||||
"""
|
||||
returns the credentials schema of the provider
|
||||
|
||||
:return: the credentials schema
|
||||
"""
|
||||
return self.entity.credentials_schema.copy()
|
||||
return deepcopy(self.entity.credentials_schema)
|
||||
|
||||
@abstractmethod
|
||||
def get_tool(self, tool_name: str) -> Tool:
|
||||
@ -48,10 +49,13 @@ class ToolProviderController(ABC):
|
||||
|
||||
:param credentials: the credentials of the tool
|
||||
"""
|
||||
credentials_schema = self.entity.credentials_schema
|
||||
credentials_schema = dict[str, ProviderConfig]()
|
||||
if credentials_schema is None:
|
||||
return
|
||||
|
||||
for credential in self.entity.credentials_schema:
|
||||
credentials_schema[credential.name] = credential
|
||||
|
||||
credentials_need_to_validate: dict[str, ProviderConfig] = {}
|
||||
for credential_name in credentials_schema:
|
||||
credentials_need_to_validate[credential_name] = credentials_schema[credential_name]
|
||||
|
@ -34,10 +34,14 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
for credential_name in provider_yaml["credentials_for_provider"]:
|
||||
provider_yaml["credentials_for_provider"][credential_name]["name"] = credential_name
|
||||
|
||||
credentials_schema = []
|
||||
for credential in provider_yaml.get("credentials_for_provider", {}):
|
||||
credentials_schema.append(credential)
|
||||
|
||||
super().__init__(
|
||||
entity=ToolProviderEntity(
|
||||
identity=provider_yaml["identity"],
|
||||
credentials_schema=provider_yaml.get("credentials_for_provider", {}) or {},
|
||||
credentials_schema=credentials_schema,
|
||||
),
|
||||
)
|
||||
|
||||
@ -84,14 +88,14 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
self.tools = tools
|
||||
return tools
|
||||
|
||||
def get_credentials_schema(self) -> dict[str, ProviderConfig]:
|
||||
def get_credentials_schema(self) -> list[ProviderConfig]:
|
||||
"""
|
||||
returns the credentials schema of the provider
|
||||
|
||||
:return: the credentials schema
|
||||
"""
|
||||
if not self.entity.credentials_schema:
|
||||
return {}
|
||||
return []
|
||||
|
||||
return self.entity.credentials_schema.copy()
|
||||
|
||||
|
@ -12,4 +12,3 @@ identity:
|
||||
icon: icon.svg
|
||||
tags:
|
||||
- productivity
|
||||
credentials_for_provider:
|
||||
|
@ -12,4 +12,3 @@ identity:
|
||||
icon: icon.svg
|
||||
tags:
|
||||
- utilities
|
||||
credentials_for_provider:
|
||||
|
@ -28,8 +28,8 @@ class ApiToolProviderController(ToolProviderController):
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, db_provider: ApiToolProvider, auth_type: ApiProviderAuthType):
|
||||
credentials_schema = {
|
||||
"auth_type": ProviderConfig(
|
||||
credentials_schema = [
|
||||
ProviderConfig(
|
||||
name="auth_type",
|
||||
required=True,
|
||||
type=ProviderConfig.Type.SELECT,
|
||||
@ -40,24 +40,24 @@ class ApiToolProviderController(ToolProviderController):
|
||||
default="none",
|
||||
help=I18nObject(en_US="The auth type of the api provider", zh_Hans="api provider 的认证类型"),
|
||||
)
|
||||
}
|
||||
]
|
||||
if auth_type == ApiProviderAuthType.API_KEY:
|
||||
credentials_schema = {
|
||||
**credentials_schema,
|
||||
"api_key_header": ProviderConfig(
|
||||
credentials_schema = [
|
||||
*credentials_schema,
|
||||
ProviderConfig(
|
||||
name="api_key_header",
|
||||
required=False,
|
||||
default="api_key",
|
||||
type=ProviderConfig.Type.TEXT_INPUT,
|
||||
help=I18nObject(en_US="The header name of the api key", zh_Hans="携带 api key 的 header 名称"),
|
||||
),
|
||||
"api_key_value": ProviderConfig(
|
||||
ProviderConfig(
|
||||
name="api_key_value",
|
||||
required=True,
|
||||
type=ProviderConfig.Type.SECRET_INPUT,
|
||||
help=I18nObject(en_US="The api key", zh_Hans="api key的值"),
|
||||
),
|
||||
"api_key_header_prefix": ProviderConfig(
|
||||
ProviderConfig(
|
||||
name="api_key_header_prefix",
|
||||
required=False,
|
||||
default="basic",
|
||||
@ -69,7 +69,7 @@ class ApiToolProviderController(ToolProviderController):
|
||||
ProviderConfig.Option(value="custom", label=I18nObject(en_US="Custom", zh_Hans="Custom")),
|
||||
],
|
||||
),
|
||||
}
|
||||
]
|
||||
elif auth_type == ApiProviderAuthType.NONE:
|
||||
pass
|
||||
|
||||
|
@ -2,7 +2,6 @@ from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.__base.tool import ToolParameter
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
@ -62,7 +61,3 @@ class ToolProviderApiEntity(BaseModel):
|
||||
"tools": tools,
|
||||
"labels": self.labels,
|
||||
}
|
||||
|
||||
|
||||
class ToolProviderCredentialsApiEntity(BaseModel):
|
||||
credentials: dict[str, ProviderConfig]
|
||||
|
@ -312,7 +312,7 @@ class ToolEntity(BaseModel):
|
||||
|
||||
class ToolProviderEntity(BaseModel):
|
||||
identity: ToolProviderIdentity
|
||||
credentials_schema: dict[str, ProviderConfig] = Field(default_factory=dict)
|
||||
credentials_schema: list[ProviderConfig] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ToolProviderEntityWithPlugin(ToolProviderEntity):
|
||||
|
@ -160,7 +160,7 @@ class ToolManager:
|
||||
credentials = builtin_provider.credentials
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=provider_controller.get_credentials_schema(),
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
)
|
||||
@ -186,7 +186,7 @@ class ToolManager:
|
||||
# decrypt the credentials
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=api_provider.get_credentials_schema(),
|
||||
config=[x.to_basic_provider_config() for x in api_provider.get_credentials_schema()],
|
||||
provider_type=api_provider.provider_type.value,
|
||||
provider_identity=api_provider.entity.identity.name,
|
||||
)
|
||||
@ -643,7 +643,7 @@ class ToolManager:
|
||||
# init tool configuration
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=controller.get_credentials_schema(),
|
||||
config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()],
|
||||
provider_type=controller.provider_type.value,
|
||||
provider_identity=controller.entity.identity.name,
|
||||
)
|
||||
|
@ -1,4 +1,3 @@
|
||||
from collections.abc import Mapping
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
@ -17,7 +16,7 @@ from core.tools.entities.tool_entities import (
|
||||
|
||||
class ProviderConfigEncrypter(BaseModel):
|
||||
tenant_id: str
|
||||
config: Mapping[str, BasicProviderConfig]
|
||||
config: list[BasicProviderConfig]
|
||||
provider_type: str
|
||||
provider_identity: str
|
||||
|
||||
@ -36,7 +35,10 @@ class ProviderConfigEncrypter(BaseModel):
|
||||
data = self._deep_copy(data)
|
||||
|
||||
# get fields need to be decrypted
|
||||
fields = self.config
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
@ -54,7 +56,10 @@ class ProviderConfigEncrypter(BaseModel):
|
||||
data = self._deep_copy(data)
|
||||
|
||||
# get fields need to be decrypted
|
||||
fields = self.config
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
@ -83,7 +88,10 @@ class ProviderConfigEncrypter(BaseModel):
|
||||
return cached_credentials
|
||||
data = self._deep_copy(data)
|
||||
# get fields need to be decrypted
|
||||
fields = self.config
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
|
@ -35,7 +35,7 @@ class BuiltinToolManageService:
|
||||
|
||||
tool_provider_configurations = ProviderConfigEncrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=provider_controller.get_credentials_schema(),
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
)
|
||||
@ -78,7 +78,7 @@ class BuiltinToolManageService:
|
||||
:return: the list of tool providers
|
||||
"""
|
||||
provider = ToolManager.get_builtin_provider(provider_name, tenant_id)
|
||||
return jsonable_encoder([v for _, v in (provider.get_credentials_schema() or {}).items()])
|
||||
return jsonable_encoder(provider.get_credentials_schema())
|
||||
|
||||
@staticmethod
|
||||
def update_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str, credentials: dict):
|
||||
@ -102,7 +102,7 @@ class BuiltinToolManageService:
|
||||
raise ValueError(f"provider {provider_name} does not need credentials")
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=provider_controller.get_credentials_schema(),
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
)
|
||||
@ -164,7 +164,7 @@ class BuiltinToolManageService:
|
||||
provider_controller = ToolManager.get_builtin_provider(provider_obj.provider, tenant_id)
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=provider_controller.get_credentials_schema(),
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
)
|
||||
@ -196,7 +196,7 @@ class BuiltinToolManageService:
|
||||
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=provider_controller.get_credentials_schema(),
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
)
|
||||
|
@ -85,7 +85,8 @@ class ToolTransformService:
|
||||
)
|
||||
|
||||
# get credentials schema
|
||||
schema = provider_controller.get_credentials_schema()
|
||||
schema = {x.to_basic_provider_config().name: x for x in provider_controller.get_credentials_schema()}
|
||||
|
||||
for name, value in schema.items():
|
||||
if result.masked_credentials:
|
||||
result.masked_credentials[name] = ""
|
||||
@ -103,7 +104,7 @@ class ToolTransformService:
|
||||
# init tool configuration
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
tenant_id=db_provider.tenant_id,
|
||||
config=provider_controller.get_credentials_schema(),
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
)
|
||||
@ -208,7 +209,7 @@ class ToolTransformService:
|
||||
# init tool configuration
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
tenant_id=db_provider.tenant_id,
|
||||
config=provider_controller.get_credentials_schema(),
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user