feat: support whitelist

This commit is contained in:
Yeuoly 2024-12-13 20:09:05 +08:00
parent 3dd6d96b5a
commit 2e2d1659ca
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61
2 changed files with 39 additions and 0 deletions

27
api/models/staging.py Normal file
View File

@ -0,0 +1,27 @@
from datetime import datetime
from sqlalchemy.orm import Mapped
from extensions.ext_database import db
from models.base import Base
from .types import StringUUID
class StagingAccountWhitelist(Base):
__tablename__ = "staging_account_whitelists"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="staging_account_whitelist_pkey"),
db.Index("account_email_idx", "email"),
)
id: Mapped[str] = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
email: Mapped[str] = db.Column(db.String(255), nullable=False)
disabled: Mapped[bool] = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
created_at: Mapped[datetime] = db.Column(
db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
)
updated_at: Mapped[datetime] = db.Column(
db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
)

View File

@ -33,6 +33,7 @@ from models.account import (
TenantStatus, TenantStatus,
) )
from models.model import DifySetup from models.model import DifySetup
from models.staging import StagingAccountWhitelist
from services.errors.account import ( from services.errors.account import (
AccountAlreadyInTenantError, AccountAlreadyInTenantError,
AccountLoginError, AccountLoginError,
@ -296,6 +297,9 @@ class AccountService:
@staticmethod @staticmethod
def login(account: Account, *, ip_address: Optional[str] = None) -> TokenPair: def login(account: Account, *, ip_address: Optional[str] = None) -> TokenPair:
if not AccountService.verify_account_whitelist(account.email):
raise ValueError("Account is not whitelisted")
if ip_address: if ip_address:
AccountService.update_login_info(account=account, ip_address=ip_address) AccountService.update_login_info(account=account, ip_address=ip_address)
@ -318,6 +322,9 @@ class AccountService:
@staticmethod @staticmethod
def refresh_token(refresh_token: str) -> TokenPair: def refresh_token(refresh_token: str) -> TokenPair:
if not AccountService.verify_account_whitelist(refresh_token):
raise ValueError("Account is not whitelisted")
# Verify the refresh token # Verify the refresh token
account_id = redis_client.get(AccountService._get_refresh_token_key(refresh_token)) account_id = redis_client.get(AccountService._get_refresh_token_key(refresh_token))
if not account_id: if not account_id:
@ -336,6 +343,11 @@ class AccountService:
return TokenPair(access_token=new_access_token, refresh_token=new_refresh_token) return TokenPair(access_token=new_access_token, refresh_token=new_refresh_token)
@staticmethod
def verify_account_whitelist(email: str) -> bool:
with Session(db.engine) as session:
return session.query(StagingAccountWhitelist).filter_by(email=email, disabled=False).first() is not None
@staticmethod @staticmethod
def load_logged_in_account(*, account_id: str): def load_logged_in_account(*, account_id: str):
return AccountService.load_user(account_id) return AccountService.load_user(account_id)