Migrate to DeclarativeBaseModel
This commit is contained in:
parent
53e1b45d40
commit
11270a7ef2
@ -3,6 +3,8 @@ from functools import wraps
|
||||
|
||||
from flask import request
|
||||
from flask_restful import Resource, reqparse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from constants.languages import supported_language
|
||||
@ -54,7 +56,8 @@ class InsertExploreAppListApi(Resource):
|
||||
parser.add_argument("position", type=int, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
app = App.query.filter(App.id == args["app_id"]).first()
|
||||
with Session(db.engine) as session:
|
||||
app = session.execute(select(App).filter(App.id == args["app_id"])).scalar_one_or_none()
|
||||
if not app:
|
||||
raise NotFound(f'App \'{args["app_id"]}\' is not found')
|
||||
|
||||
@ -70,7 +73,10 @@ class InsertExploreAppListApi(Resource):
|
||||
privacy_policy = site.privacy_policy or args["privacy_policy"] or ""
|
||||
custom_disclaimer = site.custom_disclaimer or args["custom_disclaimer"] or ""
|
||||
|
||||
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first()
|
||||
with Session(db.engine) as session:
|
||||
recommended_app = session.execute(
|
||||
select(RecommendedApp).filter(RecommendedApp.app_id == args["app_id"])
|
||||
).scalar_one_or_none()
|
||||
|
||||
if not recommended_app:
|
||||
recommended_app = RecommendedApp(
|
||||
@ -110,16 +116,26 @@ class InsertExploreAppApi(Resource):
|
||||
@only_edition_cloud
|
||||
@admin_required
|
||||
def delete(self, app_id):
|
||||
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == str(app_id)).first()
|
||||
with Session(db.engine) as session:
|
||||
recommended_app = session.execute(
|
||||
select(RecommendedApp).filter(RecommendedApp.app_id == str(app_id))
|
||||
).scalar_one_or_none()
|
||||
|
||||
if not recommended_app:
|
||||
return {"result": "success"}, 204
|
||||
|
||||
app = App.query.filter(App.id == recommended_app.app_id).first()
|
||||
with Session(db.engine) as session:
|
||||
app = session.execute(select(App).filter(App.id == recommended_app.app_id)).scalar_one_or_none()
|
||||
|
||||
if app:
|
||||
app.is_public = False
|
||||
|
||||
installed_apps = InstalledApp.query.filter(
|
||||
InstalledApp.app_id == recommended_app.app_id, InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id
|
||||
with Session(db.engine) as session:
|
||||
installed_apps = session.execute(
|
||||
select(InstalledApp).filter(
|
||||
InstalledApp.app_id == recommended_app.app_id,
|
||||
InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id,
|
||||
)
|
||||
).all()
|
||||
|
||||
for installed_app in installed_apps:
|
||||
|
@ -33,7 +33,10 @@ def _get_resource(resource_id, tenant_id, resource_model):
|
||||
select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id)
|
||||
).scalar_one_or_none()
|
||||
else:
|
||||
resource = resource_model.query.filter_by(id=resource_id, tenant_id=tenant_id).first()
|
||||
with Session(db.engine) as session:
|
||||
resource = session.execute(
|
||||
select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if resource is None:
|
||||
flask_restful.abort(404, message=f"{resource_model.__name__} not found.")
|
||||
|
@ -3,6 +3,8 @@ import secrets
|
||||
|
||||
from flask import request
|
||||
from flask_restful import Resource, reqparse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from constants.languages import languages
|
||||
from controllers.console import api
|
||||
@ -41,7 +43,8 @@ class ForgotPasswordSendEmailApi(Resource):
|
||||
else:
|
||||
language = "en-US"
|
||||
|
||||
account = Account.query.filter_by(email=args["email"]).first()
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
|
||||
token = None
|
||||
if account is None:
|
||||
if FeatureService.get_system_features().is_allow_register:
|
||||
@ -108,7 +111,8 @@ class ForgotPasswordResetApi(Resource):
|
||||
password_hashed = hash_password(new_password, salt)
|
||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
|
||||
account = Account.query.filter_by(email=reset_data.get("email")).first()
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=reset_data.get("email"))).scalar_one_or_none()
|
||||
if account:
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
|
@ -5,6 +5,8 @@ from typing import Optional
|
||||
import requests
|
||||
from flask import current_app, redirect, request
|
||||
from flask_restful import Resource
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
@ -135,7 +137,8 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
|
||||
account = Account.get_by_openid(provider, user_info.id)
|
||||
|
||||
if not account:
|
||||
account = Account.query.filter_by(email=user_info.email).first()
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=user_info.email)).scalar_one_or_none()
|
||||
|
||||
return account
|
||||
|
||||
|
@ -4,6 +4,8 @@ import json
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api
|
||||
@ -77,7 +79,10 @@ class DataSourceApi(Resource):
|
||||
def patch(self, binding_id, action):
|
||||
binding_id = str(binding_id)
|
||||
action = str(action)
|
||||
data_source_binding = DataSourceOauthBinding.query.filter_by(id=binding_id).first()
|
||||
with Session(db.engine) as session:
|
||||
data_source_binding = session.execute(
|
||||
select(DataSourceOauthBinding).filter_by(id=binding_id)
|
||||
).scalar_one_or_none()
|
||||
if data_source_binding is None:
|
||||
raise NotFound("Data source binding not found.")
|
||||
# enable binding
|
||||
@ -109,6 +114,7 @@ class DataSourceNotionListApi(Resource):
|
||||
def get(self):
|
||||
dataset_id = request.args.get("dataset_id", default=None, type=str)
|
||||
exist_page_ids = []
|
||||
with Session(db.engine) as session:
|
||||
# import notion in the exist dataset
|
||||
if dataset_id:
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@ -116,19 +122,24 @@ class DataSourceNotionListApi(Resource):
|
||||
raise NotFound("Dataset not found.")
|
||||
if dataset.data_source_type != "notion_import":
|
||||
raise ValueError("Dataset is not notion type.")
|
||||
documents = Document.query.filter_by(
|
||||
|
||||
documents = session.execute(
|
||||
select(Document).filter_by(
|
||||
dataset_id=dataset_id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
data_source_type="notion_import",
|
||||
enabled=True,
|
||||
)
|
||||
).all()
|
||||
if documents:
|
||||
for document in documents:
|
||||
data_source_info = json.loads(document.data_source_info)
|
||||
exist_page_ids.append(data_source_info["notion_page_id"])
|
||||
# get all authorized pages
|
||||
data_source_bindings = DataSourceOauthBinding.query.filter_by(
|
||||
data_source_bindings = session.execute(
|
||||
select(DataSourceOauthBinding).filter_by(
|
||||
tenant_id=current_user.current_tenant_id, provider="notion", disabled=False
|
||||
)
|
||||
).all()
|
||||
if not data_source_bindings:
|
||||
return {"notion_info": []}, 200
|
||||
@ -159,14 +170,17 @@ class DataSourceNotionApi(Resource):
|
||||
def get(self, workspace_id, page_id, page_type):
|
||||
workspace_id = str(workspace_id)
|
||||
page_id = str(page_id)
|
||||
data_source_binding = DataSourceOauthBinding.query.filter(
|
||||
with Session(db.engine) as session:
|
||||
data_source_binding = session.execute(
|
||||
select(DataSourceOauthBinding).filter(
|
||||
db.and_(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
|
||||
)
|
||||
).first()
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
if not data_source_binding:
|
||||
raise NotFound("Data source binding not found.")
|
||||
|
||||
|
@ -5,7 +5,8 @@ from datetime import datetime, timezone
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, fields, marshal, marshal_with, reqparse
|
||||
from sqlalchemy import asc, desc
|
||||
from sqlalchemy import asc, desc, select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
@ -104,7 +105,8 @@ class GetProcessRuleApi(Resource):
|
||||
rules = DocumentService.DEFAULT_RULES["rules"]
|
||||
if document_id:
|
||||
# get the latest process rule
|
||||
document = Document.query.get_or_404(document_id)
|
||||
with Session(db.engine) as session:
|
||||
document = session.execute(select(Document).get_or_404(document_id)).scalar_one_or_none()
|
||||
|
||||
dataset = DatasetService.get_dataset(document.dataset_id)
|
||||
|
||||
@ -167,7 +169,10 @@ class DatasetDocumentListApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
query = Document.query.filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id)
|
||||
with Session(db.engine) as session:
|
||||
query = session.execute(
|
||||
select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id)
|
||||
).all()
|
||||
|
||||
if search:
|
||||
search = f"%{search}%"
|
||||
@ -204,15 +209,22 @@ class DatasetDocumentListApi(Resource):
|
||||
paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
documents = paginated_documents.items
|
||||
if fetch:
|
||||
with Session(db.engine) as session:
|
||||
for document in documents:
|
||||
completed_segments = DocumentSegment.query.filter(
|
||||
completed_segments = (
|
||||
session.query(DocumentSegment)
|
||||
.filter(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != "re_segment",
|
||||
).count()
|
||||
total_segments = DocumentSegment.query.filter(
|
||||
DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment"
|
||||
).count()
|
||||
)
|
||||
.count()
|
||||
)
|
||||
total_segments = (
|
||||
session.query(DocumentSegment)
|
||||
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
|
||||
.count()
|
||||
)
|
||||
document.completed_segments = completed_segments
|
||||
document.total_segments = total_segments
|
||||
data = marshal(documents, document_with_segments_fields)
|
||||
|
@ -2,8 +2,11 @@ import os
|
||||
|
||||
from flask import session
|
||||
from flask_restful import Resource, reqparse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import StrLen
|
||||
from models.model import DifySetup
|
||||
from services.account_service import TenantService
|
||||
@ -42,7 +45,11 @@ class InitValidateAPI(Resource):
|
||||
def get_init_validate_status():
|
||||
if dify_config.EDITION == "SELF_HOSTED":
|
||||
if os.environ.get("INIT_PASSWORD"):
|
||||
return session.get("is_init_validated") or DifySetup.query.first()
|
||||
if session.get("is_init_validated"):
|
||||
return True
|
||||
|
||||
with Session(db.engine) as db_session:
|
||||
return db_session.execute(select(DifySetup)).scalar_one_or_none()
|
||||
|
||||
return True
|
||||
|
||||
|
@ -4,6 +4,7 @@ import json
|
||||
from flask_login import UserMixin
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.base import Base
|
||||
|
||||
from .types import StringUUID
|
||||
|
||||
@ -16,7 +17,7 @@ class AccountStatus(str, enum.Enum):
|
||||
CLOSED = "closed"
|
||||
|
||||
|
||||
class Account(UserMixin, db.Model):
|
||||
class Account(UserMixin, Base):
|
||||
__tablename__ = "accounts"
|
||||
__table_args__ = (db.PrimaryKeyConstraint("id", name="account_pkey"), db.Index("account_email_idx", "email"))
|
||||
|
||||
|
@ -38,7 +38,7 @@ class FileUploadConfig(BaseModel):
|
||||
number_limits: int = Field(default=0, gt=0, le=10)
|
||||
|
||||
|
||||
class DifySetup(db.Model):
|
||||
class DifySetup(BaseModel):
|
||||
__tablename__ = "dify_setups"
|
||||
__table_args__ = (db.PrimaryKeyConstraint("version", name="dify_setup_pkey"),)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user