diff --git a/api/controllers/inner_api/plugin/plugin.py b/api/controllers/inner_api/plugin/plugin.py index 6ce860b877..95500ea9f3 100644 --- a/api/controllers/inner_api/plugin/plugin.py +++ b/api/controllers/inner_api/plugin/plugin.py @@ -21,6 +21,8 @@ class PluginInvokeModelApi(Resource): parser.add_argument('parameters', type=dict, required=True, location='json') args = parser.parse_args() + + class PluginInvokeToolApi(Resource): diff --git a/api/core/agent/entities.py b/api/core/agent/entities.py index 5274224de5..189c925d82 100644 --- a/api/core/agent/entities.py +++ b/api/core/agent/entities.py @@ -1,14 +1,16 @@ from enum import Enum -from typing import Any, Literal, Optional, Union +from typing import Any, Optional, Union from pydantic import BaseModel +from core.tools.entities.tool_entities import ToolProviderType + class AgentToolEntity(BaseModel): """ Agent Tool Entity. """ - provider_type: Literal["builtin", "api", "workflow"] + provider_type: ToolProviderType provider_id: str tool_name: str tool_parameters: dict[str, Any] = {} diff --git a/api/core/callback_handler/plugin_tool_callback_handler.py b/api/core/callback_handler/plugin_tool_callback_handler.py new file mode 100644 index 0000000000..e9b9784014 --- /dev/null +++ b/api/core/callback_handler/plugin_tool_callback_handler.py @@ -0,0 +1,5 @@ +from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler + + +class DifyPluginCallbackHandler(DifyAgentCallbackHandler): + """Callback Handler that prints to std out.""" \ No newline at end of file diff --git a/api/core/tools/tool/api_tool.py b/api/core/tools/tool/api_tool.py index c8b683f9ef..1d600d5efc 100644 --- a/api/core/tools/tool/api_tool.py +++ b/api/core/tools/tool/api_tool.py @@ -1,4 +1,5 @@ import json +from collections.abc import Generator from os import getenv from typing import Any from urllib.parse import urlencode @@ -269,7 +270,7 @@ class ApiTool(Tool): except ValueError as e: return value - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage, None, None]: """ invoke http request """ @@ -283,4 +284,4 @@ class ApiTool(Tool): response = self.validate_and_parse_response(response) # assemble invoke message - return self.create_text_message(response) + yield self.create_text_message(response) diff --git a/api/core/tools/tool/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever_tool.py index 1170e1b7a5..3e81d84c92 100644 --- a/api/core/tools/tool/dataset_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever_tool.py @@ -1,3 +1,4 @@ +from collections.abc import Generator from typing import Any from core.app.app_config.entities import DatasetRetrieveConfigEntity @@ -86,7 +87,7 @@ class DatasetRetrieverTool(Tool): def tool_provider_type(self) -> ToolProviderType: return ToolProviderType.DATASET_RETRIEVAL - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage, None, None]: """ invoke dataset retriever tool """ @@ -97,7 +98,7 @@ class DatasetRetrieverTool(Tool): # invoke dataset retriever tool result = self.retrival_tool._run(query=query) - return self.create_text_message(text=result) + yield self.create_text_message(text=result) def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None: """ diff --git a/api/core/tools/tool/tool.py b/api/core/tools/tool/tool.py index 04c09c7f5b..291cac5ee3 100644 --- a/api/core/tools/tool/tool.py +++ b/api/core/tools/tool/tool.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from collections.abc import Generator from copy import deepcopy from enum import Enum from typing import Any, Optional, Union @@ -190,7 +191,7 @@ class Tool(BaseModel, ABC): return result - def invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]: + def invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage]: # update tool_parameters if self.runtime.runtime_parameters: tool_parameters.update(self.runtime.runtime_parameters) @@ -203,9 +204,6 @@ class Tool(BaseModel, ABC): tool_parameters=tool_parameters, ) - if not isinstance(result, list): - result = [result] - return result def _transform_tool_parameters_type(self, tool_parameters: dict[str, Any]) -> dict[str, Any]: @@ -221,7 +219,7 @@ class Tool(BaseModel, ABC): return result @abstractmethod - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage, None, None]: pass def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None: diff --git a/api/core/tools/tool/workflow_tool.py b/api/core/tools/tool/workflow_tool.py index 071081303c..d0a9df6479 100644 --- a/api/core/tools/tool/workflow_tool.py +++ b/api/core/tools/tool/workflow_tool.py @@ -1,5 +1,6 @@ import json import logging +from collections.abc import Generator from copy import deepcopy from typing import Any, Union @@ -34,7 +35,7 @@ class WorkflowTool(Tool): def _invoke( self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + ) -> Generator[ToolInvokeMessage, None, None]: """ invoke the tool """ @@ -46,6 +47,7 @@ class WorkflowTool(Tool): from core.app.apps.workflow.app_generator import WorkflowAppGenerator generator = WorkflowAppGenerator() + result = generator.generate( app_model=app, workflow=workflow, @@ -64,16 +66,12 @@ class WorkflowTool(Tool): if data.get('error'): raise Exception(data.get('error')) - result = [] - outputs = data.get('outputs', {}) outputs, files = self._extract_files(outputs) for file in files: - result.append(self.create_file_var_message(file)) + yield self.create_file_var_message(file) - result.append(self.create_text_message(json.dumps(outputs, ensure_ascii=False))) - - return result + yield self.create_text_message(json.dumps(outputs, ensure_ascii=False)) def _get_user(self, user_id: str) -> Union[EndUser, Account]: """ diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index 7615368934..7d94eedc5f 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -1,4 +1,5 @@ import json +from collections.abc import Generator from copy import deepcopy from datetime import datetime, timezone from mimetypes import guess_type @@ -8,6 +9,7 @@ from yarl import URL from core.app.entities.app_invoke_entities import InvokeFrom from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler +from core.callback_handler.plugin_tool_callback_handler import DifyPluginCallbackHandler from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler from core.file.file_obj import FileTransferMethod from core.ops.ops_trace_manager import TraceQueueManager @@ -64,16 +66,25 @@ class ToolEngine: tool_inputs=tool_parameters ) - meta, response = ToolEngine._invoke(tool, tool_parameters, user_id) - response = ToolFileMessageTransformer.transform_tool_invoke_messages( - messages=response, - user_id=user_id, - tenant_id=tenant_id, + messages = ToolEngine._invoke(tool, tool_parameters, user_id) + invocation_meta_dict = {'meta': None} + + def message_callback(invocation_meta_dict: dict, messages: Generator[ToolInvokeMessage, None, None]): + for message in messages: + if isinstance(message, ToolInvokeMeta): + invocation_meta_dict['meta'] = message + else: + yield message + + messages = ToolFileMessageTransformer.transform_tool_invoke_messages( + messages=message_callback(invocation_meta_dict, messages), + user_id=user_id, + tenant_id=tenant_id, conversation_id=message.conversation_id ) # extract binary data from tool invoke message - binary_files = ToolEngine._extract_tool_response_binary(response) + binary_files = ToolEngine._extract_tool_response_binary(messages) # create message file message_files = ToolEngine._create_message_files( tool_messages=binary_files, @@ -82,7 +93,9 @@ class ToolEngine: user_id=user_id ) - plain_text = ToolEngine._convert_tool_response_to_str(response) + plain_text = ToolEngine._convert_tool_response_to_str(messages) + + meta = invocation_meta_dict['meta'] # hit the callback handler agent_tool_callback.on_tool_end( @@ -127,7 +140,7 @@ class ToolEngine: user_id: str, workflow_id: str, workflow_tool_callback: DifyWorkflowCallbackHandler, workflow_call_depth: int, - ) -> list[ToolInvokeMessage]: + ) -> Generator[ToolInvokeMessage, None, None]: """ Workflow invokes the tool with the given arguments. """ @@ -154,10 +167,38 @@ class ToolEngine: except Exception as e: workflow_tool_callback.on_tool_error(e) raise e - + + @staticmethod + def plugin_invoke(tool: Tool, tool_parameters: dict, user_id: str, + callback: DifyPluginCallbackHandler + ) -> Generator[ToolInvokeMessage, None, None]: + """ + Plugin invokes the tool with the given arguments. + """ + try: + # hit the callback handler + callback.on_tool_start( + tool_name=tool.identity.name, + tool_inputs=tool_parameters + ) + + response = tool.invoke(user_id, tool_parameters) + + # hit the callback handler + callback.on_tool_end( + tool_name=tool.identity.name, + tool_inputs=tool_parameters, + tool_outputs=response, + ) + + return response + except Exception as e: + callback.on_tool_error(e) + raise e + @staticmethod def _invoke(tool: Tool, tool_parameters: dict, user_id: str) \ - -> tuple[ToolInvokeMeta, list[ToolInvokeMessage]]: + -> Generator[ToolInvokeMessage | ToolInvokeMeta, None, None]: """ Invoke the tool with the given arguments. """ @@ -170,16 +211,15 @@ class ToolEngine: 'tool_icon': tool.identity.icon }) try: - response = tool.invoke(user_id, tool_parameters) + yield from tool.invoke(user_id, tool_parameters) except Exception as e: meta.error = str(e) raise ToolEngineInvokeError(meta) finally: ended_at = datetime.now(timezone.utc) meta.time_cost = (ended_at - started_at).total_seconds() + yield meta - return meta, response - @staticmethod def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str: """ diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index e30a905cbc..5822841db7 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -18,6 +18,7 @@ from core.tools.entities.tool_entities import ( ApiProviderAuthType, ToolInvokeFrom, ToolParameter, + ToolProviderType, ) from core.tools.errors import ToolProviderNotFoundError from core.tools.provider.api_tool_provider import ApiToolProviderController @@ -26,6 +27,7 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl from core.tools.tool.api_tool import ApiTool from core.tools.tool.builtin_tool import BuiltinTool from core.tools.tool.tool import Tool +from core.tools.tool.workflow_tool import WorkflowTool from core.tools.tool_label_manager import ToolLabelManager from core.tools.utils.configuration import ( ToolConfigurationManager, @@ -78,37 +80,13 @@ class ToolManager: return tool @classmethod - def get_tool(cls, provider_type: str, provider_id: str, tool_name: str, tenant_id: str = None) \ - -> Union[BuiltinTool, ApiTool]: - """ - get the tool - - :param provider_type: the type of the provider - :param provider_name: the name of the provider - :param tool_name: the name of the tool - - :return: the tool - """ - if provider_type == 'builtin': - return cls.get_builtin_tool(provider_id, tool_name) - elif provider_type == 'api': - if tenant_id is None: - raise ValueError('tenant id is required for api provider') - api_provider, _ = cls.get_api_provider_controller(tenant_id, provider_id) - return api_provider.get_tool(tool_name) - elif provider_type == 'app': - raise NotImplementedError('app provider not implemented') - else: - raise ToolProviderNotFoundError(f'provider type {provider_type} not found') - - @classmethod - def get_tool_runtime(cls, provider_type: str, + def get_tool_runtime(cls, provider_type: ToolProviderType, provider_id: str, tool_name: str, tenant_id: str, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT) \ - -> Union[BuiltinTool, ApiTool]: + -> Union[BuiltinTool, ApiTool, WorkflowTool]: """ get the tool runtime @@ -118,7 +96,7 @@ class ToolManager: :return: the tool """ - if provider_type == 'builtin': + if provider_type == ToolProviderType.BUILT_IN: builtin_tool = cls.get_builtin_tool(provider_id, tool_name) # check if the builtin tool need credentials @@ -155,7 +133,7 @@ class ToolManager: 'tool_invoke_from': tool_invoke_from, }) - elif provider_type == 'api': + elif provider_type == ToolProviderType.API: if tenant_id is None: raise ValueError('tenant id is required for api provider') @@ -171,7 +149,7 @@ class ToolManager: 'invoke_from': invoke_from, 'tool_invoke_from': tool_invoke_from, }) - elif provider_type == 'workflow': + elif provider_type == ToolProviderType.WORKFLOW: workflow_provider = db.session.query(WorkflowToolProvider).filter( WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id @@ -190,10 +168,10 @@ class ToolManager: 'invoke_from': invoke_from, 'tool_invoke_from': tool_invoke_from, }) - elif provider_type == 'app': + elif provider_type == ToolProviderType.APP: raise NotImplementedError('app provider not implemented') else: - raise ToolProviderNotFoundError(f'provider type {provider_type} not found') + raise ToolProviderNotFoundError(f'provider type {provider_type.value} not found') @classmethod def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict) -> Union[str, int, float, bool]: @@ -554,7 +532,7 @@ class ToolManager: }) @classmethod - def get_tool_icon(cls, tenant_id: str, provider_type: str, provider_id: str) -> Union[str, dict]: + def get_tool_icon(cls, tenant_id: str, provider_type: ToolProviderType, provider_id: str) -> Union[str, dict]: """ get the tool icon @@ -563,14 +541,12 @@ class ToolManager: :param provider_id: the id of the provider :return: """ - provider_type = provider_type - provider_id = provider_id - if provider_type == 'builtin': + if provider_type == ToolProviderType.BUILT_IN: return (current_app.config.get("CONSOLE_API_URL") + "/console/api/workspaces/current/tool-provider/builtin/" + provider_id + "/icon") - elif provider_type == 'api': + elif provider_type == ToolProviderType.API: try: provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( ApiToolProvider.tenant_id == tenant_id, @@ -582,7 +558,7 @@ class ToolManager: "background": "#252525", "content": "\ud83d\ude01" } - elif provider_type == 'workflow': + elif provider_type == ToolProviderType.WORKFLOW: provider: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id diff --git a/api/core/tools/utils/configuration.py b/api/core/tools/utils/configuration.py index b213879e96..68b0cea24f 100644 --- a/api/core/tools/utils/configuration.py +++ b/api/core/tools/utils/configuration.py @@ -9,6 +9,7 @@ from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolPr from core.tools.entities.tool_entities import ( ToolParameter, ToolProviderCredentials, + ToolProviderType, ) from core.tools.provider.tool_provider import ToolProviderController from core.tools.tool.tool import Tool @@ -108,7 +109,7 @@ class ToolParameterConfigurationManager(BaseModel): tenant_id: str tool_runtime: Tool provider_name: str - provider_type: str + provider_type: ToolProviderType identity_id: str def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]: @@ -191,7 +192,7 @@ class ToolParameterConfigurationManager(BaseModel): """ cache = ToolParameterCache( tenant_id=self.tenant_id, - provider=f'{self.provider_type}.{self.provider_name}', + provider=f'{self.provider_type.value}.{self.provider_name}', tool_name=self.tool_runtime.identity.name, cache_type=ToolParameterCacheType.PARAMETER, identity_id=self.identity_id @@ -221,7 +222,7 @@ class ToolParameterConfigurationManager(BaseModel): def delete_tool_parameters_cache(self): cache = ToolParameterCache( tenant_id=self.tenant_id, - provider=f'{self.provider_type}.{self.provider_name}', + provider=f'{self.provider_type.value}.{self.provider_name}', tool_name=self.tool_runtime.identity.name, cache_type=ToolParameterCacheType.PARAMETER, identity_id=self.identity_id diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index ef9e5b67ae..770abc683c 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -1,4 +1,5 @@ import logging +from collections.abc import Generator from mimetypes import guess_extension from core.file.file_obj import FileTransferMethod, FileType, FileVar @@ -9,20 +10,18 @@ logger = logging.getLogger(__name__) class ToolFileMessageTransformer: @classmethod - def transform_tool_invoke_messages(cls, messages: list[ToolInvokeMessage], + def transform_tool_invoke_messages(cls, messages: Generator[ToolInvokeMessage, None, None], user_id: str, tenant_id: str, - conversation_id: str) -> list[ToolInvokeMessage]: + conversation_id: str) -> Generator[ToolInvokeMessage, None, None]: """ Transform tool message and handle file download """ - result = [] - for message in messages: if message.type == ToolInvokeMessage.MessageType.TEXT: - result.append(message) + yield message elif message.type == ToolInvokeMessage.MessageType.LINK: - result.append(message) + yield message elif message.type == ToolInvokeMessage.MessageType.IMAGE: # try to download image try: @@ -35,20 +34,20 @@ class ToolFileMessageTransformer: url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".png"}' - result.append(ToolInvokeMessage( + yield ToolInvokeMessage( type=ToolInvokeMessage.MessageType.IMAGE_LINK, message=url, save_as=message.save_as, meta=message.meta.copy() if message.meta is not None else {}, - )) + ) except Exception as e: logger.exception(e) - result.append(ToolInvokeMessage( + yield ToolInvokeMessage( type=ToolInvokeMessage.MessageType.TEXT, message=f"Failed to download image: {message.message}, you can try to download it yourself.", meta=message.meta.copy() if message.meta is not None else {}, save_as=message.save_as, - )) + ) elif message.type == ToolInvokeMessage.MessageType.BLOB: # get mime type and save blob to storage mimetype = message.meta.get('mime_type', 'octet/stream') @@ -67,43 +66,41 @@ class ToolFileMessageTransformer: # check if file is image if 'image' in mimetype: - result.append(ToolInvokeMessage( + yield ToolInvokeMessage( type=ToolInvokeMessage.MessageType.IMAGE_LINK, message=url, save_as=message.save_as, meta=message.meta.copy() if message.meta is not None else {}, - )) + ) else: - result.append(ToolInvokeMessage( + yield ToolInvokeMessage( type=ToolInvokeMessage.MessageType.LINK, message=url, save_as=message.save_as, meta=message.meta.copy() if message.meta is not None else {}, - )) + ) elif message.type == ToolInvokeMessage.MessageType.FILE_VAR: file_var: FileVar = message.meta.get('file_var') if file_var: if file_var.transfer_method == FileTransferMethod.TOOL_FILE: url = cls.get_tool_file_url(file_var.related_id, file_var.extension) if file_var.type == FileType.IMAGE: - result.append(ToolInvokeMessage( + yield ToolInvokeMessage( type=ToolInvokeMessage.MessageType.IMAGE_LINK, message=url, save_as=message.save_as, meta=message.meta.copy() if message.meta is not None else {}, - )) + ) else: - result.append(ToolInvokeMessage( + yield ToolInvokeMessage( type=ToolInvokeMessage.MessageType.LINK, message=url, save_as=message.save_as, meta=message.meta.copy() if message.meta is not None else {}, - )) + ) else: - result.append(message) - - return result + yield message @classmethod def get_tool_file_url(cls, tool_file_id: str, extension: str) -> str: - return f'/files/tools/{tool_file_id}{extension or ".bin"}' \ No newline at end of file + return f'/files/tools/{tool_file_id}{extension or ".bin"}' diff --git a/api/core/workflow/nodes/tool/entities.py b/api/core/workflow/nodes/tool/entities.py index 2e4743c483..bc9bfde4db 100644 --- a/api/core/workflow/nodes/tool/entities.py +++ b/api/core/workflow/nodes/tool/entities.py @@ -3,12 +3,13 @@ from typing import Any, Literal, Union from pydantic import BaseModel, field_validator from pydantic_core.core_schema import ValidationInfo +from core.tools.entities.tool_entities import ToolProviderType from core.workflow.entities.base_node_data_entities import BaseNodeData class ToolEntity(BaseModel): provider_id: str - provider_type: Literal['builtin', 'api', 'workflow'] + provider_type: ToolProviderType provider_name: str # redundancy tool_name: str tool_label: str # redundancy diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index cddea03bf8..f77ccd9bd6 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -32,7 +32,7 @@ class ToolNode(BaseNode): # fetch tool icon tool_info = { - 'provider_type': node_data.provider_type, + 'provider_type': node_data.provider_type.value, 'provider_id': node_data.provider_id } diff --git a/api/services/plugin/plugin_invoke_service.py b/api/services/plugin/plugin_invoke_service.py index 131de1ec1a..317b42a6e1 100644 --- a/api/services/plugin/plugin_invoke_service.py +++ b/api/services/plugin/plugin_invoke_service.py @@ -1,16 +1,49 @@ from collections.abc import Generator -from typing import Any +from typing import Any, Union -from core.tools.entities.tool_entities import ToolInvokeMessage +from core.app.entities.app_invoke_entities import InvokeFrom +from core.callback_handler.plugin_tool_callback_handler import DifyPluginCallbackHandler +from core.model_runtime.entities.model_entities import ModelType +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType +from core.tools.tool_engine import ToolEngine +from core.tools.tool_manager import ToolManager +from core.tools.utils.message_transformer import ToolFileMessageTransformer +from core.workflow.entities.node_entities import NodeType from models.account import Tenant +from services.tools.tools_transform_service import ToolTransformService class PluginInvokeService: @classmethod - def invoke_tool(cls, user_id: str, tenant: Tenant, - tool_provider: str, tool_name: str, + def invoke_tool(cls, user_id: str, invoke_from: InvokeFrom, tenant: Tenant, + tool_provider_type: ToolProviderType, tool_provider: str, tool_name: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage]: """ Invokes a tool with the given user ID and tool parameters. """ - \ No newline at end of file + tool_runtime = ToolManager.get_tool_runtime(tool_provider_type, provider_id=tool_provider, + tool_name=tool_name, tenant_id=tenant.id, + invoke_from=invoke_from) + + response = ToolEngine.plugin_invoke(tool_runtime, + tool_parameters, + user_id, + callback=DifyPluginCallbackHandler()) + response = ToolFileMessageTransformer.transform_tool_invoke_messages(response) + return ToolTransformService.transform_messages_to_dict(response) + + @classmethod + def invoke_model(cls, user_id: str, tenant: Tenant, + model_provider: str, model_name: str, model_type: ModelType, + model_parameters: dict[str, Any]) -> Union[dict, Generator[ToolInvokeMessage]]: + """ + Invokes a model with the given user ID and model parameters. + """ + + @classmethod + def invoke_workflow_node(cls, user_id: str, tenant: Tenant, + node_type: NodeType, node_data: dict[str, Any], + inputs: dict[str, Any]) -> Generator[ToolInvokeMessage]: + """ + Invokes a workflow node with the given user ID and node parameters. + """ \ No newline at end of file diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index 5c77732468..08023a4a92 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -1,5 +1,6 @@ import json import logging +from collections.abc import Generator from typing import Optional, Union from flask import current_app @@ -9,6 +10,7 @@ from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ( ApiProviderAuthType, + ToolInvokeMessage, ToolParameter, ToolProviderCredentials, ToolProviderType, @@ -24,8 +26,8 @@ from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvi logger = logging.getLogger(__name__) class ToolTransformService: - @staticmethod - def get_tool_provider_icon_url(provider_type: str, provider_name: str, icon: str) -> Union[str, dict]: + @classmethod + def get_tool_provider_icon_url(cls, provider_type: str, provider_name: str, icon: str) -> Union[str, dict]: """ get tool provider icon url """ @@ -45,8 +47,8 @@ class ToolTransformService: return '' - @staticmethod - def repack_provider(provider: Union[dict, UserToolProvider]): + @classmethod + def repack_provider(cls, provider: Union[dict, UserToolProvider]): """ repack provider @@ -65,8 +67,9 @@ class ToolTransformService: icon=provider.icon ) - @staticmethod + @classmethod def builtin_provider_to_user_provider( + cls, provider_controller: BuiltinToolProviderController, db_provider: Optional[BuiltinToolProvider], decrypt_credentials: bool = True, @@ -126,8 +129,9 @@ class ToolTransformService: return result - @staticmethod + @classmethod def api_provider_to_controller( + cls, db_provider: ApiToolProvider, ) -> ApiToolProviderController: """ @@ -142,8 +146,9 @@ class ToolTransformService: return controller - @staticmethod + @classmethod def workflow_provider_to_controller( + cls, db_provider: WorkflowToolProvider ) -> WorkflowToolProviderController: """ @@ -179,8 +184,9 @@ class ToolTransformService: labels=labels or [] ) - @staticmethod + @classmethod def api_provider_to_user_provider( + cls, provider_controller: ApiToolProviderController, db_provider: ApiToolProvider, decrypt_credentials: bool = True, @@ -231,8 +237,9 @@ class ToolTransformService: return result - @staticmethod + @classmethod def tool_to_user_tool( + cls, tool: Union[ApiToolBundle, WorkflowTool, Tool], credentials: dict = None, tenant_id: str = None, @@ -287,4 +294,9 @@ class ToolTransformService: ), parameters=tool.parameters, labels=labels - ) \ No newline at end of file + ) + + @classmethod + def transform_messages_to_dict(cls, responses: Generator[ToolInvokeMessage, None, None]): + for response in responses: + yield response.model_dump() \ No newline at end of file