refactor: tool response to generator
This commit is contained in:
parent
364df36ac4
commit
563d81277b
@ -23,6 +23,8 @@ class PluginInvokeModelApi(Resource):
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
|
||||
|
||||
class PluginInvokeToolApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
|
@ -1,14 +1,16 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Literal, Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
|
||||
|
||||
class AgentToolEntity(BaseModel):
|
||||
"""
|
||||
Agent Tool Entity.
|
||||
"""
|
||||
provider_type: Literal["builtin", "api", "workflow"]
|
||||
provider_type: ToolProviderType
|
||||
provider_id: str
|
||||
tool_name: str
|
||||
tool_parameters: dict[str, Any] = {}
|
||||
|
@ -0,0 +1,5 @@
|
||||
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
|
||||
|
||||
|
||||
class DifyPluginCallbackHandler(DifyAgentCallbackHandler):
|
||||
"""Callback Handler that prints to std out."""
|
@ -1,4 +1,5 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from os import getenv
|
||||
from typing import Any
|
||||
from urllib.parse import urlencode
|
||||
@ -269,7 +270,7 @@ class ApiTool(Tool):
|
||||
except ValueError as e:
|
||||
return value
|
||||
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""
|
||||
invoke http request
|
||||
"""
|
||||
@ -283,4 +284,4 @@ class ApiTool(Tool):
|
||||
response = self.validate_and_parse_response(response)
|
||||
|
||||
# assemble invoke message
|
||||
return self.create_text_message(response)
|
||||
yield self.create_text_message(response)
|
||||
|
@ -1,3 +1,4 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
||||
@ -86,7 +87,7 @@ class DatasetRetrieverTool(Tool):
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.DATASET_RETRIEVAL
|
||||
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""
|
||||
invoke dataset retriever tool
|
||||
"""
|
||||
@ -97,7 +98,7 @@ class DatasetRetrieverTool(Tool):
|
||||
# invoke dataset retriever tool
|
||||
result = self.retrival_tool._run(query=query)
|
||||
|
||||
return self.create_text_message(text=result)
|
||||
yield self.create_text_message(text=result)
|
||||
|
||||
def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None:
|
||||
"""
|
||||
|
@ -1,4 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Union
|
||||
@ -190,7 +191,7 @@ class Tool(BaseModel, ABC):
|
||||
|
||||
return result
|
||||
|
||||
def invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]:
|
||||
def invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage]:
|
||||
# update tool_parameters
|
||||
if self.runtime.runtime_parameters:
|
||||
tool_parameters.update(self.runtime.runtime_parameters)
|
||||
@ -203,9 +204,6 @@ class Tool(BaseModel, ABC):
|
||||
tool_parameters=tool_parameters,
|
||||
)
|
||||
|
||||
if not isinstance(result, list):
|
||||
result = [result]
|
||||
|
||||
return result
|
||||
|
||||
def _transform_tool_parameters_type(self, tool_parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
@ -221,7 +219,7 @@ class Tool(BaseModel, ABC):
|
||||
return result
|
||||
|
||||
@abstractmethod
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage, None, None]:
|
||||
pass
|
||||
|
||||
def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None:
|
||||
|
@ -1,5 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from copy import deepcopy
|
||||
from typing import Any, Union
|
||||
|
||||
@ -34,7 +35,7 @@ class WorkflowTool(Tool):
|
||||
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""
|
||||
invoke the tool
|
||||
"""
|
||||
@ -46,6 +47,7 @@ class WorkflowTool(Tool):
|
||||
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
generator = WorkflowAppGenerator()
|
||||
|
||||
result = generator.generate(
|
||||
app_model=app,
|
||||
workflow=workflow,
|
||||
@ -64,16 +66,12 @@ class WorkflowTool(Tool):
|
||||
if data.get('error'):
|
||||
raise Exception(data.get('error'))
|
||||
|
||||
result = []
|
||||
|
||||
outputs = data.get('outputs', {})
|
||||
outputs, files = self._extract_files(outputs)
|
||||
for file in files:
|
||||
result.append(self.create_file_var_message(file))
|
||||
yield self.create_file_var_message(file)
|
||||
|
||||
result.append(self.create_text_message(json.dumps(outputs, ensure_ascii=False)))
|
||||
|
||||
return result
|
||||
yield self.create_text_message(json.dumps(outputs, ensure_ascii=False))
|
||||
|
||||
def _get_user(self, user_id: str) -> Union[EndUser, Account]:
|
||||
"""
|
||||
|
@ -1,4 +1,5 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from copy import deepcopy
|
||||
from datetime import datetime, timezone
|
||||
from mimetypes import guess_type
|
||||
@ -8,6 +9,7 @@ from yarl import URL
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
|
||||
from core.callback_handler.plugin_tool_callback_handler import DifyPluginCallbackHandler
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
from core.file.file_obj import FileTransferMethod
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
@ -64,16 +66,25 @@ class ToolEngine:
|
||||
tool_inputs=tool_parameters
|
||||
)
|
||||
|
||||
meta, response = ToolEngine._invoke(tool, tool_parameters, user_id)
|
||||
response = ToolFileMessageTransformer.transform_tool_invoke_messages(
|
||||
messages=response,
|
||||
messages = ToolEngine._invoke(tool, tool_parameters, user_id)
|
||||
invocation_meta_dict = {'meta': None}
|
||||
|
||||
def message_callback(invocation_meta_dict: dict, messages: Generator[ToolInvokeMessage, None, None]):
|
||||
for message in messages:
|
||||
if isinstance(message, ToolInvokeMeta):
|
||||
invocation_meta_dict['meta'] = message
|
||||
else:
|
||||
yield message
|
||||
|
||||
messages = ToolFileMessageTransformer.transform_tool_invoke_messages(
|
||||
messages=message_callback(invocation_meta_dict, messages),
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=message.conversation_id
|
||||
)
|
||||
|
||||
# extract binary data from tool invoke message
|
||||
binary_files = ToolEngine._extract_tool_response_binary(response)
|
||||
binary_files = ToolEngine._extract_tool_response_binary(messages)
|
||||
# create message file
|
||||
message_files = ToolEngine._create_message_files(
|
||||
tool_messages=binary_files,
|
||||
@ -82,7 +93,9 @@ class ToolEngine:
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
plain_text = ToolEngine._convert_tool_response_to_str(response)
|
||||
plain_text = ToolEngine._convert_tool_response_to_str(messages)
|
||||
|
||||
meta = invocation_meta_dict['meta']
|
||||
|
||||
# hit the callback handler
|
||||
agent_tool_callback.on_tool_end(
|
||||
@ -127,7 +140,7 @@ class ToolEngine:
|
||||
user_id: str, workflow_id: str,
|
||||
workflow_tool_callback: DifyWorkflowCallbackHandler,
|
||||
workflow_call_depth: int,
|
||||
) -> list[ToolInvokeMessage]:
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""
|
||||
Workflow invokes the tool with the given arguments.
|
||||
"""
|
||||
@ -155,9 +168,37 @@ class ToolEngine:
|
||||
workflow_tool_callback.on_tool_error(e)
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def plugin_invoke(tool: Tool, tool_parameters: dict, user_id: str,
|
||||
callback: DifyPluginCallbackHandler
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""
|
||||
Plugin invokes the tool with the given arguments.
|
||||
"""
|
||||
try:
|
||||
# hit the callback handler
|
||||
callback.on_tool_start(
|
||||
tool_name=tool.identity.name,
|
||||
tool_inputs=tool_parameters
|
||||
)
|
||||
|
||||
response = tool.invoke(user_id, tool_parameters)
|
||||
|
||||
# hit the callback handler
|
||||
callback.on_tool_end(
|
||||
tool_name=tool.identity.name,
|
||||
tool_inputs=tool_parameters,
|
||||
tool_outputs=response,
|
||||
)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
callback.on_tool_error(e)
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def _invoke(tool: Tool, tool_parameters: dict, user_id: str) \
|
||||
-> tuple[ToolInvokeMeta, list[ToolInvokeMessage]]:
|
||||
-> Generator[ToolInvokeMessage | ToolInvokeMeta, None, None]:
|
||||
"""
|
||||
Invoke the tool with the given arguments.
|
||||
"""
|
||||
@ -170,15 +211,14 @@ class ToolEngine:
|
||||
'tool_icon': tool.identity.icon
|
||||
})
|
||||
try:
|
||||
response = tool.invoke(user_id, tool_parameters)
|
||||
yield from tool.invoke(user_id, tool_parameters)
|
||||
except Exception as e:
|
||||
meta.error = str(e)
|
||||
raise ToolEngineInvokeError(meta)
|
||||
finally:
|
||||
ended_at = datetime.now(timezone.utc)
|
||||
meta.time_cost = (ended_at - started_at).total_seconds()
|
||||
|
||||
return meta, response
|
||||
yield meta
|
||||
|
||||
@staticmethod
|
||||
def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str:
|
||||
|
@ -18,6 +18,7 @@ from core.tools.entities.tool_entities import (
|
||||
ApiProviderAuthType,
|
||||
ToolInvokeFrom,
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.errors import ToolProviderNotFoundError
|
||||
from core.tools.provider.api_tool_provider import ApiToolProviderController
|
||||
@ -26,6 +27,7 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl
|
||||
from core.tools.tool.api_tool import ApiTool
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.tool.workflow_tool import WorkflowTool
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.utils.configuration import (
|
||||
ToolConfigurationManager,
|
||||
@ -78,37 +80,13 @@ class ToolManager:
|
||||
return tool
|
||||
|
||||
@classmethod
|
||||
def get_tool(cls, provider_type: str, provider_id: str, tool_name: str, tenant_id: str = None) \
|
||||
-> Union[BuiltinTool, ApiTool]:
|
||||
"""
|
||||
get the tool
|
||||
|
||||
:param provider_type: the type of the provider
|
||||
:param provider_name: the name of the provider
|
||||
:param tool_name: the name of the tool
|
||||
|
||||
:return: the tool
|
||||
"""
|
||||
if provider_type == 'builtin':
|
||||
return cls.get_builtin_tool(provider_id, tool_name)
|
||||
elif provider_type == 'api':
|
||||
if tenant_id is None:
|
||||
raise ValueError('tenant id is required for api provider')
|
||||
api_provider, _ = cls.get_api_provider_controller(tenant_id, provider_id)
|
||||
return api_provider.get_tool(tool_name)
|
||||
elif provider_type == 'app':
|
||||
raise NotImplementedError('app provider not implemented')
|
||||
else:
|
||||
raise ToolProviderNotFoundError(f'provider type {provider_type} not found')
|
||||
|
||||
@classmethod
|
||||
def get_tool_runtime(cls, provider_type: str,
|
||||
def get_tool_runtime(cls, provider_type: ToolProviderType,
|
||||
provider_id: str,
|
||||
tool_name: str,
|
||||
tenant_id: str,
|
||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT) \
|
||||
-> Union[BuiltinTool, ApiTool]:
|
||||
-> Union[BuiltinTool, ApiTool, WorkflowTool]:
|
||||
"""
|
||||
get the tool runtime
|
||||
|
||||
@ -118,7 +96,7 @@ class ToolManager:
|
||||
|
||||
:return: the tool
|
||||
"""
|
||||
if provider_type == 'builtin':
|
||||
if provider_type == ToolProviderType.BUILT_IN:
|
||||
builtin_tool = cls.get_builtin_tool(provider_id, tool_name)
|
||||
|
||||
# check if the builtin tool need credentials
|
||||
@ -155,7 +133,7 @@ class ToolManager:
|
||||
'tool_invoke_from': tool_invoke_from,
|
||||
})
|
||||
|
||||
elif provider_type == 'api':
|
||||
elif provider_type == ToolProviderType.API:
|
||||
if tenant_id is None:
|
||||
raise ValueError('tenant id is required for api provider')
|
||||
|
||||
@ -171,7 +149,7 @@ class ToolManager:
|
||||
'invoke_from': invoke_from,
|
||||
'tool_invoke_from': tool_invoke_from,
|
||||
})
|
||||
elif provider_type == 'workflow':
|
||||
elif provider_type == ToolProviderType.WORKFLOW:
|
||||
workflow_provider = db.session.query(WorkflowToolProvider).filter(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
WorkflowToolProvider.id == provider_id
|
||||
@ -190,10 +168,10 @@ class ToolManager:
|
||||
'invoke_from': invoke_from,
|
||||
'tool_invoke_from': tool_invoke_from,
|
||||
})
|
||||
elif provider_type == 'app':
|
||||
elif provider_type == ToolProviderType.APP:
|
||||
raise NotImplementedError('app provider not implemented')
|
||||
else:
|
||||
raise ToolProviderNotFoundError(f'provider type {provider_type} not found')
|
||||
raise ToolProviderNotFoundError(f'provider type {provider_type.value} not found')
|
||||
|
||||
@classmethod
|
||||
def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict) -> Union[str, int, float, bool]:
|
||||
@ -554,7 +532,7 @@ class ToolManager:
|
||||
})
|
||||
|
||||
@classmethod
|
||||
def get_tool_icon(cls, tenant_id: str, provider_type: str, provider_id: str) -> Union[str, dict]:
|
||||
def get_tool_icon(cls, tenant_id: str, provider_type: ToolProviderType, provider_id: str) -> Union[str, dict]:
|
||||
"""
|
||||
get the tool icon
|
||||
|
||||
@ -563,14 +541,12 @@ class ToolManager:
|
||||
:param provider_id: the id of the provider
|
||||
:return:
|
||||
"""
|
||||
provider_type = provider_type
|
||||
provider_id = provider_id
|
||||
if provider_type == 'builtin':
|
||||
if provider_type == ToolProviderType.BUILT_IN:
|
||||
return (current_app.config.get("CONSOLE_API_URL")
|
||||
+ "/console/api/workspaces/current/tool-provider/builtin/"
|
||||
+ provider_id
|
||||
+ "/icon")
|
||||
elif provider_type == 'api':
|
||||
elif provider_type == ToolProviderType.API:
|
||||
try:
|
||||
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
@ -582,7 +558,7 @@ class ToolManager:
|
||||
"background": "#252525",
|
||||
"content": "\ud83d\ude01"
|
||||
}
|
||||
elif provider_type == 'workflow':
|
||||
elif provider_type == ToolProviderType.WORKFLOW:
|
||||
provider: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
WorkflowToolProvider.id == provider_id
|
||||
|
@ -9,6 +9,7 @@ from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolPr
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolParameter,
|
||||
ToolProviderCredentials,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.provider.tool_provider import ToolProviderController
|
||||
from core.tools.tool.tool import Tool
|
||||
@ -108,7 +109,7 @@ class ToolParameterConfigurationManager(BaseModel):
|
||||
tenant_id: str
|
||||
tool_runtime: Tool
|
||||
provider_name: str
|
||||
provider_type: str
|
||||
provider_type: ToolProviderType
|
||||
identity_id: str
|
||||
|
||||
def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
@ -191,7 +192,7 @@ class ToolParameterConfigurationManager(BaseModel):
|
||||
"""
|
||||
cache = ToolParameterCache(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=f'{self.provider_type}.{self.provider_name}',
|
||||
provider=f'{self.provider_type.value}.{self.provider_name}',
|
||||
tool_name=self.tool_runtime.identity.name,
|
||||
cache_type=ToolParameterCacheType.PARAMETER,
|
||||
identity_id=self.identity_id
|
||||
@ -221,7 +222,7 @@ class ToolParameterConfigurationManager(BaseModel):
|
||||
def delete_tool_parameters_cache(self):
|
||||
cache = ToolParameterCache(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=f'{self.provider_type}.{self.provider_name}',
|
||||
provider=f'{self.provider_type.value}.{self.provider_name}',
|
||||
tool_name=self.tool_runtime.identity.name,
|
||||
cache_type=ToolParameterCacheType.PARAMETER,
|
||||
identity_id=self.identity_id
|
||||
|
@ -1,4 +1,5 @@
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from mimetypes import guess_extension
|
||||
|
||||
from core.file.file_obj import FileTransferMethod, FileType, FileVar
|
||||
@ -9,20 +10,18 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class ToolFileMessageTransformer:
|
||||
@classmethod
|
||||
def transform_tool_invoke_messages(cls, messages: list[ToolInvokeMessage],
|
||||
def transform_tool_invoke_messages(cls, messages: Generator[ToolInvokeMessage, None, None],
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
conversation_id: str) -> list[ToolInvokeMessage]:
|
||||
conversation_id: str) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""
|
||||
Transform tool message and handle file download
|
||||
"""
|
||||
result = []
|
||||
|
||||
for message in messages:
|
||||
if message.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
result.append(message)
|
||||
yield message
|
||||
elif message.type == ToolInvokeMessage.MessageType.LINK:
|
||||
result.append(message)
|
||||
yield message
|
||||
elif message.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||
# try to download image
|
||||
try:
|
||||
@ -35,20 +34,20 @@ class ToolFileMessageTransformer:
|
||||
|
||||
url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".png"}'
|
||||
|
||||
result.append(ToolInvokeMessage(
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
message=url,
|
||||
save_as=message.save_as,
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
))
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
result.append(ToolInvokeMessage(
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.TEXT,
|
||||
message=f"Failed to download image: {message.message}, you can try to download it yourself.",
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
save_as=message.save_as,
|
||||
))
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
# get mime type and save blob to storage
|
||||
mimetype = message.meta.get('mime_type', 'octet/stream')
|
||||
@ -67,42 +66,40 @@ class ToolFileMessageTransformer:
|
||||
|
||||
# check if file is image
|
||||
if 'image' in mimetype:
|
||||
result.append(ToolInvokeMessage(
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
message=url,
|
||||
save_as=message.save_as,
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
))
|
||||
)
|
||||
else:
|
||||
result.append(ToolInvokeMessage(
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.LINK,
|
||||
message=url,
|
||||
save_as=message.save_as,
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
))
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.FILE_VAR:
|
||||
file_var: FileVar = message.meta.get('file_var')
|
||||
if file_var:
|
||||
if file_var.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
url = cls.get_tool_file_url(file_var.related_id, file_var.extension)
|
||||
if file_var.type == FileType.IMAGE:
|
||||
result.append(ToolInvokeMessage(
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
message=url,
|
||||
save_as=message.save_as,
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
))
|
||||
)
|
||||
else:
|
||||
result.append(ToolInvokeMessage(
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.LINK,
|
||||
message=url,
|
||||
save_as=message.save_as,
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
))
|
||||
)
|
||||
else:
|
||||
result.append(message)
|
||||
|
||||
return result
|
||||
yield message
|
||||
|
||||
@classmethod
|
||||
def get_tool_file_url(cls, tool_file_id: str, extension: str) -> str:
|
||||
|
@ -3,12 +3,13 @@ from typing import Any, Literal, Union
|
||||
from pydantic import BaseModel, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
|
||||
|
||||
class ToolEntity(BaseModel):
|
||||
provider_id: str
|
||||
provider_type: Literal['builtin', 'api', 'workflow']
|
||||
provider_type: ToolProviderType
|
||||
provider_name: str # redundancy
|
||||
tool_name: str
|
||||
tool_label: str # redundancy
|
||||
|
@ -32,7 +32,7 @@ class ToolNode(BaseNode):
|
||||
|
||||
# fetch tool icon
|
||||
tool_info = {
|
||||
'provider_type': node_data.provider_type,
|
||||
'provider_type': node_data.provider_type.value,
|
||||
'provider_id': node_data.provider_id
|
||||
}
|
||||
|
||||
|
@ -1,16 +1,49 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from typing import Any, Union
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.callback_handler.plugin_tool_callback_handler import DifyPluginCallbackHandler
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from models.account import Tenant
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
|
||||
class PluginInvokeService:
|
||||
@classmethod
|
||||
def invoke_tool(cls, user_id: str, tenant: Tenant,
|
||||
tool_provider: str, tool_name: str,
|
||||
def invoke_tool(cls, user_id: str, invoke_from: InvokeFrom, tenant: Tenant,
|
||||
tool_provider_type: ToolProviderType, tool_provider: str, tool_name: str,
|
||||
tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage]:
|
||||
"""
|
||||
Invokes a tool with the given user ID and tool parameters.
|
||||
"""
|
||||
tool_runtime = ToolManager.get_tool_runtime(tool_provider_type, provider_id=tool_provider,
|
||||
tool_name=tool_name, tenant_id=tenant.id,
|
||||
invoke_from=invoke_from)
|
||||
|
||||
response = ToolEngine.plugin_invoke(tool_runtime,
|
||||
tool_parameters,
|
||||
user_id,
|
||||
callback=DifyPluginCallbackHandler())
|
||||
response = ToolFileMessageTransformer.transform_tool_invoke_messages(response)
|
||||
return ToolTransformService.transform_messages_to_dict(response)
|
||||
|
||||
@classmethod
|
||||
def invoke_model(cls, user_id: str, tenant: Tenant,
|
||||
model_provider: str, model_name: str, model_type: ModelType,
|
||||
model_parameters: dict[str, Any]) -> Union[dict, Generator[ToolInvokeMessage]]:
|
||||
"""
|
||||
Invokes a model with the given user ID and model parameters.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def invoke_workflow_node(cls, user_id: str, tenant: Tenant,
|
||||
node_type: NodeType, node_data: dict[str, Any],
|
||||
inputs: dict[str, Any]) -> Generator[ToolInvokeMessage]:
|
||||
"""
|
||||
Invokes a workflow node with the given user ID and node parameters.
|
||||
"""
|
@ -1,5 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
|
||||
from flask import current_app
|
||||
@ -9,6 +10,7 @@ from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderAuthType,
|
||||
ToolInvokeMessage,
|
||||
ToolParameter,
|
||||
ToolProviderCredentials,
|
||||
ToolProviderType,
|
||||
@ -24,8 +26,8 @@ from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvi
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ToolTransformService:
|
||||
@staticmethod
|
||||
def get_tool_provider_icon_url(provider_type: str, provider_name: str, icon: str) -> Union[str, dict]:
|
||||
@classmethod
|
||||
def get_tool_provider_icon_url(cls, provider_type: str, provider_name: str, icon: str) -> Union[str, dict]:
|
||||
"""
|
||||
get tool provider icon url
|
||||
"""
|
||||
@ -45,8 +47,8 @@ class ToolTransformService:
|
||||
|
||||
return ''
|
||||
|
||||
@staticmethod
|
||||
def repack_provider(provider: Union[dict, UserToolProvider]):
|
||||
@classmethod
|
||||
def repack_provider(cls, provider: Union[dict, UserToolProvider]):
|
||||
"""
|
||||
repack provider
|
||||
|
||||
@ -65,8 +67,9 @@ class ToolTransformService:
|
||||
icon=provider.icon
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
def builtin_provider_to_user_provider(
|
||||
cls,
|
||||
provider_controller: BuiltinToolProviderController,
|
||||
db_provider: Optional[BuiltinToolProvider],
|
||||
decrypt_credentials: bool = True,
|
||||
@ -126,8 +129,9 @@ class ToolTransformService:
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
def api_provider_to_controller(
|
||||
cls,
|
||||
db_provider: ApiToolProvider,
|
||||
) -> ApiToolProviderController:
|
||||
"""
|
||||
@ -142,8 +146,9 @@ class ToolTransformService:
|
||||
|
||||
return controller
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
def workflow_provider_to_controller(
|
||||
cls,
|
||||
db_provider: WorkflowToolProvider
|
||||
) -> WorkflowToolProviderController:
|
||||
"""
|
||||
@ -179,8 +184,9 @@ class ToolTransformService:
|
||||
labels=labels or []
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
def api_provider_to_user_provider(
|
||||
cls,
|
||||
provider_controller: ApiToolProviderController,
|
||||
db_provider: ApiToolProvider,
|
||||
decrypt_credentials: bool = True,
|
||||
@ -231,8 +237,9 @@ class ToolTransformService:
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
def tool_to_user_tool(
|
||||
cls,
|
||||
tool: Union[ApiToolBundle, WorkflowTool, Tool],
|
||||
credentials: dict = None,
|
||||
tenant_id: str = None,
|
||||
@ -288,3 +295,8 @@ class ToolTransformService:
|
||||
parameters=tool.parameters,
|
||||
labels=labels
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def transform_messages_to_dict(cls, responses: Generator[ToolInvokeMessage, None, None]):
|
||||
for response in responses:
|
||||
yield response.model_dump()
|
Loading…
Reference in New Issue
Block a user