diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index d3b843cd56..9b19231f60 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -36,11 +36,16 @@ class LoginApi(Resource): parser.add_argument("remember_me", type=bool, required=False, default=False, location="json") args = parser.parse_args() + is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args["email"]) + if is_login_error_rate_limit: + raise EmailOrPasswordMismatchError() + try: account = AccountService.authenticate(args["email"], args["password"]) except services.errors.account.AccountLoginError: raise NotAllowedRegister() except services.errors.account.AccountPasswordError: + AccountService.add_login_error_rate_limit(args["email"]) raise EmailOrPasswordMismatchError() except services.errors.account.AccountNotFoundError: if not dify_config.ALLOW_REGISTER: @@ -57,7 +62,7 @@ class LoginApi(Resource): } token = AccountService.login(account, ip_address=get_remote_ip(request)) - + AccountService.reset_login_error_rate_limit(args["email"]) return {"result": "success", "data": token} @@ -154,7 +159,7 @@ class EmailCodeLoginApi(Resource): "?message=Workspace not found, please contact system admin to invite you to join in a workspace." ) token = AccountService.login(account, ip_address=get_remote_ip(request)) - + AccountService.reset_login_error_rate_limit(args["email"]) return {"result": "success", "data": token} diff --git a/api/services/account_service.py b/api/services/account_service.py index d9f67764df..7bd51454cb 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -48,6 +48,7 @@ class AccountService: email_code_login_rate_limiter = RateLimiter( prefix="email_code_login_rate_limit", max_attempts=5, time_window=60 * 5 ) + LOGIN_MAX_ERROR_LIMITS = 5 @staticmethod def load_user(user_id: str) -> None | Account: @@ -317,6 +318,32 @@ class AccountService: return account + @staticmethod + def add_login_error_rate_limit(email: str) -> None: + key = f"login_error_rate_limit:{email}" + count = redis_client.get(key) + if count is None: + count = 0 + count = int(count) + 1 + redis_client.setex(key, 60 * 60 * 24, count) + + @staticmethod + def is_login_error_rate_limit(email: str) -> bool: + key = f"login_error_rate_limit:{email}" + count = redis_client.get(key) + if count is None: + return False + + count = int(count) + if count > AccountService.LOGIN_MAX_ERROR_LIMITS: + return True + return False + + @staticmethod + def reset_login_error_rate_limit(email: str): + key = f"login_error_rate_limit:{email}" + redis_client.delete(key) + def _get_login_cache_key(*, account_id: str, token: str): return f"account_login:{account_id}:{token}"