2024-01-23 19:58:23 +08:00
|
|
|
from abc import ABC, abstractmethod
|
2024-08-30 14:23:14 +08:00
|
|
|
from typing import Any
|
2024-01-23 19:58:23 +08:00
|
|
|
|
2024-08-30 18:11:38 +08:00
|
|
|
from pydantic import BaseModel, ConfigDict, Field
|
2024-02-06 13:21:13 +08:00
|
|
|
|
2024-08-30 14:23:14 +08:00
|
|
|
from core.entities.provider_entities import ProviderConfig
|
2024-02-06 13:21:13 +08:00
|
|
|
from core.tools.entities.tool_entities import (
|
|
|
|
ToolProviderIdentity,
|
|
|
|
ToolProviderType,
|
|
|
|
)
|
2024-08-30 14:23:14 +08:00
|
|
|
from core.tools.errors import ToolProviderCredentialValidationError
|
2024-02-01 18:11:57 +08:00
|
|
|
from core.tools.tool.tool import Tool
|
2024-01-23 19:58:23 +08:00
|
|
|
|
|
|
|
|
|
|
|
class ToolProviderController(BaseModel, ABC):
|
2024-08-30 14:23:14 +08:00
|
|
|
identity: ToolProviderIdentity
|
|
|
|
tools: list[Tool] = Field(default_factory=list)
|
|
|
|
credentials_schema: dict[str, ProviderConfig] = Field(default_factory=dict)
|
2024-01-23 19:58:23 +08:00
|
|
|
|
2024-08-30 18:11:38 +08:00
|
|
|
model_config = ConfigDict(validate_assignment=True)
|
|
|
|
|
2024-08-30 14:23:14 +08:00
|
|
|
def get_credentials_schema(self) -> dict[str, ProviderConfig]:
|
2024-01-23 19:58:23 +08:00
|
|
|
"""
|
|
|
|
returns the credentials schema of the provider
|
|
|
|
|
|
|
|
:return: the credentials schema
|
|
|
|
"""
|
|
|
|
return self.credentials_schema.copy()
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def get_tool(self, tool_name: str) -> Tool:
|
|
|
|
"""
|
|
|
|
returns a tool that the provider can provide
|
|
|
|
|
|
|
|
:return: tool
|
|
|
|
"""
|
|
|
|
pass
|
|
|
|
|
|
|
|
@property
|
2024-05-27 22:01:11 +08:00
|
|
|
def provider_type(self) -> ToolProviderType:
|
2024-01-23 19:58:23 +08:00
|
|
|
"""
|
|
|
|
returns the type of the provider
|
|
|
|
|
|
|
|
:return: type of the provider
|
|
|
|
"""
|
|
|
|
return ToolProviderType.BUILT_IN
|
|
|
|
|
2024-02-09 15:21:33 +08:00
|
|
|
def validate_credentials_format(self, credentials: dict[str, Any]) -> None:
|
2024-01-23 19:58:23 +08:00
|
|
|
"""
|
|
|
|
validate the format of the credentials of the provider and set the default value if needed
|
|
|
|
|
|
|
|
:param credentials: the credentials of the tool
|
|
|
|
"""
|
|
|
|
credentials_schema = self.credentials_schema
|
|
|
|
if credentials_schema is None:
|
|
|
|
return
|
|
|
|
|
2024-08-30 14:23:14 +08:00
|
|
|
credentials_need_to_validate: dict[str, ProviderConfig] = {}
|
2024-01-23 19:58:23 +08:00
|
|
|
for credential_name in credentials_schema:
|
|
|
|
credentials_need_to_validate[credential_name] = credentials_schema[credential_name]
|
|
|
|
|
|
|
|
for credential_name in credentials:
|
|
|
|
if credential_name not in credentials_need_to_validate:
|
|
|
|
raise ToolProviderCredentialValidationError(f'credential {credential_name} not found in provider {self.identity.name}')
|
|
|
|
|
|
|
|
# check type
|
|
|
|
credential_schema = credentials_need_to_validate[credential_name]
|
2024-08-30 14:23:14 +08:00
|
|
|
if credential_schema == ProviderConfig.Type.SECRET_INPUT or \
|
|
|
|
credential_schema == ProviderConfig.Type.TEXT_INPUT:
|
2024-01-23 19:58:23 +08:00
|
|
|
if not isinstance(credentials[credential_name], str):
|
|
|
|
raise ToolProviderCredentialValidationError(f'credential {credential_name} should be string')
|
|
|
|
|
2024-08-30 14:23:14 +08:00
|
|
|
elif credential_schema.type == ProviderConfig.Type.SELECT:
|
2024-01-23 19:58:23 +08:00
|
|
|
if not isinstance(credentials[credential_name], str):
|
|
|
|
raise ToolProviderCredentialValidationError(f'credential {credential_name} should be string')
|
|
|
|
|
|
|
|
options = credential_schema.options
|
|
|
|
if not isinstance(options, list):
|
|
|
|
raise ToolProviderCredentialValidationError(f'credential {credential_name} options should be list')
|
|
|
|
|
|
|
|
if credentials[credential_name] not in [x.value for x in options]:
|
|
|
|
raise ToolProviderCredentialValidationError(f'credential {credential_name} should be one of {options}')
|
|
|
|
|
|
|
|
credentials_need_to_validate.pop(credential_name)
|
|
|
|
|
|
|
|
for credential_name in credentials_need_to_validate:
|
|
|
|
credential_schema = credentials_need_to_validate[credential_name]
|
|
|
|
if credential_schema.required:
|
|
|
|
raise ToolProviderCredentialValidationError(f'credential {credential_name} is required')
|
|
|
|
|
|
|
|
# the credential is not set currently, set the default value if needed
|
|
|
|
if credential_schema.default is not None:
|
|
|
|
default_value = credential_schema.default
|
|
|
|
# parse default value into the correct type
|
2024-08-30 14:23:14 +08:00
|
|
|
if credential_schema.type == ProviderConfig.Type.SECRET_INPUT or \
|
|
|
|
credential_schema.type == ProviderConfig.Type.TEXT_INPUT or \
|
|
|
|
credential_schema.type == ProviderConfig.Type.SELECT:
|
2024-01-23 19:58:23 +08:00
|
|
|
default_value = str(default_value)
|
|
|
|
|
|
|
|
credentials[credential_name] = default_value
|
2024-05-27 22:01:11 +08:00
|
|
|
|