add JSONType and json_index for db compatible

This commit is contained in:
He Wang 2025-03-13 20:40:04 +08:00
parent abeaea4f79
commit 89de8ea481
3 changed files with 50 additions and 12 deletions

View File

@ -12,7 +12,6 @@ from json import JSONDecodeError
from typing import Any, cast
from sqlalchemy import func
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped
from configs import dify_config
@ -24,7 +23,7 @@ from services.entities.knowledge_entities.knowledge_entities import ParentMode,
from .account import Account
from .engine import db
from .model import App, Tag, TagBinding, UploadFile
from .types import StringUUID
from .types import JSONType, StringUUID, json_index
class DatasetPermissionEnum(enum.StrEnum):
@ -38,7 +37,7 @@ class Dataset(db.Model): # type: ignore[name-defined]
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_pkey"),
db.Index("dataset_tenant_idx", "tenant_id"),
db.Index("retrieval_model_idx", "retrieval_model", postgresql_using="gin"),
json_index("retrieval_model_idx", "retrieval_model"),
)
INDEXING_TECHNIQUE_LIST = ["high_quality", "economy", None]
@ -60,7 +59,7 @@ class Dataset(db.Model): # type: ignore[name-defined]
embedding_model = db.Column(db.String(255), nullable=True)
embedding_model_provider = db.Column(db.String(255), nullable=True)
collection_binding_id = db.Column(StringUUID, nullable=True)
retrieval_model = db.Column(JSONB, nullable=True)
retrieval_model = db.Column(JSONType, nullable=True)
built_in_field_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
@property

View File

@ -1,12 +1,11 @@
import json
from sqlalchemy import func
from sqlalchemy.dialects.postgresql import JSONB
from models.base import Base
from .engine import db
from .types import StringUUID
from .types import JSONType, StringUUID, json_index
class DataSourceOauthBinding(db.Model): # type: ignore[name-defined]
@ -14,14 +13,14 @@ class DataSourceOauthBinding(db.Model): # type: ignore[name-defined]
__table_args__ = (
db.PrimaryKeyConstraint("id", name="source_binding_pkey"),
db.Index("source_binding_tenant_id_idx", "tenant_id"),
db.Index("source_info_idx", "source_info", postgresql_using="gin"),
json_index("source_info_idx", "source_info"),
)
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False)
access_token = db.Column(db.String(255), nullable=False)
provider = db.Column(db.String(255), nullable=False)
source_info = db.Column(JSONB, nullable=False)
source_info = db.Column(JSONType, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
disabled = db.Column(db.Boolean, nullable=True, server_default=db.text("false"))

View File

@ -1,5 +1,11 @@
from sqlalchemy import CHAR, TypeDecorator
from sqlalchemy.dialects.postgresql import UUID
import json
from sqlalchemy import CHAR, JSON, TypeDecorator
from sqlalchemy.dialects import mysql, postgresql
from configs import dify_config
from .engine import db
class StringUUID(TypeDecorator):
@ -9,14 +15,14 @@ class StringUUID(TypeDecorator):
def process_bind_param(self, value, dialect):
if value is None:
return value
elif dialect.name == "postgresql":
elif dialect.name in {"postgresql", "mysql"}:
return str(value)
else:
return value.hex
def load_dialect_impl(self, dialect):
if dialect.name == "postgresql":
return dialect.type_descriptor(UUID())
return dialect.type_descriptor(postgresql.UUID())
else:
return dialect.type_descriptor(CHAR(36))
@ -24,3 +30,37 @@ class StringUUID(TypeDecorator):
if value is None:
return value
return str(value)
class JSONType(TypeDecorator):
impl = JSON
cache_ok = True
def process_bind_param(self, value, dialect):
if value is None:
return value
else:
return json.dumps(value)
def load_dialect_impl(self, dialect):
if dialect.name == "postgresql":
return dialect.type_descriptor(postgresql.JSONB())
elif dialect.name == "mysql":
return dialect.type_descriptor(mysql.JSON())
else:
raise NotImplementedError(f"Unsupported dialect: {dialect.name}")
def process_result_value(self, value, dialect):
if value is None:
return value
else:
return json.loads(value)
def json_index(index_name, column_name):
index_name = index_name or f"{column_name}_idx"
if dify_config.SQLALCHEMY_DATABASE_URI_SCHEME == "postgresql":
return db.Index(index_name, column_name, postgresql_using="gin")
else:
return None