refactor: tool models
This commit is contained in:
parent
1fa3b9cfd8
commit
cf4e9f317e
@ -1,5 +1,5 @@
|
||||
import os
|
||||
from collections.abc import Mapping, Sequence
|
||||
from collections.abc import Iterable, Mapping
|
||||
from typing import Any, Optional, TextIO, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
@ -55,7 +55,7 @@ class DifyAgentCallbackHandler(BaseModel):
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_inputs: Mapping[str, Any],
|
||||
tool_outputs: Sequence[ToolInvokeMessage],
|
||||
tool_outputs: Iterable[ToolInvokeMessage] | str,
|
||||
message_id: Optional[str] = None,
|
||||
timer: Optional[Any] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
|
@ -1,9 +1,9 @@
|
||||
import json
|
||||
from collections.abc import Generator, Mapping
|
||||
from collections.abc import Generator, Iterable
|
||||
from copy import deepcopy
|
||||
from datetime import datetime, timezone
|
||||
from mimetypes import guess_type
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from yarl import URL
|
||||
|
||||
@ -40,7 +40,7 @@ class ToolEngine:
|
||||
user_id: str, tenant_id: str, message: Message, invoke_from: InvokeFrom,
|
||||
agent_tool_callback: DifyAgentCallbackHandler,
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> tuple[str, list[tuple[MessageFile, bool]], ToolInvokeMeta]:
|
||||
) -> tuple[str, list[tuple[MessageFile, str]], ToolInvokeMeta]:
|
||||
"""
|
||||
Agent invokes the tool with the given arguments.
|
||||
"""
|
||||
@ -67,9 +67,9 @@ class ToolEngine:
|
||||
)
|
||||
|
||||
messages = ToolEngine._invoke(tool, tool_parameters, user_id)
|
||||
invocation_meta_dict = {'meta': None}
|
||||
invocation_meta_dict: dict[str, ToolInvokeMeta] = {}
|
||||
|
||||
def message_callback(invocation_meta_dict: dict, messages: Generator[ToolInvokeMessage, None, None]):
|
||||
def message_callback(invocation_meta_dict: dict, messages: Generator[ToolInvokeMessage | ToolInvokeMeta, None, None]):
|
||||
for message in messages:
|
||||
if isinstance(message, ToolInvokeMeta):
|
||||
invocation_meta_dict['meta'] = message
|
||||
@ -136,7 +136,7 @@ class ToolEngine:
|
||||
return error_response, [], ToolInvokeMeta.error_instance(error_response)
|
||||
|
||||
@staticmethod
|
||||
def workflow_invoke(tool: Tool, tool_parameters: Mapping[str, Any],
|
||||
def workflow_invoke(tool: Tool, tool_parameters: dict[str, Any],
|
||||
user_id: str,
|
||||
workflow_tool_callback: DifyWorkflowCallbackHandler,
|
||||
workflow_call_depth: int,
|
||||
@ -156,6 +156,7 @@ class ToolEngine:
|
||||
|
||||
if tool.runtime and tool.runtime.runtime_parameters:
|
||||
tool_parameters = {**tool.runtime.runtime_parameters, **tool_parameters}
|
||||
|
||||
response = tool.invoke(user_id=user_id, tool_parameters=tool_parameters)
|
||||
|
||||
# hit the callback handler
|
||||
@ -204,6 +205,9 @@ class ToolEngine:
|
||||
"""
|
||||
Invoke the tool with the given arguments.
|
||||
"""
|
||||
if not tool.runtime:
|
||||
raise ValueError("missing runtime in tool")
|
||||
|
||||
started_at = datetime.now(timezone.utc)
|
||||
meta = ToolInvokeMeta(time_cost=0.0, error=None, tool_config={
|
||||
'tool_name': tool.identity.name,
|
||||
@ -223,42 +227,42 @@ class ToolEngine:
|
||||
yield meta
|
||||
|
||||
@staticmethod
|
||||
def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str:
|
||||
def _convert_tool_response_to_str(tool_response: Generator[ToolInvokeMessage, None, None]) -> str:
|
||||
"""
|
||||
Handle tool response
|
||||
"""
|
||||
result = ''
|
||||
for response in tool_response:
|
||||
if response.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
result += response.message
|
||||
result += cast(ToolInvokeMessage.TextMessage, response.message).text
|
||||
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
||||
result += f"result link: {response.message}. please tell user to check it."
|
||||
result += f"result link: {cast(ToolInvokeMessage.TextMessage, response.message).text}. please tell user to check it."
|
||||
elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
|
||||
response.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||
result += "image has been created and sent to user already, you do not need to create it, just tell the user to check it now."
|
||||
elif response.type == ToolInvokeMessage.MessageType.JSON:
|
||||
result += f"tool response: {json.dumps(response.message, ensure_ascii=False)}."
|
||||
result += f"tool response: {json.dumps(cast(ToolInvokeMessage.JsonMessage, response.message).json_object, ensure_ascii=False)}."
|
||||
else:
|
||||
result += f"tool response: {response.message}."
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _extract_tool_response_binary(tool_response: list[ToolInvokeMessage]) -> list[ToolInvokeMessageBinary]:
|
||||
def _extract_tool_response_binary(tool_response: Generator[ToolInvokeMessage, None, None]) -> Generator[ToolInvokeMessageBinary, None, None]:
|
||||
"""
|
||||
Extract tool response binary
|
||||
"""
|
||||
result = []
|
||||
|
||||
for response in tool_response:
|
||||
if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
|
||||
response.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||
mimetype = None
|
||||
if not response.meta:
|
||||
raise ValueError("missing meta data")
|
||||
if response.meta.get('mime_type'):
|
||||
mimetype = response.meta.get('mime_type')
|
||||
else:
|
||||
try:
|
||||
url = URL(response.message)
|
||||
url = URL(cast(ToolInvokeMessage.TextMessage, response.message).text)
|
||||
extension = url.suffix
|
||||
guess_type_result, _ = guess_type(f'a{extension}')
|
||||
if guess_type_result:
|
||||
@ -269,35 +273,36 @@ class ToolEngine:
|
||||
if not mimetype:
|
||||
mimetype = 'image/jpeg'
|
||||
|
||||
result.append(ToolInvokeMessageBinary(
|
||||
yield ToolInvokeMessageBinary(
|
||||
mimetype=response.meta.get('mime_type', 'image/jpeg'),
|
||||
url=response.message,
|
||||
url=cast(ToolInvokeMessage.TextMessage, response.message).text,
|
||||
save_as=response.save_as,
|
||||
))
|
||||
)
|
||||
elif response.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
result.append(ToolInvokeMessageBinary(
|
||||
if not response.meta:
|
||||
raise ValueError("missing meta data")
|
||||
|
||||
yield ToolInvokeMessageBinary(
|
||||
mimetype=response.meta.get('mime_type', 'octet/stream'),
|
||||
url=response.message,
|
||||
url=cast(ToolInvokeMessage.TextMessage, response.message).text,
|
||||
save_as=response.save_as,
|
||||
))
|
||||
)
|
||||
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
||||
# check if there is a mime type in meta
|
||||
if response.meta and 'mime_type' in response.meta:
|
||||
result.append(ToolInvokeMessageBinary(
|
||||
yield ToolInvokeMessageBinary(
|
||||
mimetype=response.meta.get('mime_type', 'octet/stream') if response.meta else 'octet/stream',
|
||||
url=response.message,
|
||||
url=cast(ToolInvokeMessage.TextMessage, response.message).text,
|
||||
save_as=response.save_as,
|
||||
))
|
||||
|
||||
return result
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _create_message_files(
|
||||
tool_messages: list[ToolInvokeMessageBinary],
|
||||
tool_messages: Iterable[ToolInvokeMessageBinary],
|
||||
agent_message: Message,
|
||||
invoke_from: InvokeFrom,
|
||||
user_id: str
|
||||
) -> list[tuple[Any, str]]:
|
||||
) -> list[tuple[MessageFile, str]]:
|
||||
"""
|
||||
Create message file
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from collections.abc import Generator, Sequence
|
||||
from os import path
|
||||
from typing import Any, cast
|
||||
|
||||
@ -100,7 +100,7 @@ class ToolNode(BaseNode):
|
||||
variable_pool: VariablePool,
|
||||
node_data: ToolNodeData,
|
||||
for_log: bool = False,
|
||||
) -> Mapping[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Generate parameters based on the given tool parameters, variable pool, and node data.
|
||||
|
||||
@ -110,7 +110,7 @@ class ToolNode(BaseNode):
|
||||
node_data (ToolNodeData): The data associated with the tool node.
|
||||
|
||||
Returns:
|
||||
Mapping[str, Any]: A dictionary containing the generated parameters.
|
||||
dict[str, Any]: A dictionary containing the generated parameters.
|
||||
|
||||
"""
|
||||
tool_parameters_dictionary = {parameter.name: parameter for parameter in tool_parameters}
|
||||
|
5
api/models/base.py
Normal file
5
api/models/base.py
Normal file
@ -0,0 +1,5 @@
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
@ -14,6 +14,7 @@ from core.file.tool_file_parser import ToolFileParser
|
||||
from core.file.upload_file_parser import UploadFileParser
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import generate_string
|
||||
from models.base import Base
|
||||
|
||||
from .account import Account, Tenant
|
||||
from .types import StringUUID
|
||||
@ -211,7 +212,7 @@ class App(db.Model):
|
||||
return tags if tags else []
|
||||
|
||||
|
||||
class AppModelConfig(db.Model):
|
||||
class AppModelConfig(Base):
|
||||
__tablename__ = 'app_model_configs'
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint('id', name='app_model_config_pkey'),
|
||||
@ -550,6 +551,9 @@ class Conversation(db.Model):
|
||||
else:
|
||||
app_model_config = db.session.query(AppModelConfig).filter(
|
||||
AppModelConfig.id == self.app_model_config_id).first()
|
||||
|
||||
if not app_model_config:
|
||||
raise ValueError("app config not found")
|
||||
|
||||
model_config = app_model_config.to_dict()
|
||||
|
||||
@ -640,7 +644,7 @@ class Conversation(db.Model):
|
||||
return self.override_model_configs is not None
|
||||
|
||||
|
||||
class Message(db.Model):
|
||||
class Message(Base):
|
||||
__tablename__ = 'messages'
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint('id', name='message_pkey'),
|
||||
@ -932,7 +936,7 @@ class MessageFeedback(db.Model):
|
||||
return account
|
||||
|
||||
|
||||
class MessageFile(db.Model):
|
||||
class MessageFile(Base):
|
||||
__tablename__ = 'message_files'
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint('id', name='message_file_pkey'),
|
||||
@ -940,15 +944,15 @@ class MessageFile(db.Model):
|
||||
db.Index('message_file_created_by_idx', 'created_by')
|
||||
)
|
||||
|
||||
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
message_id = db.Column(StringUUID, nullable=False)
|
||||
type = db.Column(db.String(255), nullable=False)
|
||||
transfer_method = db.Column(db.String(255), nullable=False)
|
||||
url = db.Column(db.Text, nullable=True)
|
||||
belongs_to = db.Column(db.String(255), nullable=True)
|
||||
upload_file_id = db.Column(StringUUID, nullable=True)
|
||||
created_by_role = db.Column(db.String(255), nullable=False)
|
||||
created_by = db.Column(StringUUID, nullable=False)
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=db.text('uuid_generate_v4()'))
|
||||
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
type: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
transfer_method: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
url: Mapped[str] = mapped_column(db.Text, nullable=True)
|
||||
belongs_to: Mapped[str] = mapped_column(db.String(255), nullable=True)
|
||||
upload_file_id: Mapped[str] = mapped_column(StringUUID, nullable=True)
|
||||
created_by_role: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
|
||||
|
||||
|
@ -1,12 +1,13 @@
|
||||
import json
|
||||
|
||||
from sqlalchemy import ForeignKey
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration
|
||||
from extensions.ext_database import db
|
||||
from models.base import Base
|
||||
|
||||
from .model import Account, App, Tenant
|
||||
from .types import StringUUID
|
||||
@ -277,9 +278,6 @@ class ToolConversationVariables(db.Model):
|
||||
@property
|
||||
def variables(self) -> dict:
|
||||
return json.loads(self.variables_str)
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
class ToolFile(Base):
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user