refactor: using DeclarativeBase as parent class of models, refactored tools

This commit is contained in:
Yeuoly 2024-09-29 17:00:58 +08:00
parent c8bc3892b3
commit e9e5c8806a
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61
17 changed files with 225 additions and 120 deletions

View File

@ -6,7 +6,7 @@ from flask_restful import Resource, reqparse
from configs import dify_config from configs import dify_config
from libs.helper import StrLen, email, get_remote_ip from libs.helper import StrLen, email, get_remote_ip
from libs.password import valid_password from libs.password import valid_password
from models.model import DifySetup from models.model import DifySetup, db
from services.account_service import RegisterService, TenantService from services.account_service import RegisterService, TenantService
from . import api from . import api
@ -69,7 +69,7 @@ def setup_required(view):
def get_setup_status(): def get_setup_status():
if dify_config.EDITION == "SELF_HOSTED": if dify_config.EDITION == "SELF_HOSTED":
return DifySetup.query.first() return db.session.query(DifySetup).first()
else: else:
return True return True

View File

@ -610,16 +610,17 @@ class ToolLabelsApi(Resource):
api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers") api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers")
# builtin tool provider # builtin tool provider
api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin/<provider>/tools") api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/tools")
api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin/<provider>/delete") api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin/<path:provider>/delete")
api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin/<provider>/update") api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin/<path:provider>/update")
api.add_resource( api.add_resource(
ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin/<provider>/credentials" ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/credentials"
) )
api.add_resource( api.add_resource(
ToolBuiltinProviderCredentialsSchemaApi, "/workspaces/current/tool-provider/builtin/<provider>/credentials_schema" ToolBuiltinProviderCredentialsSchemaApi,
"/workspaces/current/tool-provider/builtin/<path:provider>/credentials_schema",
) )
api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin/<provider>/icon") api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin/<path:provider>/icon")
# api tool provider # api tool provider
api.add_resource(ToolApiProviderAddApi, "/workspaces/current/tool-provider/api/add") api.add_resource(ToolApiProviderAddApi, "/workspaces/current/tool-provider/api/add")

View File

@ -14,9 +14,9 @@ class PluginToolManager(BasePluginManager):
provider follows format: plugin_id/provider_name provider follows format: plugin_id/provider_name
""" """
if "/" in provider: if "/" in provider:
parts = provider.split("/", 1) parts = provider.split("/", -1)
if len(parts) == 2: if len(parts) >= 2:
return parts[0], parts[1] return "/".join(parts[:-1]), parts[-1]
raise ValueError(f"invalid provider format: {provider}") raise ValueError(f"invalid provider format: {provider}")
raise ValueError(f"invalid provider format: {provider}") raise ValueError(f"invalid provider format: {provider}")
@ -46,6 +46,10 @@ class PluginToolManager(BasePluginManager):
for provider in response: for provider in response:
provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}" provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}"
# override the provider name for each tool to plugin_id/provider_name
for tool in provider.declaration.tools:
tool.identity.provider = provider.declaration.identity.name
return response return response
def fetch_tool_provider(self, tenant_id: str, provider: str) -> PluginToolProviderEntity: def fetch_tool_provider(self, tenant_id: str, provider: str) -> PluginToolProviderEntity:
@ -54,15 +58,26 @@ class PluginToolManager(BasePluginManager):
""" """
plugin_id, provider_name = self._split_provider(provider) plugin_id, provider_name = self._split_provider(provider)
def transformer(json_response: dict[str, Any]) -> dict:
for tool in json_response.get("data", {}).get("declaration", {}).get("tools", []):
tool["identity"]["provider"] = provider_name
return json_response
response = self._request_with_plugin_daemon_response( response = self._request_with_plugin_daemon_response(
"GET", "GET",
f"plugin/{tenant_id}/management/tool", f"plugin/{tenant_id}/management/tool",
PluginToolProviderEntity, PluginToolProviderEntity,
params={"provider": provider_name, "plugin_id": plugin_id}, params={"provider": provider_name, "plugin_id": plugin_id},
transformer=transformer,
) )
response.declaration.identity.name = f"{response.plugin_id}/{response.declaration.identity.name}" response.declaration.identity.name = f"{response.plugin_id}/{response.declaration.identity.name}"
# override the provider name for each tool to plugin_id/provider_name
for tool in response.declaration.tools:
tool.identity.provider = response.declaration.identity.name
return response return response
def invoke( def invoke(

View File

@ -11,12 +11,10 @@ from core.tools.plugin_tool.tool import PluginTool
class PluginToolProviderController(BuiltinToolProviderController): class PluginToolProviderController(BuiltinToolProviderController):
entity: ToolProviderEntityWithPlugin entity: ToolProviderEntityWithPlugin
tenant_id: str tenant_id: str
plugin_id: str
def __init__(self, entity: ToolProviderEntityWithPlugin, tenant_id: str, plugin_id: str) -> None: def __init__(self, entity: ToolProviderEntityWithPlugin, tenant_id: str) -> None:
self.entity = entity self.entity = entity
self.tenant_id = tenant_id self.tenant_id = tenant_id
self.plugin_id = plugin_id
@property @property
def provider_type(self) -> ToolProviderType: def provider_type(self) -> ToolProviderType:
@ -35,7 +33,6 @@ class PluginToolProviderController(BuiltinToolProviderController):
if not manager.validate_provider_credentials( if not manager.validate_provider_credentials(
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
user_id=user_id, user_id=user_id,
plugin_id=self.plugin_id,
provider=self.entity.identity.name, provider=self.entity.identity.name,
credentials=credentials, credentials=credentials,
): ):
@ -54,7 +51,6 @@ class PluginToolProviderController(BuiltinToolProviderController):
entity=tool_entity, entity=tool_entity,
runtime=ToolRuntime(tenant_id=self.tenant_id), runtime=ToolRuntime(tenant_id=self.tenant_id),
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
plugin_id=self.plugin_id,
) )
def get_tools(self) -> list[PluginTool]: def get_tools(self) -> list[PluginTool]:
@ -66,7 +62,6 @@ class PluginToolProviderController(BuiltinToolProviderController):
entity=tool_entity, entity=tool_entity,
runtime=ToolRuntime(tenant_id=self.tenant_id), runtime=ToolRuntime(tenant_id=self.tenant_id),
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
plugin_id=self.plugin_id,
) )
for tool_entity in self.entity.tools for tool_entity in self.entity.tools
] ]

View File

@ -9,12 +9,10 @@ from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, Too
class PluginTool(Tool): class PluginTool(Tool):
tenant_id: str tenant_id: str
plugin_id: str
def __init__(self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str, plugin_id: str) -> None: def __init__(self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str) -> None:
super().__init__(entity, runtime) super().__init__(entity, runtime)
self.tenant_id = tenant_id self.tenant_id = tenant_id
self.plugin_id = plugin_id
@property @property
def tool_provider_type(self) -> ToolProviderType: def tool_provider_type(self) -> ToolProviderType:
@ -25,7 +23,6 @@ class PluginTool(Tool):
return manager.invoke( return manager.invoke(
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
user_id=user_id, user_id=user_id,
plugin_id=self.plugin_id,
tool_provider=self.entity.identity.provider, tool_provider=self.entity.identity.provider,
tool_name=self.entity.identity.name, tool_name=self.entity.identity.name,
credentials=self.runtime.credentials, credentials=self.runtime.credentials,
@ -37,5 +34,4 @@ class PluginTool(Tool):
entity=self.entity, entity=self.entity,
runtime=runtime, runtime=runtime,
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
plugin_id=self.plugin_id,
) )

View File

@ -86,7 +86,6 @@ class ToolManager:
return PluginToolProviderController( return PluginToolProviderController(
entity=provider_entity.declaration, entity=provider_entity.declaration,
tenant_id=tenant_id, tenant_id=tenant_id,
plugin_id=provider_entity.plugin_id,
) )
@classmethod @classmethod
@ -158,12 +157,11 @@ class ToolManager:
# decrypt the credentials # decrypt the credentials
credentials = builtin_provider.credentials credentials = builtin_provider.credentials
controller = cls.get_builtin_provider(provider_id, tenant_id)
tool_configuration = ProviderConfigEncrypter( tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id, tenant_id=tenant_id,
config=controller.get_credentials_schema(), config=provider_controller.get_credentials_schema(),
provider_type=controller.provider_type.value, provider_type=provider_controller.provider_type.value,
provider_identity=controller.entity.identity.name, provider_identity=provider_controller.entity.identity.name,
) )
decrypted_credentials = tool_configuration.decrypt(credentials) decrypted_credentials = tool_configuration.decrypt(credentials)
@ -400,7 +398,6 @@ class ToolManager:
PluginToolProviderController( PluginToolProviderController(
entity=provider.declaration, entity=provider.declaration,
tenant_id=tenant_id, tenant_id=tenant_id,
plugin_id=provider.plugin_id,
) )
for provider in provider_entities for provider in provider_entities
] ]
@ -525,7 +522,7 @@ class ToolManager:
) )
if isinstance(provider, PluginToolProviderController): if isinstance(provider, PluginToolProviderController):
result_providers[f"plugin_provider.{user_provider.name}.{provider.plugin_id}"] = user_provider result_providers[f"plugin_provider.{user_provider.name}"] = user_provider
else: else:
result_providers[f"builtin_provider.{user_provider.name}"] = user_provider result_providers[f"builtin_provider.{user_provider.name}"] = user_provider

View File

@ -31,19 +31,16 @@ def get_engine_url():
# from myapp import mymodel # from myapp import mymodel
# target_metadata = mymodel.Base.metadata # target_metadata = mymodel.Base.metadata
config.set_main_option('sqlalchemy.url', get_engine_url()) config.set_main_option('sqlalchemy.url', get_engine_url())
target_db = current_app.extensions['migrate'].db
# other values from the config, defined by the needs of env.py, # other values from the config, defined by the needs of env.py,
# can be acquired: # can be acquired:
# my_important_option = config.get_main_option("my_important_option") # my_important_option = config.get_main_option("my_important_option")
# ... etc. # ... etc.
from models.base import Base
def get_metadata(): def get_metadata():
if hasattr(target_db, 'metadatas'): return Base.metadata
return target_db.metadatas[None]
return target_db.metadata
def include_object(object, name, type_, reflected, compare_to): def include_object(object, name, type_, reflected, compare_to):
if type_ == "foreign_key_constraint": if type_ == "foreign_key_constraint":

View File

@ -0,0 +1,39 @@
"""increase max length of builtin tool provider
Revision ID: ddcc8bbef391
Revises: d57ba9ebb251
Create Date: 2024-09-29 08:35:58.062698
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'ddcc8bbef391'
down_revision = 'd57ba9ebb251'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
batch_op.alter_column('provider',
existing_type=sa.VARCHAR(length=40),
type_=sa.String(length=256),
existing_nullable=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
batch_op.alter_column('provider',
existing_type=sa.String(length=256),
type_=sa.VARCHAR(length=40),
existing_nullable=False)
# ### end Alembic commands ###

View File

@ -1,5 +1,5 @@
from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import declarative_base
from extensions.ext_database import metadata
class Base(DeclarativeBase): Base = declarative_base(metadata=metadata)
pass

View File

@ -2,12 +2,15 @@ import json
import re import re
import uuid import uuid
from enum import Enum from enum import Enum
from typing import Optional from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from models.workflow import Workflow
from flask import request from flask import request
from flask_login import UserMixin from flask_login import UserMixin
from sqlalchemy import Float, func, text from sqlalchemy import Float, Index, PrimaryKeyConstraint, func, text
from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.orm import Mapped, mapped_column
from configs import dify_config from configs import dify_config
from core.file.tool_file_parser import ToolFileParser from core.file.tool_file_parser import ToolFileParser
@ -20,7 +23,7 @@ from .account import Account, Tenant
from .types import StringUUID from .types import StringUUID
class DifySetup(db.Model): class DifySetup(Base):
__tablename__ = "dify_setups" __tablename__ = "dify_setups"
__table_args__ = (db.PrimaryKeyConstraint("version", name="dify_setup_pkey"),) __table_args__ = (db.PrimaryKeyConstraint("version", name="dify_setup_pkey"),)
@ -55,7 +58,7 @@ class IconType(Enum):
EMOJI = "emoji" EMOJI = "emoji"
class App(db.Model): class App(Base):
__tablename__ = "apps" __tablename__ = "apps"
__table_args__ = (db.PrimaryKeyConstraint("id", name="app_pkey"), db.Index("app_tenant_id_idx", "tenant_id")) __table_args__ = (db.PrimaryKeyConstraint("id", name="app_pkey"), db.Index("app_tenant_id_idx", "tenant_id"))
@ -133,7 +136,8 @@ class App(db.Model):
return False return False
if not app_model_config.agent_mode: if not app_model_config.agent_mode:
return False return False
if self.app_model_config.agent_mode_dict.get("enabled", False) and self.app_model_config.agent_mode_dict.get(
if app_model_config.agent_mode_dict.get("enabled", False) and app_model_config.agent_mode_dict.get(
"strategy", "" "strategy", ""
) in {"function_call", "react"}: ) in {"function_call", "react"}:
self.mode = AppMode.AGENT_CHAT.value self.mode = AppMode.AGENT_CHAT.value
@ -250,7 +254,7 @@ class AppModelConfig(Base):
return app return app
@property @property
def model_dict(self) -> dict: def model_dict(self):
return json.loads(self.model) if self.model else None return json.loads(self.model) if self.model else None
@property @property
@ -284,6 +288,9 @@ class AppModelConfig(Base):
) )
if annotation_setting: if annotation_setting:
collection_binding_detail = annotation_setting.collection_binding_detail collection_binding_detail = annotation_setting.collection_binding_detail
if not collection_binding_detail:
raise ValueError("Collection binding detail not found")
return { return {
"id": annotation_setting.id, "id": annotation_setting.id,
"enabled": True, "enabled": True,
@ -314,7 +321,7 @@ class AppModelConfig(Base):
return json.loads(self.external_data_tools) if self.external_data_tools else [] return json.loads(self.external_data_tools) if self.external_data_tools else []
@property @property
def user_input_form_list(self) -> dict: def user_input_form_list(self):
return json.loads(self.user_input_form) if self.user_input_form else [] return json.loads(self.user_input_form) if self.user_input_form else []
@property @property
@ -458,7 +465,7 @@ class AppModelConfig(Base):
return new_app_model_config return new_app_model_config
class RecommendedApp(db.Model): class RecommendedApp(Base):
__tablename__ = "recommended_apps" __tablename__ = "recommended_apps"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="recommended_app_pkey"), db.PrimaryKeyConstraint("id", name="recommended_app_pkey"),
@ -486,7 +493,7 @@ class RecommendedApp(db.Model):
return app return app
class InstalledApp(db.Model): class InstalledApp(Base):
__tablename__ = "installed_apps" __tablename__ = "installed_apps"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="installed_app_pkey"), db.PrimaryKeyConstraint("id", name="installed_app_pkey"),
@ -522,7 +529,7 @@ class Conversation(Base):
db.Index("conversation_app_from_user_idx", "app_id", "from_source", "from_end_user_id"), db.Index("conversation_app_from_user_idx", "app_id", "from_source", "from_end_user_id"),
) )
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
app_id = db.Column(StringUUID, nullable=False) app_id = db.Column(StringUUID, nullable=False)
app_model_config_id = db.Column(StringUUID, nullable=True) app_model_config_id = db.Column(StringUUID, nullable=True)
model_provider = db.Column(db.String(255), nullable=True) model_provider = db.Column(db.String(255), nullable=True)
@ -546,10 +553,8 @@ class Conversation(Base):
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
messages: Mapped[list["Message"]] = relationship( messages = db.relationship("Message", backref="conversation", lazy="select", passive_deletes="all")
"Message", backref="conversation", lazy="select", passive_deletes="all" message_annotations = db.relationship(
)
message_annotations: Mapped[list["MessageAnnotation"]] = relationship(
"MessageAnnotation", backref="conversation", lazy="select", passive_deletes="all" "MessageAnnotation", backref="conversation", lazy="select", passive_deletes="all"
) )
@ -578,7 +583,7 @@ class Conversation(Base):
) )
if not app_model_config: if not app_model_config:
raise ValueError("app config not found") return {}
model_config = app_model_config.to_dict() model_config = app_model_config.to_dict()
@ -692,12 +697,12 @@ class Conversation(Base):
class Message(Base): class Message(Base):
__tablename__ = "messages" __tablename__ = "messages"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="message_pkey"), PrimaryKeyConstraint("id", name="message_pkey"),
db.Index("message_app_id_idx", "app_id", "created_at"), Index("message_app_id_idx", "app_id", "created_at"),
db.Index("message_conversation_id_idx", "conversation_id"), Index("message_conversation_id_idx", "conversation_id"),
db.Index("message_end_user_idx", "app_id", "from_source", "from_end_user_id"), Index("message_end_user_idx", "app_id", "from_source", "from_end_user_id"),
db.Index("message_account_idx", "app_id", "from_source", "from_account_id"), Index("message_account_idx", "app_id", "from_source", "from_account_id"),
db.Index("message_workflow_run_id_idx", "conversation_id", "workflow_run_id"), Index("message_workflow_run_id_idx", "conversation_id", "workflow_run_id"),
) )
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
@ -705,10 +710,10 @@ class Message(Base):
model_provider = db.Column(db.String(255), nullable=True) model_provider = db.Column(db.String(255), nullable=True)
model_id = db.Column(db.String(255), nullable=True) model_id = db.Column(db.String(255), nullable=True)
override_model_configs = db.Column(db.Text) override_model_configs = db.Column(db.Text)
conversation_id: Mapped[str] = mapped_column(StringUUID, db.ForeignKey("conversations.id"), nullable=False) conversation_id = db.Column(StringUUID, db.ForeignKey("conversations.id"), nullable=False)
inputs: Mapped[str] = mapped_column(db.JSON) inputs = db.Column(db.JSON)
query: Mapped[str] = mapped_column(db.Text, nullable=False) query = db.Column(db.Text, nullable=False)
message: Mapped[str] = mapped_column(db.JSON, nullable=False) message = db.Column(db.JSON, nullable=False)
message_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) message_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
message_unit_price = db.Column(db.Numeric(10, 4), nullable=False) message_unit_price = db.Column(db.Numeric(10, 4), nullable=False)
message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
@ -974,7 +979,7 @@ class Message(Base):
) )
class MessageFeedback(db.Model): class MessageFeedback(Base):
__tablename__ = "message_feedbacks" __tablename__ = "message_feedbacks"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="message_feedback_pkey"), db.PrimaryKeyConstraint("id", name="message_feedback_pkey"),
@ -1009,15 +1014,15 @@ class MessageFile(Base):
db.Index("message_file_created_by_idx", "created_by"), db.Index("message_file_created_by_idx", "created_by"),
) )
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False) message_id = db.Column(StringUUID, nullable=False)
type: Mapped[str] = mapped_column(db.String(255), nullable=False) type = db.Column(db.String(255), nullable=False)
transfer_method: Mapped[str] = mapped_column(db.String(255), nullable=False) transfer_method = db.Column(db.String(255), nullable=False)
url: Mapped[str] = mapped_column(db.Text, nullable=True) url = db.Column(db.Text, nullable=True)
belongs_to: Mapped[str] = mapped_column(db.String(255), nullable=True) belongs_to = db.Column(db.String(255), nullable=True)
upload_file_id: Mapped[str] = mapped_column(StringUUID, nullable=True) upload_file_id = db.Column(StringUUID, nullable=True)
created_by_role: Mapped[str] = mapped_column(db.String(255), nullable=False) created_by_role = db.Column(db.String(255), nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
@ -1032,7 +1037,7 @@ class MessageAnnotation(Base):
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
app_id = db.Column(StringUUID, nullable=False) app_id = db.Column(StringUUID, nullable=False)
conversation_id: Mapped[str] = mapped_column(StringUUID, db.ForeignKey("conversations.id"), nullable=True) conversation_id = db.Column(StringUUID, db.ForeignKey("conversations.id"), nullable=True)
message_id = db.Column(StringUUID, nullable=True) message_id = db.Column(StringUUID, nullable=True)
question = db.Column(db.Text, nullable=True) question = db.Column(db.Text, nullable=True)
content = db.Column(db.Text, nullable=False) content = db.Column(db.Text, nullable=False)
@ -1052,7 +1057,7 @@ class MessageAnnotation(Base):
return account return account
class AppAnnotationHitHistory(db.Model): class AppAnnotationHitHistory(Base):
__tablename__ = "app_annotation_hit_histories" __tablename__ = "app_annotation_hit_histories"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="app_annotation_hit_histories_pkey"), db.PrimaryKeyConstraint("id", name="app_annotation_hit_histories_pkey"),
@ -1090,7 +1095,7 @@ class AppAnnotationHitHistory(db.Model):
return account return account
class AppAnnotationSetting(db.Model): class AppAnnotationSetting(Base):
__tablename__ = "app_annotation_settings" __tablename__ = "app_annotation_settings"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="app_annotation_settings_pkey"), db.PrimaryKeyConstraint("id", name="app_annotation_settings_pkey"),
@ -1138,7 +1143,7 @@ class AppAnnotationSetting(db.Model):
return collection_binding_detail return collection_binding_detail
class OperationLog(db.Model): class OperationLog(Base):
__tablename__ = "operation_logs" __tablename__ = "operation_logs"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="operation_log_pkey"), db.PrimaryKeyConstraint("id", name="operation_log_pkey"),
@ -1155,7 +1160,7 @@ class OperationLog(db.Model):
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
class EndUser(UserMixin, db.Model): class EndUser(UserMixin, Base):
__tablename__ = "end_users" __tablename__ = "end_users"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="end_user_pkey"), db.PrimaryKeyConstraint("id", name="end_user_pkey"),
@ -1175,7 +1180,7 @@ class EndUser(UserMixin, db.Model):
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
class Site(db.Model): class Site(Base):
__tablename__ = "sites" __tablename__ = "sites"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="site_pkey"), db.PrimaryKeyConstraint("id", name="site_pkey"),
@ -1222,7 +1227,7 @@ class Site(db.Model):
return dify_config.APP_WEB_URL or request.url_root.rstrip("/") return dify_config.APP_WEB_URL or request.url_root.rstrip("/")
class ApiToken(db.Model): class ApiToken(Base):
__tablename__ = "api_tokens" __tablename__ = "api_tokens"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="api_token_pkey"), db.PrimaryKeyConstraint("id", name="api_token_pkey"),
@ -1249,7 +1254,7 @@ class ApiToken(db.Model):
return result return result
class UploadFile(db.Model): class UploadFile(Base):
__tablename__ = "upload_files" __tablename__ = "upload_files"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="upload_file_pkey"), db.PrimaryKeyConstraint("id", name="upload_file_pkey"),
@ -1273,7 +1278,7 @@ class UploadFile(db.Model):
hash = db.Column(db.String(255), nullable=True) hash = db.Column(db.String(255), nullable=True)
class ApiRequest(db.Model): class ApiRequest(Base):
__tablename__ = "api_requests" __tablename__ = "api_requests"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="api_request_pkey"), db.PrimaryKeyConstraint("id", name="api_request_pkey"),
@ -1290,7 +1295,7 @@ class ApiRequest(db.Model):
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
class MessageChain(db.Model): class MessageChain(Base):
__tablename__ = "message_chains" __tablename__ = "message_chains"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="message_chain_pkey"), db.PrimaryKeyConstraint("id", name="message_chain_pkey"),
@ -1395,7 +1400,7 @@ class MessageAgentThought(Base):
return {} return {}
@property @property
def tool_outputs_dict(self) -> dict: def tool_outputs_dict(self):
tools = self.tools tools = self.tools
try: try:
if self.observation: if self.observation:
@ -1417,7 +1422,7 @@ class MessageAgentThought(Base):
return dict.fromkeys(tools, self.observation) return dict.fromkeys(tools, self.observation)
class DatasetRetrieverResource(db.Model): class DatasetRetrieverResource(Base):
__tablename__ = "dataset_retriever_resources" __tablename__ = "dataset_retriever_resources"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_retriever_resource_pkey"), db.PrimaryKeyConstraint("id", name="dataset_retriever_resource_pkey"),
@ -1444,7 +1449,7 @@ class DatasetRetrieverResource(db.Model):
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
class Tag(db.Model): class Tag(Base):
__tablename__ = "tags" __tablename__ = "tags"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="tag_pkey"), db.PrimaryKeyConstraint("id", name="tag_pkey"),
@ -1462,7 +1467,7 @@ class Tag(db.Model):
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
class TagBinding(db.Model): class TagBinding(Base):
__tablename__ = "tag_bindings" __tablename__ = "tag_bindings"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="tag_binding_pkey"), db.PrimaryKeyConstraint("id", name="tag_binding_pkey"),
@ -1478,7 +1483,7 @@ class TagBinding(db.Model):
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
class TraceAppConfig(db.Model): class TraceAppConfig(Base):
__tablename__ = "trace_app_config" __tablename__ = "trace_app_config"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="tracing_app_config_pkey"), db.PrimaryKeyConstraint("id", name="tracing_app_config_pkey"),

View File

@ -1,6 +1,7 @@
from enum import Enum from enum import Enum
from extensions.ext_database import db from extensions.ext_database import db
from models.base import Base
from .types import StringUUID from .types import StringUUID
@ -35,7 +36,7 @@ class ProviderQuotaType(Enum):
raise ValueError(f"No matching enum found for value '{value}'") raise ValueError(f"No matching enum found for value '{value}'")
class Provider(db.Model): class Provider(Base):
""" """
Provider model representing the API providers and their configurations. Provider model representing the API providers and their configurations.
""" """
@ -88,7 +89,7 @@ class Provider(db.Model):
return self.is_valid and self.token_is_set return self.is_valid and self.token_is_set
class ProviderModel(db.Model): class ProviderModel(Base):
""" """
Provider model representing the API provider_models and their configurations. Provider model representing the API provider_models and their configurations.
""" """
@ -113,7 +114,7 @@ class ProviderModel(db.Model):
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
class TenantDefaultModel(db.Model): class TenantDefaultModel(Base):
__tablename__ = "tenant_default_models" __tablename__ = "tenant_default_models"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="tenant_default_model_pkey"), db.PrimaryKeyConstraint("id", name="tenant_default_model_pkey"),
@ -129,7 +130,7 @@ class TenantDefaultModel(db.Model):
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
class TenantPreferredModelProvider(db.Model): class TenantPreferredModelProvider(Base):
__tablename__ = "tenant_preferred_model_providers" __tablename__ = "tenant_preferred_model_providers"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="tenant_preferred_model_provider_pkey"), db.PrimaryKeyConstraint("id", name="tenant_preferred_model_provider_pkey"),
@ -144,7 +145,7 @@ class TenantPreferredModelProvider(db.Model):
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
class ProviderOrder(db.Model): class ProviderOrder(Base):
__tablename__ = "provider_orders" __tablename__ = "provider_orders"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="provider_order_pkey"), db.PrimaryKeyConstraint("id", name="provider_order_pkey"),
@ -169,7 +170,7 @@ class ProviderOrder(db.Model):
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
class ProviderModelSetting(db.Model): class ProviderModelSetting(Base):
""" """
Provider model settings for record the model enabled status and load balancing status. Provider model settings for record the model enabled status and load balancing status.
""" """
@ -191,7 +192,7 @@ class ProviderModelSetting(db.Model):
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
class LoadBalancingModelConfig(db.Model): class LoadBalancingModelConfig(Base):
""" """
Configurations for load balancing models. Configurations for load balancing models.
""" """

View File

@ -3,11 +3,12 @@ import json
from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.dialects.postgresql import JSONB
from extensions.ext_database import db from extensions.ext_database import db
from models.base import Base
from .types import StringUUID from .types import StringUUID
class DataSourceOauthBinding(db.Model): class DataSourceOauthBinding(Base):
__tablename__ = "data_source_oauth_bindings" __tablename__ = "data_source_oauth_bindings"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="source_binding_pkey"), db.PrimaryKeyConstraint("id", name="source_binding_pkey"),
@ -25,7 +26,7 @@ class DataSourceOauthBinding(db.Model):
disabled = db.Column(db.Boolean, nullable=True, server_default=db.text("false")) disabled = db.Column(db.Boolean, nullable=True, server_default=db.text("false"))
class DataSourceApiKeyAuthBinding(db.Model): class DataSourceApiKeyAuthBinding(Base):
__tablename__ = "data_source_api_key_auth_bindings" __tablename__ = "data_source_api_key_auth_bindings"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="data_source_api_key_auth_binding_pkey"), db.PrimaryKeyConstraint("id", name="data_source_api_key_auth_binding_pkey"),

View File

@ -3,9 +3,10 @@ from datetime import datetime, timezone
from celery import states from celery import states
from extensions.ext_database import db from extensions.ext_database import db
from models.base import Base
class CeleryTask(db.Model): class CeleryTask(Base):
"""Task result/status.""" """Task result/status."""
__tablename__ = "celery_taskmeta" __tablename__ = "celery_taskmeta"
@ -29,7 +30,7 @@ class CeleryTask(db.Model):
queue = db.Column(db.String(155), nullable=True) queue = db.Column(db.String(155), nullable=True)
class CeleryTaskSet(db.Model): class CeleryTaskSet(Base):
"""TaskSet result.""" """TaskSet result."""
__tablename__ = "celery_tasksetmeta" __tablename__ = "celery_tasksetmeta"

View File

@ -2,6 +2,7 @@ import json
from enum import Enum from enum import Enum
from extensions.ext_database import db from extensions.ext_database import db
from models.base import Base
from .types import StringUUID from .types import StringUUID
@ -17,7 +18,7 @@ class ToolProviderName(Enum):
raise ValueError(f"No matching enum found for value '{value}'") raise ValueError(f"No matching enum found for value '{value}'")
class ToolProvider(db.Model): class ToolProvider(Base):
__tablename__ = "tool_providers" __tablename__ = "tool_providers"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="tool_provider_pkey"), db.PrimaryKeyConstraint("id", name="tool_provider_pkey"),

View File

@ -1,8 +1,11 @@
import json import json
from datetime import datetime from datetime import datetime
from deprecated import deprecated
from sqlalchemy import ForeignKey
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import Mapped, mapped_column
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration
from extensions.ext_database import db from extensions.ext_database import db
@ -31,7 +34,7 @@ class BuiltinToolProvider(Base):
# who created this tool provider # who created this tool provider
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False) user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# name of the tool provider # name of the tool provider
provider: Mapped[str] = mapped_column(db.String(40), nullable=False) provider: Mapped[str] = mapped_column(db.String(256), nullable=False)
# credential of the tool provider # credential of the tool provider
encrypted_credentials: Mapped[str] = mapped_column(db.Text, nullable=True) encrypted_credentials: Mapped[str] = mapped_column(db.Text, nullable=True)
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
@ -182,7 +185,7 @@ class WorkflowToolProvider(Base):
return db.session.query(App).filter(App.id == self.app_id).first() return db.session.query(App).filter(App.id == self.app_id).first()
class ToolModelInvoke(db.Model): class ToolModelInvoke(Base):
""" """
store the invoke logs from tool invoke store the invoke logs from tool invoke
""" """
@ -219,7 +222,7 @@ class ToolModelInvoke(db.Model):
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
class ToolConversationVariables(db.Model): class ToolConversationVariables(Base):
""" """
store the conversation variables from tool invoke store the conversation variables from tool invoke
""" """
@ -275,3 +278,46 @@ class ToolFile(Base):
mimetype: Mapped[str] = mapped_column(db.String(255), nullable=False) mimetype: Mapped[str] = mapped_column(db.String(255), nullable=False)
# original url # original url
original_url: Mapped[str] = mapped_column(db.String(2048), nullable=True) original_url: Mapped[str] = mapped_column(db.String(2048), nullable=True)
@deprecated
class DeprecatedPublishedAppTool(Base):
"""
The table stores the apps published as a tool for each person.
"""
__tablename__ = "tool_published_apps"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="published_app_tool_pkey"),
db.UniqueConstraint("app_id", "user_id", name="unique_published_app_tool"),
)
# id of the tool provider
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
# id of the app
app_id = db.Column(StringUUID, ForeignKey("apps.id"), nullable=False)
# who published this tool
user_id = db.Column(StringUUID, nullable=False)
# description of the tool, stored in i18n format, for human
description = db.Column(db.Text, nullable=False)
# llm_description of the tool, for LLM
llm_description = db.Column(db.Text, nullable=False)
# query description, query will be seem as a parameter of the tool,
# to describe this parameter to llm, we need this field
query_description = db.Column(db.Text, nullable=False)
# query name, the name of the query parameter
query_name = db.Column(db.String(40), nullable=False)
# name of the tool provider
tool_name = db.Column(db.String(40), nullable=False)
# author
author = db.Column(db.String(40), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
@property
def description_i18n(self) -> I18nObject:
return I18nObject(**json.loads(self.description))
@property
def app(self) -> App:
return db.session.query(App).filter(App.id == self.app_id).first()

View File

@ -1,10 +1,11 @@
from extensions.ext_database import db from extensions.ext_database import db
from models.base import Base
from .model import Message from .model import Message
from .types import StringUUID from .types import StringUUID
class SavedMessage(db.Model): class SavedMessage(Base):
__tablename__ = "saved_messages" __tablename__ = "saved_messages"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="saved_message_pkey"), db.PrimaryKeyConstraint("id", name="saved_message_pkey"),
@ -23,7 +24,7 @@ class SavedMessage(db.Model):
return db.session.query(Message).filter(Message.id == self.message_id).first() return db.session.query(Message).filter(Message.id == self.message_id).first()
class PinnedConversation(db.Model): class PinnedConversation(Base):
__tablename__ = "pinned_conversations" __tablename__ = "pinned_conversations"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="pinned_conversation_pkey"), db.PrimaryKeyConstraint("id", name="pinned_conversation_pkey"),

View File

@ -2,10 +2,13 @@ import json
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from typing import Any, Optional, Union from typing import TYPE_CHECKING, Any, Union
from sqlalchemy import func if TYPE_CHECKING:
from sqlalchemy.orm import Mapped from models.model import AppMode
from sqlalchemy import Index, PrimaryKeyConstraint, func
from sqlalchemy.orm import Mapped, mapped_column
import contexts import contexts
from constants import HIDDEN_VALUE from constants import HIDDEN_VALUE
@ -13,6 +16,7 @@ from core.app.segments import SecretVariable, Variable, factory
from core.helper import encrypter from core.helper import encrypter
from extensions.ext_database import db from extensions.ext_database import db
from libs import helper from libs import helper
from models.base import Base
from .account import Account from .account import Account
from .types import StringUUID from .types import StringUUID
@ -75,7 +79,7 @@ class WorkflowType(Enum):
return cls.WORKFLOW if app_mode == AppMode.WORKFLOW else cls.CHAT return cls.WORKFLOW if app_mode == AppMode.WORKFLOW else cls.CHAT
class Workflow(db.Model): class Workflow(Base):
""" """
Workflow, for `Workflow App` and `Chat App workflow mode`. Workflow, for `Workflow App` and `Chat App workflow mode`.
@ -345,7 +349,7 @@ class WorkflowRunStatus(Enum):
raise ValueError(f"invalid workflow run status value {value}") raise ValueError(f"invalid workflow run status value {value}")
class WorkflowRun(db.Model): class WorkflowRun(Base):
""" """
Workflow Run Workflow Run
@ -436,7 +440,7 @@ class WorkflowRun(db.Model):
return json.loads(self.outputs) if self.outputs else None return json.loads(self.outputs) if self.outputs else None
@property @property
def message(self) -> Optional["Message"]: def message(self):
from models.model import Message from models.model import Message
return ( return (
@ -542,7 +546,7 @@ class WorkflowNodeExecutionStatus(Enum):
raise ValueError(f"invalid workflow node execution status value {value}") raise ValueError(f"invalid workflow node execution status value {value}")
class WorkflowNodeExecution(db.Model): class WorkflowNodeExecution(Base):
""" """
Workflow Node Execution Workflow Node Execution
@ -708,7 +712,7 @@ class WorkflowAppLogCreatedFrom(Enum):
raise ValueError(f"invalid workflow app log created from value {value}") raise ValueError(f"invalid workflow app log created from value {value}")
class WorkflowAppLog(db.Model): class WorkflowAppLog(Base):
""" """
Workflow App execution log, excluding workflow debugging records. Workflow App execution log, excluding workflow debugging records.
@ -770,15 +774,20 @@ class WorkflowAppLog(db.Model):
return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None
class ConversationVariable(db.Model): class ConversationVariable(Base):
__tablename__ = "workflow_conversation_variables" __tablename__ = "workflow_conversation_variables"
__table_args__ = (
PrimaryKeyConstraint("id", "conversation_id", name="workflow_conversation_variables_pkey"),
Index("workflow__conversation_variables_app_id_idx", "app_id"),
Index("workflow__conversation_variables_created_at_idx", "created_at"),
)
id: Mapped[str] = db.Column(StringUUID, primary_key=True) id: Mapped[str] = mapped_column(StringUUID, primary_key=True)
conversation_id: Mapped[str] = db.Column(StringUUID, nullable=False, primary_key=True) conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False, primary_key=True)
app_id: Mapped[str] = db.Column(StringUUID, nullable=False, index=True) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
data = db.Column(db.Text, nullable=False) data = mapped_column(db.Text, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, index=True, server_default=db.text("CURRENT_TIMESTAMP(0)")) created_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = db.Column( updated_at = mapped_column(
db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
) )