diff --git a/api/controllers/inner_api/plugin/plugin.py b/api/controllers/inner_api/plugin/plugin.py index f785b2aed6..1e7b3f1526 100644 --- a/api/controllers/inner_api/plugin/plugin.py +++ b/api/controllers/inner_api/plugin/plugin.py @@ -1,5 +1,3 @@ -import time - from flask_restful import Resource from controllers.console.setup import setup_required @@ -10,6 +8,7 @@ from core.plugin.backwards_invocation.app import PluginAppBackwardsInvocation from core.plugin.backwards_invocation.base import BaseBackwardsInvocationResponse from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocation from core.plugin.backwards_invocation.node import PluginNodeBackwardsInvocation +from core.plugin.backwards_invocation.tool import PluginToolBackwardsInvocation from core.plugin.encrypt import PluginEncrypter from core.plugin.entities.request import ( RequestInvokeApp, @@ -24,7 +23,7 @@ from core.plugin.entities.request import ( RequestInvokeTool, RequestInvokeTTS, ) -from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.entities.tool_entities import ToolProviderType from libs.helper import compact_generate_response from models.account import Tenant @@ -138,17 +137,16 @@ class PluginInvokeToolApi(Resource): @plugin_data(payload_type=RequestInvokeTool) def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeTool): def generator(): - for i in range(10): - time.sleep(0.1) - yield ( - ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.TEXT, - message=ToolInvokeMessage.TextMessage(text="helloworld"), - ) - .model_dump_json() - .encode() - + b"\n\n" - ) + return PluginToolBackwardsInvocation.convert_to_event_stream( + PluginToolBackwardsInvocation.invoke_tool( + tenant_id=tenant_model.id, + user_id=user_id, + tool_type=ToolProviderType.value_of(payload.tool_type), + provider=payload.provider, + tool_name=payload.tool, + tool_parameters=payload.tool_parameters, + ), + ) return compact_generate_response(generator()) diff --git a/api/core/plugin/backwards_invocation/tool.py b/api/core/plugin/backwards_invocation/tool.py new file mode 100644 index 0000000000..1d62743f13 --- /dev/null +++ b/api/core/plugin/backwards_invocation/tool.py @@ -0,0 +1,45 @@ +from collections.abc import Generator +from typing import Any + +from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler +from core.plugin.backwards_invocation.base import BaseBackwardsInvocation +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType +from core.tools.tool_engine import ToolEngine +from core.tools.tool_manager import ToolManager +from core.tools.utils.message_transformer import ToolFileMessageTransformer + + +class PluginToolBackwardsInvocation(BaseBackwardsInvocation): + """ + Backwards invocation for plugin tools. + """ + + @classmethod + def invoke_tool( + cls, + tenant_id: str, + user_id: str, + tool_type: ToolProviderType, + provider: str, + tool_name: str, + tool_parameters: dict[str, Any], + ) -> Generator[ToolInvokeMessage, None, None]: + """ + invoke tool + """ + # get tool runtime + try: + tool_runtime = ToolManager.get_tool_runtime_from_plugin( + tool_type, tenant_id, provider, tool_name, tool_parameters + ) + response = ToolEngine.generic_invoke( + tool_runtime, tool_parameters, user_id, DifyWorkflowCallbackHandler(), workflow_call_depth=1 + ) + + response = ToolFileMessageTransformer.transform_tool_invoke_messages( + response, user_id=user_id, tenant_id=tenant_id + ) + + return response + except Exception as e: + raise e diff --git a/api/core/plugin/entities/request.py b/api/core/plugin/entities/request.py index 19bf329674..ae94bc95f6 100644 --- a/api/core/plugin/entities/request.py +++ b/api/core/plugin/entities/request.py @@ -32,6 +32,11 @@ class RequestInvokeTool(BaseModel): Request to invoke a tool """ + tool_type: Literal["builtin", "workflow", "api"] + provider: str + tool: str + tool_parameters: dict + class BaseRequestInvokeModel(BaseModel): provider: str diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index e6ef1df79f..b98ee28fb4 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -378,6 +378,7 @@ class ToolInvokeFrom(Enum): WORKFLOW = "workflow" AGENT = "agent" + PLUGIN = "plugin" class ToolProviderID: diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index 5396acc285..17be66035d 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -131,7 +131,7 @@ class ToolEngine: return error_response, [], ToolInvokeMeta.error_instance(error_response) @staticmethod - def workflow_invoke( + def generic_invoke( tool: Tool, tool_parameters: dict[str, Any], user_id: str, diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 54ae3a4117..4d8ee1399b 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -365,6 +365,40 @@ class ToolManager: tool_runtime.runtime.runtime_parameters.update(runtime_parameters) return tool_runtime + @classmethod + def get_tool_runtime_from_plugin( + cls, + tool_type: ToolProviderType, + tenant_id: str, + provider: str, + tool_name: str, + tool_parameters: dict[str, Any], + ) -> Tool: + """ + get tool runtime from plugin + """ + tool_entity = cls.get_tool_runtime( + provider_type=tool_type, + provider_id=provider, + tool_name=tool_name, + tenant_id=tenant_id, + invoke_from=InvokeFrom.SERVICE_API, + tool_invoke_from=ToolInvokeFrom.PLUGIN, + ) + runtime_parameters = {} + parameters = tool_entity.get_merged_runtime_parameters() + for parameter in parameters: + if parameter.form == ToolParameter.ToolParameterForm.FORM: + # save tool parameter to tool entity memory + value = cls._init_runtime_parameter(parameter, tool_parameters) + runtime_parameters[parameter.name] = value + + if not tool_entity.runtime: + raise Exception("tool missing runtime") + + tool_entity.runtime.runtime_parameters.update(runtime_parameters) + return tool_entity + @classmethod def get_builtin_provider_icon(cls, provider: str, tenant_id: str) -> tuple[str, str]: """ diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index c52052c3a0..c65826dc3c 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -66,7 +66,7 @@ class ToolNode(BaseNode): ) try: - message_stream = ToolEngine.workflow_invoke( + message_stream = ToolEngine.generic_invoke( tool=tool_runtime, tool_parameters=parameters, user_id=self.user_id,