From 2799aaa0b0de7390cf3d207a642736dfb0d781ab Mon Sep 17 00:00:00 2001 From: He Wang Date: Mon, 10 Mar 2025 21:27:39 +0800 Subject: [PATCH] make flask migration models compatible to both postgresql and mysql --- api/.env.example | 5 +- api/configs/middleware/__init__.py | 6 +- api/models/account.py | 22 +++---- api/models/api_based_extension.py | 4 +- api/models/dataset.py | 97 +++++++++++++++--------------- api/models/model.py | 75 +++++++++++------------ api/models/provider.py | 28 ++++----- api/models/source.py | 11 ++-- api/models/tools.py | 24 ++++---- api/models/types.py | 65 ++++++++++++++++++-- api/models/web.py | 10 +-- api/models/workflow.py | 32 +++++----- docker/.env.example | 6 +- 13 files changed, 225 insertions(+), 160 deletions(-) diff --git a/api/.env.example b/api/.env.example index 151ed14120..4d39b27eb1 100644 --- a/api/.env.example +++ b/api/.env.example @@ -50,13 +50,16 @@ REDIS_USE_CLUSTERS=false REDIS_CLUSTERS= REDIS_CLUSTERS_PASSWORD= -# PostgreSQL database configuration +# Database configuration, use postgresql by default DB_USERNAME=postgres DB_PASSWORD=difyai123456 DB_HOST=localhost DB_PORT=5432 DB_DATABASE=dify +# Database URI scheme +SQLALCHEMY_DATABASE_URI_SCHEME=postgresql + # Storage configuration # use for store upload files, private keys... # storage type: opendal, s3, aliyun-oss, azure-blob, baidu-obs, google-storage, huawei-obs, oci-storage, tencent-cos, volcengine-tos, supabase diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index 3bd638bc74..c7df8207ec 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -175,13 +175,15 @@ class DatabaseConfig(BaseSettings): @computed_field def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]: - return { + options = { "pool_size": self.SQLALCHEMY_POOL_SIZE, "max_overflow": self.SQLALCHEMY_MAX_OVERFLOW, "pool_recycle": self.SQLALCHEMY_POOL_RECYCLE, "pool_pre_ping": self.SQLALCHEMY_POOL_PRE_PING, - "connect_args": {"options": "-c timezone=UTC"}, } + if self.SQLALCHEMY_DATABASE_URI_SCHEME == "postgresql": + options["connect_args"] = {"options": "-c timezone=UTC"} + return options class CeleryConfig(DatabaseConfig): diff --git a/api/models/account.py b/api/models/account.py index a0b8957fe1..7b490958db 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -8,7 +8,7 @@ from sqlalchemy.orm import Mapped, mapped_column from models.base import Base from .engine import db -from .types import StringUUID +from .types import StringUUID, adjusted_text, uuid_default, varchar_default class AccountStatus(enum.StrEnum): @@ -23,7 +23,7 @@ class Account(UserMixin, Base): __tablename__ = "accounts" __table_args__ = (db.PrimaryKeyConstraint("id", name="account_pkey"), db.Index("account_email_idx", "email")) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, **uuid_default()) name = db.Column(db.String(255), nullable=False) email = db.Column(db.String(255), nullable=False) password = db.Column(db.String(255), nullable=True) @@ -35,7 +35,7 @@ class Account(UserMixin, Base): last_login_at = db.Column(db.DateTime) last_login_ip = db.Column(db.String(255)) last_active_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - status = db.Column(db.String(16), nullable=False, server_default=db.text("'active'::character varying")) + status = db.Column(db.String(16), nullable=False, **varchar_default("active")) initialized_at = db.Column(db.DateTime) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -195,12 +195,12 @@ class Tenant(db.Model): # type: ignore[name-defined] __tablename__ = "tenants" __table_args__ = (db.PrimaryKeyConstraint("id", name="tenant_pkey"),) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, **uuid_default()) name = db.Column(db.String(255), nullable=False) encrypt_public_key = db.Column(db.Text) - plan = db.Column(db.String(255), nullable=False, server_default=db.text("'basic'::character varying")) - status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) - custom_config = db.Column(db.Text) + plan = db.Column(db.String(255), nullable=False, **varchar_default("basic")) + status = db.Column(db.String(255), nullable=False, **varchar_default("normal")) + custom_config = db.Column(adjusted_text()) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -229,7 +229,7 @@ class TenantAccountJoin(db.Model): # type: ignore[name-defined] db.UniqueConstraint("tenant_id", "account_id", name="unique_tenant_account_join"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, **uuid_default()) tenant_id = db.Column(StringUUID, nullable=False) account_id = db.Column(StringUUID, nullable=False) current = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) @@ -247,7 +247,7 @@ class AccountIntegrate(db.Model): # type: ignore[name-defined] db.UniqueConstraint("provider", "open_id", name="unique_provider_open_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, **uuid_default()) account_id = db.Column(StringUUID, nullable=False) provider = db.Column(db.String(16), nullable=False) open_id = db.Column(db.String(255), nullable=False) @@ -267,7 +267,7 @@ class InvitationCode(db.Model): # type: ignore[name-defined] id = db.Column(db.Integer, nullable=False) batch = db.Column(db.String(255), nullable=False) code = db.Column(db.String(32), nullable=False) - status = db.Column(db.String(16), nullable=False, server_default=db.text("'unused'::character varying")) + status = db.Column(db.String(16), nullable=False, **varchar_default("unused")) used_at = db.Column(db.DateTime) used_by_tenant_id = db.Column(StringUUID) used_by_account_id = db.Column(StringUUID) @@ -292,7 +292,7 @@ class TenantPluginPermission(Base): db.UniqueConstraint("tenant_id", name="unique_tenant_plugin"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, **uuid_default()) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) install_permission: Mapped[InstallPermission] = mapped_column( db.String(16), nullable=False, server_default="everyone" diff --git a/api/models/api_based_extension.py b/api/models/api_based_extension.py index 6b6d808710..a527721219 100644 --- a/api/models/api_based_extension.py +++ b/api/models/api_based_extension.py @@ -3,7 +3,7 @@ import enum from sqlalchemy import func from .engine import db -from .types import StringUUID +from .types import StringUUID, uuid_default class APIBasedExtensionPoint(enum.Enum): @@ -20,7 +20,7 @@ class APIBasedExtension(db.Model): # type: ignore[name-defined] db.Index("api_based_extension_tenant_idx", "tenant_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, **uuid_default()) tenant_id = db.Column(StringUUID, nullable=False) name = db.Column(db.String(255), nullable=False) api_endpoint = db.Column(db.String(255), nullable=False) diff --git a/api/models/dataset.py b/api/models/dataset.py index 28589eb8c1..7d326403f5 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -12,7 +12,6 @@ from json import JSONDecodeError from typing import Any, cast from sqlalchemy import func -from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import Mapped from configs import dify_config @@ -24,7 +23,15 @@ from services.entities.knowledge_entities.knowledge_entities import ParentMode, from .account import Account from .engine import db from .model import App, Tag, TagBinding, UploadFile -from .types import StringUUID +from .types import ( + StringUUID, + adjusted_json_index, + adjusted_jsonb, + adjusted_text, + no_length_string, + uuid_default, + varchar_default, +) class DatasetPermissionEnum(enum.StrEnum): @@ -38,21 +45,21 @@ class Dataset(db.Model): # type: ignore[name-defined] __table_args__ = ( db.PrimaryKeyConstraint("id", name="dataset_pkey"), db.Index("dataset_tenant_idx", "tenant_id"), - db.Index("retrieval_model_idx", "retrieval_model", postgresql_using="gin"), + adjusted_json_index("retrieval_model_idx", "retrieval_model"), ) INDEXING_TECHNIQUE_LIST = ["high_quality", "economy", None] PROVIDER_LIST = ["vendor", "external", None] - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, **uuid_default()) tenant_id = db.Column(StringUUID, nullable=False) name = db.Column(db.String(255), nullable=False) description = db.Column(db.Text, nullable=True) - provider = db.Column(db.String(255), nullable=False, server_default=db.text("'vendor'::character varying")) - permission = db.Column(db.String(255), nullable=False, server_default=db.text("'only_me'::character varying")) + provider = db.Column(db.String(255), nullable=False, **varchar_default("vendor")) + permission = db.Column(db.String(255), nullable=False, **varchar_default("only_me")) data_source_type = db.Column(db.String(255)) indexing_technique = db.Column(db.String(255), nullable=True) - index_struct = db.Column(db.Text, nullable=True) + index_struct = db.Column(adjusted_text(), nullable=True) created_by = db.Column(StringUUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = db.Column(StringUUID, nullable=True) @@ -60,7 +67,7 @@ class Dataset(db.Model): # type: ignore[name-defined] embedding_model = db.Column(db.String(255), nullable=True) embedding_model_provider = db.Column(db.String(255), nullable=True) collection_binding_id = db.Column(StringUUID, nullable=True) - retrieval_model = db.Column(JSONB, nullable=True) + retrieval_model = db.Column(adjusted_jsonb(), nullable=True) built_in_field_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) @property @@ -262,9 +269,9 @@ class DatasetProcessRule(db.Model): # type: ignore[name-defined] db.Index("dataset_process_rule_dataset_id_idx", "dataset_id"), ) - id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, nullable=False, **uuid_default()) dataset_id = db.Column(StringUUID, nullable=False) - mode = db.Column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying")) + mode = db.Column(db.String(255), nullable=False, **varchar_default("automatic")) rules = db.Column(db.Text, nullable=True) created_by = db.Column(StringUUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -302,16 +309,16 @@ class Document(db.Model): # type: ignore[name-defined] db.Index("document_dataset_id_idx", "dataset_id"), db.Index("document_is_paused_idx", "is_paused"), db.Index("document_tenant_idx", "tenant_id"), - db.Index("document_metadata_idx", "doc_metadata", postgresql_using="gin"), + adjusted_json_index("document_metadata_idx", "doc_metadata"), ) # initial fields - id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, nullable=False, **uuid_default()) tenant_id = db.Column(StringUUID, nullable=False) dataset_id = db.Column(StringUUID, nullable=False) position = db.Column(db.Integer, nullable=False) data_source_type = db.Column(db.String(255), nullable=False) - data_source_info = db.Column(db.Text, nullable=True) + data_source_info = db.Column(adjusted_text(), nullable=True) dataset_process_rule_id = db.Column(StringUUID, nullable=True) batch = db.Column(db.String(255), nullable=False) name = db.Column(db.String(255), nullable=False) @@ -349,7 +356,7 @@ class Document(db.Model): # type: ignore[name-defined] stopped_at = db.Column(db.DateTime, nullable=True) # basic fields - indexing_status = db.Column(db.String(255), nullable=False, server_default=db.text("'waiting'::character varying")) + indexing_status = db.Column(db.String(255), nullable=False, **varchar_default("waiting")) enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) disabled_at = db.Column(db.DateTime, nullable=True) disabled_by = db.Column(StringUUID, nullable=True) @@ -359,8 +366,8 @@ class Document(db.Model): # type: ignore[name-defined] archived_at = db.Column(db.DateTime, nullable=True) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) doc_type = db.Column(db.String(40), nullable=True) - doc_metadata = db.Column(JSONB, nullable=True) - doc_form = db.Column(db.String(255), nullable=False, server_default=db.text("'text_model'::character varying")) + doc_metadata = db.Column(adjusted_jsonb(), nullable=True) + doc_form = db.Column(db.String(255), nullable=False, **varchar_default("text_model")) doc_language = db.Column(db.String(255), nullable=True) DATA_SOURCES = ["upload_file", "notion_import", "website_crawl"] @@ -648,7 +655,7 @@ class DocumentSegment(db.Model): # type: ignore[name-defined] ) # initial fields - id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, nullable=False, **uuid_default()) tenant_id = db.Column(StringUUID, nullable=False) dataset_id = db.Column(StringUUID, nullable=False) document_id = db.Column(StringUUID, nullable=False) @@ -668,7 +675,7 @@ class DocumentSegment(db.Model): # type: ignore[name-defined] enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) disabled_at = db.Column(db.DateTime, nullable=True) disabled_by = db.Column(StringUUID, nullable=True) - status = db.Column(db.String(255), nullable=False, server_default=db.text("'waiting'::character varying")) + status = db.Column(db.String(255), nullable=False, **varchar_default("waiting")) created_by = db.Column(StringUUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = db.Column(StringUUID, nullable=True) @@ -777,7 +784,7 @@ class ChildChunk(db.Model): # type: ignore[name-defined] ) # initial fields - id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, nullable=False, **uuid_default()) tenant_id = db.Column(StringUUID, nullable=False) dataset_id = db.Column(StringUUID, nullable=False) document_id = db.Column(StringUUID, nullable=False) @@ -788,7 +795,7 @@ class ChildChunk(db.Model): # type: ignore[name-defined] # indexing fields index_node_id = db.Column(db.String(255), nullable=True) index_node_hash = db.Column(db.String(255), nullable=True) - type = db.Column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying")) + type = db.Column(db.String(255), nullable=False, **varchar_default("automatic")) created_by = db.Column(StringUUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) updated_by = db.Column(StringUUID, nullable=True) @@ -817,7 +824,7 @@ class AppDatasetJoin(db.Model): # type: ignore[name-defined] db.Index("app_dataset_join_app_dataset_idx", "dataset_id", "app_id"), ) - id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, primary_key=True, nullable=False, **uuid_default()) app_id = db.Column(StringUUID, nullable=False) dataset_id = db.Column(StringUUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) @@ -834,12 +841,12 @@ class DatasetQuery(db.Model): # type: ignore[name-defined] db.Index("dataset_query_dataset_id_idx", "dataset_id"), ) - id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, primary_key=True, nullable=False, **uuid_default()) dataset_id = db.Column(StringUUID, nullable=False) content = db.Column(db.Text, nullable=False) source = db.Column(db.String(255), nullable=False) source_app_id = db.Column(StringUUID, nullable=True) - created_by_role = db.Column(db.String, nullable=False) + created_by_role = db.Column(no_length_string(), nullable=False) created_by = db.Column(StringUUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) @@ -851,12 +858,10 @@ class DatasetKeywordTable(db.Model): # type: ignore[name-defined] db.Index("dataset_keyword_table_dataset_id_idx", "dataset_id"), ) - id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, primary_key=True, **uuid_default()) dataset_id = db.Column(StringUUID, nullable=False, unique=True) - keyword_table = db.Column(db.Text, nullable=False) - data_source_type = db.Column( - db.String(255), nullable=False, server_default=db.text("'database'::character varying") - ) + keyword_table = db.Column(adjusted_text(), nullable=False) + data_source_type = db.Column(db.String(255), nullable=False, **varchar_default("database")) @property def keyword_table_dict(self): @@ -897,14 +902,12 @@ class Embedding(db.Model): # type: ignore[name-defined] db.Index("created_at_idx", "created_at"), ) - id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) - model_name = db.Column( - db.String(255), nullable=False, server_default=db.text("'text-embedding-ada-002'::character varying") - ) + id = db.Column(StringUUID, primary_key=True, **uuid_default()) + model_name = db.Column(db.String(255), nullable=False, **varchar_default("text-embedding-ada-002")) hash = db.Column(db.String(64), nullable=False) embedding = db.Column(db.LargeBinary, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - provider_name = db.Column(db.String(255), nullable=False, server_default=db.text("''::character varying")) + provider_name = db.Column(db.String(255), nullable=False, **varchar_default("")) def set_embedding(self, embedding_data: list[float]): self.embedding = pickle.dumps(embedding_data, protocol=pickle.HIGHEST_PROTOCOL) @@ -920,10 +923,10 @@ class DatasetCollectionBinding(db.Model): # type: ignore[name-defined] db.Index("provider_model_name_idx", "provider_name", "model_name"), ) - id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, primary_key=True, **uuid_default()) provider_name = db.Column(db.String(255), nullable=False) model_name = db.Column(db.String(255), nullable=False) - type = db.Column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False) + type = db.Column(db.String(40), nullable=False, **varchar_default("dataset")) collection_name = db.Column(db.String(64), nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -937,12 +940,12 @@ class TidbAuthBinding(db.Model): # type: ignore[name-defined] db.Index("tidb_auth_bindings_created_at_idx", "created_at"), db.Index("tidb_auth_bindings_status_idx", "status"), ) - id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, primary_key=True, **uuid_default()) tenant_id = db.Column(StringUUID, nullable=True) cluster_id = db.Column(db.String(255), nullable=False) cluster_name = db.Column(db.String(255), nullable=False) active = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - status = db.Column(db.String(255), nullable=False, server_default=db.text("CREATING")) + status = db.Column(db.String(255), nullable=False, server_default=db.text("'CREATING'")) account = db.Column(db.String(255), nullable=False) password = db.Column(db.String(255), nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -954,7 +957,7 @@ class Whitelist(db.Model): # type: ignore[name-defined] db.PrimaryKeyConstraint("id", name="whitelists_pkey"), db.Index("whitelists_tenant_idx", "tenant_id"), ) - id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, primary_key=True, **uuid_default()) tenant_id = db.Column(StringUUID, nullable=True) category = db.Column(db.String(255), nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -969,7 +972,7 @@ class DatasetPermission(db.Model): # type: ignore[name-defined] db.Index("idx_dataset_permissions_tenant_id", "tenant_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"), primary_key=True) + id = db.Column(StringUUID, primary_key=True, **uuid_default()) dataset_id = db.Column(StringUUID, nullable=False) account_id = db.Column(StringUUID, nullable=False) tenant_id = db.Column(StringUUID, nullable=False) @@ -985,11 +988,11 @@ class ExternalKnowledgeApis(db.Model): # type: ignore[name-defined] db.Index("external_knowledge_apis_name_idx", "name"), ) - id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, nullable=False, **uuid_default()) name = db.Column(db.String(255), nullable=False) description = db.Column(db.String(255), nullable=False) tenant_id = db.Column(StringUUID, nullable=False) - settings = db.Column(db.Text, nullable=True) + settings = db.Column(adjusted_text(), nullable=True) created_by = db.Column(StringUUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = db.Column(StringUUID, nullable=True) @@ -1040,11 +1043,11 @@ class ExternalKnowledgeBindings(db.Model): # type: ignore[name-defined] db.Index("external_knowledge_bindings_external_knowledge_api_idx", "external_knowledge_api_id"), ) - id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, nullable=False, **uuid_default()) tenant_id = db.Column(StringUUID, nullable=False) external_knowledge_api_id = db.Column(StringUUID, nullable=False) dataset_id = db.Column(StringUUID, nullable=False) - external_knowledge_id = db.Column(db.Text, nullable=False) + external_knowledge_id = db.Column(StringUUID, nullable=False) created_by = db.Column(StringUUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = db.Column(StringUUID, nullable=True) @@ -1060,7 +1063,7 @@ class DatasetAutoDisableLog(db.Model): # type: ignore[name-defined] db.Index("dataset_auto_disable_log_created_atx", "created_at"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, **uuid_default()) tenant_id = db.Column(StringUUID, nullable=False) dataset_id = db.Column(StringUUID, nullable=False) document_id = db.Column(StringUUID, nullable=False) @@ -1076,7 +1079,7 @@ class RateLimitLog(db.Model): # type: ignore[name-defined] db.Index("rate_limit_log_operation_idx", "operation"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, **uuid_default()) tenant_id = db.Column(StringUUID, nullable=False) subscription_plan = db.Column(db.String(255), nullable=False) operation = db.Column(db.String(255), nullable=False) @@ -1091,7 +1094,7 @@ class DatasetMetadata(db.Model): # type: ignore[name-defined] db.Index("dataset_metadata_dataset_idx", "dataset_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, **uuid_default()) tenant_id = db.Column(StringUUID, nullable=False) dataset_id = db.Column(StringUUID, nullable=False) type = db.Column(db.String(255), nullable=False) @@ -1112,7 +1115,7 @@ class DatasetMetadataBinding(db.Model): # type: ignore[name-defined] db.Index("dataset_metadata_binding_document_idx", "document_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, **uuid_default()) tenant_id = db.Column(StringUUID, nullable=False) dataset_id = db.Column(StringUUID, nullable=False) metadata_id = db.Column(StringUUID, nullable=False) diff --git a/api/models/model.py b/api/models/model.py index 0a5256c335..d2b0eebfaf 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -1,6 +1,5 @@ import json import re -import uuid from collections.abc import Mapping from datetime import datetime from enum import Enum @@ -33,7 +32,7 @@ from models.workflow import WorkflowRunStatus from .account import Account, Tenant from .engine import db -from .types import StringUUID +from .types import StringUUID, no_length_string, text_default, uuid_default, varchar_default if TYPE_CHECKING: from .workflow import Workflow @@ -78,17 +77,17 @@ class App(Base): __tablename__ = "apps" __table_args__ = (db.PrimaryKeyConstraint("id", name="app_pkey"), db.Index("app_tenant_id_idx", "tenant_id")) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, **uuid_default()) tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False) name = db.Column(db.String(255), nullable=False) - description = db.Column(db.Text, nullable=False, server_default=db.text("''::character varying")) + description = db.Column(db.Text, nullable=False, **varchar_default("")) mode: Mapped[str] = mapped_column(db.String(255), nullable=False) icon_type = db.Column(db.String(255), nullable=True) # image, emoji icon = db.Column(db.String(255)) icon_background = db.Column(db.String(255)) app_model_config_id = db.Column(StringUUID, nullable=True) workflow_id = db.Column(StringUUID, nullable=True) - status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) + status = db.Column(db.String(255), nullable=False, **varchar_default("normal")) enable_site = db.Column(db.Boolean, nullable=False) enable_api = db.Column(db.Boolean, nullable=False) api_rpm = db.Column(db.Integer, nullable=False, server_default=db.text("0")) @@ -301,7 +300,7 @@ class AppModelConfig(Base): __tablename__ = "app_model_configs" __table_args__ = (db.PrimaryKeyConstraint("id", name="app_model_config_pkey"), db.Index("app_app_id_idx", "app_id")) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, **uuid_default()) app_id = db.Column(StringUUID, nullable=False) provider = db.Column(db.String(255), nullable=True) model_id = db.Column(db.String(255), nullable=True) @@ -323,7 +322,7 @@ class AppModelConfig(Base): agent_mode = db.Column(db.Text) sensitive_word_avoidance = db.Column(db.Text) retriever_resource = db.Column(db.Text) - prompt_type = db.Column(db.String(255), nullable=False, server_default=db.text("'simple'::character varying")) + prompt_type = db.Column(db.String(255), nullable=False, **varchar_default("simple")) chat_prompt_config = db.Column(db.Text) completion_prompt_config = db.Column(db.Text) dataset_configs = db.Column(db.Text) @@ -555,7 +554,7 @@ class RecommendedApp(Base): db.Index("recommended_app_is_listed_idx", "is_listed", "language"), ) - id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, primary_key=True, **uuid_default()) app_id = db.Column(StringUUID, nullable=False) description = db.Column(db.JSON, nullable=False) copyright = db.Column(db.String(255), nullable=False) @@ -565,7 +564,7 @@ class RecommendedApp(Base): position = db.Column(db.Integer, nullable=False, default=0) is_listed = db.Column(db.Boolean, nullable=False, default=True) install_count = db.Column(db.Integer, nullable=False, default=0) - language = db.Column(db.String(255), nullable=False, server_default=db.text("'en-US'::character varying")) + language = db.Column(db.String(255), nullable=False, **varchar_default("en-US")) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -584,7 +583,7 @@ class InstalledApp(Base): db.UniqueConstraint("tenant_id", "app_id", name="unique_tenant_app"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, **uuid_default()) tenant_id = db.Column(StringUUID, nullable=False) app_id = db.Column(StringUUID, nullable=False) app_owner_tenant_id = db.Column(StringUUID, nullable=False) @@ -611,7 +610,7 @@ class Conversation(db.Model): # type: ignore[name-defined] db.Index("conversation_app_from_user_idx", "app_id", "from_source", "from_end_user_id"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, **uuid_default()) app_id = db.Column(StringUUID, nullable=False) app_model_config_id = db.Column(StringUUID, nullable=True) model_provider = db.Column(db.String(255), nullable=True) @@ -851,7 +850,7 @@ class Message(db.Model): # type: ignore[name-defined] Index("message_created_at_idx", "created_at"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, **uuid_default()) app_id = db.Column(StringUUID, nullable=False) model_provider = db.Column(db.String(255), nullable=True) model_id = db.Column(db.String(255), nullable=True) @@ -871,7 +870,7 @@ class Message(db.Model): # type: ignore[name-defined] provider_response_latency = db.Column(db.Float, nullable=False, server_default=db.text("0")) total_price = db.Column(db.Numeric(10, 7)) currency = db.Column(db.String(255), nullable=False) - status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) + status = db.Column(db.String(255), nullable=False, **varchar_default("normal")) error = db.Column(db.Text) message_metadata = db.Column(db.Text) invoke_from: Mapped[Optional[str]] = db.Column(db.String(255), nullable=True) @@ -1200,7 +1199,7 @@ class MessageFeedback(db.Model): # type: ignore[name-defined] db.Index("message_feedback_conversation_idx", "conversation_id", "from_source", "rating"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, **uuid_default()) app_id = db.Column(StringUUID, nullable=False) conversation_id = db.Column(StringUUID, nullable=False) message_id = db.Column(StringUUID, nullable=False) @@ -1247,7 +1246,7 @@ class MessageFile(db.Model): # type: ignore[name-defined] self.created_by_role = created_by_role.value self.created_by = created_by - id: Mapped[str] = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = db.Column(StringUUID, **uuid_default()) message_id: Mapped[str] = db.Column(StringUUID, nullable=False) type: Mapped[str] = db.Column(db.String(255), nullable=False) transfer_method: Mapped[str] = db.Column(db.String(255), nullable=False) @@ -1268,7 +1267,7 @@ class MessageAnnotation(db.Model): # type: ignore[name-defined] db.Index("message_annotation_message_idx", "message_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, **uuid_default()) app_id = db.Column(StringUUID, nullable=False) conversation_id = db.Column(StringUUID, db.ForeignKey("conversations.id"), nullable=True) message_id = db.Column(StringUUID, nullable=True) @@ -1300,7 +1299,7 @@ class AppAnnotationHitHistory(db.Model): # type: ignore[name-defined] db.Index("app_annotation_hit_histories_message_idx", "message_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, **uuid_default()) app_id = db.Column(StringUUID, nullable=False) annotation_id = db.Column(StringUUID, nullable=False) source = db.Column(db.Text, nullable=False) @@ -1335,7 +1334,7 @@ class AppAnnotationSetting(db.Model): # type: ignore[name-defined] db.Index("app_annotation_settings_app_idx", "app_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, **uuid_default()) app_id = db.Column(StringUUID, nullable=False) score_threshold = db.Column(Float, nullable=False, server_default=db.text("0")) collection_binding_id = db.Column(StringUUID, nullable=False) @@ -1383,7 +1382,7 @@ class OperationLog(Base): db.Index("operation_log_account_action_idx", "tenant_id", "account_id", "action"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, **uuid_default()) tenant_id = db.Column(StringUUID, nullable=False) account_id = db.Column(StringUUID, nullable=False) action = db.Column(db.String(255), nullable=False) @@ -1401,14 +1400,14 @@ class EndUser(Base, UserMixin): db.Index("end_user_tenant_session_id_idx", "tenant_id", "session_id", "type"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, **uuid_default()) tenant_id = db.Column(StringUUID, nullable=False) app_id = db.Column(StringUUID, nullable=True) type = db.Column(db.String(255), nullable=False) external_user_id = db.Column(db.String(255), nullable=True) name = db.Column(db.String(255)) is_anonymous = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) - session_id: Mapped[str] = mapped_column() + session_id: Mapped[str] = mapped_column(db.String(255), nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -1421,7 +1420,7 @@ class Site(Base): db.Index("site_code_idx", "code", "status"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, **uuid_default()) app_id = db.Column(StringUUID, nullable=False) title = db.Column(db.String(255), nullable=False) icon_type = db.Column(db.String(255), nullable=True) @@ -1439,7 +1438,7 @@ class Site(Base): customize_domain = db.Column(db.String(255)) customize_token_strategy = db.Column(db.String(255), nullable=False) prompt_public = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) + status = db.Column(db.String(255), nullable=False, **varchar_default("normal")) created_by = db.Column(StringUUID, nullable=True) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = db.Column(StringUUID, nullable=True) @@ -1479,7 +1478,7 @@ class ApiToken(Base): db.Index("api_token_tenant_idx", "tenant_id", "type"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, **uuid_default()) app_id = db.Column(StringUUID, nullable=True) tenant_id = db.Column(StringUUID, nullable=True) type = db.Column(db.String(16), nullable=False) @@ -1503,7 +1502,7 @@ class UploadFile(Base): db.Index("upload_file_tenant_idx", "tenant_id"), ) - id: Mapped[str] = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = db.Column(StringUUID, **uuid_default()) tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False) storage_type: Mapped[str] = db.Column(db.String(255), nullable=False) key: Mapped[str] = db.Column(db.String(255), nullable=False) @@ -1511,9 +1510,7 @@ class UploadFile(Base): size: Mapped[int] = db.Column(db.Integer, nullable=False) extension: Mapped[str] = db.Column(db.String(255), nullable=False) mime_type: Mapped[str] = db.Column(db.String(255), nullable=True) - created_by_role: Mapped[str] = db.Column( - db.String(255), nullable=False, server_default=db.text("'account'::character varying") - ) + created_by_role: Mapped[str] = db.Column(db.String(255), nullable=False, **varchar_default("account")) created_by: Mapped[str] = db.Column(StringUUID, nullable=False) created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) used: Mapped[bool] = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) @@ -1565,7 +1562,7 @@ class ApiRequest(Base): db.Index("api_request_token_idx", "tenant_id", "api_token_id"), ) - id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, nullable=False, **uuid_default()) tenant_id = db.Column(StringUUID, nullable=False) api_token_id = db.Column(StringUUID, nullable=False) path = db.Column(db.String(255), nullable=False) @@ -1582,7 +1579,7 @@ class MessageChain(Base): db.Index("message_chain_message_id_idx", "message_id"), ) - id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, nullable=False, **uuid_default()) message_id = db.Column(StringUUID, nullable=False) type = db.Column(db.String(255), nullable=False) input = db.Column(db.Text, nullable=True) @@ -1598,14 +1595,14 @@ class MessageAgentThought(Base): db.Index("message_agent_thought_message_chain_id_idx", "message_chain_id"), ) - id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, nullable=False, **uuid_default()) message_id = db.Column(StringUUID, nullable=False) message_chain_id = db.Column(StringUUID, nullable=True) position = db.Column(db.Integer, nullable=False) thought = db.Column(db.Text, nullable=True) tool = db.Column(db.Text, nullable=True) - tool_labels_str = db.Column(db.Text, nullable=False, server_default=db.text("'{}'::text")) - tool_meta_str = db.Column(db.Text, nullable=False, server_default=db.text("'{}'::text")) + tool_labels_str = db.Column(db.Text, nullable=False, **text_default("{}")) + tool_meta_str = db.Column(db.Text, nullable=False, **text_default("{}")) tool_input = db.Column(db.Text, nullable=True) observation = db.Column(db.Text, nullable=True) # plugin_id = db.Column(StringUUID, nullable=True) ## for future design @@ -1621,9 +1618,9 @@ class MessageAgentThought(Base): answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) tokens = db.Column(db.Integer, nullable=True) total_price = db.Column(db.Numeric, nullable=True) - currency = db.Column(db.String, nullable=True) + currency = db.Column(no_length_string(), nullable=True) latency = db.Column(db.Float, nullable=True) - created_by_role = db.Column(db.String, nullable=False) + created_by_role = db.Column(no_length_string(), nullable=False) created_by = db.Column(StringUUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) @@ -1711,7 +1708,7 @@ class DatasetRetrieverResource(Base): db.Index("dataset_retriever_resource_message_id_idx", "message_id"), ) - id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, nullable=False, **uuid_default()) message_id = db.Column(StringUUID, nullable=False) position = db.Column(db.Integer, nullable=False) dataset_id = db.Column(StringUUID, nullable=False) @@ -1741,7 +1738,7 @@ class Tag(Base): TAG_TYPE_LIST = ["knowledge", "app"] - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, **uuid_default()) tenant_id = db.Column(StringUUID, nullable=True) type = db.Column(db.String(16), nullable=False) name = db.Column(db.String(255), nullable=False) @@ -1757,7 +1754,7 @@ class TagBinding(Base): db.Index("tag_bind_tag_id_idx", "tag_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, **uuid_default()) tenant_id = db.Column(StringUUID, nullable=True) tag_id = db.Column(StringUUID, nullable=True) target_id = db.Column(StringUUID, nullable=True) @@ -1772,7 +1769,7 @@ class TraceAppConfig(Base): db.Index("trace_app_config_app_id_idx", "app_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, **uuid_default()) app_id = db.Column(StringUUID, nullable=False) tracing_provider = db.Column(db.String(255), nullable=True) tracing_config = db.Column(db.JSON, nullable=True) diff --git a/api/models/provider.py b/api/models/provider.py index 567400702d..581e83254d 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -5,7 +5,7 @@ from sqlalchemy import func from models.base import Base from .engine import db -from .types import StringUUID +from .types import StringUUID, adjusted_text, uuid_default, varchar_default class ProviderType(Enum): @@ -52,15 +52,15 @@ class Provider(Base): ), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, **uuid_default()) tenant_id = db.Column(StringUUID, nullable=False) provider_name = db.Column(db.String(255), nullable=False) - provider_type = db.Column(db.String(40), nullable=False, server_default=db.text("'custom'::character varying")) - encrypted_config = db.Column(db.Text, nullable=True) + provider_type = db.Column(db.String(40), nullable=False, **varchar_default("custom")) + encrypted_config = db.Column(adjusted_text(), nullable=True) is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) last_used = db.Column(db.DateTime, nullable=True) - quota_type = db.Column(db.String(40), nullable=True, server_default=db.text("''::character varying")) + quota_type = db.Column(db.String(40), nullable=True, **varchar_default("")) quota_limit = db.Column(db.BigInteger, nullable=True) quota_used = db.Column(db.BigInteger, default=0) @@ -105,12 +105,12 @@ class ProviderModel(Base): ), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, **uuid_default()) tenant_id = db.Column(StringUUID, nullable=False) provider_name = db.Column(db.String(255), nullable=False) model_name = db.Column(db.String(255), nullable=False) model_type = db.Column(db.String(40), nullable=False) - encrypted_config = db.Column(db.Text, nullable=True) + encrypted_config = db.Column(adjusted_text(), nullable=True) is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -123,7 +123,7 @@ class TenantDefaultModel(Base): db.Index("tenant_default_model_tenant_id_provider_type_idx", "tenant_id", "provider_name", "model_type"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, **uuid_default()) tenant_id = db.Column(StringUUID, nullable=False) provider_name = db.Column(db.String(255), nullable=False) model_name = db.Column(db.String(255), nullable=False) @@ -139,7 +139,7 @@ class TenantPreferredModelProvider(Base): db.Index("tenant_preferred_model_provider_tenant_provider_idx", "tenant_id", "provider_name"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, **uuid_default()) tenant_id = db.Column(StringUUID, nullable=False) provider_name = db.Column(db.String(255), nullable=False) preferred_provider_type = db.Column(db.String(40), nullable=False) @@ -154,7 +154,7 @@ class ProviderOrder(Base): db.Index("provider_order_tenant_provider_idx", "tenant_id", "provider_name"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, **uuid_default()) tenant_id = db.Column(StringUUID, nullable=False) provider_name = db.Column(db.String(255), nullable=False) account_id = db.Column(StringUUID, nullable=False) @@ -164,7 +164,7 @@ class ProviderOrder(Base): quantity = db.Column(db.Integer, nullable=False, server_default=db.text("1")) currency = db.Column(db.String(40)) total_amount = db.Column(db.Integer) - payment_status = db.Column(db.String(40), nullable=False, server_default=db.text("'wait_pay'::character varying")) + payment_status = db.Column(db.String(40), nullable=False, **varchar_default("wait_pay")) paid_at = db.Column(db.DateTime) pay_failed_at = db.Column(db.DateTime) refunded_at = db.Column(db.DateTime) @@ -183,7 +183,7 @@ class ProviderModelSetting(Base): db.Index("provider_model_setting_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, **uuid_default()) tenant_id = db.Column(StringUUID, nullable=False) provider_name = db.Column(db.String(255), nullable=False) model_name = db.Column(db.String(255), nullable=False) @@ -205,13 +205,13 @@ class LoadBalancingModelConfig(Base): db.Index("load_balancing_model_config_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, **uuid_default()) tenant_id = db.Column(StringUUID, nullable=False) provider_name = db.Column(db.String(255), nullable=False) model_name = db.Column(db.String(255), nullable=False) model_type = db.Column(db.String(40), nullable=False) name = db.Column(db.String(255), nullable=False) - encrypted_config = db.Column(db.Text, nullable=True) + encrypted_config = db.Column(adjusted_text(), nullable=True) enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/source.py b/api/models/source.py index b9d7d91346..2ebb858df8 100644 --- a/api/models/source.py +++ b/api/models/source.py @@ -1,12 +1,11 @@ import json from sqlalchemy import func -from sqlalchemy.dialects.postgresql import JSONB from models.base import Base from .engine import db -from .types import StringUUID +from .types import StringUUID, adjusted_json_index, adjusted_jsonb, uuid_default class DataSourceOauthBinding(db.Model): # type: ignore[name-defined] @@ -14,14 +13,14 @@ class DataSourceOauthBinding(db.Model): # type: ignore[name-defined] __table_args__ = ( db.PrimaryKeyConstraint("id", name="source_binding_pkey"), db.Index("source_binding_tenant_id_idx", "tenant_id"), - db.Index("source_info_idx", "source_info", postgresql_using="gin"), + adjusted_json_index("source_info_idx", "source_info"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, **uuid_default()) tenant_id = db.Column(StringUUID, nullable=False) access_token = db.Column(db.String(255), nullable=False) provider = db.Column(db.String(255), nullable=False) - source_info = db.Column(JSONB, nullable=False) + source_info = db.Column(adjusted_jsonb(), nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) disabled = db.Column(db.Boolean, nullable=True, server_default=db.text("false")) @@ -35,7 +34,7 @@ class DataSourceApiKeyAuthBinding(Base): db.Index("data_source_api_key_auth_binding_provider_idx", "provider"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, **uuid_default()) tenant_id = db.Column(StringUUID, nullable=False) category = db.Column(db.String(255), nullable=False) provider = db.Column(db.String(255), nullable=False) diff --git a/api/models/tools.py b/api/models/tools.py index aef1490729..5cc1bef04e 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -14,7 +14,7 @@ from models.base import Base from .engine import db from .model import Account, App, Tenant -from .types import StringUUID +from .types import StringUUID, no_length_string, uuid_default class BuiltinToolProvider(Base): @@ -30,7 +30,7 @@ class BuiltinToolProvider(Base): ) # id of the tool provider - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, **uuid_default()) # id of the tenant tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=True) # who created this tool provider @@ -62,7 +62,7 @@ class ApiToolProvider(Base): db.UniqueConstraint("name", "tenant_id", name="unique_api_tool_provider"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, **uuid_default()) # name of the api provider name = db.Column(db.String(255), nullable=False) # icon @@ -122,7 +122,7 @@ class ToolLabelBinding(Base): db.UniqueConstraint("tool_id", "label_name", name="unique_tool_label_bind"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, **uuid_default()) # tool id tool_id: Mapped[str] = mapped_column(db.String(64), nullable=False) # tool type @@ -143,7 +143,7 @@ class WorkflowToolProvider(Base): db.UniqueConstraint("tenant_id", "app_id", name="unique_workflow_tool_provider_app_id"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, **uuid_default()) # name of the workflow provider name: Mapped[str] = mapped_column(db.String(255), nullable=False) # label of the workflow provider @@ -161,7 +161,7 @@ class WorkflowToolProvider(Base): # description of the provider description: Mapped[str] = mapped_column(db.Text, nullable=False) # parameter configuration - parameter_configuration: Mapped[str] = mapped_column(db.Text, nullable=False, server_default="[]") + parameter_configuration: Mapped[str] = mapped_column(db.Text, nullable=False, default="[]") # privacy policy privacy_policy: Mapped[str] = mapped_column(db.String(255), nullable=True, server_default="") @@ -201,7 +201,7 @@ class ToolModelInvoke(Base): __tablename__ = "tool_model_invokes" __table_args__ = (db.PrimaryKeyConstraint("id", name="tool_model_invoke_pkey"),) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, **uuid_default()) # who invoke this tool user_id = db.Column(StringUUID, nullable=False) # tenant id @@ -244,7 +244,7 @@ class ToolConversationVariables(Base): db.Index("conversation_id_idx", "conversation_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, **uuid_default()) # conversation user id user_id = db.Column(StringUUID, nullable=False) # tenant id @@ -273,7 +273,7 @@ class ToolFile(Base): db.Index("tool_file_conversation_id_idx", "conversation_id"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, **uuid_default()) # conversation user id user_id: Mapped[str] = mapped_column(StringUUID) # tenant id @@ -287,7 +287,7 @@ class ToolFile(Base): # original url original_url: Mapped[str] = mapped_column(db.String(2048), nullable=True) # name - name: Mapped[str] = mapped_column(default="") + name: Mapped[str] = mapped_column(no_length_string(), default="") # size size: Mapped[int] = mapped_column(default=-1) @@ -326,14 +326,14 @@ class DeprecatedPublishedAppTool(Base): def description_i18n(self) -> I18nObject: return I18nObject(**json.loads(self.description)) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, **uuid_default()) user_id: Mapped[str] = db.Column(StringUUID, nullable=False) tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False) conversation_id: Mapped[Optional[str]] = db.Column(StringUUID, nullable=True) file_key: Mapped[str] = db.Column(db.String(255), nullable=False) mimetype: Mapped[str] = db.Column(db.String(255), nullable=False) original_url: Mapped[Optional[str]] = db.Column(db.String(2048), nullable=True) - name: Mapped[str] = mapped_column(default="") + name: Mapped[str] = mapped_column(no_length_string(), default="") size: Mapped[int] = mapped_column(default=-1) def __init__( diff --git a/api/models/types.py b/api/models/types.py index cb6773e70c..30cb2fdfbd 100644 --- a/api/models/types.py +++ b/api/models/types.py @@ -1,5 +1,11 @@ -from sqlalchemy import CHAR, TypeDecorator -from sqlalchemy.dialects.postgresql import UUID +import uuid + +from sqlalchemy import CHAR, JSON, TypeDecorator +from sqlalchemy.dialects import mysql, postgresql + +from configs import dify_config + +from .engine import db class StringUUID(TypeDecorator): @@ -9,14 +15,14 @@ class StringUUID(TypeDecorator): def process_bind_param(self, value, dialect): if value is None: return value - elif dialect.name == "postgresql": + elif dialect.name in {"postgresql", "mysql"}: return str(value) else: return value.hex def load_dialect_impl(self, dialect): if dialect.name == "postgresql": - return dialect.type_descriptor(UUID()) + return dialect.type_descriptor(postgresql.UUID()) else: return dialect.type_descriptor(CHAR(36)) @@ -24,3 +30,54 @@ class StringUUID(TypeDecorator): if value is None: return value return str(value) + + +def adjusted_jsonb(): + if dify_config.SQLALCHEMY_DATABASE_URI_SCHEME == "postgresql": + return postgresql.JSONB + else: + return JSON + + +def adjusted_json_index(index_name, column_name): + index_name = index_name or f"{column_name}_idx" + + if dify_config.SQLALCHEMY_DATABASE_URI_SCHEME == "postgresql": + return db.Index(index_name, column_name, postgresql_using="gin") + else: + return None + + +def no_length_string(): + if "mysql" in dify_config.SQLALCHEMY_DATABASE_URI_SCHEME: + return db.String(255) + else: + return db.String + + +def adjusted_text(): + if "mysql" in dify_config.SQLALCHEMY_DATABASE_URI_SCHEME: + return mysql.LONGTEXT + else: + return db.TEXT + + +def uuid_default(): + if dify_config.SQLALCHEMY_DATABASE_URI_SCHEME == "postgresql": + return {"server_default": db.text("uuid_generate_v4()")} + else: + return {"default": lambda: uuid.uuid4()} + + +def varchar_default(varchar): + if dify_config.SQLALCHEMY_DATABASE_URI_SCHEME == "postgresql": + return {"server_default": db.text(f"'{varchar}'::character varying")} + else: + return {"default": varchar} + + +def text_default(varchar): + if dify_config.SQLALCHEMY_DATABASE_URI_SCHEME == "postgresql": + return {"server_default": db.text(f"'{varchar}'::text")} + else: + return {"default": varchar} diff --git a/api/models/web.py b/api/models/web.py index fe2f0c47f8..5c1f5f12a0 100644 --- a/api/models/web.py +++ b/api/models/web.py @@ -5,7 +5,7 @@ from models.base import Base from .engine import db from .model import Message -from .types import StringUUID +from .types import StringUUID, uuid_default, varchar_default class SavedMessage(Base): @@ -15,10 +15,10 @@ class SavedMessage(Base): db.Index("saved_message_message_idx", "app_id", "message_id", "created_by_role", "created_by"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, **uuid_default()) app_id = db.Column(StringUUID, nullable=False) message_id = db.Column(StringUUID, nullable=False) - created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying")) + created_by_role = db.Column(db.String(255), nullable=False, **varchar_default("end_user")) created_by = db.Column(StringUUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -34,9 +34,9 @@ class PinnedConversation(Base): db.Index("pinned_conversation_conversation_idx", "app_id", "conversation_id", "created_by_role", "created_by"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = db.Column(StringUUID, **uuid_default()) app_id = db.Column(StringUUID, nullable=False) conversation_id: Mapped[str] = mapped_column(StringUUID) - created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying")) + created_by_role = db.Column(db.String(255), nullable=False, **varchar_default("end_user")) created_by = db.Column(StringUUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/workflow.py b/api/models/workflow.py index ed6820702c..29c7bd4c1f 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -25,7 +25,7 @@ from models.enums import CreatedByRole from .account import Account from .engine import db -from .types import StringUUID +from .types import StringUUID, adjusted_text, no_length_string, uuid_default if TYPE_CHECKING: from models.model import AppMode @@ -105,15 +105,15 @@ class Workflow(Base): db.Index("workflow_version_idx", "tenant_id", "app_id", "version"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, **uuid_default()) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) type: Mapped[str] = mapped_column(db.String(255), nullable=False) - version: Mapped[str] - marked_name: Mapped[str] = mapped_column(default="", server_default="") - marked_comment: Mapped[str] = mapped_column(default="", server_default="") - graph: Mapped[str] = mapped_column(sa.Text) - _features: Mapped[str] = mapped_column("features", sa.TEXT) + version: Mapped[str] = mapped_column(db.String(255)) + marked_name: Mapped[str] = mapped_column(no_length_string(), default="", server_default="") + marked_comment: Mapped[str] = mapped_column(no_length_string(), default="", server_default="") + graph: Mapped[str] = mapped_column(adjusted_text()) + _features: Mapped[str] = mapped_column("features", adjusted_text()) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by: Mapped[Optional[str]] = mapped_column(StringUUID) @@ -124,10 +124,10 @@ class Workflow(Base): server_onupdate=func.current_timestamp(), ) _environment_variables: Mapped[str] = mapped_column( - "environment_variables", db.Text, nullable=False, server_default="{}" + "environment_variables", adjusted_text(), nullable=False, default="{}" ) _conversation_variables: Mapped[str] = mapped_column( - "conversation_variables", db.Text, nullable=False, server_default="{}" + "conversation_variables", adjusted_text(), nullable=False, default="{}" ) @classmethod @@ -400,7 +400,7 @@ class WorkflowRun(Base): db.Index("workflow_run_tenant_app_sequence_idx", "tenant_id", "app_id", "sequence_number"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, **uuid_default()) tenant_id: Mapped[str] = mapped_column(StringUUID) app_id: Mapped[str] = mapped_column(StringUUID) sequence_number: Mapped[int] = mapped_column() @@ -408,8 +408,8 @@ class WorkflowRun(Base): type: Mapped[str] = mapped_column(db.String(255)) triggered_from: Mapped[str] = mapped_column(db.String(255)) version: Mapped[str] = mapped_column(db.String(255)) - graph: Mapped[Optional[str]] = mapped_column(db.Text) - inputs: Mapped[Optional[str]] = mapped_column(db.Text) + graph: Mapped[Optional[str]] = mapped_column(adjusted_text()) + inputs: Mapped[Optional[str]] = mapped_column(adjusted_text()) status: Mapped[str] = mapped_column(db.String(255)) # running, succeeded, failed, stopped, partial-succeeded outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}") error: Mapped[Optional[str]] = mapped_column(db.Text) @@ -629,7 +629,7 @@ class WorkflowNodeExecution(Base): ), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, **uuid_default()) tenant_id: Mapped[str] = mapped_column(StringUUID) app_id: Mapped[str] = mapped_column(StringUUID) workflow_id: Mapped[str] = mapped_column(StringUUID) @@ -642,7 +642,7 @@ class WorkflowNodeExecution(Base): node_type: Mapped[str] = mapped_column(db.String(255)) title: Mapped[str] = mapped_column(db.String(255)) inputs: Mapped[Optional[str]] = mapped_column(db.Text) - process_data: Mapped[Optional[str]] = mapped_column(db.Text) + process_data: Mapped[Optional[str]] = mapped_column(adjusted_text()) outputs: Mapped[Optional[str]] = mapped_column(db.Text) status: Mapped[str] = mapped_column(db.String(255)) error: Mapped[Optional[str]] = mapped_column(db.Text) @@ -758,7 +758,7 @@ class WorkflowAppLog(Base): db.Index("workflow_app_log_app_idx", "tenant_id", "app_id"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, **uuid_default()) tenant_id: Mapped[str] = mapped_column(StringUUID) app_id: Mapped[str] = mapped_column(StringUUID) workflow_id = db.Column(StringUUID, nullable=False) @@ -796,7 +796,7 @@ class ConversationVariable(Base): id: Mapped[str] = mapped_column(StringUUID, primary_key=True) conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False, primary_key=True) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - data = mapped_column(db.Text, nullable=False) + data = mapped_column(adjusted_text(), nullable=False) created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = mapped_column( db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() diff --git a/docker/.env.example b/docker/.env.example index 6efad1bc9c..224dcb8e8c 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -177,7 +177,7 @@ API_TOOL_DEFAULT_READ_TIMEOUT=60 # ------------------------------ # Database Configuration -# The database uses PostgreSQL. Please use the public schema. +# The database uses PostgreSQL by default. Please use the public schema. # It is consistent with the configuration in the 'db' service below. # ------------------------------ @@ -186,6 +186,10 @@ DB_PASSWORD=difyai123456 DB_HOST=db DB_PORT=5432 DB_DATABASE=dify + +# Database URI scheme +SQLALCHEMY_DATABASE_URI_SCHEME=postgresql + # The size of the database connection pool. # The default is 30 connections, which can be appropriately increased. SQLALCHEMY_POOL_SIZE=30