fix: tool type

This commit is contained in:
Yeuoly 2024-08-29 14:06:10 +08:00
parent 531ffaec4f
commit c8b0160ea9
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61

View File

@ -4,7 +4,7 @@ from copy import deepcopy
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, Any, Optional, Union from typing import TYPE_CHECKING, Any, Optional, Union
from pydantic import BaseModel, ConfigDict, field_validator from pydantic import BaseModel, ConfigDict, Field, field_validator
from pydantic_core.core_schema import ValidationInfo from pydantic_core.core_schema import ValidationInfo
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
@ -27,8 +27,8 @@ if TYPE_CHECKING:
class Tool(BaseModel, ABC): class Tool(BaseModel, ABC):
identity: Optional[ToolIdentity] = None identity: ToolIdentity
parameters: Optional[list[ToolParameter]] = None parameters: list[ToolParameter] = Field(default_factory=list)
description: Optional[ToolDescription] = None description: Optional[ToolDescription] = None
is_team_authorization: bool = False is_team_authorization: bool = False
@ -194,10 +194,8 @@ class Tool(BaseModel, ABC):
return result return result
def invoke(self, user_id: str, tool_parameters: Mapping[str, Any]) -> Generator[ToolInvokeMessage]: def invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage]:
# update tool_parameters if self.runtime and self.runtime.runtime_parameters:
# TODO: Fix type error.
if self.runtime.runtime_parameters:
tool_parameters.update(self.runtime.runtime_parameters) tool_parameters.update(self.runtime.runtime_parameters)
# try parse tool parameters into the correct type # try parse tool parameters into the correct type
@ -210,7 +208,7 @@ class Tool(BaseModel, ABC):
return result return result
def _transform_tool_parameters_type(self, tool_parameters: Mapping[str, Any]) -> dict[str, Any]: def _transform_tool_parameters_type(self, tool_parameters: dict[str, Any]) -> dict[str, Any]:
""" """
Transform tool parameters type Transform tool parameters type
""" """
@ -289,7 +287,7 @@ class Tool(BaseModel, ABC):
:return: the image message :return: the image message
""" """
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE, return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE,
message=image, message=ToolInvokeMessage.TextMessage(text=image),
save_as=save_as) save_as=save_as)
def create_file_var_message(self, file_var: "FileVar") -> ToolInvokeMessage: def create_file_var_message(self, file_var: "FileVar") -> ToolInvokeMessage:
@ -308,7 +306,7 @@ class Tool(BaseModel, ABC):
:return: the link message :return: the link message
""" """
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.LINK, return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.LINK,
message=link, message=ToolInvokeMessage.TextMessage(text=link),
save_as=save_as) save_as=save_as)
def create_text_message(self, text: str, save_as: str = '') -> ToolInvokeMessage: def create_text_message(self, text: str, save_as: str = '') -> ToolInvokeMessage:
@ -320,7 +318,7 @@ class Tool(BaseModel, ABC):
""" """
return ToolInvokeMessage( return ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.TEXT, type=ToolInvokeMessage.MessageType.TEXT,
message=text, message=ToolInvokeMessage.TextMessage(text=text),
save_as=save_as save_as=save_as
) )
@ -331,10 +329,18 @@ class Tool(BaseModel, ABC):
:param blob: the blob :param blob: the blob
:return: the blob message :return: the blob message
""" """
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.BLOB, message=blob, meta=meta, save_as=save_as) return ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB,
message=ToolInvokeMessage.BlobMessage(blob=blob),
meta=meta,
save_as=save_as
)
def create_json_message(self, object: dict) -> ToolInvokeMessage: def create_json_message(self, object: dict) -> ToolInvokeMessage:
""" """
create a json message create a json message
""" """
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.JSON, message=object) return ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.JSON,
message=ToolInvokeMessage.JsonMessage(json_object=object)
)