diff --git a/api/models/staging.py b/api/models/staging.py new file mode 100644 index 0000000000..3e345ac47f --- /dev/null +++ b/api/models/staging.py @@ -0,0 +1,27 @@ +from datetime import datetime + +from sqlalchemy.orm import Mapped + +from extensions.ext_database import db +from models.base import Base + +from .types import StringUUID + + +class StagingAccountWhitelist(Base): + __tablename__ = "staging_account_whitelists" + + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="staging_account_whitelist_pkey"), + db.Index("account_email_idx", "email"), + ) + + id: Mapped[str] = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + email: Mapped[str] = db.Column(db.String(255), nullable=False) + disabled: Mapped[bool] = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + created_at: Mapped[datetime] = db.Column( + db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + ) + updated_at: Mapped[datetime] = db.Column( + db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + ) diff --git a/api/services/account_service.py b/api/services/account_service.py index 7613f48a3e..24f005f1e1 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -33,6 +33,7 @@ from models.account import ( TenantStatus, ) from models.model import DifySetup +from models.staging import StagingAccountWhitelist from services.errors.account import ( AccountAlreadyInTenantError, AccountLoginError, @@ -296,6 +297,9 @@ class AccountService: @staticmethod def login(account: Account, *, ip_address: Optional[str] = None) -> TokenPair: + if not AccountService.verify_account_whitelist(account.email): + raise ValueError("Account is not whitelisted") + if ip_address: AccountService.update_login_info(account=account, ip_address=ip_address) @@ -318,6 +322,9 @@ class AccountService: @staticmethod def refresh_token(refresh_token: str) -> TokenPair: + if not AccountService.verify_account_whitelist(refresh_token): + raise ValueError("Account is not whitelisted") + # Verify the refresh token account_id = redis_client.get(AccountService._get_refresh_token_key(refresh_token)) if not account_id: @@ -336,6 +343,11 @@ class AccountService: return TokenPair(access_token=new_access_token, refresh_token=new_refresh_token) + @staticmethod + def verify_account_whitelist(email: str) -> bool: + with Session(db.engine) as session: + return session.query(StagingAccountWhitelist).filter_by(email=email, disabled=False).first() is not None + @staticmethod def load_logged_in_account(*, account_id: str): return AccountService.load_user(account_id)