From 955e2871f40f3c4b37bef36b13ceec53f793f561 Mon Sep 17 00:00:00 2001 From: Joe <1264204425@qq.com> Date: Mon, 2 Sep 2024 11:09:40 +0800 Subject: [PATCH] feat: add oauth account not found --- api/controllers/console/auth/oauth.py | 12 +++++++++--- api/services/account_service.py | 6 +++++- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index ae1b49f3ec..cdb454dd31 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -6,6 +6,7 @@ import requests from flask import current_app, redirect, request from flask_restful import Resource +import services from configs import dify_config from constants.languages import languages from extensions.ext_database import db @@ -13,6 +14,7 @@ from libs.helper import get_remote_ip from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo from models.account import Account, AccountStatus from services.account_service import AccountService, RegisterService, TenantService +from services.errors.account import AccountNotFound from .. import api @@ -69,7 +71,10 @@ class OAuthCallback(Resource): logging.exception(f"An error occurred during the OAuth process with {provider}: {e.response.text}") return {"error": "OAuth process failed"}, 400 - account = _generate_account(provider, user_info) + try: + account = _generate_account(provider, user_info) + except services.errors.account.AccountNotFound as e: + return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=AccountNotFound") # Check account status if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value: return {"error": "Account is banned or closed."}, 403 @@ -99,8 +104,7 @@ def _generate_account(provider: str, user_info: OAuthUserInfo): # Get account by openid or email. account = _get_account_by_openid_or_email(provider, user_info) - if not account: - # Create account + if not account and dify_config.ALLOW_REGISTER: account_name = user_info.name if user_info.name else "Dify" account = RegisterService.register( email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider @@ -114,6 +118,8 @@ def _generate_account(provider: str, user_info: OAuthUserInfo): interface_language = languages[0] account.interface_language = interface_language db.session.commit() + else: + raise AccountNotFound() # Link account AccountService.link_account_integrate(provider, user_info.id, account) diff --git a/api/services/account_service.py b/api/services/account_service.py index ecf959eb4b..019460261c 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -23,6 +23,7 @@ from models.model import DifySetup from services.errors.account import ( AccountAlreadyInTenantError, AccountLoginError, + AccountNotFound, AccountNotLinkTenantError, AccountRegisterError, CannotOperateSelfError, @@ -92,7 +93,7 @@ class AccountService: account = Account.query.filter_by(email=email).first() if not account: - raise AccountLoginError("Invalid email or password.") + raise AccountNotFound() if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value: raise AccountLoginError("Account is banned or closed.") @@ -330,6 +331,9 @@ class TenantService: @staticmethod def create_owner_tenant_if_not_exist(account: Account): """Create owner tenant if not exist""" + if not dify_config.ALLOW_CREATE_WORKSPACE: + raise Unauthorized("Create workspace is not allowed.") + available_ta = ( TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first() )