Migrate to DeclarativeBaseModel

This commit is contained in:
Yeuoly 2024-10-21 20:38:27 +08:00
parent 53e1b45d40
commit 11270a7ef2
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61
9 changed files with 138 additions and 78 deletions

View File

@ -3,6 +3,8 @@ from functools import wraps
from flask import request
from flask_restful import Resource, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound, Unauthorized
from constants.languages import supported_language
@ -54,7 +56,8 @@ class InsertExploreAppListApi(Resource):
parser.add_argument("position", type=int, required=True, nullable=False, location="json")
args = parser.parse_args()
app = App.query.filter(App.id == args["app_id"]).first()
with Session(db.engine) as session:
app = session.execute(select(App).filter(App.id == args["app_id"])).scalar_one_or_none()
if not app:
raise NotFound(f'App \'{args["app_id"]}\' is not found')
@ -70,7 +73,10 @@ class InsertExploreAppListApi(Resource):
privacy_policy = site.privacy_policy or args["privacy_policy"] or ""
custom_disclaimer = site.custom_disclaimer or args["custom_disclaimer"] or ""
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first()
with Session(db.engine) as session:
recommended_app = session.execute(
select(RecommendedApp).filter(RecommendedApp.app_id == args["app_id"])
).scalar_one_or_none()
if not recommended_app:
recommended_app = RecommendedApp(
@ -110,16 +116,26 @@ class InsertExploreAppApi(Resource):
@only_edition_cloud
@admin_required
def delete(self, app_id):
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == str(app_id)).first()
with Session(db.engine) as session:
recommended_app = session.execute(
select(RecommendedApp).filter(RecommendedApp.app_id == str(app_id))
).scalar_one_or_none()
if not recommended_app:
return {"result": "success"}, 204
app = App.query.filter(App.id == recommended_app.app_id).first()
with Session(db.engine) as session:
app = session.execute(select(App).filter(App.id == recommended_app.app_id)).scalar_one_or_none()
if app:
app.is_public = False
installed_apps = InstalledApp.query.filter(
InstalledApp.app_id == recommended_app.app_id, InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id
with Session(db.engine) as session:
installed_apps = session.execute(
select(InstalledApp).filter(
InstalledApp.app_id == recommended_app.app_id,
InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id,
)
).all()
for installed_app in installed_apps:

View File

@ -33,7 +33,10 @@ def _get_resource(resource_id, tenant_id, resource_model):
select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id)
).scalar_one_or_none()
else:
resource = resource_model.query.filter_by(id=resource_id, tenant_id=tenant_id).first()
with Session(db.engine) as session:
resource = session.execute(
select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id)
).scalar_one_or_none()
if resource is None:
flask_restful.abort(404, message=f"{resource_model.__name__} not found.")

View File

@ -3,6 +3,8 @@ import secrets
from flask import request
from flask_restful import Resource, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
from constants.languages import languages
from controllers.console import api
@ -41,7 +43,8 @@ class ForgotPasswordSendEmailApi(Resource):
else:
language = "en-US"
account = Account.query.filter_by(email=args["email"]).first()
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
token = None
if account is None:
if FeatureService.get_system_features().is_allow_register:
@ -108,7 +111,8 @@ class ForgotPasswordResetApi(Resource):
password_hashed = hash_password(new_password, salt)
base64_password_hashed = base64.b64encode(password_hashed).decode()
account = Account.query.filter_by(email=reset_data.get("email")).first()
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=reset_data.get("email"))).scalar_one_or_none()
if account:
account.password = base64_password_hashed
account.password_salt = base64_salt

View File

@ -5,6 +5,8 @@ from typing import Optional
import requests
from flask import current_app, redirect, request
from flask_restful import Resource
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import Unauthorized
from configs import dify_config
@ -135,7 +137,8 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
account = Account.get_by_openid(provider, user_info.id)
if not account:
account = Account.query.filter_by(email=user_info.email).first()
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=user_info.email)).scalar_one_or_none()
return account

View File

@ -4,6 +4,8 @@ import json
from flask import request
from flask_login import current_user
from flask_restful import Resource, marshal_with, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from controllers.console import api
@ -77,7 +79,10 @@ class DataSourceApi(Resource):
def patch(self, binding_id, action):
binding_id = str(binding_id)
action = str(action)
data_source_binding = DataSourceOauthBinding.query.filter_by(id=binding_id).first()
with Session(db.engine) as session:
data_source_binding = session.execute(
select(DataSourceOauthBinding).filter_by(id=binding_id)
).scalar_one_or_none()
if data_source_binding is None:
raise NotFound("Data source binding not found.")
# enable binding
@ -109,6 +114,7 @@ class DataSourceNotionListApi(Resource):
def get(self):
dataset_id = request.args.get("dataset_id", default=None, type=str)
exist_page_ids = []
with Session(db.engine) as session:
# import notion in the exist dataset
if dataset_id:
dataset = DatasetService.get_dataset(dataset_id)
@ -116,19 +122,24 @@ class DataSourceNotionListApi(Resource):
raise NotFound("Dataset not found.")
if dataset.data_source_type != "notion_import":
raise ValueError("Dataset is not notion type.")
documents = Document.query.filter_by(
documents = session.execute(
select(Document).filter_by(
dataset_id=dataset_id,
tenant_id=current_user.current_tenant_id,
data_source_type="notion_import",
enabled=True,
)
).all()
if documents:
for document in documents:
data_source_info = json.loads(document.data_source_info)
exist_page_ids.append(data_source_info["notion_page_id"])
# get all authorized pages
data_source_bindings = DataSourceOauthBinding.query.filter_by(
data_source_bindings = session.execute(
select(DataSourceOauthBinding).filter_by(
tenant_id=current_user.current_tenant_id, provider="notion", disabled=False
)
).all()
if not data_source_bindings:
return {"notion_info": []}, 200
@ -159,14 +170,17 @@ class DataSourceNotionApi(Resource):
def get(self, workspace_id, page_id, page_type):
workspace_id = str(workspace_id)
page_id = str(page_id)
data_source_binding = DataSourceOauthBinding.query.filter(
with Session(db.engine) as session:
data_source_binding = session.execute(
select(DataSourceOauthBinding).filter(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
)
).first()
)
).scalar_one_or_none()
if not data_source_binding:
raise NotFound("Data source binding not found.")

View File

@ -5,7 +5,8 @@ from datetime import datetime, timezone
from flask import request
from flask_login import current_user
from flask_restful import Resource, fields, marshal, marshal_with, reqparse
from sqlalchemy import asc, desc
from sqlalchemy import asc, desc, select
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, NotFound
import services
@ -104,7 +105,8 @@ class GetProcessRuleApi(Resource):
rules = DocumentService.DEFAULT_RULES["rules"]
if document_id:
# get the latest process rule
document = Document.query.get_or_404(document_id)
with Session(db.engine) as session:
document = session.execute(select(Document).get_or_404(document_id)).scalar_one_or_none()
dataset = DatasetService.get_dataset(document.dataset_id)
@ -167,7 +169,10 @@ class DatasetDocumentListApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
query = Document.query.filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id)
with Session(db.engine) as session:
query = session.execute(
select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id)
).all()
if search:
search = f"%{search}%"
@ -204,15 +209,22 @@ class DatasetDocumentListApi(Resource):
paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
documents = paginated_documents.items
if fetch:
with Session(db.engine) as session:
for document in documents:
completed_segments = DocumentSegment.query.filter(
completed_segments = (
session.query(DocumentSegment)
.filter(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment",
).count()
total_segments = DocumentSegment.query.filter(
DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment"
).count()
)
.count()
)
total_segments = (
session.query(DocumentSegment)
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count()
)
document.completed_segments = completed_segments
document.total_segments = total_segments
data = marshal(documents, document_with_segments_fields)

View File

@ -2,8 +2,11 @@ import os
from flask import session
from flask_restful import Resource, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config
from extensions.ext_database import db
from libs.helper import StrLen
from models.model import DifySetup
from services.account_service import TenantService
@ -42,7 +45,11 @@ class InitValidateAPI(Resource):
def get_init_validate_status():
if dify_config.EDITION == "SELF_HOSTED":
if os.environ.get("INIT_PASSWORD"):
return session.get("is_init_validated") or DifySetup.query.first()
if session.get("is_init_validated"):
return True
with Session(db.engine) as db_session:
return db_session.execute(select(DifySetup)).scalar_one_or_none()
return True

View File

@ -4,6 +4,7 @@ import json
from flask_login import UserMixin
from extensions.ext_database import db
from models.base import Base
from .types import StringUUID
@ -16,7 +17,7 @@ class AccountStatus(str, enum.Enum):
CLOSED = "closed"
class Account(UserMixin, db.Model):
class Account(UserMixin, Base):
__tablename__ = "accounts"
__table_args__ = (db.PrimaryKeyConstraint("id", name="account_pkey"), db.Index("account_email_idx", "email"))

View File

@ -38,7 +38,7 @@ class FileUploadConfig(BaseModel):
number_limits: int = Field(default=0, gt=0, le=10)
class DifySetup(db.Model):
class DifySetup(BaseModel):
__tablename__ = "dify_setups"
__table_args__ = (db.PrimaryKeyConstraint("version", name="dify_setup_pkey"),)