From 4893631d656b6f55f8bee50d1bde4fd2268fb6e3 Mon Sep 17 00:00:00 2001 From: Joe <1264204425@qq.com> Date: Mon, 9 Sep 2024 18:19:55 +0800 Subject: [PATCH] fix: oauth error when not allowed create workspace fix: oauth error when not allowed create workspace --- api/controllers/console/auth/oauth.py | 5 +++++ api/services/account_service.py | 18 ++++++++---------- api/services/errors/workspace.py | 5 +++++ 3 files changed, 18 insertions(+), 10 deletions(-) create mode 100644 api/services/errors/workspace.py diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 2e0eeb2895..b653c91dfd 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -15,6 +15,7 @@ from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo from models.account import Account, AccountStatus from services.account_service import AccountService, RegisterService, TenantService from services.errors.account import AccountNotFound +from services.errors.workspace import WorkSpaceNotAllowedCreateError from .. import api @@ -90,6 +91,10 @@ class OAuthCallback(Resource): account = _generate_account(provider, user_info) except AccountNotFound: return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=AccountNotFound") + except WorkSpaceNotAllowedCreateError: + return redirect( + f"{dify_config.CONSOLE_WEB_URL}/signin?message=Workspace not found, please contact system admin to invite you to join in a workspace." + ) # Check account status if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value: diff --git a/api/services/account_service.py b/api/services/account_service.py index 2f16e34097..7cbbc8c428 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -37,6 +37,7 @@ from services.errors.account import ( RoleAlreadyAssignedError, TenantNotFound, ) +from services.errors.workspace import WorkSpaceNotAllowedCreateError from tasks.mail_email_code_login import send_email_code_login_mail_task from tasks.mail_invite_member_task import send_invite_member_mail_task from tasks.mail_reset_password_task import send_reset_password_mail_task @@ -604,7 +605,6 @@ class RegisterService: provider: Optional[str] = None, language: Optional[str] = None, status: Optional[AccountStatus] = None, - is_invite_member: Optional[bool] = False, ) -> Account: db.session.begin_nested() """Register account""" @@ -618,13 +618,13 @@ class RegisterService: if open_id is not None or provider is not None: AccountService.link_account_integrate(provider, open_id, account) - should_create_workspace = not is_invite_member or (is_invite_member and dify_config.ALLOW_CREATE_WORKSPACE) + if not dify_config.ALLOW_CREATE_WORKSPACE: + raise WorkSpaceNotAllowedCreateError() - if should_create_workspace: - tenant = TenantService.create_tenant(f"{account.name}'s Workspace") - TenantService.create_tenant_member(tenant, account, role="owner") - account.current_tenant = tenant - tenant_was_created.send(tenant) + tenant = TenantService.create_tenant(f"{account.name}'s Workspace") + TenantService.create_tenant_member(tenant, account, role="owner") + account.current_tenant = tenant + tenant_was_created.send(tenant) db.session.commit() except Exception as e: @@ -645,9 +645,7 @@ class RegisterService: TenantService.check_member_permission(tenant, inviter, None, "add") name = email.split("@")[0] - account = cls.register( - email=email, name=name, language=language, status=AccountStatus.PENDING, is_invite_member=True - ) + account = cls.register(email=email, name=name, language=language, status=AccountStatus.PENDING) # Create new tenant member for invited tenant TenantService.create_tenant_member(tenant, account, role) TenantService.switch_tenant(account, tenant.id) diff --git a/api/services/errors/workspace.py b/api/services/errors/workspace.py new file mode 100644 index 0000000000..600ebad8c1 --- /dev/null +++ b/api/services/errors/workspace.py @@ -0,0 +1,5 @@ +from services.errors.base import BaseServiceError + + +class WorkSpaceNotAllowedCreateError(BaseServiceError): + pass