dify/api/core/tools/tool/tool.py

347 lines
11 KiB
Python
Raw Normal View History

2024-02-01 18:11:57 +08:00
from abc import ABC, abstractmethod
2024-08-29 14:09:47 +08:00
from collections.abc import Generator
from copy import deepcopy
from enum import Enum
from typing import TYPE_CHECKING, Any, Optional, Union
2024-08-29 14:06:10 +08:00
from pydantic import BaseModel, ConfigDict, Field, field_validator
from pydantic_core.core_schema import ValidationInfo
2024-05-27 22:01:11 +08:00
from core.app.entities.app_invoke_entities import InvokeFrom
from core.tools.entities.tool_entities import (
ToolDescription,
ToolIdentity,
2024-05-27 22:01:11 +08:00
ToolInvokeFrom,
ToolInvokeMessage,
ToolParameter,
ToolProviderType,
ToolRuntimeImageVariable,
ToolRuntimeVariable,
ToolRuntimeVariablePool,
)
2024-02-01 18:11:57 +08:00
from core.tools.tool_file_manager import ToolFileManager
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
2024-02-01 18:11:57 +08:00
if TYPE_CHECKING:
from core.file.file_obj import FileVar
class Tool(BaseModel, ABC):
2024-08-29 14:06:10 +08:00
identity: ToolIdentity
parameters: list[ToolParameter] = Field(default_factory=list)
description: Optional[ToolDescription] = None
is_team_authorization: bool = False
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
@field_validator('parameters', mode='before')
@classmethod
def set_parameters(cls, v, validation_info: ValidationInfo) -> list[ToolParameter]:
2024-05-27 22:01:11 +08:00
return v or []
class Runtime(BaseModel):
"""
2024-07-29 16:40:04 +08:00
Meta data of a tool call processing
"""
2024-07-29 16:40:04 +08:00
def __init__(self, **data: Any):
super().__init__(**data)
if not self.runtime_parameters:
self.runtime_parameters = {}
tenant_id: Optional[str] = None
tool_id: Optional[str] = None
invoke_from: Optional[InvokeFrom] = None
tool_invoke_from: Optional[ToolInvokeFrom] = None
credentials: Optional[dict[str, Any]] = None
2024-08-30 14:23:14 +08:00
runtime_parameters: dict[str, Any] = Field(default_factory=dict)
runtime: Optional[Runtime] = None
variables: Optional[ToolRuntimeVariablePool] = None
def __init__(self, **data: Any):
super().__init__(**data)
class VARIABLE_KEY(Enum):
IMAGE = 'image'
2024-05-27 22:01:11 +08:00
def fork_tool_runtime(self, runtime: dict[str, Any]) -> 'Tool':
"""
2024-07-29 16:40:04 +08:00
fork a new tool with meta data
2024-07-29 16:40:04 +08:00
:param meta: the meta data of a tool call processing, tenant_id is required
:return: the new tool
"""
return self.__class__(
identity=self.identity.model_copy() if self.identity else None,
parameters=self.parameters.copy() if self.parameters else None,
description=self.description.model_copy() if self.description else None,
2024-05-27 22:01:11 +08:00
runtime=Tool.Runtime(**runtime),
)
2024-07-29 16:40:04 +08:00
@abstractmethod
def tool_provider_type(self) -> ToolProviderType:
"""
2024-07-29 16:40:04 +08:00
get the tool provider type
2024-07-29 16:40:04 +08:00
:return: the tool provider type
"""
2024-07-29 16:40:04 +08:00
def load_variables(self, variables: ToolRuntimeVariablePool):
"""
2024-07-29 16:40:04 +08:00
load variables from database
2024-07-29 16:40:04 +08:00
:param conversation_id: the conversation id
"""
self.variables = variables
def set_image_variable(self, variable_name: str, image_key: str) -> None:
"""
2024-07-29 16:40:04 +08:00
set an image variable
"""
if not self.variables:
return
2024-07-29 16:40:04 +08:00
self.variables.set_file(self.identity.name, variable_name, image_key)
def set_text_variable(self, variable_name: str, text: str) -> None:
"""
2024-07-29 16:40:04 +08:00
set a text variable
"""
if not self.variables:
return
2024-07-29 16:40:04 +08:00
self.variables.set_text(self.identity.name, variable_name, text)
2024-07-29 16:40:04 +08:00
def get_variable(self, name: Union[str, Enum]) -> Optional[ToolRuntimeVariable]:
"""
2024-07-29 16:40:04 +08:00
get a variable
2024-07-29 16:40:04 +08:00
:param name: the name of the variable
:return: the variable
"""
if not self.variables:
return None
2024-07-29 16:40:04 +08:00
if isinstance(name, Enum):
name = name.value
2024-07-29 16:40:04 +08:00
for variable in self.variables.pool:
if variable.name == name:
return variable
2024-07-29 16:40:04 +08:00
return None
def get_default_image_variable(self) -> Optional[ToolRuntimeVariable]:
"""
2024-07-29 16:40:04 +08:00
get the default image variable
2024-07-29 16:40:04 +08:00
:return: the image variable
"""
if not self.variables:
return None
2024-07-29 16:40:04 +08:00
return self.get_variable(self.VARIABLE_KEY.IMAGE)
2024-07-29 16:40:04 +08:00
def get_variable_file(self, name: Union[str, Enum]) -> Optional[bytes]:
"""
2024-07-29 16:40:04 +08:00
get a variable file
2024-07-29 16:40:04 +08:00
:param name: the name of the variable
:return: the variable file
"""
variable = self.get_variable(name)
if not variable:
return None
2024-07-29 16:40:04 +08:00
if not isinstance(variable, ToolRuntimeImageVariable):
return None
message_file_id = variable.value
# get file binary
file_binary = ToolFileManager.get_file_binary_by_message_file_id(message_file_id)
if not file_binary:
return None
2024-07-29 16:40:04 +08:00
return file_binary[0]
2024-07-29 16:40:04 +08:00
def list_variables(self) -> list[ToolRuntimeVariable]:
"""
2024-07-29 16:40:04 +08:00
list all variables
2024-07-29 16:40:04 +08:00
:return: the variables
"""
if not self.variables:
return []
2024-07-29 16:40:04 +08:00
return self.variables.pool
2024-07-29 16:40:04 +08:00
def list_default_image_variables(self) -> list[ToolRuntimeVariable]:
"""
2024-07-29 16:40:04 +08:00
list all image variables
2024-07-29 16:40:04 +08:00
:return: the image variables
"""
if not self.variables:
return []
2024-07-29 16:40:04 +08:00
result = []
2024-07-29 16:40:04 +08:00
for variable in self.variables.pool:
if variable.name.startswith(self.VARIABLE_KEY.IMAGE.value):
result.append(variable)
return result
2024-08-29 14:06:10 +08:00
def invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage]:
if self.runtime and self.runtime.runtime_parameters:
2024-01-31 11:58:07 +08:00
tool_parameters.update(self.runtime.runtime_parameters)
# try parse tool parameters into the correct type
tool_parameters = self._transform_tool_parameters_type(tool_parameters)
result = self._invoke(
user_id=user_id,
tool_parameters=tool_parameters,
)
return result
2024-08-29 14:06:10 +08:00
def _transform_tool_parameters_type(self, tool_parameters: dict[str, Any]) -> dict[str, Any]:
"""
Transform tool parameters type
"""
# Temp fix for the issue that the tool parameters will be converted to empty while validating the credentials
result = deepcopy(tool_parameters)
for parameter in self.parameters or []:
if parameter.name in tool_parameters:
2024-07-29 16:40:04 +08:00
result[parameter.name] = ToolParameterConverter.cast_parameter_by_type(
tool_parameters[parameter.name], parameter.type
)
return result
@abstractmethod
2024-07-09 15:37:56 +08:00
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage, None, None]:
pass
2024-07-29 16:40:04 +08:00
def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None:
"""
2024-07-29 16:40:04 +08:00
validate the credentials
2024-07-29 16:40:04 +08:00
:param credentials: the credentials
:param parameters: the parameters
"""
pass
def get_runtime_parameters(self) -> list[ToolParameter]:
"""
2024-07-29 16:40:04 +08:00
get the runtime parameters
2024-07-29 16:40:04 +08:00
interface for developer to dynamic change the parameters of a tool depends on the variables pool
2024-07-29 16:40:04 +08:00
:return: the runtime parameters
"""
return self.parameters or []
2024-07-29 16:40:04 +08:00
2024-03-08 20:31:13 +08:00
def get_all_runtime_parameters(self) -> list[ToolParameter]:
"""
2024-07-29 16:40:04 +08:00
get all runtime parameters
2024-03-08 20:31:13 +08:00
2024-07-29 16:40:04 +08:00
:return: all runtime parameters
2024-03-08 20:31:13 +08:00
"""
parameters = self.parameters or []
parameters = parameters.copy()
user_parameters = self.get_runtime_parameters() or []
user_parameters = user_parameters.copy()
# override parameters
for parameter in user_parameters:
# check if parameter in tool parameters
found = False
for tool_parameter in parameters:
if tool_parameter.name == parameter.name:
found = True
break
if found:
# override parameter
tool_parameter.type = parameter.type
tool_parameter.form = parameter.form
tool_parameter.required = parameter.required
tool_parameter.default = parameter.default
tool_parameter.options = parameter.options
tool_parameter.llm_description = parameter.llm_description
else:
# add new parameter
parameters.append(parameter)
return parameters
2024-07-29 16:40:04 +08:00
def create_image_message(self, image: str, save_as: str = '') -> ToolInvokeMessage:
"""
2024-07-29 16:40:04 +08:00
create an image message
2024-07-29 16:40:04 +08:00
:param image: the url of the image
:return: the image message
"""
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE,
2024-08-29 14:06:10 +08:00
message=ToolInvokeMessage.TextMessage(text=image),
save_as=save_as)
2024-07-29 16:40:04 +08:00
def create_file_var_message(self, file_var: "FileVar") -> ToolInvokeMessage:
2024-05-27 22:01:11 +08:00
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.FILE_VAR,
2024-08-29 13:09:13 +08:00
message=None,
2024-05-27 22:01:11 +08:00
meta={
'file_var': file_var
},
save_as='')
2024-07-29 16:40:04 +08:00
def create_link_message(self, link: str, save_as: str = '') -> ToolInvokeMessage:
"""
2024-07-29 16:40:04 +08:00
create a link message
2024-07-29 16:40:04 +08:00
:param link: the url of the link
:return: the link message
"""
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.LINK,
2024-08-29 14:06:10 +08:00
message=ToolInvokeMessage.TextMessage(text=link),
save_as=save_as)
2024-07-29 16:40:04 +08:00
def create_text_message(self, text: str, save_as: str = '') -> ToolInvokeMessage:
"""
2024-07-29 16:40:04 +08:00
create a text message
2024-07-29 16:40:04 +08:00
:param text: the text
:return: the text message
"""
return ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.TEXT,
2024-08-29 14:06:10 +08:00
message=ToolInvokeMessage.TextMessage(text=text),
save_as=save_as
)
2024-07-29 16:40:04 +08:00
def create_blob_message(self, blob: bytes, meta: Optional[dict] = None, save_as: str = '') -> ToolInvokeMessage:
"""
2024-07-29 16:40:04 +08:00
create a blob message
2024-07-29 16:40:04 +08:00
:param blob: the blob
:return: the blob message
"""
2024-08-29 14:06:10 +08:00
return ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB,
message=ToolInvokeMessage.BlobMessage(blob=blob),
meta=meta,
save_as=save_as
)
def create_json_message(self, object: dict) -> ToolInvokeMessage:
"""
2024-07-29 16:40:04 +08:00
create a json message
"""
2024-08-29 14:06:10 +08:00
return ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.JSON,
message=ToolInvokeMessage.JsonMessage(json_object=object)
)