refactor: credentials schemas to array

This commit is contained in:
Yeuoly 2024-09-30 17:39:13 +08:00
parent c9f80b46a1
commit 6dfc31a542
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61
15 changed files with 55 additions and 47 deletions

View File

@ -159,3 +159,6 @@ class ProviderConfig(BasicProviderConfig):
help: Optional[I18nObject] = None
url: Optional[str] = None
placeholder: Optional[I18nObject] = None
def to_basic_provider_config(self) -> BasicProviderConfig:
return BasicProviderConfig(type=self.type, name=self.name)

View File

@ -1,6 +1,3 @@
from collections.abc import Mapping
from typing import Any
from core.plugin.entities.request import RequestInvokeEncrypt
from core.tools.utils.configuration import ProviderConfigEncrypter
from models.account import Tenant
@ -11,7 +8,7 @@ class PluginEncrypter:
def invoke_encrypt(cls, tenant: Tenant, payload: RequestInvokeEncrypt) -> dict:
encrypter = ProviderConfigEncrypter(
tenant_id=tenant.id,
config=payload.data,
config=payload.config,
provider_type=payload.namespace,
provider_identity=payload.identity,
)

View File

@ -1,4 +1,3 @@
from collections.abc import Mapping
from datetime import datetime
from pydantic import BaseModel, Field
@ -12,7 +11,7 @@ class EndpointDeclaration(BaseModel):
declaration of an endpoint
"""
settings: Mapping[str, ProviderConfig] = Field(default_factory=Mapping)
settings: list[ProviderConfig] = Field(default_factory=list)
class EndpointEntity(BasePluginEntity):

View File

@ -1,4 +1,3 @@
from collections.abc import Mapping
from typing import Any, Literal, Optional
from pydantic import BaseModel, ConfigDict, Field, field_validator
@ -181,4 +180,4 @@ class RequestInvokeEncrypt(BaseModel):
namespace: Literal["endpoint"]
identity: str
data: dict = Field(default_factory=dict)
config: Mapping[str, BasicProviderConfig] = Field(default_factory=Mapping)
config: list[BasicProviderConfig] = Field(default_factory=list)

View File

@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from copy import deepcopy
from typing import Any
from core.entities.provider_entities import ProviderConfig
@ -16,13 +17,13 @@ class ToolProviderController(ABC):
def __init__(self, entity: ToolProviderEntity) -> None:
self.entity = entity
def get_credentials_schema(self) -> dict[str, ProviderConfig]:
def get_credentials_schema(self) -> list[ProviderConfig]:
"""
returns the credentials schema of the provider
:return: the credentials schema
"""
return self.entity.credentials_schema.copy()
return deepcopy(self.entity.credentials_schema)
@abstractmethod
def get_tool(self, tool_name: str) -> Tool:
@ -48,10 +49,13 @@ class ToolProviderController(ABC):
:param credentials: the credentials of the tool
"""
credentials_schema = self.entity.credentials_schema
credentials_schema = dict[str, ProviderConfig]()
if credentials_schema is None:
return
for credential in self.entity.credentials_schema:
credentials_schema[credential.name] = credential
credentials_need_to_validate: dict[str, ProviderConfig] = {}
for credential_name in credentials_schema:
credentials_need_to_validate[credential_name] = credentials_schema[credential_name]

View File

@ -34,10 +34,14 @@ class BuiltinToolProviderController(ToolProviderController):
for credential_name in provider_yaml["credentials_for_provider"]:
provider_yaml["credentials_for_provider"][credential_name]["name"] = credential_name
credentials_schema = []
for credential in provider_yaml.get("credentials_for_provider", {}):
credentials_schema.append(credential)
super().__init__(
entity=ToolProviderEntity(
identity=provider_yaml["identity"],
credentials_schema=provider_yaml.get("credentials_for_provider", {}) or {},
credentials_schema=credentials_schema,
),
)
@ -84,14 +88,14 @@ class BuiltinToolProviderController(ToolProviderController):
self.tools = tools
return tools
def get_credentials_schema(self) -> dict[str, ProviderConfig]:
def get_credentials_schema(self) -> list[ProviderConfig]:
"""
returns the credentials schema of the provider
:return: the credentials schema
"""
if not self.entity.credentials_schema:
return {}
return []
return self.entity.credentials_schema.copy()

View File

@ -12,4 +12,3 @@ identity:
icon: icon.svg
tags:
- productivity
credentials_for_provider:

View File

@ -12,4 +12,3 @@ identity:
icon: icon.svg
tags:
- utilities
credentials_for_provider:

View File

@ -28,8 +28,8 @@ class ApiToolProviderController(ToolProviderController):
@classmethod
def from_db(cls, db_provider: ApiToolProvider, auth_type: ApiProviderAuthType):
credentials_schema = {
"auth_type": ProviderConfig(
credentials_schema = [
ProviderConfig(
name="auth_type",
required=True,
type=ProviderConfig.Type.SELECT,
@ -40,24 +40,24 @@ class ApiToolProviderController(ToolProviderController):
default="none",
help=I18nObject(en_US="The auth type of the api provider", zh_Hans="api provider 的认证类型"),
)
}
]
if auth_type == ApiProviderAuthType.API_KEY:
credentials_schema = {
**credentials_schema,
"api_key_header": ProviderConfig(
credentials_schema = [
*credentials_schema,
ProviderConfig(
name="api_key_header",
required=False,
default="api_key",
type=ProviderConfig.Type.TEXT_INPUT,
help=I18nObject(en_US="The header name of the api key", zh_Hans="携带 api key 的 header 名称"),
),
"api_key_value": ProviderConfig(
ProviderConfig(
name="api_key_value",
required=True,
type=ProviderConfig.Type.SECRET_INPUT,
help=I18nObject(en_US="The api key", zh_Hans="api key的值"),
),
"api_key_header_prefix": ProviderConfig(
ProviderConfig(
name="api_key_header_prefix",
required=False,
default="basic",
@ -69,7 +69,7 @@ class ApiToolProviderController(ToolProviderController):
ProviderConfig.Option(value="custom", label=I18nObject(en_US="Custom", zh_Hans="Custom")),
],
),
}
]
elif auth_type == ApiProviderAuthType.NONE:
pass

View File

@ -2,7 +2,6 @@ from typing import Literal, Optional
from pydantic import BaseModel, Field
from core.entities.provider_entities import ProviderConfig
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.__base.tool import ToolParameter
from core.tools.entities.common_entities import I18nObject
@ -62,7 +61,3 @@ class ToolProviderApiEntity(BaseModel):
"tools": tools,
"labels": self.labels,
}
class ToolProviderCredentialsApiEntity(BaseModel):
credentials: dict[str, ProviderConfig]

View File

@ -312,7 +312,7 @@ class ToolEntity(BaseModel):
class ToolProviderEntity(BaseModel):
identity: ToolProviderIdentity
credentials_schema: dict[str, ProviderConfig] = Field(default_factory=dict)
credentials_schema: list[ProviderConfig] = Field(default_factory=list)
class ToolProviderEntityWithPlugin(ToolProviderEntity):

View File

@ -160,7 +160,7 @@ class ToolManager:
credentials = builtin_provider.credentials
tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=provider_controller.get_credentials_schema(),
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
@ -186,7 +186,7 @@ class ToolManager:
# decrypt the credentials
tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=api_provider.get_credentials_schema(),
config=[x.to_basic_provider_config() for x in api_provider.get_credentials_schema()],
provider_type=api_provider.provider_type.value,
provider_identity=api_provider.entity.identity.name,
)
@ -643,7 +643,7 @@ class ToolManager:
# init tool configuration
tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=controller.get_credentials_schema(),
config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()],
provider_type=controller.provider_type.value,
provider_identity=controller.entity.identity.name,
)

View File

@ -1,4 +1,3 @@
from collections.abc import Mapping
from copy import deepcopy
from typing import Any
@ -17,7 +16,7 @@ from core.tools.entities.tool_entities import (
class ProviderConfigEncrypter(BaseModel):
tenant_id: str
config: Mapping[str, BasicProviderConfig]
config: list[BasicProviderConfig]
provider_type: str
provider_identity: str
@ -36,7 +35,10 @@ class ProviderConfigEncrypter(BaseModel):
data = self._deep_copy(data)
# get fields need to be decrypted
fields = self.config
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
@ -54,7 +56,10 @@ class ProviderConfigEncrypter(BaseModel):
data = self._deep_copy(data)
# get fields need to be decrypted
fields = self.config
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
@ -83,7 +88,10 @@ class ProviderConfigEncrypter(BaseModel):
return cached_credentials
data = self._deep_copy(data)
# get fields need to be decrypted
fields = self.config
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:

View File

@ -35,7 +35,7 @@ class BuiltinToolManageService:
tool_provider_configurations = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=provider_controller.get_credentials_schema(),
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
@ -78,7 +78,7 @@ class BuiltinToolManageService:
:return: the list of tool providers
"""
provider = ToolManager.get_builtin_provider(provider_name, tenant_id)
return jsonable_encoder([v for _, v in (provider.get_credentials_schema() or {}).items()])
return jsonable_encoder(provider.get_credentials_schema())
@staticmethod
def update_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str, credentials: dict):
@ -102,7 +102,7 @@ class BuiltinToolManageService:
raise ValueError(f"provider {provider_name} does not need credentials")
tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=provider_controller.get_credentials_schema(),
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
@ -164,7 +164,7 @@ class BuiltinToolManageService:
provider_controller = ToolManager.get_builtin_provider(provider_obj.provider, tenant_id)
tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=provider_controller.get_credentials_schema(),
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
@ -196,7 +196,7 @@ class BuiltinToolManageService:
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=provider_controller.get_credentials_schema(),
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)

View File

@ -85,7 +85,8 @@ class ToolTransformService:
)
# get credentials schema
schema = provider_controller.get_credentials_schema()
schema = {x.to_basic_provider_config().name: x for x in provider_controller.get_credentials_schema()}
for name, value in schema.items():
if result.masked_credentials:
result.masked_credentials[name] = ""
@ -103,7 +104,7 @@ class ToolTransformService:
# init tool configuration
tool_configuration = ProviderConfigEncrypter(
tenant_id=db_provider.tenant_id,
config=provider_controller.get_credentials_schema(),
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
@ -208,7 +209,7 @@ class ToolTransformService:
# init tool configuration
tool_configuration = ProviderConfigEncrypter(
tenant_id=db_provider.tenant_id,
config=provider_controller.get_credentials_schema(),
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)