diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index e934903910..a23ad5ef47 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -10,7 +10,12 @@ from controllers.console import api from controllers.console.apikey import api_key_fields, api_key_list from controllers.console.app.error import ProviderNotInitializeError from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError -from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required +from controllers.console.wraps import ( + account_initialization_required, + cloud_edition_billing_rate_limit_check, + enterprise_license_required, + setup_required, +) from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.indexing_runner import IndexingRunner from core.model_runtime.entities.model_entities import ModelType @@ -96,6 +101,7 @@ class DatasetListApi(Resource): @setup_required @login_required @account_initialization_required + @cloud_edition_billing_rate_limit_check("knowledge") def post(self): parser = reqparse.RequestParser() parser.add_argument( @@ -210,6 +216,7 @@ class DatasetApi(Resource): @setup_required @login_required @account_initialization_required + @cloud_edition_billing_rate_limit_check("knowledge") def patch(self, dataset_id): dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) @@ -313,6 +320,7 @@ class DatasetApi(Resource): @setup_required @login_required @account_initialization_required + @cloud_edition_billing_rate_limit_check("knowledge") def delete(self, dataset_id): dataset_id_str = str(dataset_id) diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 7ba9f5e121..6e0e8f1903 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -26,6 +26,7 @@ from controllers.console.datasets.error import ( ) from controllers.console.wraps import ( account_initialization_required, + cloud_edition_billing_rate_limit_check, cloud_edition_billing_resource_check, setup_required, ) @@ -242,6 +243,7 @@ class DatasetDocumentListApi(Resource): @account_initialization_required @marshal_with(documents_and_batch_fields) @cloud_edition_billing_resource_check("vector_space") + @cloud_edition_billing_rate_limit_check("knowledge") def post(self, dataset_id): dataset_id = str(dataset_id) @@ -297,6 +299,7 @@ class DatasetDocumentListApi(Resource): @setup_required @login_required @account_initialization_required + @cloud_edition_billing_rate_limit_check("knowledge") def delete(self, dataset_id): dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) @@ -320,6 +323,7 @@ class DatasetInitApi(Resource): @account_initialization_required @marshal_with(dataset_and_document_fields) @cloud_edition_billing_resource_check("vector_space") + @cloud_edition_billing_rate_limit_check("knowledge") def post(self): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: @@ -694,6 +698,7 @@ class DocumentProcessingApi(DocumentResource): @setup_required @login_required @account_initialization_required + @cloud_edition_billing_rate_limit_check("knowledge") def patch(self, dataset_id, document_id, action): dataset_id = str(dataset_id) document_id = str(document_id) @@ -730,6 +735,7 @@ class DocumentDeleteApi(DocumentResource): @setup_required @login_required @account_initialization_required + @cloud_edition_billing_rate_limit_check("knowledge") def delete(self, dataset_id, document_id): dataset_id = str(dataset_id) document_id = str(document_id) @@ -798,6 +804,7 @@ class DocumentStatusApi(DocumentResource): @login_required @account_initialization_required @cloud_edition_billing_resource_check("vector_space") + @cloud_edition_billing_rate_limit_check("knowledge") def patch(self, dataset_id, action): dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) @@ -893,6 +900,7 @@ class DocumentPauseApi(DocumentResource): @setup_required @login_required @account_initialization_required + @cloud_edition_billing_rate_limit_check("knowledge") def patch(self, dataset_id, document_id): """pause document.""" dataset_id = str(dataset_id) @@ -925,6 +933,7 @@ class DocumentRecoverApi(DocumentResource): @setup_required @login_required @account_initialization_required + @cloud_edition_billing_rate_limit_check("knowledge") def patch(self, dataset_id, document_id): """recover document.""" dataset_id = str(dataset_id) @@ -954,6 +963,7 @@ class DocumentRetryApi(DocumentResource): @setup_required @login_required @account_initialization_required + @cloud_edition_billing_rate_limit_check("knowledge") def post(self, dataset_id): """retry document.""" diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index d2c94045ad..4642ed3573 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -19,6 +19,7 @@ from controllers.console.datasets.error import ( from controllers.console.wraps import ( account_initialization_required, cloud_edition_billing_knowledge_limit_check, + cloud_edition_billing_rate_limit_check, cloud_edition_billing_resource_check, setup_required, ) @@ -106,6 +107,7 @@ class DatasetDocumentSegmentListApi(Resource): @setup_required @login_required @account_initialization_required + @cloud_edition_billing_rate_limit_check("knowledge") def delete(self, dataset_id, document_id): # check dataset dataset_id = str(dataset_id) @@ -137,6 +139,7 @@ class DatasetDocumentSegmentApi(Resource): @login_required @account_initialization_required @cloud_edition_billing_resource_check("vector_space") + @cloud_edition_billing_rate_limit_check("knowledge") def patch(self, dataset_id, document_id, action): dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) @@ -191,6 +194,7 @@ class DatasetDocumentSegmentAddApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_knowledge_limit_check("add_segment") + @cloud_edition_billing_rate_limit_check("knowledge") def post(self, dataset_id, document_id): # check dataset dataset_id = str(dataset_id) @@ -240,6 +244,7 @@ class DatasetDocumentSegmentUpdateApi(Resource): @login_required @account_initialization_required @cloud_edition_billing_resource_check("vector_space") + @cloud_edition_billing_rate_limit_check("knowledge") def patch(self, dataset_id, document_id, segment_id): # check dataset dataset_id = str(dataset_id) @@ -299,6 +304,7 @@ class DatasetDocumentSegmentUpdateApi(Resource): @setup_required @login_required @account_initialization_required + @cloud_edition_billing_rate_limit_check("knowledge") def delete(self, dataset_id, document_id, segment_id): # check dataset dataset_id = str(dataset_id) @@ -336,6 +342,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_knowledge_limit_check("add_segment") + @cloud_edition_billing_rate_limit_check("knowledge") def post(self, dataset_id, document_id): # check dataset dataset_id = str(dataset_id) @@ -402,6 +409,7 @@ class ChildChunkAddApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_knowledge_limit_check("add_segment") + @cloud_edition_billing_rate_limit_check("knowledge") def post(self, dataset_id, document_id, segment_id): # check dataset dataset_id = str(dataset_id) @@ -499,6 +507,7 @@ class ChildChunkAddApi(Resource): @login_required @account_initialization_required @cloud_edition_billing_resource_check("vector_space") + @cloud_edition_billing_rate_limit_check("knowledge") def patch(self, dataset_id, document_id, segment_id): # check dataset dataset_id = str(dataset_id) @@ -542,6 +551,7 @@ class ChildChunkUpdateApi(Resource): @setup_required @login_required @account_initialization_required + @cloud_edition_billing_rate_limit_check("knowledge") def delete(self, dataset_id, document_id, segment_id, child_chunk_id): # check dataset dataset_id = str(dataset_id) @@ -586,6 +596,7 @@ class ChildChunkUpdateApi(Resource): @login_required @account_initialization_required @cloud_edition_billing_resource_check("vector_space") + @cloud_edition_billing_rate_limit_check("knowledge") def patch(self, dataset_id, document_id, segment_id, child_chunk_id): # check dataset dataset_id = str(dataset_id) diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py index 18b746f547..d344e9d126 100644 --- a/api/controllers/console/datasets/hit_testing.py +++ b/api/controllers/console/datasets/hit_testing.py @@ -2,7 +2,11 @@ from flask_restful import Resource # type: ignore from controllers.console import api from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase -from controllers.console.wraps import account_initialization_required, setup_required +from controllers.console.wraps import ( + account_initialization_required, + cloud_edition_billing_rate_limit_check, + setup_required, +) from libs.login import login_required @@ -10,6 +14,7 @@ class HitTestingApi(Resource, DatasetsHitTestingBase): @setup_required @login_required @account_initialization_required + @cloud_edition_billing_rate_limit_check("knowledge") def post(self, dataset_id): dataset_id_str = str(dataset_id) diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index a6f64700f2..ed6e16b035 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -1,5 +1,6 @@ import json import os +import time from functools import wraps from flask import abort, request @@ -8,6 +9,8 @@ from flask_login import current_user # type: ignore from configs import dify_config from controllers.console.workspace.error import AccountNotInitializedError from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.dataset import RateLimitLog from models.model import DifySetup from services.feature_service import FeatureService, LicenseStatus from services.operation_service import OperationService @@ -67,7 +70,9 @@ def cloud_edition_billing_resource_check(resource: str): elif resource == "apps" and 0 < apps.limit <= apps.size: abort(403, "The number of apps has reached the limit of your subscription.") elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size: - abort(403, "The capacity of the vector space has reached the limit of your subscription.") + abort( + 403, "The capacity of the knowledge storage space has reached the limit of your subscription." + ) elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size: # The api of file upload is used in the multiple places, # so we need to check the source of the request from datasets @@ -112,6 +117,41 @@ def cloud_edition_billing_knowledge_limit_check(resource: str): return interceptor +def cloud_edition_billing_rate_limit_check(resource: str): + def interceptor(view): + @wraps(view) + def decorated(*args, **kwargs): + if resource == "knowledge": + knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_user.current_tenant_id) + if knowledge_rate_limit.enabled: + current_time = int(time.time() * 1000) + key = f"rate_limit_{current_user.current_tenant_id}" + + redis_client.zadd(key, {current_time: current_time}) + + redis_client.zremrangebyscore(key, 0, current_time - 60000) + + request_count = redis_client.zcard(key) + + if request_count > knowledge_rate_limit.limit: + # add ratelimit record + rate_limit_log = RateLimitLog( + tenant_id=current_user.current_tenant_id, + subscription_plan=knowledge_rate_limit.subscription_plan, + operation="knowledge", + ) + db.session.add(rate_limit_log) + db.session.commit() + abort( + 403, "Sorry, you have reached the knowledge base request rate limit of your subscription." + ) + return view(*args, **kwargs) + + return decorated + + return interceptor + + def cloud_utm_record(view): @wraps(view) def decorated(*args, **kwargs): diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index bc5dba9a78..7f87bf438b 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -1,3 +1,4 @@ +import time from collections.abc import Callable from datetime import UTC, datetime, timedelta from enum import Enum @@ -13,8 +14,10 @@ from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden, Unauthorized from extensions.ext_database import db +from extensions.ext_redis import redis_client from libs.login import _get_user from models.account import Account, Tenant, TenantAccountJoin, TenantStatus +from models.dataset import RateLimitLog from models.model import ApiToken, App, EndUser from services.feature_service import FeatureService @@ -139,6 +142,43 @@ def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: s return interceptor +def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str): + def interceptor(view): + @wraps(view) + def decorated(*args, **kwargs): + api_token = validate_and_get_api_token(api_token_type) + + if resource == "knowledge": + knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(api_token.tenant_id) + if knowledge_rate_limit.enabled: + current_time = int(time.time() * 1000) + key = f"rate_limit_{api_token.tenant_id}" + + redis_client.zadd(key, {current_time: current_time}) + + redis_client.zremrangebyscore(key, 0, current_time - 60000) + + request_count = redis_client.zcard(key) + + if request_count > knowledge_rate_limit.limit: + # add ratelimit record + rate_limit_log = RateLimitLog( + tenant_id=api_token.tenant_id, + subscription_plan=knowledge_rate_limit.subscription_plan, + operation="knowledge", + ) + db.session.add(rate_limit_log) + db.session.commit() + raise Forbidden( + "Sorry, you have reached the knowledge base request rate limit of your subscription." + ) + return view(*args, **kwargs) + + return decorated + + return interceptor + + def validate_dataset_token(view=None): def decorator(view): @wraps(view) diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index d77def2c40..11c56e3172 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -1,4 +1,5 @@ import logging +import time from collections.abc import Mapping, Sequence from typing import Any, cast @@ -19,8 +20,10 @@ from core.workflow.entities.node_entities import NodeRunResult from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType from extensions.ext_database import db -from models.dataset import Dataset, Document +from extensions.ext_redis import redis_client +from models.dataset import Dataset, Document, RateLimitLog from models.workflow import WorkflowNodeExecutionStatus +from services.feature_service import FeatureService from .entities import KnowledgeRetrievalNodeData from .exc import ( @@ -61,6 +64,31 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]): return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Query is required." ) + # check rate limit + if self.tenant_id: + knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(self.tenant_id) + if knowledge_rate_limit.enabled: + current_time = int(time.time() * 1000) + key = f"rate_limit_{self.tenant_id}" + redis_client.zadd(key, {current_time: current_time}) + redis_client.zremrangebyscore(key, 0, current_time - 60000) + request_count = redis_client.zcard(key) + if request_count > knowledge_rate_limit.limit: + # add ratelimit record + rate_limit_log = RateLimitLog( + tenant_id=self.tenant_id, + subscription_plan=knowledge_rate_limit.subscription_plan, + operation="knowledge", + ) + db.session.add(rate_limit_log) + db.session.commit() + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=variables, + error="Sorry, you have reached the knowledge base request rate limit of your subscription.", + error_type="RateLimitExceeded", + ) + # retrieve knowledge try: results = self._fetch_dataset_retriever(node_data=self.node_data, query=query) diff --git a/api/migrations/versions/2025_01_14_0617-f051706725cc_add_rate_limit_logs.py b/api/migrations/versions/2025_01_14_0617-f051706725cc_add_rate_limit_logs.py new file mode 100644 index 0000000000..4335471d01 --- /dev/null +++ b/api/migrations/versions/2025_01_14_0617-f051706725cc_add_rate_limit_logs.py @@ -0,0 +1,43 @@ +"""add_rate_limit_logs + +Revision ID: f051706725cc +Revises: 923752d42eb6 +Create Date: 2025-01-14 06:17:35.536388 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'f051706725cc' +down_revision = 'd20049ed0af6' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('rate_limit_logs', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('subscription_plan', sa.String(length=255), nullable=False), + sa.Column('operation', sa.String(length=255), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='rate_limit_log_pkey') + ) + with op.batch_alter_table('rate_limit_logs', schema=None) as batch_op: + batch_op.create_index('rate_limit_log_operation_idx', ['operation'], unique=False) + batch_op.create_index('rate_limit_log_tenant_idx', ['tenant_id'], unique=False) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('rate_limit_logs', schema=None) as batch_op: + batch_op.drop_index('rate_limit_log_tenant_idx') + batch_op.drop_index('rate_limit_log_operation_idx') + + op.drop_table('rate_limit_logs') + # ### end Alembic commands ### diff --git a/api/models/dataset.py b/api/models/dataset.py index 0fc99b2693..4275f2728c 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -930,3 +930,18 @@ class DatasetAutoDisableLog(db.Model): # type: ignore[name-defined] document_id = db.Column(StringUUID, nullable=False) notified = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + + +class RateLimitLog(db.Model): # type: ignore[name-defined] + __tablename__ = "rate_limit_logs" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="rate_limit_log_pkey"), + db.Index("rate_limit_log_tenant_idx", "tenant_id"), + db.Index("rate_limit_log_operation_idx", "operation"), + ) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id = db.Column(StringUUID, nullable=False) + subscription_plan = db.Column(db.String(255), nullable=False) + operation = db.Column(db.String(255), nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) diff --git a/api/services/billing_service.py b/api/services/billing_service.py index ad141035cc..ab68aad45a 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -22,6 +22,17 @@ class BillingService: billing_info = cls._send_request("GET", "/subscription/info", params=params) return billing_info + @classmethod + def get_knowledge_rate_limit(cls, tenant_id: str): + params = {"tenant_id": tenant_id} + + knowledge_rate_limit = cls._send_request("GET", "/subscription/knowledge-rate-limit", params=params) + + return { + "limit": knowledge_rate_limit.get("limit", 10), + "subscription_plan": knowledge_rate_limit.get("subscription_plan", "sandbox"), + } + @classmethod def get_subscription(cls, plan: str, interval: str, prefilled_email: str = "", tenant_id: str = ""): params = {"plan": plan, "interval": interval, "prefilled_email": prefilled_email, "tenant_id": tenant_id} diff --git a/api/services/feature_service.py b/api/services/feature_service.py index a42b3020cd..113cd52552 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -41,6 +41,7 @@ class FeatureModel(BaseModel): members: LimitationModel = LimitationModel(size=0, limit=1) apps: LimitationModel = LimitationModel(size=0, limit=10) vector_space: LimitationModel = LimitationModel(size=0, limit=5) + knowledge_rate_limit: int = 10 annotation_quota_limit: LimitationModel = LimitationModel(size=0, limit=10) documents_upload_quota: LimitationModel = LimitationModel(size=0, limit=50) docs_processing: str = "standard" @@ -52,6 +53,12 @@ class FeatureModel(BaseModel): model_config = ConfigDict(protected_namespaces=()) +class KnowledgeRateLimitModel(BaseModel): + enabled: bool = False + limit: int = 10 + subscription_plan: str = "" + + class SystemFeatureModel(BaseModel): sso_enforced_for_signin: bool = False sso_enforced_for_signin_protocol: str = "" @@ -81,6 +88,16 @@ class FeatureService: return features + @classmethod + def get_knowledge_rate_limit(cls, tenant_id: str): + knowledge_rate_limit = KnowledgeRateLimitModel() + if dify_config.BILLING_ENABLED and tenant_id: + knowledge_rate_limit.enabled = True + limit_info = BillingService.get_knowledge_rate_limit(tenant_id) + knowledge_rate_limit.limit = limit_info.get("limit", 10) + knowledge_rate_limit.subscription_plan = limit_info.get("subscription_plan", "sandbox") + return knowledge_rate_limit + @classmethod def get_system_features(cls) -> SystemFeatureModel: system_features = SystemFeatureModel() @@ -149,6 +166,9 @@ class FeatureService: if "model_load_balancing_enabled" in billing_info: features.model_load_balancing_enabled = billing_info["model_load_balancing_enabled"] + if "knowledge_rate_limit" in billing_info: + features.knowledge_rate_limit = billing_info["knowledge_rate_limit"]["limit"] + @classmethod def _fulfill_params_from_enterprise(cls, features): enterprise_info = EnterpriseService.get_info()