refactor: tool models

This commit is contained in:
Yeuoly 2024-08-30 15:55:10 +08:00
parent 1fa3b9cfd8
commit cf4e9f317e
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61
6 changed files with 60 additions and 48 deletions

View File

@ -1,5 +1,5 @@
import os
from collections.abc import Mapping, Sequence
from collections.abc import Iterable, Mapping
from typing import Any, Optional, TextIO, Union
from pydantic import BaseModel
@ -55,7 +55,7 @@ class DifyAgentCallbackHandler(BaseModel):
self,
tool_name: str,
tool_inputs: Mapping[str, Any],
tool_outputs: Sequence[ToolInvokeMessage],
tool_outputs: Iterable[ToolInvokeMessage] | str,
message_id: Optional[str] = None,
timer: Optional[Any] = None,
trace_manager: Optional[TraceQueueManager] = None

View File

@ -1,9 +1,9 @@
import json
from collections.abc import Generator, Mapping
from collections.abc import Generator, Iterable
from copy import deepcopy
from datetime import datetime, timezone
from mimetypes import guess_type
from typing import Any, Optional, Union
from typing import Any, Optional, Union, cast
from yarl import URL
@ -40,7 +40,7 @@ class ToolEngine:
user_id: str, tenant_id: str, message: Message, invoke_from: InvokeFrom,
agent_tool_callback: DifyAgentCallbackHandler,
trace_manager: Optional[TraceQueueManager] = None
) -> tuple[str, list[tuple[MessageFile, bool]], ToolInvokeMeta]:
) -> tuple[str, list[tuple[MessageFile, str]], ToolInvokeMeta]:
"""
Agent invokes the tool with the given arguments.
"""
@ -67,9 +67,9 @@ class ToolEngine:
)
messages = ToolEngine._invoke(tool, tool_parameters, user_id)
invocation_meta_dict = {'meta': None}
invocation_meta_dict: dict[str, ToolInvokeMeta] = {}
def message_callback(invocation_meta_dict: dict, messages: Generator[ToolInvokeMessage, None, None]):
def message_callback(invocation_meta_dict: dict, messages: Generator[ToolInvokeMessage | ToolInvokeMeta, None, None]):
for message in messages:
if isinstance(message, ToolInvokeMeta):
invocation_meta_dict['meta'] = message
@ -136,7 +136,7 @@ class ToolEngine:
return error_response, [], ToolInvokeMeta.error_instance(error_response)
@staticmethod
def workflow_invoke(tool: Tool, tool_parameters: Mapping[str, Any],
def workflow_invoke(tool: Tool, tool_parameters: dict[str, Any],
user_id: str,
workflow_tool_callback: DifyWorkflowCallbackHandler,
workflow_call_depth: int,
@ -156,6 +156,7 @@ class ToolEngine:
if tool.runtime and tool.runtime.runtime_parameters:
tool_parameters = {**tool.runtime.runtime_parameters, **tool_parameters}
response = tool.invoke(user_id=user_id, tool_parameters=tool_parameters)
# hit the callback handler
@ -204,6 +205,9 @@ class ToolEngine:
"""
Invoke the tool with the given arguments.
"""
if not tool.runtime:
raise ValueError("missing runtime in tool")
started_at = datetime.now(timezone.utc)
meta = ToolInvokeMeta(time_cost=0.0, error=None, tool_config={
'tool_name': tool.identity.name,
@ -223,42 +227,42 @@ class ToolEngine:
yield meta
@staticmethod
def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str:
def _convert_tool_response_to_str(tool_response: Generator[ToolInvokeMessage, None, None]) -> str:
"""
Handle tool response
"""
result = ''
for response in tool_response:
if response.type == ToolInvokeMessage.MessageType.TEXT:
result += response.message
result += cast(ToolInvokeMessage.TextMessage, response.message).text
elif response.type == ToolInvokeMessage.MessageType.LINK:
result += f"result link: {response.message}. please tell user to check it."
result += f"result link: {cast(ToolInvokeMessage.TextMessage, response.message).text}. please tell user to check it."
elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
response.type == ToolInvokeMessage.MessageType.IMAGE:
result += "image has been created and sent to user already, you do not need to create it, just tell the user to check it now."
elif response.type == ToolInvokeMessage.MessageType.JSON:
result += f"tool response: {json.dumps(response.message, ensure_ascii=False)}."
result += f"tool response: {json.dumps(cast(ToolInvokeMessage.JsonMessage, response.message).json_object, ensure_ascii=False)}."
else:
result += f"tool response: {response.message}."
return result
@staticmethod
def _extract_tool_response_binary(tool_response: list[ToolInvokeMessage]) -> list[ToolInvokeMessageBinary]:
def _extract_tool_response_binary(tool_response: Generator[ToolInvokeMessage, None, None]) -> Generator[ToolInvokeMessageBinary, None, None]:
"""
Extract tool response binary
"""
result = []
for response in tool_response:
if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
response.type == ToolInvokeMessage.MessageType.IMAGE:
mimetype = None
if not response.meta:
raise ValueError("missing meta data")
if response.meta.get('mime_type'):
mimetype = response.meta.get('mime_type')
else:
try:
url = URL(response.message)
url = URL(cast(ToolInvokeMessage.TextMessage, response.message).text)
extension = url.suffix
guess_type_result, _ = guess_type(f'a{extension}')
if guess_type_result:
@ -269,35 +273,36 @@ class ToolEngine:
if not mimetype:
mimetype = 'image/jpeg'
result.append(ToolInvokeMessageBinary(
yield ToolInvokeMessageBinary(
mimetype=response.meta.get('mime_type', 'image/jpeg'),
url=response.message,
url=cast(ToolInvokeMessage.TextMessage, response.message).text,
save_as=response.save_as,
))
)
elif response.type == ToolInvokeMessage.MessageType.BLOB:
result.append(ToolInvokeMessageBinary(
if not response.meta:
raise ValueError("missing meta data")
yield ToolInvokeMessageBinary(
mimetype=response.meta.get('mime_type', 'octet/stream'),
url=response.message,
url=cast(ToolInvokeMessage.TextMessage, response.message).text,
save_as=response.save_as,
))
)
elif response.type == ToolInvokeMessage.MessageType.LINK:
# check if there is a mime type in meta
if response.meta and 'mime_type' in response.meta:
result.append(ToolInvokeMessageBinary(
yield ToolInvokeMessageBinary(
mimetype=response.meta.get('mime_type', 'octet/stream') if response.meta else 'octet/stream',
url=response.message,
url=cast(ToolInvokeMessage.TextMessage, response.message).text,
save_as=response.save_as,
))
return result
)
@staticmethod
def _create_message_files(
tool_messages: list[ToolInvokeMessageBinary],
tool_messages: Iterable[ToolInvokeMessageBinary],
agent_message: Message,
invoke_from: InvokeFrom,
user_id: str
) -> list[tuple[Any, str]]:
) -> list[tuple[MessageFile, str]]:
"""
Create message file

View File

@ -1,4 +1,4 @@
from collections.abc import Generator, Mapping, Sequence
from collections.abc import Generator, Sequence
from os import path
from typing import Any, cast
@ -100,7 +100,7 @@ class ToolNode(BaseNode):
variable_pool: VariablePool,
node_data: ToolNodeData,
for_log: bool = False,
) -> Mapping[str, Any]:
) -> dict[str, Any]:
"""
Generate parameters based on the given tool parameters, variable pool, and node data.
@ -110,7 +110,7 @@ class ToolNode(BaseNode):
node_data (ToolNodeData): The data associated with the tool node.
Returns:
Mapping[str, Any]: A dictionary containing the generated parameters.
dict[str, Any]: A dictionary containing the generated parameters.
"""
tool_parameters_dictionary = {parameter.name: parameter for parameter in tool_parameters}

5
api/models/base.py Normal file
View File

@ -0,0 +1,5 @@
from sqlalchemy.orm import DeclarativeBase
class Base(DeclarativeBase):
pass

View File

@ -14,6 +14,7 @@ from core.file.tool_file_parser import ToolFileParser
from core.file.upload_file_parser import UploadFileParser
from extensions.ext_database import db
from libs.helper import generate_string
from models.base import Base
from .account import Account, Tenant
from .types import StringUUID
@ -211,7 +212,7 @@ class App(db.Model):
return tags if tags else []
class AppModelConfig(db.Model):
class AppModelConfig(Base):
__tablename__ = 'app_model_configs'
__table_args__ = (
db.PrimaryKeyConstraint('id', name='app_model_config_pkey'),
@ -550,6 +551,9 @@ class Conversation(db.Model):
else:
app_model_config = db.session.query(AppModelConfig).filter(
AppModelConfig.id == self.app_model_config_id).first()
if not app_model_config:
raise ValueError("app config not found")
model_config = app_model_config.to_dict()
@ -640,7 +644,7 @@ class Conversation(db.Model):
return self.override_model_configs is not None
class Message(db.Model):
class Message(Base):
__tablename__ = 'messages'
__table_args__ = (
db.PrimaryKeyConstraint('id', name='message_pkey'),
@ -932,7 +936,7 @@ class MessageFeedback(db.Model):
return account
class MessageFile(db.Model):
class MessageFile(Base):
__tablename__ = 'message_files'
__table_args__ = (
db.PrimaryKeyConstraint('id', name='message_file_pkey'),
@ -940,15 +944,15 @@ class MessageFile(db.Model):
db.Index('message_file_created_by_idx', 'created_by')
)
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
message_id = db.Column(StringUUID, nullable=False)
type = db.Column(db.String(255), nullable=False)
transfer_method = db.Column(db.String(255), nullable=False)
url = db.Column(db.Text, nullable=True)
belongs_to = db.Column(db.String(255), nullable=True)
upload_file_id = db.Column(StringUUID, nullable=True)
created_by_role = db.Column(db.String(255), nullable=False)
created_by = db.Column(StringUUID, nullable=False)
id: Mapped[str] = mapped_column(StringUUID, default=db.text('uuid_generate_v4()'))
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
type: Mapped[str] = mapped_column(db.String(255), nullable=False)
transfer_method: Mapped[str] = mapped_column(db.String(255), nullable=False)
url: Mapped[str] = mapped_column(db.Text, nullable=True)
belongs_to: Mapped[str] = mapped_column(db.String(255), nullable=True)
upload_file_id: Mapped[str] = mapped_column(StringUUID, nullable=True)
created_by_role: Mapped[str] = mapped_column(db.String(255), nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))

View File

@ -1,12 +1,13 @@
import json
from sqlalchemy import ForeignKey
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from sqlalchemy.orm import Mapped, mapped_column
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration
from extensions.ext_database import db
from models.base import Base
from .model import Account, App, Tenant
from .types import StringUUID
@ -277,9 +278,6 @@ class ToolConversationVariables(db.Model):
@property
def variables(self) -> dict:
return json.loads(self.variables_str)
class Base(DeclarativeBase):
pass
class ToolFile(Base):
"""