From 89de8ea48181a3084087caa3b4854fa4338a17ee Mon Sep 17 00:00:00 2001 From: He Wang Date: Thu, 13 Mar 2025 20:40:04 +0800 Subject: [PATCH] add JSONType and json_index for db compatible --- api/models/dataset.py | 7 +++---- api/models/source.py | 7 +++---- api/models/types.py | 48 +++++++++++++++++++++++++++++++++++++++---- 3 files changed, 50 insertions(+), 12 deletions(-) diff --git a/api/models/dataset.py b/api/models/dataset.py index 28589eb8c1..11fa7205aa 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -12,7 +12,6 @@ from json import JSONDecodeError from typing import Any, cast from sqlalchemy import func -from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import Mapped from configs import dify_config @@ -24,7 +23,7 @@ from services.entities.knowledge_entities.knowledge_entities import ParentMode, from .account import Account from .engine import db from .model import App, Tag, TagBinding, UploadFile -from .types import StringUUID +from .types import JSONType, StringUUID, json_index class DatasetPermissionEnum(enum.StrEnum): @@ -38,7 +37,7 @@ class Dataset(db.Model): # type: ignore[name-defined] __table_args__ = ( db.PrimaryKeyConstraint("id", name="dataset_pkey"), db.Index("dataset_tenant_idx", "tenant_id"), - db.Index("retrieval_model_idx", "retrieval_model", postgresql_using="gin"), + json_index("retrieval_model_idx", "retrieval_model"), ) INDEXING_TECHNIQUE_LIST = ["high_quality", "economy", None] @@ -60,7 +59,7 @@ class Dataset(db.Model): # type: ignore[name-defined] embedding_model = db.Column(db.String(255), nullable=True) embedding_model_provider = db.Column(db.String(255), nullable=True) collection_binding_id = db.Column(StringUUID, nullable=True) - retrieval_model = db.Column(JSONB, nullable=True) + retrieval_model = db.Column(JSONType, nullable=True) built_in_field_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) @property diff --git a/api/models/source.py b/api/models/source.py index b9d7d91346..e28f75941b 100644 --- a/api/models/source.py +++ b/api/models/source.py @@ -1,12 +1,11 @@ import json from sqlalchemy import func -from sqlalchemy.dialects.postgresql import JSONB from models.base import Base from .engine import db -from .types import StringUUID +from .types import JSONType, StringUUID, json_index class DataSourceOauthBinding(db.Model): # type: ignore[name-defined] @@ -14,14 +13,14 @@ class DataSourceOauthBinding(db.Model): # type: ignore[name-defined] __table_args__ = ( db.PrimaryKeyConstraint("id", name="source_binding_pkey"), db.Index("source_binding_tenant_id_idx", "tenant_id"), - db.Index("source_info_idx", "source_info", postgresql_using="gin"), + json_index("source_info_idx", "source_info"), ) id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) access_token = db.Column(db.String(255), nullable=False) provider = db.Column(db.String(255), nullable=False) - source_info = db.Column(JSONB, nullable=False) + source_info = db.Column(JSONType, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) disabled = db.Column(db.Boolean, nullable=True, server_default=db.text("false")) diff --git a/api/models/types.py b/api/models/types.py index cb6773e70c..80ad693e96 100644 --- a/api/models/types.py +++ b/api/models/types.py @@ -1,5 +1,11 @@ -from sqlalchemy import CHAR, TypeDecorator -from sqlalchemy.dialects.postgresql import UUID +import json + +from sqlalchemy import CHAR, JSON, TypeDecorator +from sqlalchemy.dialects import mysql, postgresql + +from configs import dify_config + +from .engine import db class StringUUID(TypeDecorator): @@ -9,14 +15,14 @@ class StringUUID(TypeDecorator): def process_bind_param(self, value, dialect): if value is None: return value - elif dialect.name == "postgresql": + elif dialect.name in {"postgresql", "mysql"}: return str(value) else: return value.hex def load_dialect_impl(self, dialect): if dialect.name == "postgresql": - return dialect.type_descriptor(UUID()) + return dialect.type_descriptor(postgresql.UUID()) else: return dialect.type_descriptor(CHAR(36)) @@ -24,3 +30,37 @@ class StringUUID(TypeDecorator): if value is None: return value return str(value) + + +class JSONType(TypeDecorator): + impl = JSON + cache_ok = True + + def process_bind_param(self, value, dialect): + if value is None: + return value + else: + return json.dumps(value) + + def load_dialect_impl(self, dialect): + if dialect.name == "postgresql": + return dialect.type_descriptor(postgresql.JSONB()) + elif dialect.name == "mysql": + return dialect.type_descriptor(mysql.JSON()) + else: + raise NotImplementedError(f"Unsupported dialect: {dialect.name}") + + def process_result_value(self, value, dialect): + if value is None: + return value + else: + return json.loads(value) + + +def json_index(index_name, column_name): + index_name = index_name or f"{column_name}_idx" + + if dify_config.SQLALCHEMY_DATABASE_URI_SCHEME == "postgresql": + return db.Index(index_name, column_name, postgresql_using="gin") + else: + return None