feat(api/auth): switch-to-stateful-authentication (#5438)

This commit is contained in:
-LAN- 2024-06-21 12:39:07 +08:00 committed by GitHub
parent 26b6fd2236
commit 1336b844fd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 89 additions and 43 deletions

View File

@ -2,7 +2,7 @@ import os
from configs.app_configs import DifyConfigs
if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true':
if not os.environ.get("DEBUG") or os.environ.get("DEBUG", "false").lower() != 'true':
from gevent import monkey
monkey.patch_all()
@ -152,7 +152,8 @@ def initialize_extensions(app):
@login_manager.request_loader
def load_user_from_request(request_from_flask_login):
"""Load user based on the request."""
if request.blueprint in ['console', 'inner_api']:
if request.blueprint not in ['console', 'inner_api']:
return None
# Check if the user_id contains a dot, indicating the old format
auth_header = request.headers.get('Authorization', '')
if not auth_header:
@ -170,9 +171,7 @@ def load_user_from_request(request_from_flask_login):
decoded = PassportService().verify(auth_token)
user_id = decoded.get('user_id')
return AccountService.load_user(user_id)
else:
return None
return AccountService.load_logged_in_account(account_id=user_id, token=auth_token)
@login_manager.unauthorized_handler

View File

@ -1,3 +1,5 @@
from typing import cast
import flask_login
from flask import current_app, request
from flask_restful import Resource, reqparse
@ -5,8 +7,9 @@ from flask_restful import Resource, reqparse
import services
from controllers.console import api
from controllers.console.setup import setup_required
from libs.helper import email
from libs.helper import email, get_remote_ip
from libs.password import valid_password
from models.account import Account
from services.account_service import AccountService, TenantService
@ -34,10 +37,7 @@ class LoginApi(Resource):
if len(tenants) == 0:
return {'result': 'fail', 'data': 'workspace not found, please contact system admin to invite you to join in a workspace'}
AccountService.update_last_login(account, request)
# todo: return the user info
token = AccountService.get_account_jwt_token(account)
token = AccountService.login(account, ip_address=get_remote_ip(request))
return {'result': 'success', 'data': token}
@ -46,6 +46,9 @@ class LogoutApi(Resource):
@setup_required
def get(self):
account = cast(Account, flask_login.current_user)
token = request.headers.get('Authorization', '').split(' ')[1]
AccountService.logout(account=account, token=token)
flask_login.logout_user()
return {'result': 'success'}

View File

@ -8,6 +8,7 @@ from flask_restful import Resource
from constants.languages import languages
from extensions.ext_database import db
from libs.helper import get_remote_ip
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
from models.account import Account, AccountStatus
from services.account_service import AccountService, RegisterService, TenantService
@ -78,9 +79,7 @@ class OAuthCallback(Resource):
TenantService.create_owner_tenant_if_not_exist(account)
AccountService.update_last_login(account, request)
token = AccountService.get_account_jwt_token(account)
token = AccountService.login(account, ip_address=get_remote_ip(request))
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?console_token={token}')

View File

@ -4,7 +4,7 @@ from flask import current_app, request
from flask_restful import Resource, reqparse
from extensions.ext_database import db
from libs.helper import email, str_len
from libs.helper import email, get_remote_ip, str_len
from libs.password import valid_password
from models.model import DifySetup
from services.account_service import AccountService, RegisterService, TenantService
@ -61,7 +61,7 @@ class SetupApi(Resource):
TenantService.create_owner_tenant_if_not_exist(account)
setup()
AccountService.update_last_login(account, request)
AccountService.update_last_login(account, ip_address=get_remote_ip(request))
return {'result': 'success'}, 201

View File

@ -140,7 +140,7 @@ def generate_string(n):
return result
def get_remote_ip(request):
def get_remote_ip(request) -> str:
if request.headers.get('CF-Connecting-IP'):
return request.headers.get('Cf-Connecting-Ip')
elif request.headers.getlist("X-Forwarded-For"):

View File

@ -1 +1,3 @@
import services.errors
from . import errors
__all__ = ['errors']

View File

@ -13,7 +13,6 @@ from werkzeug.exceptions import Unauthorized
from constants.languages import language_timezone_mapping, languages
from events.tenant_event import tenant_was_created
from extensions.ext_redis import redis_client
from libs.helper import get_remote_ip
from libs.passport import PassportService
from libs.password import compare_password, hash_password, valid_password
from libs.rsa import generate_key_pair
@ -67,10 +66,10 @@ class AccountService:
@staticmethod
def get_account_jwt_token(account):
def get_account_jwt_token(account, *, exp: timedelta = timedelta(days=30)):
payload = {
"user_id": account.id,
"exp": datetime.now(timezone.utc).replace(tzinfo=None) + timedelta(days=30),
"exp": datetime.now(timezone.utc).replace(tzinfo=None) + exp,
"iss": current_app.config['EDITION'],
"sub": 'Console API Passport',
}
@ -195,14 +194,35 @@ class AccountService:
return account
@staticmethod
def update_last_login(account: Account, request) -> None:
def update_last_login(account: Account, *, ip_address: str) -> None:
"""Update last login time and ip"""
account.last_login_at = datetime.now(timezone.utc).replace(tzinfo=None)
account.last_login_ip = get_remote_ip(request)
account.last_login_ip = ip_address
db.session.add(account)
db.session.commit()
logging.info(f'Account {account.id} logged in successfully.')
@staticmethod
def login(account: Account, *, ip_address: Optional[str] = None):
if ip_address:
AccountService.update_last_login(account, ip_address=ip_address)
exp = timedelta(days=30)
token = AccountService.get_account_jwt_token(account, exp=exp)
redis_client.set(_get_login_cache_key(account_id=account.id, token=token), '1', ex=int(exp.total_seconds()))
return token
@staticmethod
def logout(*, account: Account, token: str):
redis_client.delete(_get_login_cache_key(account_id=account.id, token=token))
@staticmethod
def load_logged_in_account(*, account_id: str, token: str):
if not redis_client.get(_get_login_cache_key(account_id=account_id, token=token)):
return None
return AccountService.load_user(account_id)
def _get_login_cache_key(*, account_id: str, token: str):
return f"account_login:{account_id}:{token}"
class TenantService:

View File

@ -1,6 +1,29 @@
__all__ = [
'base', 'conversation', 'message', 'index', 'app_model_config', 'account', 'document', 'dataset',
'app', 'completion', 'audio', 'file'
]
from . import (
account,
app,
app_model_config,
audio,
base,
completion,
conversation,
dataset,
document,
file,
index,
message,
)
from . import *
__all__ = [
"base",
"conversation",
"message",
"index",
"app_model_config",
"account",
"document",
"dataset",
"app",
"completion",
"audio",
"file",
]