From 11270a7ef21a282558d1e68b9f042f6b24d2b957 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Mon, 21 Oct 2024 20:38:27 +0800 Subject: [PATCH] Migrate to DeclarativeBaseModel --- api/controllers/console/admin.py | 30 +++-- api/controllers/console/apikey.py | 5 +- .../console/auth/forgot_password.py | 8 +- api/controllers/console/auth/oauth.py | 5 +- .../console/datasets/data_source.py | 112 ++++++++++-------- .../console/datasets/datasets_document.py | 42 ++++--- api/controllers/console/init_validate.py | 9 +- api/models/account.py | 3 +- api/models/model.py | 2 +- 9 files changed, 138 insertions(+), 78 deletions(-) diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py index f78ea9b288..6a61ffc234 100644 --- a/api/controllers/console/admin.py +++ b/api/controllers/console/admin.py @@ -3,6 +3,8 @@ from functools import wraps from flask import request from flask_restful import Resource, reqparse +from sqlalchemy import select +from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound, Unauthorized from constants.languages import supported_language @@ -54,7 +56,8 @@ class InsertExploreAppListApi(Resource): parser.add_argument("position", type=int, required=True, nullable=False, location="json") args = parser.parse_args() - app = App.query.filter(App.id == args["app_id"]).first() + with Session(db.engine) as session: + app = session.execute(select(App).filter(App.id == args["app_id"])).scalar_one_or_none() if not app: raise NotFound(f'App \'{args["app_id"]}\' is not found') @@ -70,7 +73,10 @@ class InsertExploreAppListApi(Resource): privacy_policy = site.privacy_policy or args["privacy_policy"] or "" custom_disclaimer = site.custom_disclaimer or args["custom_disclaimer"] or "" - recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first() + with Session(db.engine) as session: + recommended_app = session.execute( + select(RecommendedApp).filter(RecommendedApp.app_id == args["app_id"]) + ).scalar_one_or_none() if not recommended_app: recommended_app = RecommendedApp( @@ -110,17 +116,27 @@ class InsertExploreAppApi(Resource): @only_edition_cloud @admin_required def delete(self, app_id): - recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == str(app_id)).first() + with Session(db.engine) as session: + recommended_app = session.execute( + select(RecommendedApp).filter(RecommendedApp.app_id == str(app_id)) + ).scalar_one_or_none() + if not recommended_app: return {"result": "success"}, 204 - app = App.query.filter(App.id == recommended_app.app_id).first() + with Session(db.engine) as session: + app = session.execute(select(App).filter(App.id == recommended_app.app_id)).scalar_one_or_none() + if app: app.is_public = False - installed_apps = InstalledApp.query.filter( - InstalledApp.app_id == recommended_app.app_id, InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id - ).all() + with Session(db.engine) as session: + installed_apps = session.execute( + select(InstalledApp).filter( + InstalledApp.app_id == recommended_app.app_id, + InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id, + ) + ).all() for installed_app in installed_apps: db.session.delete(installed_app) diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index 25930a140e..e014964bf9 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -33,7 +33,10 @@ def _get_resource(resource_id, tenant_id, resource_model): select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id) ).scalar_one_or_none() else: - resource = resource_model.query.filter_by(id=resource_id, tenant_id=tenant_id).first() + with Session(db.engine) as session: + resource = session.execute( + select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id) + ).scalar_one_or_none() if resource is None: flask_restful.abort(404, message=f"{resource_model.__name__} not found.") diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index 7fea610610..3c2de4612f 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -3,6 +3,8 @@ import secrets from flask import request from flask_restful import Resource, reqparse +from sqlalchemy import select +from sqlalchemy.orm import Session from constants.languages import languages from controllers.console import api @@ -41,7 +43,8 @@ class ForgotPasswordSendEmailApi(Resource): else: language = "en-US" - account = Account.query.filter_by(email=args["email"]).first() + with Session(db.engine) as session: + account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none() token = None if account is None: if FeatureService.get_system_features().is_allow_register: @@ -108,7 +111,8 @@ class ForgotPasswordResetApi(Resource): password_hashed = hash_password(new_password, salt) base64_password_hashed = base64.b64encode(password_hashed).decode() - account = Account.query.filter_by(email=reset_data.get("email")).first() + with Session(db.engine) as session: + account = session.execute(select(Account).filter_by(email=reset_data.get("email"))).scalar_one_or_none() if account: account.password = base64_password_hashed account.password_salt = base64_salt diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 282e69448e..45ae77a002 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -5,6 +5,8 @@ from typing import Optional import requests from flask import current_app, redirect, request from flask_restful import Resource +from sqlalchemy import select +from sqlalchemy.orm import Session from werkzeug.exceptions import Unauthorized from configs import dify_config @@ -135,7 +137,8 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> account = Account.get_by_openid(provider, user_info.id) if not account: - account = Account.query.filter_by(email=user_info.email).first() + with Session(db.engine) as session: + account = session.execute(select(Account).filter_by(email=user_info.email)).scalar_one_or_none() return account diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index a2c9760782..f024e3799c 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -4,6 +4,8 @@ import json from flask import request from flask_login import current_user from flask_restful import Resource, marshal_with, reqparse +from sqlalchemy import select +from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound from controllers.console import api @@ -77,7 +79,10 @@ class DataSourceApi(Resource): def patch(self, binding_id, action): binding_id = str(binding_id) action = str(action) - data_source_binding = DataSourceOauthBinding.query.filter_by(id=binding_id).first() + with Session(db.engine) as session: + data_source_binding = session.execute( + select(DataSourceOauthBinding).filter_by(id=binding_id) + ).scalar_one_or_none() if data_source_binding is None: raise NotFound("Data source binding not found.") # enable binding @@ -109,47 +114,53 @@ class DataSourceNotionListApi(Resource): def get(self): dataset_id = request.args.get("dataset_id", default=None, type=str) exist_page_ids = [] - # import notion in the exist dataset - if dataset_id: - dataset = DatasetService.get_dataset(dataset_id) - if not dataset: - raise NotFound("Dataset not found.") - if dataset.data_source_type != "notion_import": - raise ValueError("Dataset is not notion type.") - documents = Document.query.filter_by( - dataset_id=dataset_id, - tenant_id=current_user.current_tenant_id, - data_source_type="notion_import", - enabled=True, + with Session(db.engine) as session: + # import notion in the exist dataset + if dataset_id: + dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise NotFound("Dataset not found.") + if dataset.data_source_type != "notion_import": + raise ValueError("Dataset is not notion type.") + + documents = session.execute( + select(Document).filter_by( + dataset_id=dataset_id, + tenant_id=current_user.current_tenant_id, + data_source_type="notion_import", + enabled=True, + ) + ).all() + if documents: + for document in documents: + data_source_info = json.loads(document.data_source_info) + exist_page_ids.append(data_source_info["notion_page_id"]) + # get all authorized pages + data_source_bindings = session.execute( + select(DataSourceOauthBinding).filter_by( + tenant_id=current_user.current_tenant_id, provider="notion", disabled=False + ) ).all() - if documents: - for document in documents: - data_source_info = json.loads(document.data_source_info) - exist_page_ids.append(data_source_info["notion_page_id"]) - # get all authorized pages - data_source_bindings = DataSourceOauthBinding.query.filter_by( - tenant_id=current_user.current_tenant_id, provider="notion", disabled=False - ).all() - if not data_source_bindings: - return {"notion_info": []}, 200 - pre_import_info_list = [] - for data_source_binding in data_source_bindings: - source_info = data_source_binding.source_info - pages = source_info["pages"] - # Filter out already bound pages - for page in pages: - if page["page_id"] in exist_page_ids: - page["is_bound"] = True - else: - page["is_bound"] = False - pre_import_info = { - "workspace_name": source_info["workspace_name"], - "workspace_icon": source_info["workspace_icon"], - "workspace_id": source_info["workspace_id"], - "pages": pages, - } - pre_import_info_list.append(pre_import_info) - return {"notion_info": pre_import_info_list}, 200 + if not data_source_bindings: + return {"notion_info": []}, 200 + pre_import_info_list = [] + for data_source_binding in data_source_bindings: + source_info = data_source_binding.source_info + pages = source_info["pages"] + # Filter out already bound pages + for page in pages: + if page["page_id"] in exist_page_ids: + page["is_bound"] = True + else: + page["is_bound"] = False + pre_import_info = { + "workspace_name": source_info["workspace_name"], + "workspace_icon": source_info["workspace_icon"], + "workspace_id": source_info["workspace_id"], + "pages": pages, + } + pre_import_info_list.append(pre_import_info) + return {"notion_info": pre_import_info_list}, 200 class DataSourceNotionApi(Resource): @@ -159,14 +170,17 @@ class DataSourceNotionApi(Resource): def get(self, workspace_id, page_id, page_type): workspace_id = str(workspace_id) page_id = str(page_id) - data_source_binding = DataSourceOauthBinding.query.filter( - db.and_( - DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == "notion", - DataSourceOauthBinding.disabled == False, - DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', - ) - ).first() + with Session(db.engine) as session: + data_source_binding = session.execute( + select(DataSourceOauthBinding).filter( + db.and_( + DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + DataSourceOauthBinding.provider == "notion", + DataSourceOauthBinding.disabled == False, + DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', + ) + ) + ).scalar_one_or_none() if not data_source_binding: raise NotFound("Data source binding not found.") diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 31b4f7b741..235d147559 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -5,7 +5,8 @@ from datetime import datetime, timezone from flask import request from flask_login import current_user from flask_restful import Resource, fields, marshal, marshal_with, reqparse -from sqlalchemy import asc, desc +from sqlalchemy import asc, desc, select +from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden, NotFound import services @@ -104,7 +105,8 @@ class GetProcessRuleApi(Resource): rules = DocumentService.DEFAULT_RULES["rules"] if document_id: # get the latest process rule - document = Document.query.get_or_404(document_id) + with Session(db.engine) as session: + document = session.execute(select(Document).get_or_404(document_id)).scalar_one_or_none() dataset = DatasetService.get_dataset(document.dataset_id) @@ -167,7 +169,10 @@ class DatasetDocumentListApi(Resource): except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - query = Document.query.filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id) + with Session(db.engine) as session: + query = session.execute( + select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id) + ).all() if search: search = f"%{search}%" @@ -204,18 +209,25 @@ class DatasetDocumentListApi(Resource): paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) documents = paginated_documents.items if fetch: - for document in documents: - completed_segments = DocumentSegment.query.filter( - DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document.id), - DocumentSegment.status != "re_segment", - ).count() - total_segments = DocumentSegment.query.filter( - DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment" - ).count() - document.completed_segments = completed_segments - document.total_segments = total_segments - data = marshal(documents, document_with_segments_fields) + with Session(db.engine) as session: + for document in documents: + completed_segments = ( + session.query(DocumentSegment) + .filter( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != "re_segment", + ) + .count() + ) + total_segments = ( + session.query(DocumentSegment) + .filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") + .count() + ) + document.completed_segments = completed_segments + document.total_segments = total_segments + data = marshal(documents, document_with_segments_fields) else: data = marshal(documents, document_fields) response = { diff --git a/api/controllers/console/init_validate.py b/api/controllers/console/init_validate.py index ae759bb752..b19e331d2e 100644 --- a/api/controllers/console/init_validate.py +++ b/api/controllers/console/init_validate.py @@ -2,8 +2,11 @@ import os from flask import session from flask_restful import Resource, reqparse +from sqlalchemy import select +from sqlalchemy.orm import Session from configs import dify_config +from extensions.ext_database import db from libs.helper import StrLen from models.model import DifySetup from services.account_service import TenantService @@ -42,7 +45,11 @@ class InitValidateAPI(Resource): def get_init_validate_status(): if dify_config.EDITION == "SELF_HOSTED": if os.environ.get("INIT_PASSWORD"): - return session.get("is_init_validated") or DifySetup.query.first() + if session.get("is_init_validated"): + return True + + with Session(db.engine) as db_session: + return db_session.execute(select(DifySetup)).scalar_one_or_none() return True diff --git a/api/models/account.py b/api/models/account.py index 60b4f11aad..ae87e22649 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -4,6 +4,7 @@ import json from flask_login import UserMixin from extensions.ext_database import db +from models.base import Base from .types import StringUUID @@ -16,7 +17,7 @@ class AccountStatus(str, enum.Enum): CLOSED = "closed" -class Account(UserMixin, db.Model): +class Account(UserMixin, Base): __tablename__ = "accounts" __table_args__ = (db.PrimaryKeyConstraint("id", name="account_pkey"), db.Index("account_email_idx", "email")) diff --git a/api/models/model.py b/api/models/model.py index 12c57ab372..0da55cb9de 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -38,7 +38,7 @@ class FileUploadConfig(BaseModel): number_limits: int = Field(default=0, gt=0, le=10) -class DifySetup(db.Model): +class DifySetup(BaseModel): __tablename__ = "dify_setups" __table_args__ = (db.PrimaryKeyConstraint("version", name="dify_setup_pkey"),)