diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index 029a99bb3f..11c0d97410 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -21,7 +21,7 @@ from events.tenant_event import tenant_was_created from libs.helper import email, get_remote_ip from libs.password import valid_password from models.account import Account -from services.account_service import AccountService, TenantService +from services.account_service import AccountService, RegisterService, TenantService from services.errors.workspace import WorkSpaceNotAllowedCreateError @@ -35,14 +35,26 @@ class LoginApi(Resource): parser.add_argument("email", type=email, required=True, location="json") parser.add_argument("password", type=valid_password, required=True, location="json") parser.add_argument("remember_me", type=bool, required=False, default=False, location="json") + parser.add_argument("invite_token", type=str, required=False, default=None, location="json") args = parser.parse_args() is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args["email"]) if is_login_error_rate_limit: raise EmailPasswordLoginLimitError() + invitation = args["invite_token"] + if invitation: + invitation = RegisterService.get_invitation_if_token_valid(None, args["email"], invitation) + try: - account = AccountService.authenticate(args["email"], args["password"]) + if invitation: + data = invitation.get("data", {}) + invitee_email = data.get("email") if data else None + if invitee_email != args["email"]: + raise InvalidEmailError() + account = AccountService.authenticate(args["email"], args["password"], args["invite_token"]) + else: + account = AccountService.authenticate(args["email"], args["password"]) except services.errors.account.AccountLoginError: raise NotAllowedRegister() except services.errors.account.AccountPasswordError: diff --git a/api/services/account_service.py b/api/services/account_service.py index 8983e047ef..8be0589a9f 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -93,7 +93,7 @@ class AccountService: return token @staticmethod - def authenticate(email: str, password: str) -> Account: + def authenticate(email: str, password: str, invite_token: str = None) -> Account: """authenticate account with email and password""" account = Account.query.filter_by(email=email).first() @@ -102,6 +102,16 @@ class AccountService: if account.status in {AccountStatus.BANNED.value, AccountStatus.CLOSED.value}: raise AccountLoginError("Account is banned or closed.") + + if password and invite_token: + # if invite_token is valid, set password and password_salt + salt = secrets.token_bytes(16) + base64_salt = base64.b64encode(salt).decode() + password_hashed = hash_password(password, salt) + base64_password_hashed = base64.b64encode(password_hashed).decode() + account.password = base64_password_hashed + account.password_salt = base64_salt + if account.password is None or not compare_password(password, account.password, account.password_salt): raise AccountPasswordError("Invalid email or password.") @@ -109,7 +119,8 @@ class AccountService: if account.status == AccountStatus.PENDING.value: account.status = AccountStatus.ACTIVE.value account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None) - db.session.commit() + + db.session.commit() return account