feat: support whitelist
This commit is contained in:
parent
3dd6d96b5a
commit
2e2d1659ca
27
api/models/staging.py
Normal file
27
api/models/staging.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Mapped
|
||||||
|
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models.base import Base
|
||||||
|
|
||||||
|
from .types import StringUUID
|
||||||
|
|
||||||
|
|
||||||
|
class StagingAccountWhitelist(Base):
|
||||||
|
__tablename__ = "staging_account_whitelists"
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
db.PrimaryKeyConstraint("id", name="staging_account_whitelist_pkey"),
|
||||||
|
db.Index("account_email_idx", "email"),
|
||||||
|
)
|
||||||
|
|
||||||
|
id: Mapped[str] = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||||
|
email: Mapped[str] = db.Column(db.String(255), nullable=False)
|
||||||
|
disabled: Mapped[bool] = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||||
|
created_at: Mapped[datetime] = db.Column(
|
||||||
|
db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
|
||||||
|
)
|
||||||
|
updated_at: Mapped[datetime] = db.Column(
|
||||||
|
db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
|
||||||
|
)
|
@ -33,6 +33,7 @@ from models.account import (
|
|||||||
TenantStatus,
|
TenantStatus,
|
||||||
)
|
)
|
||||||
from models.model import DifySetup
|
from models.model import DifySetup
|
||||||
|
from models.staging import StagingAccountWhitelist
|
||||||
from services.errors.account import (
|
from services.errors.account import (
|
||||||
AccountAlreadyInTenantError,
|
AccountAlreadyInTenantError,
|
||||||
AccountLoginError,
|
AccountLoginError,
|
||||||
@ -296,6 +297,9 @@ class AccountService:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def login(account: Account, *, ip_address: Optional[str] = None) -> TokenPair:
|
def login(account: Account, *, ip_address: Optional[str] = None) -> TokenPair:
|
||||||
|
if not AccountService.verify_account_whitelist(account.email):
|
||||||
|
raise ValueError("Account is not whitelisted")
|
||||||
|
|
||||||
if ip_address:
|
if ip_address:
|
||||||
AccountService.update_login_info(account=account, ip_address=ip_address)
|
AccountService.update_login_info(account=account, ip_address=ip_address)
|
||||||
|
|
||||||
@ -318,6 +322,9 @@ class AccountService:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def refresh_token(refresh_token: str) -> TokenPair:
|
def refresh_token(refresh_token: str) -> TokenPair:
|
||||||
|
if not AccountService.verify_account_whitelist(refresh_token):
|
||||||
|
raise ValueError("Account is not whitelisted")
|
||||||
|
|
||||||
# Verify the refresh token
|
# Verify the refresh token
|
||||||
account_id = redis_client.get(AccountService._get_refresh_token_key(refresh_token))
|
account_id = redis_client.get(AccountService._get_refresh_token_key(refresh_token))
|
||||||
if not account_id:
|
if not account_id:
|
||||||
@ -336,6 +343,11 @@ class AccountService:
|
|||||||
|
|
||||||
return TokenPair(access_token=new_access_token, refresh_token=new_refresh_token)
|
return TokenPair(access_token=new_access_token, refresh_token=new_refresh_token)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def verify_account_whitelist(email: str) -> bool:
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
return session.query(StagingAccountWhitelist).filter_by(email=email, disabled=False).first() is not None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load_logged_in_account(*, account_id: str):
|
def load_logged_in_account(*, account_id: str):
|
||||||
return AccountService.load_user(account_id)
|
return AccountService.load_user(account_id)
|
||||||
|
Loading…
Reference in New Issue
Block a user