From c84f00403528a5e49b8651ce653e4182801c01c7 Mon Sep 17 00:00:00 2001 From: Joe <1264204425@qq.com> Date: Mon, 2 Sep 2024 14:50:45 +0800 Subject: [PATCH] feat: add oauth invite redict --- api/controllers/console/auth/oauth.py | 14 ++++++++++++-- api/libs/oauth.py | 9 +++++++-- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index cdb454dd31..fe7d9757f1 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -43,7 +43,7 @@ def get_oauth_providers(): class OAuthLogin(Resource): - def get(self, provider: str): + def get(self, provider: str, invite_toke: Optional[str] = None): OAUTH_PROVIDERS = get_oauth_providers() with current_app.app_context(): oauth_provider = OAUTH_PROVIDERS.get(provider) @@ -51,7 +51,7 @@ class OAuthLogin(Resource): if not oauth_provider: return {"error": "Invalid provider"}, 400 - auth_url = oauth_provider.get_authorization_url() + auth_url = oauth_provider.get_authorization_url(invite_toke) return redirect(auth_url) @@ -64,6 +64,11 @@ class OAuthCallback(Resource): return {"error": "Invalid provider"}, 400 code = request.args.get("code") + state = request.args.get("state") + invite_token = None + if state: + invite_token = state + try: token = oauth_provider.get_access_token(code) user_info = oauth_provider.get_user_info(token) @@ -71,6 +76,11 @@ class OAuthCallback(Resource): logging.exception(f"An error occurred during the OAuth process with {provider}: {e.response.text}") return {"error": "OAuth process failed"}, 400 + if invite_token: + return redirect( + f"{dify_config.CONSOLE_WEB_URL}/invite-settings?invite_token={invite_token}" + ) + try: account = _generate_account(provider, user_info) except services.errors.account.AccountNotFound as e: diff --git a/api/libs/oauth.py b/api/libs/oauth.py index d8ce1a1e66..6b6919de24 100644 --- a/api/libs/oauth.py +++ b/api/libs/oauth.py @@ -1,5 +1,6 @@ import urllib.parse from dataclasses import dataclass +from typing import Optional import requests @@ -40,12 +41,14 @@ class GitHubOAuth(OAuth): _USER_INFO_URL = "https://api.github.com/user" _EMAIL_INFO_URL = "https://api.github.com/user/emails" - def get_authorization_url(self): + def get_authorization_url(self, invite_token: Optional[str] = None): params = { "client_id": self.client_id, "redirect_uri": self.redirect_uri, "scope": "user:email", # Request only basic user information } + if invite_token: + params["state"] = invite_token return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" def get_access_token(self, code: str): @@ -90,13 +93,15 @@ class GoogleOAuth(OAuth): _TOKEN_URL = "https://oauth2.googleapis.com/token" _USER_INFO_URL = "https://www.googleapis.com/oauth2/v3/userinfo" - def get_authorization_url(self): + def get_authorization_url(self, invite_token: Optional[str] = None): params = { "client_id": self.client_id, "response_type": "code", "redirect_uri": self.redirect_uri, "scope": "openid email", } + if invite_token: + params["state"] = invite_token return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" def get_access_token(self, code: str):