dify/api/controllers/console/auth/oauth.py

148 lines
5.6 KiB
Python
Raw Normal View History

2023-05-15 08:51:32 +08:00
import logging
from datetime import datetime, timezone
2023-05-15 08:51:32 +08:00
from typing import Optional
import requests
from flask import current_app, redirect, request
from flask_restful import Resource
2024-09-02 16:03:44 +08:00
from werkzeug.exceptions import Unauthorized
from configs import dify_config
from constants.languages import languages
from extensions.ext_database import db
from libs.helper import get_remote_ip
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
2023-05-15 08:51:32 +08:00
from models.account import Account, AccountStatus
from services.account_service import AccountService, RegisterService, TenantService
2023-05-15 08:51:32 +08:00
from .. import api
def get_oauth_providers():
with current_app.app_context():
if not dify_config.GITHUB_CLIENT_ID or not dify_config.GITHUB_CLIENT_SECRET:
github_oauth = None
else:
github_oauth = GitHubOAuth(
client_id=dify_config.GITHUB_CLIENT_ID,
client_secret=dify_config.GITHUB_CLIENT_SECRET,
redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/authorize/github",
)
if not dify_config.GOOGLE_CLIENT_ID or not dify_config.GOOGLE_CLIENT_SECRET:
google_oauth = None
else:
google_oauth = GoogleOAuth(
client_id=dify_config.GOOGLE_CLIENT_ID,
client_secret=dify_config.GOOGLE_CLIENT_SECRET,
redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/authorize/google",
)
OAUTH_PROVIDERS = {"github": github_oauth, "google": google_oauth}
2023-05-15 08:51:32 +08:00
return OAUTH_PROVIDERS
class OAuthLogin(Resource):
2024-09-02 15:24:13 +08:00
def get(self, provider: str):
2024-09-02 15:25:24 +08:00
invite_token = request.args.get("invite_token") or None
2023-05-15 08:51:32 +08:00
OAUTH_PROVIDERS = get_oauth_providers()
with current_app.app_context():
oauth_provider = OAUTH_PROVIDERS.get(provider)
print(vars(oauth_provider))
if not oauth_provider:
return {"error": "Invalid provider"}, 400
2023-05-15 08:51:32 +08:00
auth_url = oauth_provider.get_authorization_url(invite_token=invite_token)
2023-05-15 08:51:32 +08:00
return redirect(auth_url)
class OAuthCallback(Resource):
def get(self, provider: str):
OAUTH_PROVIDERS = get_oauth_providers()
with current_app.app_context():
oauth_provider = OAUTH_PROVIDERS.get(provider)
if not oauth_provider:
return {"error": "Invalid provider"}, 400
2023-05-15 08:51:32 +08:00
code = request.args.get("code")
2024-09-02 14:50:45 +08:00
state = request.args.get("state")
invite_token = None
if state:
invite_token = state
2023-05-15 08:51:32 +08:00
try:
token = oauth_provider.get_access_token(code)
user_info = oauth_provider.get_user_info(token)
except requests.exceptions.HTTPError as e:
logging.exception(f"An error occurred during the OAuth process with {provider}: {e.response.text}")
return {"error": "OAuth process failed"}, 400
2023-05-15 08:51:32 +08:00
2024-09-09 14:49:21 +08:00
if invite_token and RegisterService.is_valid_invite_token(invite_token):
2024-09-02 18:04:11 +08:00
invitation = RegisterService._get_invitation_by_token(token=invite_token)
if invitation:
invitation_email = invitation.get("email", None)
if invitation_email != user_info.email:
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=InvalidToken")
2024-09-09 14:50:56 +08:00
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}")
2024-09-02 14:50:45 +08:00
2024-09-02 11:09:40 +08:00
try:
account = _generate_account(provider, user_info)
2024-09-02 16:03:44 +08:00
except Unauthorized:
2024-09-02 11:09:40 +08:00
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=AccountNotFound")
2024-09-02 16:03:44 +08:00
2023-05-15 08:51:32 +08:00
# Check account status
if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value:
return {"error": "Account is banned or closed."}, 403
2023-05-15 08:51:32 +08:00
if account.status == AccountStatus.PENDING.value:
account.status = AccountStatus.ACTIVE.value
account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None)
2023-05-15 08:51:32 +08:00
db.session.commit()
2024-09-02 16:03:44 +08:00
try:
TenantService.create_owner_tenant_if_not_exist(account)
except Unauthorized:
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=WorkspaceNotFound")
token = AccountService.login(account, ip_address=get_remote_ip(request))
2023-09-25 12:49:16 +08:00
return redirect(f"{dify_config.CONSOLE_WEB_URL}?console_token={token}")
2023-05-15 08:51:32 +08:00
def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]:
account = Account.get_by_openid(provider, user_info.id)
if not account:
account = Account.query.filter_by(email=user_info.email).first()
return account
def _generate_account(provider: str, user_info: OAuthUserInfo):
# Get account by openid or email.
account = _get_account_by_openid_or_email(provider, user_info)
2024-09-02 15:24:13 +08:00
if not account:
account_name = user_info.name if user_info.name else "Dify"
2023-05-15 08:51:32 +08:00
account = RegisterService.register(
email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider
2023-05-15 08:51:32 +08:00
)
# Set interface language
2024-01-23 21:14:53 +08:00
preferred_lang = request.accept_languages.best_match(languages)
if preferred_lang and preferred_lang in languages:
interface_language = preferred_lang
2023-05-15 08:51:32 +08:00
else:
2024-01-23 21:14:53 +08:00
interface_language = languages[0]
2023-05-15 08:51:32 +08:00
account.interface_language = interface_language
db.session.commit()
# Link account
AccountService.link_account_integrate(provider, user_info.id, account)
return account
api.add_resource(OAuthLogin, "/oauth/login/<provider>")
api.add_resource(OAuthCallback, "/oauth/authorize/<provider>")