From c8b0160ea9afb9c93061b33ba74cbe269a09debc Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Thu, 29 Aug 2024 14:06:10 +0800 Subject: [PATCH] fix: tool type --- api/core/tools/tool/tool.py | 32 +++++++++++++++++++------------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/api/core/tools/tool/tool.py b/api/core/tools/tool/tool.py index ae346759e2..52513c13f9 100644 --- a/api/core/tools/tool/tool.py +++ b/api/core/tools/tool/tool.py @@ -4,7 +4,7 @@ from copy import deepcopy from enum import Enum from typing import TYPE_CHECKING, Any, Optional, Union -from pydantic import BaseModel, ConfigDict, field_validator +from pydantic import BaseModel, ConfigDict, Field, field_validator from pydantic_core.core_schema import ValidationInfo from core.app.entities.app_invoke_entities import InvokeFrom @@ -27,8 +27,8 @@ if TYPE_CHECKING: class Tool(BaseModel, ABC): - identity: Optional[ToolIdentity] = None - parameters: Optional[list[ToolParameter]] = None + identity: ToolIdentity + parameters: list[ToolParameter] = Field(default_factory=list) description: Optional[ToolDescription] = None is_team_authorization: bool = False @@ -194,10 +194,8 @@ class Tool(BaseModel, ABC): return result - def invoke(self, user_id: str, tool_parameters: Mapping[str, Any]) -> Generator[ToolInvokeMessage]: - # update tool_parameters - # TODO: Fix type error. - if self.runtime.runtime_parameters: + def invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage]: + if self.runtime and self.runtime.runtime_parameters: tool_parameters.update(self.runtime.runtime_parameters) # try parse tool parameters into the correct type @@ -210,7 +208,7 @@ class Tool(BaseModel, ABC): return result - def _transform_tool_parameters_type(self, tool_parameters: Mapping[str, Any]) -> dict[str, Any]: + def _transform_tool_parameters_type(self, tool_parameters: dict[str, Any]) -> dict[str, Any]: """ Transform tool parameters type """ @@ -289,7 +287,7 @@ class Tool(BaseModel, ABC): :return: the image message """ return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE, - message=image, + message=ToolInvokeMessage.TextMessage(text=image), save_as=save_as) def create_file_var_message(self, file_var: "FileVar") -> ToolInvokeMessage: @@ -308,7 +306,7 @@ class Tool(BaseModel, ABC): :return: the link message """ return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.LINK, - message=link, + message=ToolInvokeMessage.TextMessage(text=link), save_as=save_as) def create_text_message(self, text: str, save_as: str = '') -> ToolInvokeMessage: @@ -320,7 +318,7 @@ class Tool(BaseModel, ABC): """ return ToolInvokeMessage( type=ToolInvokeMessage.MessageType.TEXT, - message=text, + message=ToolInvokeMessage.TextMessage(text=text), save_as=save_as ) @@ -331,10 +329,18 @@ class Tool(BaseModel, ABC): :param blob: the blob :return: the blob message """ - return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.BLOB, message=blob, meta=meta, save_as=save_as) + return ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.BLOB, + message=ToolInvokeMessage.BlobMessage(blob=blob), + meta=meta, + save_as=save_as + ) def create_json_message(self, object: dict) -> ToolInvokeMessage: """ create a json message """ - return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.JSON, message=object) + return ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.JSON, + message=ToolInvokeMessage.JsonMessage(json_object=object) + )