dify/api/core/tools/utils/configuration.py

245 lines
8.5 KiB
Python
Raw Normal View History

2024-08-30 18:11:38 +08:00
from collections.abc import Mapping
from copy import deepcopy
2024-05-15 12:25:04 +08:00
from typing import Any
from pydantic import BaseModel
2024-08-30 18:11:38 +08:00
from core.entities.provider_entities import BasicProviderConfig
from core.helper import encrypter
2024-03-08 20:31:13 +08:00
from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType
2024-09-20 02:25:14 +08:00
from core.tools.__base.tool import Tool
2024-03-08 15:22:55 +08:00
from core.tools.entities.tool_entities import (
2024-03-08 20:31:13 +08:00
ToolParameter,
2024-07-09 15:37:56 +08:00
ToolProviderType,
2024-03-08 15:22:55 +08:00
)
2024-08-30 21:25:58 +08:00
class ProviderConfigEncrypter(BaseModel):
tenant_id: str
2024-08-30 18:11:38 +08:00
config: Mapping[str, BasicProviderConfig]
provider_type: str
provider_identity: str
2024-08-30 21:25:58 +08:00
def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
"""
2024-08-30 21:25:58 +08:00
deep copy data
"""
2024-08-30 21:25:58 +08:00
return deepcopy(data)
2024-09-20 02:25:14 +08:00
def encrypt(self, data: dict[str, str]) -> dict[str, str]:
"""
2024-01-31 11:58:07 +08:00
encrypt tool credentials with tenant id
return a deep copy of credentials with encrypted values
"""
2024-08-30 21:25:58 +08:00
data = self._deep_copy(data)
# get fields need to be decrypted
2024-08-30 18:11:38 +08:00
fields = self.config
for field_name, field in fields.items():
2024-08-30 18:11:38 +08:00
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
2024-08-30 21:25:58 +08:00
if field_name in data:
encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name])
data[field_name] = encrypted
2024-08-30 21:25:58 +08:00
return data
2024-09-20 02:25:14 +08:00
def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
"""
mask tool credentials
return a deep copy of credentials with masked values
"""
2024-08-30 21:25:58 +08:00
data = self._deep_copy(data)
# get fields need to be decrypted
2024-08-30 18:11:38 +08:00
fields = self.config
for field_name, field in fields.items():
2024-08-30 18:11:38 +08:00
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
2024-08-30 21:25:58 +08:00
if field_name in data:
if len(data[field_name]) > 6:
data[field_name] = \
data[field_name][:2] + \
'*' * (len(data[field_name]) - 4) + \
data[field_name][-2:]
else:
2024-08-30 21:25:58 +08:00
data[field_name] = '*' * len(data[field_name])
2024-08-30 21:25:58 +08:00
return data
2024-09-20 02:25:14 +08:00
def decrypt(self, data: dict[str, str]) -> dict[str, str]:
"""
2024-01-31 11:58:07 +08:00
decrypt tool credentials with tenant id
return a deep copy of credentials with decrypted values
"""
cache = ToolProviderCredentialsCache(
2024-05-27 22:01:11 +08:00
tenant_id=self.tenant_id,
2024-08-30 18:11:38 +08:00
identity_id=f'{self.provider_type}.{self.provider_identity}',
cache_type=ToolProviderCredentialsCacheType.PROVIDER
)
cached_credentials = cache.get()
if cached_credentials:
return cached_credentials
2024-08-30 21:25:58 +08:00
data = self._deep_copy(data)
# get fields need to be decrypted
2024-08-30 18:11:38 +08:00
fields = self.config
for field_name, field in fields.items():
2024-08-30 18:11:38 +08:00
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
2024-08-30 21:25:58 +08:00
if field_name in data:
try:
2024-08-30 21:25:58 +08:00
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
except:
pass
2024-08-30 21:25:58 +08:00
cache.set(data)
return data
def delete_tool_credentials_cache(self):
cache = ToolProviderCredentialsCache(
2024-05-27 22:01:11 +08:00
tenant_id=self.tenant_id,
2024-08-30 18:11:38 +08:00
identity_id=f'{self.provider_type}.{self.provider_identity}',
cache_type=ToolProviderCredentialsCacheType.PROVIDER
)
cache.delete()
2024-03-08 15:22:55 +08:00
2024-03-08 20:31:13 +08:00
class ToolParameterConfigurationManager(BaseModel):
"""
Tool parameter configuration manager
"""
2024-03-08 20:31:13 +08:00
tenant_id: str
tool_runtime: Tool
provider_name: str
2024-07-09 15:37:56 +08:00
provider_type: ToolProviderType
2024-04-23 15:22:42 +08:00
identity_id: str
2024-03-08 20:31:13 +08:00
def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
deep copy parameters
"""
2024-04-23 15:22:42 +08:00
return deepcopy(parameters)
2024-03-08 20:31:13 +08:00
def _merge_parameters(self) -> list[ToolParameter]:
"""
merge parameters
"""
# get tool parameters
tool_parameters = self.tool_runtime.parameters or []
# get tool runtime parameters
runtime_parameters = self.tool_runtime.get_runtime_parameters() or []
# override parameters
current_parameters = tool_parameters.copy()
for runtime_parameter in runtime_parameters:
found = False
for index, parameter in enumerate(current_parameters):
if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form:
current_parameters[index] = runtime_parameter
found = True
break
if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
current_parameters.append(runtime_parameter)
return current_parameters
2024-03-08 20:31:13 +08:00
def mask_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
mask tool parameters
return a deep copy of parameters with masked values
"""
parameters = self._deep_copy(parameters)
# override parameters
current_parameters = self._merge_parameters()
for parameter in current_parameters:
if (
parameter.form == ToolParameter.ToolParameterForm.FORM
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
):
2024-03-08 20:31:13 +08:00
if parameter.name in parameters:
if len(parameters[parameter.name]) > 6:
parameters[parameter.name] = (
parameters[parameter.name][:2]
+ "*" * (len(parameters[parameter.name]) - 4)
+ parameters[parameter.name][-2:]
)
2024-03-08 20:31:13 +08:00
else:
parameters[parameter.name] = "*" * len(parameters[parameter.name])
2024-03-08 20:31:13 +08:00
return parameters
2024-03-08 20:31:13 +08:00
def encrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
encrypt tool parameters with tenant id
return a deep copy of parameters with encrypted values
"""
# override parameters
current_parameters = self._merge_parameters()
2024-04-23 15:22:42 +08:00
parameters = self._deep_copy(parameters)
2024-03-08 20:31:13 +08:00
for parameter in current_parameters:
if (
parameter.form == ToolParameter.ToolParameterForm.FORM
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
):
2024-03-08 20:31:13 +08:00
if parameter.name in parameters:
encrypted = encrypter.encrypt_token(self.tenant_id, parameters[parameter.name])
parameters[parameter.name] = encrypted
2024-03-08 20:31:13 +08:00
return parameters
2024-03-08 20:31:13 +08:00
def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
decrypt tool parameters with tenant id
return a deep copy of parameters with decrypted values
"""
cache = ToolParameterCache(
tenant_id=self.tenant_id,
2024-07-09 15:37:56 +08:00
provider=f'{self.provider_type.value}.{self.provider_name}',
2024-03-08 20:31:13 +08:00
tool_name=self.tool_runtime.identity.name,
2024-04-23 15:22:42 +08:00
cache_type=ToolParameterCacheType.PARAMETER,
identity_id=self.identity_id,
2024-03-08 20:31:13 +08:00
)
cached_parameters = cache.get()
if cached_parameters:
return cached_parameters
# override parameters
current_parameters = self._merge_parameters()
has_secret_input = False
for parameter in current_parameters:
if (
parameter.form == ToolParameter.ToolParameterForm.FORM
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
):
2024-03-08 20:31:13 +08:00
if parameter.name in parameters:
try:
has_secret_input = True
parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name])
except:
pass
2024-03-08 20:31:13 +08:00
if has_secret_input:
cache.set(parameters)
return parameters
2024-03-08 20:31:13 +08:00
def delete_tool_parameters_cache(self):
cache = ToolParameterCache(
tenant_id=self.tenant_id,
2024-07-09 15:37:56 +08:00
provider=f'{self.provider_type.value}.{self.provider_name}',
2024-03-08 20:31:13 +08:00
tool_name=self.tool_runtime.identity.name,
2024-04-23 15:22:42 +08:00
cache_type=ToolParameterCacheType.PARAMETER,
identity_id=self.identity_id,
2024-03-08 20:31:13 +08:00
)
cache.delete()