diff --git a/api/app.py b/api/app.py index 52461aac93..82ec64e6b7 100644 --- a/api/app.py +++ b/api/app.py @@ -2,7 +2,7 @@ import os from configs.app_configs import DifyConfigs -if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true': +if not os.environ.get("DEBUG") or os.environ.get("DEBUG", "false").lower() != 'true': from gevent import monkey monkey.patch_all() @@ -152,27 +152,26 @@ def initialize_extensions(app): @login_manager.request_loader def load_user_from_request(request_from_flask_login): """Load user based on the request.""" - if request.blueprint in ['console', 'inner_api']: - # Check if the user_id contains a dot, indicating the old format - auth_header = request.headers.get('Authorization', '') - if not auth_header: - auth_token = request.args.get('_token') - if not auth_token: - raise Unauthorized('Invalid Authorization token.') - else: - if ' ' not in auth_header: - raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.') - auth_scheme, auth_token = auth_header.split(None, 1) - auth_scheme = auth_scheme.lower() - if auth_scheme != 'bearer': - raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.') - - decoded = PassportService().verify(auth_token) - user_id = decoded.get('user_id') - - return AccountService.load_user(user_id) - else: + if request.blueprint not in ['console', 'inner_api']: return None + # Check if the user_id contains a dot, indicating the old format + auth_header = request.headers.get('Authorization', '') + if not auth_header: + auth_token = request.args.get('_token') + if not auth_token: + raise Unauthorized('Invalid Authorization token.') + else: + if ' ' not in auth_header: + raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.') + auth_scheme, auth_token = auth_header.split(None, 1) + auth_scheme = auth_scheme.lower() + if auth_scheme != 'bearer': + raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.') + + decoded = PassportService().verify(auth_token) + user_id = decoded.get('user_id') + + return AccountService.load_logged_in_account(account_id=user_id, token=auth_token) @login_manager.unauthorized_handler diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index 8a24e58413..67d6dc8e95 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -1,3 +1,5 @@ +from typing import cast + import flask_login from flask import current_app, request from flask_restful import Resource, reqparse @@ -5,8 +7,9 @@ from flask_restful import Resource, reqparse import services from controllers.console import api from controllers.console.setup import setup_required -from libs.helper import email +from libs.helper import email, get_remote_ip from libs.password import valid_password +from models.account import Account from services.account_service import AccountService, TenantService @@ -34,10 +37,7 @@ class LoginApi(Resource): if len(tenants) == 0: return {'result': 'fail', 'data': 'workspace not found, please contact system admin to invite you to join in a workspace'} - AccountService.update_last_login(account, request) - - # todo: return the user info - token = AccountService.get_account_jwt_token(account) + token = AccountService.login(account, ip_address=get_remote_ip(request)) return {'result': 'success', 'data': token} @@ -46,6 +46,9 @@ class LogoutApi(Resource): @setup_required def get(self): + account = cast(Account, flask_login.current_user) + token = request.headers.get('Authorization', '').split(' ')[1] + AccountService.logout(account=account, token=token) flask_login.logout_user() return {'result': 'success'} diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index e5b80e9a57..2e4a627e06 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -8,6 +8,7 @@ from flask_restful import Resource from constants.languages import languages from extensions.ext_database import db +from libs.helper import get_remote_ip from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo from models.account import Account, AccountStatus from services.account_service import AccountService, RegisterService, TenantService @@ -78,9 +79,7 @@ class OAuthCallback(Resource): TenantService.create_owner_tenant_if_not_exist(account) - AccountService.update_last_login(account, request) - - token = AccountService.get_account_jwt_token(account) + token = AccountService.login(account, ip_address=get_remote_ip(request)) return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?console_token={token}') diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py index 1911559cff..a8fdde2791 100644 --- a/api/controllers/console/setup.py +++ b/api/controllers/console/setup.py @@ -4,7 +4,7 @@ from flask import current_app, request from flask_restful import Resource, reqparse from extensions.ext_database import db -from libs.helper import email, str_len +from libs.helper import email, get_remote_ip, str_len from libs.password import valid_password from models.model import DifySetup from services.account_service import AccountService, RegisterService, TenantService @@ -61,7 +61,7 @@ class SetupApi(Resource): TenantService.create_owner_tenant_if_not_exist(account) setup() - AccountService.update_last_login(account, request) + AccountService.update_last_login(account, ip_address=get_remote_ip(request)) return {'result': 'success'}, 201 diff --git a/api/libs/helper.py b/api/libs/helper.py index fa326c5a53..ebabb2ea47 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -140,7 +140,7 @@ def generate_string(n): return result -def get_remote_ip(request): +def get_remote_ip(request) -> str: if request.headers.get('CF-Connecting-IP'): return request.headers.get('Cf-Connecting-Ip') elif request.headers.getlist("X-Forwarded-For"): diff --git a/api/services/__init__.py b/api/services/__init__.py index 20e68ab6d9..6891436314 100644 --- a/api/services/__init__.py +++ b/api/services/__init__.py @@ -1 +1,3 @@ -import services.errors +from . import errors + +__all__ = ['errors'] diff --git a/api/services/account_service.py b/api/services/account_service.py index 7551c9cb4b..2c401aad91 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -13,7 +13,6 @@ from werkzeug.exceptions import Unauthorized from constants.languages import language_timezone_mapping, languages from events.tenant_event import tenant_was_created from extensions.ext_redis import redis_client -from libs.helper import get_remote_ip from libs.passport import PassportService from libs.password import compare_password, hash_password, valid_password from libs.rsa import generate_key_pair @@ -67,10 +66,10 @@ class AccountService: @staticmethod - def get_account_jwt_token(account): + def get_account_jwt_token(account, *, exp: timedelta = timedelta(days=30)): payload = { "user_id": account.id, - "exp": datetime.now(timezone.utc).replace(tzinfo=None) + timedelta(days=30), + "exp": datetime.now(timezone.utc).replace(tzinfo=None) + exp, "iss": current_app.config['EDITION'], "sub": 'Console API Passport', } @@ -195,14 +194,35 @@ class AccountService: return account @staticmethod - def update_last_login(account: Account, request) -> None: + def update_last_login(account: Account, *, ip_address: str) -> None: """Update last login time and ip""" account.last_login_at = datetime.now(timezone.utc).replace(tzinfo=None) - account.last_login_ip = get_remote_ip(request) + account.last_login_ip = ip_address db.session.add(account) db.session.commit() logging.info(f'Account {account.id} logged in successfully.') + @staticmethod + def login(account: Account, *, ip_address: Optional[str] = None): + if ip_address: + AccountService.update_last_login(account, ip_address=ip_address) + exp = timedelta(days=30) + token = AccountService.get_account_jwt_token(account, exp=exp) + redis_client.set(_get_login_cache_key(account_id=account.id, token=token), '1', ex=int(exp.total_seconds())) + return token + + @staticmethod + def logout(*, account: Account, token: str): + redis_client.delete(_get_login_cache_key(account_id=account.id, token=token)) + + @staticmethod + def load_logged_in_account(*, account_id: str, token: str): + if not redis_client.get(_get_login_cache_key(account_id=account_id, token=token)): + return None + return AccountService.load_user(account_id) + +def _get_login_cache_key(*, account_id: str, token: str): + return f"account_login:{account_id}:{token}" class TenantService: diff --git a/api/services/errors/__init__.py b/api/services/errors/__init__.py index 493919d373..bb5711145c 100644 --- a/api/services/errors/__init__.py +++ b/api/services/errors/__init__.py @@ -1,6 +1,29 @@ -__all__ = [ - 'base', 'conversation', 'message', 'index', 'app_model_config', 'account', 'document', 'dataset', - 'app', 'completion', 'audio', 'file' -] +from . import ( + account, + app, + app_model_config, + audio, + base, + completion, + conversation, + dataset, + document, + file, + index, + message, +) -from . import * +__all__ = [ + "base", + "conversation", + "message", + "index", + "app_model_config", + "account", + "document", + "dataset", + "app", + "completion", + "audio", + "file", +]