add JSONType and json_index for db compatible
This commit is contained in:
parent
abeaea4f79
commit
89de8ea481
@ -12,7 +12,6 @@ from json import JSONDecodeError
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import Mapped
|
||||
|
||||
from configs import dify_config
|
||||
@ -24,7 +23,7 @@ from services.entities.knowledge_entities.knowledge_entities import ParentMode,
|
||||
from .account import Account
|
||||
from .engine import db
|
||||
from .model import App, Tag, TagBinding, UploadFile
|
||||
from .types import StringUUID
|
||||
from .types import JSONType, StringUUID, json_index
|
||||
|
||||
|
||||
class DatasetPermissionEnum(enum.StrEnum):
|
||||
@ -38,7 +37,7 @@ class Dataset(db.Model): # type: ignore[name-defined]
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="dataset_pkey"),
|
||||
db.Index("dataset_tenant_idx", "tenant_id"),
|
||||
db.Index("retrieval_model_idx", "retrieval_model", postgresql_using="gin"),
|
||||
json_index("retrieval_model_idx", "retrieval_model"),
|
||||
)
|
||||
|
||||
INDEXING_TECHNIQUE_LIST = ["high_quality", "economy", None]
|
||||
@ -60,7 +59,7 @@ class Dataset(db.Model): # type: ignore[name-defined]
|
||||
embedding_model = db.Column(db.String(255), nullable=True)
|
||||
embedding_model_provider = db.Column(db.String(255), nullable=True)
|
||||
collection_binding_id = db.Column(StringUUID, nullable=True)
|
||||
retrieval_model = db.Column(JSONB, nullable=True)
|
||||
retrieval_model = db.Column(JSONType, nullable=True)
|
||||
built_in_field_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
|
||||
@property
|
||||
|
@ -1,12 +1,11 @@
|
||||
import json
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
|
||||
from models.base import Base
|
||||
|
||||
from .engine import db
|
||||
from .types import StringUUID
|
||||
from .types import JSONType, StringUUID, json_index
|
||||
|
||||
|
||||
class DataSourceOauthBinding(db.Model): # type: ignore[name-defined]
|
||||
@ -14,14 +13,14 @@ class DataSourceOauthBinding(db.Model): # type: ignore[name-defined]
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="source_binding_pkey"),
|
||||
db.Index("source_binding_tenant_id_idx", "tenant_id"),
|
||||
db.Index("source_info_idx", "source_info", postgresql_using="gin"),
|
||||
json_index("source_info_idx", "source_info"),
|
||||
)
|
||||
|
||||
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
access_token = db.Column(db.String(255), nullable=False)
|
||||
provider = db.Column(db.String(255), nullable=False)
|
||||
source_info = db.Column(JSONB, nullable=False)
|
||||
source_info = db.Column(JSONType, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
disabled = db.Column(db.Boolean, nullable=True, server_default=db.text("false"))
|
||||
|
@ -1,5 +1,11 @@
|
||||
from sqlalchemy import CHAR, TypeDecorator
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
import json
|
||||
|
||||
from sqlalchemy import CHAR, JSON, TypeDecorator
|
||||
from sqlalchemy.dialects import mysql, postgresql
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
from .engine import db
|
||||
|
||||
|
||||
class StringUUID(TypeDecorator):
|
||||
@ -9,14 +15,14 @@ class StringUUID(TypeDecorator):
|
||||
def process_bind_param(self, value, dialect):
|
||||
if value is None:
|
||||
return value
|
||||
elif dialect.name == "postgresql":
|
||||
elif dialect.name in {"postgresql", "mysql"}:
|
||||
return str(value)
|
||||
else:
|
||||
return value.hex
|
||||
|
||||
def load_dialect_impl(self, dialect):
|
||||
if dialect.name == "postgresql":
|
||||
return dialect.type_descriptor(UUID())
|
||||
return dialect.type_descriptor(postgresql.UUID())
|
||||
else:
|
||||
return dialect.type_descriptor(CHAR(36))
|
||||
|
||||
@ -24,3 +30,37 @@ class StringUUID(TypeDecorator):
|
||||
if value is None:
|
||||
return value
|
||||
return str(value)
|
||||
|
||||
|
||||
class JSONType(TypeDecorator):
|
||||
impl = JSON
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
if value is None:
|
||||
return value
|
||||
else:
|
||||
return json.dumps(value)
|
||||
|
||||
def load_dialect_impl(self, dialect):
|
||||
if dialect.name == "postgresql":
|
||||
return dialect.type_descriptor(postgresql.JSONB())
|
||||
elif dialect.name == "mysql":
|
||||
return dialect.type_descriptor(mysql.JSON())
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported dialect: {dialect.name}")
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
if value is None:
|
||||
return value
|
||||
else:
|
||||
return json.loads(value)
|
||||
|
||||
|
||||
def json_index(index_name, column_name):
|
||||
index_name = index_name or f"{column_name}_idx"
|
||||
|
||||
if dify_config.SQLALCHEMY_DATABASE_URI_SCHEME == "postgresql":
|
||||
return db.Index(index_name, column_name, postgresql_using="gin")
|
||||
else:
|
||||
return None
|
||||
|
Loading…
Reference in New Issue
Block a user