diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py index 46b4ef5d87..b66e74aee0 100644 --- a/api/controllers/console/setup.py +++ b/api/controllers/console/setup.py @@ -6,7 +6,7 @@ from flask_restful import Resource, reqparse from configs import dify_config from libs.helper import StrLen, email, get_remote_ip from libs.password import valid_password -from models.model import DifySetup +from models.model import DifySetup, db from services.account_service import RegisterService, TenantService from . import api @@ -69,7 +69,7 @@ def setup_required(view): def get_setup_status(): if dify_config.EDITION == "SELF_HOSTED": - return DifySetup.query.first() + return db.session.query(DifySetup).first() else: return True diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 1b49103ced..14edc9ac13 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -610,16 +610,17 @@ class ToolLabelsApi(Resource): api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers") # builtin tool provider -api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin//tools") -api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin//delete") -api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin//update") +api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin//tools") +api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin//delete") +api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin//update") api.add_resource( - ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin//credentials" + ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin//credentials" ) api.add_resource( - ToolBuiltinProviderCredentialsSchemaApi, "/workspaces/current/tool-provider/builtin//credentials_schema" + ToolBuiltinProviderCredentialsSchemaApi, + "/workspaces/current/tool-provider/builtin//credentials_schema", ) -api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin//icon") +api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin//icon") # api tool provider api.add_resource(ToolApiProviderAddApi, "/workspaces/current/tool-provider/api/add") diff --git a/api/core/plugin/manager/tool.py b/api/core/plugin/manager/tool.py index 5981bcb55e..50970243a1 100644 --- a/api/core/plugin/manager/tool.py +++ b/api/core/plugin/manager/tool.py @@ -14,9 +14,9 @@ class PluginToolManager(BasePluginManager): provider follows format: plugin_id/provider_name """ if "/" in provider: - parts = provider.split("/", 1) - if len(parts) == 2: - return parts[0], parts[1] + parts = provider.split("/", -1) + if len(parts) >= 2: + return "/".join(parts[:-1]), parts[-1] raise ValueError(f"invalid provider format: {provider}") raise ValueError(f"invalid provider format: {provider}") @@ -46,6 +46,10 @@ class PluginToolManager(BasePluginManager): for provider in response: provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}" + # override the provider name for each tool to plugin_id/provider_name + for tool in provider.declaration.tools: + tool.identity.provider = provider.declaration.identity.name + return response def fetch_tool_provider(self, tenant_id: str, provider: str) -> PluginToolProviderEntity: @@ -54,15 +58,26 @@ class PluginToolManager(BasePluginManager): """ plugin_id, provider_name = self._split_provider(provider) + def transformer(json_response: dict[str, Any]) -> dict: + for tool in json_response.get("data", {}).get("declaration", {}).get("tools", []): + tool["identity"]["provider"] = provider_name + + return json_response + response = self._request_with_plugin_daemon_response( "GET", f"plugin/{tenant_id}/management/tool", PluginToolProviderEntity, params={"provider": provider_name, "plugin_id": plugin_id}, + transformer=transformer, ) response.declaration.identity.name = f"{response.plugin_id}/{response.declaration.identity.name}" + # override the provider name for each tool to plugin_id/provider_name + for tool in response.declaration.tools: + tool.identity.provider = response.declaration.identity.name + return response def invoke( diff --git a/api/core/tools/plugin_tool/provider.py b/api/core/tools/plugin_tool/provider.py index 4982e74056..b6758df2bb 100644 --- a/api/core/tools/plugin_tool/provider.py +++ b/api/core/tools/plugin_tool/provider.py @@ -11,12 +11,10 @@ from core.tools.plugin_tool.tool import PluginTool class PluginToolProviderController(BuiltinToolProviderController): entity: ToolProviderEntityWithPlugin tenant_id: str - plugin_id: str - def __init__(self, entity: ToolProviderEntityWithPlugin, tenant_id: str, plugin_id: str) -> None: + def __init__(self, entity: ToolProviderEntityWithPlugin, tenant_id: str) -> None: self.entity = entity self.tenant_id = tenant_id - self.plugin_id = plugin_id @property def provider_type(self) -> ToolProviderType: @@ -35,7 +33,6 @@ class PluginToolProviderController(BuiltinToolProviderController): if not manager.validate_provider_credentials( tenant_id=self.tenant_id, user_id=user_id, - plugin_id=self.plugin_id, provider=self.entity.identity.name, credentials=credentials, ): @@ -54,7 +51,6 @@ class PluginToolProviderController(BuiltinToolProviderController): entity=tool_entity, runtime=ToolRuntime(tenant_id=self.tenant_id), tenant_id=self.tenant_id, - plugin_id=self.plugin_id, ) def get_tools(self) -> list[PluginTool]: @@ -66,7 +62,6 @@ class PluginToolProviderController(BuiltinToolProviderController): entity=tool_entity, runtime=ToolRuntime(tenant_id=self.tenant_id), tenant_id=self.tenant_id, - plugin_id=self.plugin_id, ) for tool_entity in self.entity.tools ] diff --git a/api/core/tools/plugin_tool/tool.py b/api/core/tools/plugin_tool/tool.py index 7a4f147cd0..7c6c4de3e0 100644 --- a/api/core/tools/plugin_tool/tool.py +++ b/api/core/tools/plugin_tool/tool.py @@ -9,12 +9,10 @@ from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, Too class PluginTool(Tool): tenant_id: str - plugin_id: str - def __init__(self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str, plugin_id: str) -> None: + def __init__(self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str) -> None: super().__init__(entity, runtime) self.tenant_id = tenant_id - self.plugin_id = plugin_id @property def tool_provider_type(self) -> ToolProviderType: @@ -25,7 +23,6 @@ class PluginTool(Tool): return manager.invoke( tenant_id=self.tenant_id, user_id=user_id, - plugin_id=self.plugin_id, tool_provider=self.entity.identity.provider, tool_name=self.entity.identity.name, credentials=self.runtime.credentials, @@ -37,5 +34,4 @@ class PluginTool(Tool): entity=self.entity, runtime=runtime, tenant_id=self.tenant_id, - plugin_id=self.plugin_id, ) diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 0463f84817..fdd7e8385b 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -86,7 +86,6 @@ class ToolManager: return PluginToolProviderController( entity=provider_entity.declaration, tenant_id=tenant_id, - plugin_id=provider_entity.plugin_id, ) @classmethod @@ -158,12 +157,11 @@ class ToolManager: # decrypt the credentials credentials = builtin_provider.credentials - controller = cls.get_builtin_provider(provider_id, tenant_id) tool_configuration = ProviderConfigEncrypter( tenant_id=tenant_id, - config=controller.get_credentials_schema(), - provider_type=controller.provider_type.value, - provider_identity=controller.entity.identity.name, + config=provider_controller.get_credentials_schema(), + provider_type=provider_controller.provider_type.value, + provider_identity=provider_controller.entity.identity.name, ) decrypted_credentials = tool_configuration.decrypt(credentials) @@ -400,7 +398,6 @@ class ToolManager: PluginToolProviderController( entity=provider.declaration, tenant_id=tenant_id, - plugin_id=provider.plugin_id, ) for provider in provider_entities ] @@ -525,7 +522,7 @@ class ToolManager: ) if isinstance(provider, PluginToolProviderController): - result_providers[f"plugin_provider.{user_provider.name}.{provider.plugin_id}"] = user_provider + result_providers[f"plugin_provider.{user_provider.name}"] = user_provider else: result_providers[f"builtin_provider.{user_provider.name}"] = user_provider diff --git a/api/migrations/env.py b/api/migrations/env.py index ad3a122c04..a5d815dcfd 100644 --- a/api/migrations/env.py +++ b/api/migrations/env.py @@ -31,19 +31,16 @@ def get_engine_url(): # from myapp import mymodel # target_metadata = mymodel.Base.metadata config.set_main_option('sqlalchemy.url', get_engine_url()) -target_db = current_app.extensions['migrate'].db # other values from the config, defined by the needs of env.py, # can be acquired: # my_important_option = config.get_main_option("my_important_option") # ... etc. +from models.base import Base def get_metadata(): - if hasattr(target_db, 'metadatas'): - return target_db.metadatas[None] - return target_db.metadata - + return Base.metadata def include_object(object, name, type_, reflected, compare_to): if type_ == "foreign_key_constraint": diff --git a/api/migrations/versions/2024_09_29_0835-ddcc8bbef391_increase_max_length_of_builtin_tool_provider.py b/api/migrations/versions/2024_09_29_0835-ddcc8bbef391_increase_max_length_of_builtin_tool_provider.py new file mode 100644 index 0000000000..4b16fe7f31 --- /dev/null +++ b/api/migrations/versions/2024_09_29_0835-ddcc8bbef391_increase_max_length_of_builtin_tool_provider.py @@ -0,0 +1,39 @@ +"""increase max length of builtin tool provider + +Revision ID: ddcc8bbef391 +Revises: d57ba9ebb251 +Create Date: 2024-09-29 08:35:58.062698 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'ddcc8bbef391' +down_revision = 'd57ba9ebb251' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op: + batch_op.alter_column('provider', + existing_type=sa.VARCHAR(length=40), + type_=sa.String(length=256), + existing_nullable=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op: + batch_op.alter_column('provider', + existing_type=sa.String(length=256), + type_=sa.VARCHAR(length=40), + existing_nullable=False) + + # ### end Alembic commands ### diff --git a/api/models/base.py b/api/models/base.py index fa2b68a5d2..da4648efa6 100644 --- a/api/models/base.py +++ b/api/models/base.py @@ -1,5 +1,5 @@ -from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import declarative_base +from extensions.ext_database import metadata -class Base(DeclarativeBase): - pass +Base = declarative_base(metadata=metadata) diff --git a/api/models/model.py b/api/models/model.py index 2fad0f5409..660c7e0a36 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -2,12 +2,15 @@ import json import re import uuid from enum import Enum -from typing import Optional +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from models.workflow import Workflow from flask import request from flask_login import UserMixin -from sqlalchemy import Float, func, text -from sqlalchemy.orm import Mapped, mapped_column, relationship +from sqlalchemy import Float, Index, PrimaryKeyConstraint, func, text +from sqlalchemy.orm import Mapped, mapped_column from configs import dify_config from core.file.tool_file_parser import ToolFileParser @@ -20,7 +23,7 @@ from .account import Account, Tenant from .types import StringUUID -class DifySetup(db.Model): +class DifySetup(Base): __tablename__ = "dify_setups" __table_args__ = (db.PrimaryKeyConstraint("version", name="dify_setup_pkey"),) @@ -55,7 +58,7 @@ class IconType(Enum): EMOJI = "emoji" -class App(db.Model): +class App(Base): __tablename__ = "apps" __table_args__ = (db.PrimaryKeyConstraint("id", name="app_pkey"), db.Index("app_tenant_id_idx", "tenant_id")) @@ -133,7 +136,8 @@ class App(db.Model): return False if not app_model_config.agent_mode: return False - if self.app_model_config.agent_mode_dict.get("enabled", False) and self.app_model_config.agent_mode_dict.get( + + if app_model_config.agent_mode_dict.get("enabled", False) and app_model_config.agent_mode_dict.get( "strategy", "" ) in {"function_call", "react"}: self.mode = AppMode.AGENT_CHAT.value @@ -250,7 +254,7 @@ class AppModelConfig(Base): return app @property - def model_dict(self) -> dict: + def model_dict(self): return json.loads(self.model) if self.model else None @property @@ -284,6 +288,9 @@ class AppModelConfig(Base): ) if annotation_setting: collection_binding_detail = annotation_setting.collection_binding_detail + if not collection_binding_detail: + raise ValueError("Collection binding detail not found") + return { "id": annotation_setting.id, "enabled": True, @@ -314,7 +321,7 @@ class AppModelConfig(Base): return json.loads(self.external_data_tools) if self.external_data_tools else [] @property - def user_input_form_list(self) -> dict: + def user_input_form_list(self): return json.loads(self.user_input_form) if self.user_input_form else [] @property @@ -458,7 +465,7 @@ class AppModelConfig(Base): return new_app_model_config -class RecommendedApp(db.Model): +class RecommendedApp(Base): __tablename__ = "recommended_apps" __table_args__ = ( db.PrimaryKeyConstraint("id", name="recommended_app_pkey"), @@ -486,7 +493,7 @@ class RecommendedApp(db.Model): return app -class InstalledApp(db.Model): +class InstalledApp(Base): __tablename__ = "installed_apps" __table_args__ = ( db.PrimaryKeyConstraint("id", name="installed_app_pkey"), @@ -522,7 +529,7 @@ class Conversation(Base): db.Index("conversation_app_from_user_idx", "app_id", "from_source", "from_end_user_id"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) app_model_config_id = db.Column(StringUUID, nullable=True) model_provider = db.Column(db.String(255), nullable=True) @@ -546,10 +553,8 @@ class Conversation(Base): created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - messages: Mapped[list["Message"]] = relationship( - "Message", backref="conversation", lazy="select", passive_deletes="all" - ) - message_annotations: Mapped[list["MessageAnnotation"]] = relationship( + messages = db.relationship("Message", backref="conversation", lazy="select", passive_deletes="all") + message_annotations = db.relationship( "MessageAnnotation", backref="conversation", lazy="select", passive_deletes="all" ) @@ -578,7 +583,7 @@ class Conversation(Base): ) if not app_model_config: - raise ValueError("app config not found") + return {} model_config = app_model_config.to_dict() @@ -692,12 +697,12 @@ class Conversation(Base): class Message(Base): __tablename__ = "messages" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="message_pkey"), - db.Index("message_app_id_idx", "app_id", "created_at"), - db.Index("message_conversation_id_idx", "conversation_id"), - db.Index("message_end_user_idx", "app_id", "from_source", "from_end_user_id"), - db.Index("message_account_idx", "app_id", "from_source", "from_account_id"), - db.Index("message_workflow_run_id_idx", "conversation_id", "workflow_run_id"), + PrimaryKeyConstraint("id", name="message_pkey"), + Index("message_app_id_idx", "app_id", "created_at"), + Index("message_conversation_id_idx", "conversation_id"), + Index("message_end_user_idx", "app_id", "from_source", "from_end_user_id"), + Index("message_account_idx", "app_id", "from_source", "from_account_id"), + Index("message_workflow_run_id_idx", "conversation_id", "workflow_run_id"), ) id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) @@ -705,10 +710,10 @@ class Message(Base): model_provider = db.Column(db.String(255), nullable=True) model_id = db.Column(db.String(255), nullable=True) override_model_configs = db.Column(db.Text) - conversation_id: Mapped[str] = mapped_column(StringUUID, db.ForeignKey("conversations.id"), nullable=False) - inputs: Mapped[str] = mapped_column(db.JSON) - query: Mapped[str] = mapped_column(db.Text, nullable=False) - message: Mapped[str] = mapped_column(db.JSON, nullable=False) + conversation_id = db.Column(StringUUID, db.ForeignKey("conversations.id"), nullable=False) + inputs = db.Column(db.JSON) + query = db.Column(db.Text, nullable=False) + message = db.Column(db.JSON, nullable=False) message_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) message_unit_price = db.Column(db.Numeric(10, 4), nullable=False) message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) @@ -974,7 +979,7 @@ class Message(Base): ) -class MessageFeedback(db.Model): +class MessageFeedback(Base): __tablename__ = "message_feedbacks" __table_args__ = ( db.PrimaryKeyConstraint("id", name="message_feedback_pkey"), @@ -1009,15 +1014,15 @@ class MessageFile(Base): db.Index("message_file_created_by_idx", "created_by"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) - message_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - type: Mapped[str] = mapped_column(db.String(255), nullable=False) - transfer_method: Mapped[str] = mapped_column(db.String(255), nullable=False) - url: Mapped[str] = mapped_column(db.Text, nullable=True) - belongs_to: Mapped[str] = mapped_column(db.String(255), nullable=True) - upload_file_id: Mapped[str] = mapped_column(StringUUID, nullable=True) - created_by_role: Mapped[str] = mapped_column(db.String(255), nullable=False) - created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + message_id = db.Column(StringUUID, nullable=False) + type = db.Column(db.String(255), nullable=False) + transfer_method = db.Column(db.String(255), nullable=False) + url = db.Column(db.Text, nullable=True) + belongs_to = db.Column(db.String(255), nullable=True) + upload_file_id = db.Column(StringUUID, nullable=True) + created_by_role = db.Column(db.String(255), nullable=False) + created_by = db.Column(StringUUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @@ -1032,7 +1037,7 @@ class MessageAnnotation(Base): id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) - conversation_id: Mapped[str] = mapped_column(StringUUID, db.ForeignKey("conversations.id"), nullable=True) + conversation_id = db.Column(StringUUID, db.ForeignKey("conversations.id"), nullable=True) message_id = db.Column(StringUUID, nullable=True) question = db.Column(db.Text, nullable=True) content = db.Column(db.Text, nullable=False) @@ -1052,7 +1057,7 @@ class MessageAnnotation(Base): return account -class AppAnnotationHitHistory(db.Model): +class AppAnnotationHitHistory(Base): __tablename__ = "app_annotation_hit_histories" __table_args__ = ( db.PrimaryKeyConstraint("id", name="app_annotation_hit_histories_pkey"), @@ -1090,7 +1095,7 @@ class AppAnnotationHitHistory(db.Model): return account -class AppAnnotationSetting(db.Model): +class AppAnnotationSetting(Base): __tablename__ = "app_annotation_settings" __table_args__ = ( db.PrimaryKeyConstraint("id", name="app_annotation_settings_pkey"), @@ -1138,7 +1143,7 @@ class AppAnnotationSetting(db.Model): return collection_binding_detail -class OperationLog(db.Model): +class OperationLog(Base): __tablename__ = "operation_logs" __table_args__ = ( db.PrimaryKeyConstraint("id", name="operation_log_pkey"), @@ -1155,7 +1160,7 @@ class OperationLog(db.Model): updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) -class EndUser(UserMixin, db.Model): +class EndUser(UserMixin, Base): __tablename__ = "end_users" __table_args__ = ( db.PrimaryKeyConstraint("id", name="end_user_pkey"), @@ -1175,7 +1180,7 @@ class EndUser(UserMixin, db.Model): updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) -class Site(db.Model): +class Site(Base): __tablename__ = "sites" __table_args__ = ( db.PrimaryKeyConstraint("id", name="site_pkey"), @@ -1222,7 +1227,7 @@ class Site(db.Model): return dify_config.APP_WEB_URL or request.url_root.rstrip("/") -class ApiToken(db.Model): +class ApiToken(Base): __tablename__ = "api_tokens" __table_args__ = ( db.PrimaryKeyConstraint("id", name="api_token_pkey"), @@ -1249,7 +1254,7 @@ class ApiToken(db.Model): return result -class UploadFile(db.Model): +class UploadFile(Base): __tablename__ = "upload_files" __table_args__ = ( db.PrimaryKeyConstraint("id", name="upload_file_pkey"), @@ -1273,7 +1278,7 @@ class UploadFile(db.Model): hash = db.Column(db.String(255), nullable=True) -class ApiRequest(db.Model): +class ApiRequest(Base): __tablename__ = "api_requests" __table_args__ = ( db.PrimaryKeyConstraint("id", name="api_request_pkey"), @@ -1290,7 +1295,7 @@ class ApiRequest(db.Model): created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) -class MessageChain(db.Model): +class MessageChain(Base): __tablename__ = "message_chains" __table_args__ = ( db.PrimaryKeyConstraint("id", name="message_chain_pkey"), @@ -1395,7 +1400,7 @@ class MessageAgentThought(Base): return {} @property - def tool_outputs_dict(self) -> dict: + def tool_outputs_dict(self): tools = self.tools try: if self.observation: @@ -1417,7 +1422,7 @@ class MessageAgentThought(Base): return dict.fromkeys(tools, self.observation) -class DatasetRetrieverResource(db.Model): +class DatasetRetrieverResource(Base): __tablename__ = "dataset_retriever_resources" __table_args__ = ( db.PrimaryKeyConstraint("id", name="dataset_retriever_resource_pkey"), @@ -1444,7 +1449,7 @@ class DatasetRetrieverResource(db.Model): created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) -class Tag(db.Model): +class Tag(Base): __tablename__ = "tags" __table_args__ = ( db.PrimaryKeyConstraint("id", name="tag_pkey"), @@ -1462,7 +1467,7 @@ class Tag(db.Model): created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) -class TagBinding(db.Model): +class TagBinding(Base): __tablename__ = "tag_bindings" __table_args__ = ( db.PrimaryKeyConstraint("id", name="tag_binding_pkey"), @@ -1478,7 +1483,7 @@ class TagBinding(db.Model): created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) -class TraceAppConfig(db.Model): +class TraceAppConfig(Base): __tablename__ = "trace_app_config" __table_args__ = ( db.PrimaryKeyConstraint("id", name="tracing_app_config_pkey"), diff --git a/api/models/provider.py b/api/models/provider.py index 644915e781..d3c6db9bab 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -1,6 +1,7 @@ from enum import Enum from extensions.ext_database import db +from models.base import Base from .types import StringUUID @@ -35,7 +36,7 @@ class ProviderQuotaType(Enum): raise ValueError(f"No matching enum found for value '{value}'") -class Provider(db.Model): +class Provider(Base): """ Provider model representing the API providers and their configurations. """ @@ -88,7 +89,7 @@ class Provider(db.Model): return self.is_valid and self.token_is_set -class ProviderModel(db.Model): +class ProviderModel(Base): """ Provider model representing the API provider_models and their configurations. """ @@ -113,7 +114,7 @@ class ProviderModel(db.Model): updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) -class TenantDefaultModel(db.Model): +class TenantDefaultModel(Base): __tablename__ = "tenant_default_models" __table_args__ = ( db.PrimaryKeyConstraint("id", name="tenant_default_model_pkey"), @@ -129,7 +130,7 @@ class TenantDefaultModel(db.Model): updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) -class TenantPreferredModelProvider(db.Model): +class TenantPreferredModelProvider(Base): __tablename__ = "tenant_preferred_model_providers" __table_args__ = ( db.PrimaryKeyConstraint("id", name="tenant_preferred_model_provider_pkey"), @@ -144,7 +145,7 @@ class TenantPreferredModelProvider(db.Model): updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) -class ProviderOrder(db.Model): +class ProviderOrder(Base): __tablename__ = "provider_orders" __table_args__ = ( db.PrimaryKeyConstraint("id", name="provider_order_pkey"), @@ -169,7 +170,7 @@ class ProviderOrder(db.Model): updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) -class ProviderModelSetting(db.Model): +class ProviderModelSetting(Base): """ Provider model settings for record the model enabled status and load balancing status. """ @@ -191,7 +192,7 @@ class ProviderModelSetting(db.Model): updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) -class LoadBalancingModelConfig(db.Model): +class LoadBalancingModelConfig(Base): """ Configurations for load balancing models. """ diff --git a/api/models/source.py b/api/models/source.py index 07695f06e6..efd94227d0 100644 --- a/api/models/source.py +++ b/api/models/source.py @@ -3,11 +3,12 @@ import json from sqlalchemy.dialects.postgresql import JSONB from extensions.ext_database import db +from models.base import Base from .types import StringUUID -class DataSourceOauthBinding(db.Model): +class DataSourceOauthBinding(Base): __tablename__ = "data_source_oauth_bindings" __table_args__ = ( db.PrimaryKeyConstraint("id", name="source_binding_pkey"), @@ -25,7 +26,7 @@ class DataSourceOauthBinding(db.Model): disabled = db.Column(db.Boolean, nullable=True, server_default=db.text("false")) -class DataSourceApiKeyAuthBinding(db.Model): +class DataSourceApiKeyAuthBinding(Base): __tablename__ = "data_source_api_key_auth_bindings" __table_args__ = ( db.PrimaryKeyConstraint("id", name="data_source_api_key_auth_binding_pkey"), diff --git a/api/models/task.py b/api/models/task.py index 57b147c78d..6fab2a72c2 100644 --- a/api/models/task.py +++ b/api/models/task.py @@ -3,9 +3,10 @@ from datetime import datetime, timezone from celery import states from extensions.ext_database import db +from models.base import Base -class CeleryTask(db.Model): +class CeleryTask(Base): """Task result/status.""" __tablename__ = "celery_taskmeta" @@ -29,7 +30,7 @@ class CeleryTask(db.Model): queue = db.Column(db.String(155), nullable=True) -class CeleryTaskSet(db.Model): +class CeleryTaskSet(Base): """TaskSet result.""" __tablename__ = "celery_tasksetmeta" diff --git a/api/models/tool.py b/api/models/tool.py index a81bb65174..d70c905851 100644 --- a/api/models/tool.py +++ b/api/models/tool.py @@ -2,6 +2,7 @@ import json from enum import Enum from extensions.ext_database import db +from models.base import Base from .types import StringUUID @@ -17,7 +18,7 @@ class ToolProviderName(Enum): raise ValueError(f"No matching enum found for value '{value}'") -class ToolProvider(db.Model): +class ToolProvider(Base): __tablename__ = "tool_providers" __table_args__ = ( db.PrimaryKeyConstraint("id", name="tool_provider_pkey"), diff --git a/api/models/tools.py b/api/models/tools.py index 485e16b228..1e99749b24 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -1,8 +1,11 @@ import json from datetime import datetime +from deprecated import deprecated +from sqlalchemy import ForeignKey from sqlalchemy.orm import Mapped, mapped_column +from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration from extensions.ext_database import db @@ -31,7 +34,7 @@ class BuiltinToolProvider(Base): # who created this tool provider user_id: Mapped[str] = mapped_column(StringUUID, nullable=False) # name of the tool provider - provider: Mapped[str] = mapped_column(db.String(40), nullable=False) + provider: Mapped[str] = mapped_column(db.String(256), nullable=False) # credential of the tool provider encrypted_credentials: Mapped[str] = mapped_column(db.Text, nullable=True) created_at: Mapped[datetime] = mapped_column( @@ -182,7 +185,7 @@ class WorkflowToolProvider(Base): return db.session.query(App).filter(App.id == self.app_id).first() -class ToolModelInvoke(db.Model): +class ToolModelInvoke(Base): """ store the invoke logs from tool invoke """ @@ -219,7 +222,7 @@ class ToolModelInvoke(db.Model): updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) -class ToolConversationVariables(db.Model): +class ToolConversationVariables(Base): """ store the conversation variables from tool invoke """ @@ -275,3 +278,46 @@ class ToolFile(Base): mimetype: Mapped[str] = mapped_column(db.String(255), nullable=False) # original url original_url: Mapped[str] = mapped_column(db.String(2048), nullable=True) + + +@deprecated +class DeprecatedPublishedAppTool(Base): + """ + The table stores the apps published as a tool for each person. + """ + + __tablename__ = "tool_published_apps" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="published_app_tool_pkey"), + db.UniqueConstraint("app_id", "user_id", name="unique_published_app_tool"), + ) + + # id of the tool provider + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + # id of the app + app_id = db.Column(StringUUID, ForeignKey("apps.id"), nullable=False) + # who published this tool + user_id = db.Column(StringUUID, nullable=False) + # description of the tool, stored in i18n format, for human + description = db.Column(db.Text, nullable=False) + # llm_description of the tool, for LLM + llm_description = db.Column(db.Text, nullable=False) + # query description, query will be seem as a parameter of the tool, + # to describe this parameter to llm, we need this field + query_description = db.Column(db.Text, nullable=False) + # query name, the name of the query parameter + query_name = db.Column(db.String(40), nullable=False) + # name of the tool provider + tool_name = db.Column(db.String(40), nullable=False) + # author + author = db.Column(db.String(40), nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + + @property + def description_i18n(self) -> I18nObject: + return I18nObject(**json.loads(self.description)) + + @property + def app(self) -> App: + return db.session.query(App).filter(App.id == self.app_id).first() diff --git a/api/models/web.py b/api/models/web.py index bc088c185d..934008a443 100644 --- a/api/models/web.py +++ b/api/models/web.py @@ -1,10 +1,11 @@ from extensions.ext_database import db +from models.base import Base from .model import Message from .types import StringUUID -class SavedMessage(db.Model): +class SavedMessage(Base): __tablename__ = "saved_messages" __table_args__ = ( db.PrimaryKeyConstraint("id", name="saved_message_pkey"), @@ -23,7 +24,7 @@ class SavedMessage(db.Model): return db.session.query(Message).filter(Message.id == self.message_id).first() -class PinnedConversation(db.Model): +class PinnedConversation(Base): __tablename__ = "pinned_conversations" __table_args__ = ( db.PrimaryKeyConstraint("id", name="pinned_conversation_pkey"), diff --git a/api/models/workflow.py b/api/models/workflow.py index 9c93ea4cea..0b7d255954 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -2,10 +2,13 @@ import json from collections.abc import Mapping, Sequence from datetime import datetime from enum import Enum -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Any, Union -from sqlalchemy import func -from sqlalchemy.orm import Mapped +if TYPE_CHECKING: + from models.model import AppMode + +from sqlalchemy import Index, PrimaryKeyConstraint, func +from sqlalchemy.orm import Mapped, mapped_column import contexts from constants import HIDDEN_VALUE @@ -13,6 +16,7 @@ from core.app.segments import SecretVariable, Variable, factory from core.helper import encrypter from extensions.ext_database import db from libs import helper +from models.base import Base from .account import Account from .types import StringUUID @@ -75,7 +79,7 @@ class WorkflowType(Enum): return cls.WORKFLOW if app_mode == AppMode.WORKFLOW else cls.CHAT -class Workflow(db.Model): +class Workflow(Base): """ Workflow, for `Workflow App` and `Chat App workflow mode`. @@ -345,7 +349,7 @@ class WorkflowRunStatus(Enum): raise ValueError(f"invalid workflow run status value {value}") -class WorkflowRun(db.Model): +class WorkflowRun(Base): """ Workflow Run @@ -436,7 +440,7 @@ class WorkflowRun(db.Model): return json.loads(self.outputs) if self.outputs else None @property - def message(self) -> Optional["Message"]: + def message(self): from models.model import Message return ( @@ -542,7 +546,7 @@ class WorkflowNodeExecutionStatus(Enum): raise ValueError(f"invalid workflow node execution status value {value}") -class WorkflowNodeExecution(db.Model): +class WorkflowNodeExecution(Base): """ Workflow Node Execution @@ -708,7 +712,7 @@ class WorkflowAppLogCreatedFrom(Enum): raise ValueError(f"invalid workflow app log created from value {value}") -class WorkflowAppLog(db.Model): +class WorkflowAppLog(Base): """ Workflow App execution log, excluding workflow debugging records. @@ -770,15 +774,20 @@ class WorkflowAppLog(db.Model): return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None -class ConversationVariable(db.Model): +class ConversationVariable(Base): __tablename__ = "workflow_conversation_variables" + __table_args__ = ( + PrimaryKeyConstraint("id", "conversation_id", name="workflow_conversation_variables_pkey"), + Index("workflow__conversation_variables_app_id_idx", "app_id"), + Index("workflow__conversation_variables_created_at_idx", "created_at"), + ) - id: Mapped[str] = db.Column(StringUUID, primary_key=True) - conversation_id: Mapped[str] = db.Column(StringUUID, nullable=False, primary_key=True) - app_id: Mapped[str] = db.Column(StringUUID, nullable=False, index=True) - data = db.Column(db.Text, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, index=True, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column( + id: Mapped[str] = mapped_column(StringUUID, primary_key=True) + conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False, primary_key=True) + app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + data = mapped_column(db.Text, nullable=False) + created_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = mapped_column( db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() )