Feat/new saas billing (#14996)

This commit is contained in:
Jyong 2025-03-10 19:50:11 +08:00 committed by GitHub
parent c8cc31af88
commit 9b2a9260ef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 235 additions and 4 deletions

View File

@ -10,7 +10,12 @@ from controllers.console import api
from controllers.console.apikey import api_key_fields, api_key_list
from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_rate_limit_check,
enterprise_license_required,
setup_required,
)
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.indexing_runner import IndexingRunner
from core.model_runtime.entities.model_entities import ModelType
@ -96,6 +101,7 @@ class DatasetListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self):
parser = reqparse.RequestParser()
parser.add_argument(
@ -210,6 +216,7 @@ class DatasetApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -313,6 +320,7 @@ class DatasetApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id):
dataset_id_str = str(dataset_id)

View File

@ -26,6 +26,7 @@ from controllers.console.datasets.error import (
)
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_rate_limit_check,
cloud_edition_billing_resource_check,
setup_required,
)
@ -242,6 +243,7 @@ class DatasetDocumentListApi(Resource):
@account_initialization_required
@marshal_with(documents_and_batch_fields)
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id):
dataset_id = str(dataset_id)
@ -297,6 +299,7 @@ class DatasetDocumentListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id):
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
@ -320,6 +323,7 @@ class DatasetInitApi(Resource):
@account_initialization_required
@marshal_with(dataset_and_document_fields)
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
@ -694,6 +698,7 @@ class DocumentProcessingApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, action):
dataset_id = str(dataset_id)
document_id = str(document_id)
@ -730,6 +735,7 @@ class DocumentDeleteApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id, document_id):
dataset_id = str(dataset_id)
document_id = str(document_id)
@ -798,6 +804,7 @@ class DocumentStatusApi(DocumentResource):
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, action):
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
@ -893,6 +900,7 @@ class DocumentPauseApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id):
"""pause document."""
dataset_id = str(dataset_id)
@ -925,6 +933,7 @@ class DocumentRecoverApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id):
"""recover document."""
dataset_id = str(dataset_id)
@ -954,6 +963,7 @@ class DocumentRetryApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id):
"""retry document."""

View File

@ -19,6 +19,7 @@ from controllers.console.datasets.error import (
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_knowledge_limit_check,
cloud_edition_billing_rate_limit_check,
cloud_edition_billing_resource_check,
setup_required,
)
@ -106,6 +107,7 @@ class DatasetDocumentSegmentListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id, document_id):
# check dataset
dataset_id = str(dataset_id)
@ -137,6 +139,7 @@ class DatasetDocumentSegmentApi(Resource):
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, action):
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
@ -191,6 +194,7 @@ class DatasetDocumentSegmentAddApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id, document_id):
# check dataset
dataset_id = str(dataset_id)
@ -240,6 +244,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, segment_id):
# check dataset
dataset_id = str(dataset_id)
@ -299,6 +304,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id, document_id, segment_id):
# check dataset
dataset_id = str(dataset_id)
@ -336,6 +342,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id, document_id):
# check dataset
dataset_id = str(dataset_id)
@ -402,6 +409,7 @@ class ChildChunkAddApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id, document_id, segment_id):
# check dataset
dataset_id = str(dataset_id)
@ -499,6 +507,7 @@ class ChildChunkAddApi(Resource):
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, segment_id):
# check dataset
dataset_id = str(dataset_id)
@ -542,6 +551,7 @@ class ChildChunkUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id, document_id, segment_id, child_chunk_id):
# check dataset
dataset_id = str(dataset_id)
@ -586,6 +596,7 @@ class ChildChunkUpdateApi(Resource):
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, segment_id, child_chunk_id):
# check dataset
dataset_id = str(dataset_id)

View File

@ -2,7 +2,11 @@ from flask_restful import Resource # type: ignore
from controllers.console import api
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_rate_limit_check,
setup_required,
)
from libs.login import login_required
@ -10,6 +14,7 @@ class HitTestingApi(Resource, DatasetsHitTestingBase):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id):
dataset_id_str = str(dataset_id)

View File

@ -1,5 +1,6 @@
import json
import os
import time
from functools import wraps
from flask import abort, request
@ -8,6 +9,8 @@ from flask_login import current_user # type: ignore
from configs import dify_config
from controllers.console.workspace.error import AccountNotInitializedError
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import RateLimitLog
from models.model import DifySetup
from services.feature_service import FeatureService, LicenseStatus
from services.operation_service import OperationService
@ -67,7 +70,9 @@ def cloud_edition_billing_resource_check(resource: str):
elif resource == "apps" and 0 < apps.limit <= apps.size:
abort(403, "The number of apps has reached the limit of your subscription.")
elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size:
abort(403, "The capacity of the vector space has reached the limit of your subscription.")
abort(
403, "The capacity of the knowledge storage space has reached the limit of your subscription."
)
elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
# The api of file upload is used in the multiple places,
# so we need to check the source of the request from datasets
@ -112,6 +117,41 @@ def cloud_edition_billing_knowledge_limit_check(resource: str):
return interceptor
def cloud_edition_billing_rate_limit_check(resource: str):
def interceptor(view):
@wraps(view)
def decorated(*args, **kwargs):
if resource == "knowledge":
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_user.current_tenant_id)
if knowledge_rate_limit.enabled:
current_time = int(time.time() * 1000)
key = f"rate_limit_{current_user.current_tenant_id}"
redis_client.zadd(key, {current_time: current_time})
redis_client.zremrangebyscore(key, 0, current_time - 60000)
request_count = redis_client.zcard(key)
if request_count > knowledge_rate_limit.limit:
# add ratelimit record
rate_limit_log = RateLimitLog(
tenant_id=current_user.current_tenant_id,
subscription_plan=knowledge_rate_limit.subscription_plan,
operation="knowledge",
)
db.session.add(rate_limit_log)
db.session.commit()
abort(
403, "Sorry, you have reached the knowledge base request rate limit of your subscription."
)
return view(*args, **kwargs)
return decorated
return interceptor
def cloud_utm_record(view):
@wraps(view)
def decorated(*args, **kwargs):

View File

@ -1,3 +1,4 @@
import time
from collections.abc import Callable
from datetime import UTC, datetime, timedelta
from enum import Enum
@ -13,8 +14,10 @@ from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, Unauthorized
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.login import _get_user
from models.account import Account, Tenant, TenantAccountJoin, TenantStatus
from models.dataset import RateLimitLog
from models.model import ApiToken, App, EndUser
from services.feature_service import FeatureService
@ -139,6 +142,43 @@ def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: s
return interceptor
def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str):
def interceptor(view):
@wraps(view)
def decorated(*args, **kwargs):
api_token = validate_and_get_api_token(api_token_type)
if resource == "knowledge":
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(api_token.tenant_id)
if knowledge_rate_limit.enabled:
current_time = int(time.time() * 1000)
key = f"rate_limit_{api_token.tenant_id}"
redis_client.zadd(key, {current_time: current_time})
redis_client.zremrangebyscore(key, 0, current_time - 60000)
request_count = redis_client.zcard(key)
if request_count > knowledge_rate_limit.limit:
# add ratelimit record
rate_limit_log = RateLimitLog(
tenant_id=api_token.tenant_id,
subscription_plan=knowledge_rate_limit.subscription_plan,
operation="knowledge",
)
db.session.add(rate_limit_log)
db.session.commit()
raise Forbidden(
"Sorry, you have reached the knowledge base request rate limit of your subscription."
)
return view(*args, **kwargs)
return decorated
return interceptor
def validate_dataset_token(view=None):
def decorator(view):
@wraps(view)

View File

@ -1,4 +1,5 @@
import logging
import time
from collections.abc import Mapping, Sequence
from typing import Any, cast
@ -19,8 +20,10 @@ from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db
from models.dataset import Dataset, Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset, Document, RateLimitLog
from models.workflow import WorkflowNodeExecutionStatus
from services.feature_service import FeatureService
from .entities import KnowledgeRetrievalNodeData
from .exc import (
@ -61,6 +64,31 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Query is required."
)
# check rate limit
if self.tenant_id:
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(self.tenant_id)
if knowledge_rate_limit.enabled:
current_time = int(time.time() * 1000)
key = f"rate_limit_{self.tenant_id}"
redis_client.zadd(key, {current_time: current_time})
redis_client.zremrangebyscore(key, 0, current_time - 60000)
request_count = redis_client.zcard(key)
if request_count > knowledge_rate_limit.limit:
# add ratelimit record
rate_limit_log = RateLimitLog(
tenant_id=self.tenant_id,
subscription_plan=knowledge_rate_limit.subscription_plan,
operation="knowledge",
)
db.session.add(rate_limit_log)
db.session.commit()
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error="Sorry, you have reached the knowledge base request rate limit of your subscription.",
error_type="RateLimitExceeded",
)
# retrieve knowledge
try:
results = self._fetch_dataset_retriever(node_data=self.node_data, query=query)

View File

@ -0,0 +1,43 @@
"""add_rate_limit_logs
Revision ID: f051706725cc
Revises: 923752d42eb6
Create Date: 2025-01-14 06:17:35.536388
"""
from alembic import op
import models as models
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = 'f051706725cc'
down_revision = 'd20049ed0af6'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('rate_limit_logs',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('subscription_plan', sa.String(length=255), nullable=False),
sa.Column('operation', sa.String(length=255), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.PrimaryKeyConstraint('id', name='rate_limit_log_pkey')
)
with op.batch_alter_table('rate_limit_logs', schema=None) as batch_op:
batch_op.create_index('rate_limit_log_operation_idx', ['operation'], unique=False)
batch_op.create_index('rate_limit_log_tenant_idx', ['tenant_id'], unique=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('rate_limit_logs', schema=None) as batch_op:
batch_op.drop_index('rate_limit_log_tenant_idx')
batch_op.drop_index('rate_limit_log_operation_idx')
op.drop_table('rate_limit_logs')
# ### end Alembic commands ###

View File

@ -930,3 +930,18 @@ class DatasetAutoDisableLog(db.Model): # type: ignore[name-defined]
document_id = db.Column(StringUUID, nullable=False)
notified = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
class RateLimitLog(db.Model): # type: ignore[name-defined]
__tablename__ = "rate_limit_logs"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="rate_limit_log_pkey"),
db.Index("rate_limit_log_tenant_idx", "tenant_id"),
db.Index("rate_limit_log_operation_idx", "operation"),
)
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False)
subscription_plan = db.Column(db.String(255), nullable=False)
operation = db.Column(db.String(255), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))

View File

@ -22,6 +22,17 @@ class BillingService:
billing_info = cls._send_request("GET", "/subscription/info", params=params)
return billing_info
@classmethod
def get_knowledge_rate_limit(cls, tenant_id: str):
params = {"tenant_id": tenant_id}
knowledge_rate_limit = cls._send_request("GET", "/subscription/knowledge-rate-limit", params=params)
return {
"limit": knowledge_rate_limit.get("limit", 10),
"subscription_plan": knowledge_rate_limit.get("subscription_plan", "sandbox"),
}
@classmethod
def get_subscription(cls, plan: str, interval: str, prefilled_email: str = "", tenant_id: str = ""):
params = {"plan": plan, "interval": interval, "prefilled_email": prefilled_email, "tenant_id": tenant_id}

View File

@ -41,6 +41,7 @@ class FeatureModel(BaseModel):
members: LimitationModel = LimitationModel(size=0, limit=1)
apps: LimitationModel = LimitationModel(size=0, limit=10)
vector_space: LimitationModel = LimitationModel(size=0, limit=5)
knowledge_rate_limit: int = 10
annotation_quota_limit: LimitationModel = LimitationModel(size=0, limit=10)
documents_upload_quota: LimitationModel = LimitationModel(size=0, limit=50)
docs_processing: str = "standard"
@ -52,6 +53,12 @@ class FeatureModel(BaseModel):
model_config = ConfigDict(protected_namespaces=())
class KnowledgeRateLimitModel(BaseModel):
enabled: bool = False
limit: int = 10
subscription_plan: str = ""
class SystemFeatureModel(BaseModel):
sso_enforced_for_signin: bool = False
sso_enforced_for_signin_protocol: str = ""
@ -81,6 +88,16 @@ class FeatureService:
return features
@classmethod
def get_knowledge_rate_limit(cls, tenant_id: str):
knowledge_rate_limit = KnowledgeRateLimitModel()
if dify_config.BILLING_ENABLED and tenant_id:
knowledge_rate_limit.enabled = True
limit_info = BillingService.get_knowledge_rate_limit(tenant_id)
knowledge_rate_limit.limit = limit_info.get("limit", 10)
knowledge_rate_limit.subscription_plan = limit_info.get("subscription_plan", "sandbox")
return knowledge_rate_limit
@classmethod
def get_system_features(cls) -> SystemFeatureModel:
system_features = SystemFeatureModel()
@ -149,6 +166,9 @@ class FeatureService:
if "model_load_balancing_enabled" in billing_info:
features.model_load_balancing_enabled = billing_info["model_load_balancing_enabled"]
if "knowledge_rate_limit" in billing_info:
features.knowledge_rate_limit = billing_info["knowledge_rate_limit"]["limit"]
@classmethod
def _fulfill_params_from_enterprise(cls, features):
enterprise_info = EnterpriseService.get_info()