From d1dbbc1e33a9dc70e44137919b48f0f54af56d7c Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 5 Jun 2024 00:13:04 +0800 Subject: [PATCH] feat: backend model load balancing support (#4927) --- api/config.py | 57 +- api/controllers/console/__init__.py | 2 +- api/controllers/console/feature.py | 7 +- api/controllers/console/version.py | 37 +- .../workspace/load_balancing_config.py | 106 ++++ api/controllers/console/workspace/models.py | 118 +++- api/core/app/apps/base_app_runner.py | 25 +- .../easy_ui_based_generate_task_pipeline.py | 18 +- api/core/application_manager.py | 0 api/core/entities/model_entities.py | 12 +- api/core/entities/provider_configuration.py | 281 ++++++++- api/core/entities/provider_entities.py | 19 + api/core/extension/extensible.py | 2 +- api/core/helper/model_provider_cache.py | 1 + .../{utils => helper}/module_import_helper.py | 0 api/core/{utils => helper}/position_helper.py | 0 api/core/indexing_runner.py | 19 +- api/core/memory/token_buffer_memory.py | 13 +- api/core/model_manager.py | 292 ++++++++- .../model_providers/__base/ai_model.py | 2 +- .../model_providers/__base/model_provider.py | 2 +- .../model_providers/model_provider_factory.py | 4 +- api/core/prompt/prompt_transform.py | 14 +- api/core/provider_manager.py | 167 +++++- api/core/rag/docstore/dataset_docstore.py | 9 +- api/core/rag/splitter/fixed_text_splitter.py | 9 +- api/core/tools/provider/builtin/_positions.py | 2 +- .../tools/provider/builtin_tool_provider.py | 8 +- api/core/tools/tool_manager.py | 20 +- .../tools/utils/model_invocation_utils.py | 17 +- .../question_classifier_node.py | 12 +- .../4e99a8df00ff_add_load_balancing.py | 126 ++++ api/models/provider.py | 53 +- api/services/dataset_service.py | 18 +- .../entities/model_provider_entities.py | 13 +- api/services/feature_service.py | 37 +- api/services/model_load_balancing_service.py | 565 ++++++++++++++++++ api/services/model_provider_service.py | 64 +- api/services/workflow_service.py | 2 +- .../batch_create_segment_to_index_task.py | 8 +- .../utils/test_module_import_helper.py | 2 +- .../workflow/nodes/test_llm.py | 8 +- .../nodes/test_parameter_extractor.py | 3 +- .../core/prompt/test_prompt_transform.py | 11 +- .../unit_tests/core/test_model_manager.py | 77 +++ .../unit_tests/core/test_provider_manager.py | 183 ++++++ .../position_helper/test_position_helper.py | 2 +- 47 files changed, 2191 insertions(+), 256 deletions(-) create mode 100644 api/controllers/console/workspace/load_balancing_config.py delete mode 100644 api/core/application_manager.py rename api/core/{utils => helper}/module_import_helper.py (100%) rename api/core/{utils => helper}/position_helper.py (100%) create mode 100644 api/migrations/versions/4e99a8df00ff_add_load_balancing.py create mode 100644 api/services/model_load_balancing_service.py create mode 100644 api/tests/unit_tests/core/test_model_manager.py create mode 100644 api/tests/unit_tests/core/test_provider_manager.py diff --git a/api/config.py b/api/config.py index 40f659a8df..1ff43329db 100644 --- a/api/config.py +++ b/api/config.py @@ -70,6 +70,7 @@ DEFAULTS = { 'INVITE_EXPIRY_HOURS': 72, 'BILLING_ENABLED': 'False', 'CAN_REPLACE_LOGO': 'False', + 'MODEL_LB_ENABLED': 'False', 'ETL_TYPE': 'dify', 'KEYWORD_STORE': 'jieba', 'BATCH_UPLOAD_LIMIT': 20, @@ -123,6 +124,7 @@ class Config: self.LOG_FILE = get_env('LOG_FILE') self.LOG_FORMAT = get_env('LOG_FORMAT') self.LOG_DATEFORMAT = get_env('LOG_DATEFORMAT') + self.API_COMPRESSION_ENABLED = get_bool_env('API_COMPRESSION_ENABLED') # The backend URL prefix of the console API. # used to concatenate the login authorization callback or notion integration callback. @@ -210,27 +212,41 @@ class Config: if self.CELERY_BACKEND == 'database' else self.CELERY_BROKER_URL self.BROKER_USE_SSL = self.CELERY_BROKER_URL.startswith('rediss://') + # ------------------------ + # Code Execution Sandbox Configurations. + # ------------------------ + self.CODE_EXECUTION_ENDPOINT = get_env('CODE_EXECUTION_ENDPOINT') + self.CODE_EXECUTION_API_KEY = get_env('CODE_EXECUTION_API_KEY') + # ------------------------ # File Storage Configurations. # ------------------------ self.STORAGE_TYPE = get_env('STORAGE_TYPE') self.STORAGE_LOCAL_PATH = get_env('STORAGE_LOCAL_PATH') + + # S3 Storage settings self.S3_ENDPOINT = get_env('S3_ENDPOINT') self.S3_BUCKET_NAME = get_env('S3_BUCKET_NAME') self.S3_ACCESS_KEY = get_env('S3_ACCESS_KEY') self.S3_SECRET_KEY = get_env('S3_SECRET_KEY') self.S3_REGION = get_env('S3_REGION') self.S3_ADDRESS_STYLE = get_env('S3_ADDRESS_STYLE') + + # Azure Blob Storage settings self.AZURE_BLOB_ACCOUNT_NAME = get_env('AZURE_BLOB_ACCOUNT_NAME') self.AZURE_BLOB_ACCOUNT_KEY = get_env('AZURE_BLOB_ACCOUNT_KEY') self.AZURE_BLOB_CONTAINER_NAME = get_env('AZURE_BLOB_CONTAINER_NAME') self.AZURE_BLOB_ACCOUNT_URL = get_env('AZURE_BLOB_ACCOUNT_URL') + + # Aliyun Storage settings self.ALIYUN_OSS_BUCKET_NAME = get_env('ALIYUN_OSS_BUCKET_NAME') self.ALIYUN_OSS_ACCESS_KEY = get_env('ALIYUN_OSS_ACCESS_KEY') self.ALIYUN_OSS_SECRET_KEY = get_env('ALIYUN_OSS_SECRET_KEY') self.ALIYUN_OSS_ENDPOINT = get_env('ALIYUN_OSS_ENDPOINT') self.ALIYUN_OSS_REGION = get_env('ALIYUN_OSS_REGION') self.ALIYUN_OSS_AUTH_VERSION = get_env('ALIYUN_OSS_AUTH_VERSION') + + # Google Cloud Storage settings self.GOOGLE_STORAGE_BUCKET_NAME = get_env('GOOGLE_STORAGE_BUCKET_NAME') self.GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64 = get_env('GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64') @@ -240,6 +256,7 @@ class Config: # ------------------------ self.VECTOR_STORE = get_env('VECTOR_STORE') self.KEYWORD_STORE = get_env('KEYWORD_STORE') + # qdrant settings self.QDRANT_URL = get_env('QDRANT_URL') self.QDRANT_API_KEY = get_env('QDRANT_API_KEY') @@ -323,6 +340,19 @@ class Config: self.UPLOAD_FILE_SIZE_LIMIT = int(get_env('UPLOAD_FILE_SIZE_LIMIT')) self.UPLOAD_FILE_BATCH_LIMIT = int(get_env('UPLOAD_FILE_BATCH_LIMIT')) self.UPLOAD_IMAGE_FILE_SIZE_LIMIT = int(get_env('UPLOAD_IMAGE_FILE_SIZE_LIMIT')) + self.BATCH_UPLOAD_LIMIT = get_env('BATCH_UPLOAD_LIMIT') + + # RAG ETL Configurations. + self.ETL_TYPE = get_env('ETL_TYPE') + self.UNSTRUCTURED_API_URL = get_env('UNSTRUCTURED_API_URL') + self.UNSTRUCTURED_API_KEY = get_env('UNSTRUCTURED_API_KEY') + self.KEYWORD_DATA_SOURCE_TYPE = get_env('KEYWORD_DATA_SOURCE_TYPE') + + # Indexing Configurations. + self.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH = get_env('INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH') + + # Tool Configurations. + self.TOOL_ICON_CACHE_MAX_AGE = get_env('TOOL_ICON_CACHE_MAX_AGE') self.WORKFLOW_MAX_EXECUTION_STEPS = int(get_env('WORKFLOW_MAX_EXECUTION_STEPS')) self.WORKFLOW_MAX_EXECUTION_TIME = int(get_env('WORKFLOW_MAX_EXECUTION_TIME')) @@ -378,24 +408,15 @@ class Config: self.HOSTED_FETCH_APP_TEMPLATES_MODE = get_env('HOSTED_FETCH_APP_TEMPLATES_MODE') self.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN = get_env('HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN') - self.ETL_TYPE = get_env('ETL_TYPE') - self.UNSTRUCTURED_API_URL = get_env('UNSTRUCTURED_API_URL') - self.UNSTRUCTURED_API_KEY = get_env('UNSTRUCTURED_API_KEY') + # Model Load Balancing Configurations. + self.MODEL_LB_ENABLED = get_bool_env('MODEL_LB_ENABLED') + + # Platform Billing Configurations. self.BILLING_ENABLED = get_bool_env('BILLING_ENABLED') - self.CAN_REPLACE_LOGO = get_bool_env('CAN_REPLACE_LOGO') - self.BATCH_UPLOAD_LIMIT = get_env('BATCH_UPLOAD_LIMIT') - - self.CODE_EXECUTION_ENDPOINT = get_env('CODE_EXECUTION_ENDPOINT') - self.CODE_EXECUTION_API_KEY = get_env('CODE_EXECUTION_API_KEY') - - self.API_COMPRESSION_ENABLED = get_bool_env('API_COMPRESSION_ENABLED') - self.TOOL_ICON_CACHE_MAX_AGE = get_env('TOOL_ICON_CACHE_MAX_AGE') - - self.KEYWORD_DATA_SOURCE_TYPE = get_env('KEYWORD_DATA_SOURCE_TYPE') + # ------------------------ + # Enterprise feature Configurations. + # **Before using, please contact business@dify.ai by email to inquire about licensing matters.** + # ------------------------ self.ENTERPRISE_ENABLED = get_bool_env('ENTERPRISE_ENABLED') - - # ------------------------ - # Indexing Configurations. - # ------------------------ - self.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH = get_env('INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH') + self.CAN_REPLACE_LOGO = get_bool_env('CAN_REPLACE_LOGO') diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 72ec05f654..306b7384cf 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -54,4 +54,4 @@ from .explore import ( from .tag import tags # Import workspace controllers -from .workspace import account, members, model_providers, models, tool_providers, workspace +from .workspace import account, load_balancing_config, members, model_providers, models, tool_providers, workspace diff --git a/api/controllers/console/feature.py b/api/controllers/console/feature.py index 7334f85a57..44d9d67522 100644 --- a/api/controllers/console/feature.py +++ b/api/controllers/console/feature.py @@ -1,14 +1,19 @@ from flask_login import current_user from flask_restful import Resource +from libs.login import login_required from services.feature_service import FeatureService from . import api -from .wraps import cloud_utm_record +from .setup import setup_required +from .wraps import account_initialization_required, cloud_utm_record class FeatureApi(Resource): + @setup_required + @login_required + @account_initialization_required @cloud_utm_record def get(self): return FeatureService.get_features(current_user.current_tenant_id).dict() diff --git a/api/controllers/console/version.py b/api/controllers/console/version.py index a50e4c41a8..faf36c4f40 100644 --- a/api/controllers/console/version.py +++ b/api/controllers/console/version.py @@ -17,13 +17,19 @@ class VersionApi(Resource): args = parser.parse_args() check_update_url = current_app.config['CHECK_UPDATE_URL'] - if not check_update_url: - return { - 'version': '0.0.0', - 'release_date': '', - 'release_notes': '', - 'can_auto_update': False + result = { + 'version': current_app.config['CURRENT_VERSION'], + 'release_date': '', + 'release_notes': '', + 'can_auto_update': False, + 'features': { + 'can_replace_logo': current_app.config['CAN_REPLACE_LOGO'], + 'model_load_balancing_enabled': current_app.config['MODEL_LB_ENABLED'] } + } + + if not check_update_url: + return result try: response = requests.get(check_update_url, { @@ -31,20 +37,15 @@ class VersionApi(Resource): }) except Exception as error: logging.warning("Check update version error: {}.".format(str(error))) - return { - 'version': args.get('current_version'), - 'release_date': '', - 'release_notes': '', - 'can_auto_update': False - } + result['version'] = args.get('current_version') + return result content = json.loads(response.content) - return { - 'version': content['version'], - 'release_date': content['releaseDate'], - 'release_notes': content['releaseNotes'], - 'can_auto_update': content['canAutoUpdate'] - } + result['version'] = content['version'] + result['release_date'] = content['releaseDate'] + result['release_notes'] = content['releaseNotes'] + result['can_auto_update'] = content['canAutoUpdate'] + return result api.add_resource(VersionApi, '/version') diff --git a/api/controllers/console/workspace/load_balancing_config.py b/api/controllers/console/workspace/load_balancing_config.py new file mode 100644 index 0000000000..50514e39f6 --- /dev/null +++ b/api/controllers/console/workspace/load_balancing_config.py @@ -0,0 +1,106 @@ +from flask_restful import Resource, reqparse +from werkzeug.exceptions import Forbidden + +from controllers.console import api +from controllers.console.setup import setup_required +from controllers.console.wraps import account_initialization_required +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from libs.login import current_user, login_required +from models.account import TenantAccountRole +from services.model_load_balancing_service import ModelLoadBalancingService + + +class LoadBalancingCredentialsValidateApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self, provider: str): + if not TenantAccountRole.is_privileged_role(current_user.current_tenant.current_role): + raise Forbidden() + + tenant_id = current_user.current_tenant_id + + parser = reqparse.RequestParser() + parser.add_argument('model', type=str, required=True, nullable=False, location='json') + parser.add_argument('model_type', type=str, required=True, nullable=False, + choices=[mt.value for mt in ModelType], location='json') + parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') + args = parser.parse_args() + + # validate model load balancing credentials + model_load_balancing_service = ModelLoadBalancingService() + + result = True + error = None + + try: + model_load_balancing_service.validate_load_balancing_credentials( + tenant_id=tenant_id, + provider=provider, + model=args['model'], + model_type=args['model_type'], + credentials=args['credentials'] + ) + except CredentialsValidateFailedError as ex: + result = False + error = str(ex) + + response = {'result': 'success' if result else 'error'} + + if not result: + response['error'] = error + + return response + + +class LoadBalancingConfigCredentialsValidateApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self, provider: str, config_id: str): + if not TenantAccountRole.is_privileged_role(current_user.current_tenant.current_role): + raise Forbidden() + + tenant_id = current_user.current_tenant_id + + parser = reqparse.RequestParser() + parser.add_argument('model', type=str, required=True, nullable=False, location='json') + parser.add_argument('model_type', type=str, required=True, nullable=False, + choices=[mt.value for mt in ModelType], location='json') + parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') + args = parser.parse_args() + + # validate model load balancing config credentials + model_load_balancing_service = ModelLoadBalancingService() + + result = True + error = None + + try: + model_load_balancing_service.validate_load_balancing_credentials( + tenant_id=tenant_id, + provider=provider, + model=args['model'], + model_type=args['model_type'], + credentials=args['credentials'], + config_id=config_id, + ) + except CredentialsValidateFailedError as ex: + result = False + error = str(ex) + + response = {'result': 'success' if result else 'error'} + + if not result: + response['error'] = error + + return response + + +# Load Balancing Config +api.add_resource(LoadBalancingCredentialsValidateApi, + '/workspaces/current/model-providers//models/load-balancing-configs/credentials-validate') + +api.add_resource(LoadBalancingConfigCredentialsValidateApi, + '/workspaces/current/model-providers//models/load-balancing-configs//credentials-validate') diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index 23239b1902..76ae6a4ab9 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -12,6 +12,7 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.utils.encoders import jsonable_encoder from libs.login import login_required from models.account import TenantAccountRole +from services.model_load_balancing_service import ModelLoadBalancingService from services.model_provider_service import ModelProviderService @@ -104,21 +105,56 @@ class ModelProviderModelApi(Resource): parser.add_argument('model', type=str, required=True, nullable=False, location='json') parser.add_argument('model_type', type=str, required=True, nullable=False, choices=[mt.value for mt in ModelType], location='json') - parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') + parser.add_argument('credentials', type=dict, required=False, nullable=True, location='json') + parser.add_argument('load_balancing', type=dict, required=False, nullable=True, location='json') + parser.add_argument('config_from', type=str, required=False, nullable=True, location='json') args = parser.parse_args() - model_provider_service = ModelProviderService() + model_load_balancing_service = ModelLoadBalancingService() - try: - model_provider_service.save_model_credentials( + if ('load_balancing' in args and args['load_balancing'] and + 'enabled' in args['load_balancing'] and args['load_balancing']['enabled']): + if 'configs' not in args['load_balancing']: + raise ValueError('invalid load balancing configs') + + # save load balancing configs + model_load_balancing_service.update_load_balancing_configs( tenant_id=tenant_id, provider=provider, model=args['model'], model_type=args['model_type'], - credentials=args['credentials'] + configs=args['load_balancing']['configs'] ) - except CredentialsValidateFailedError as ex: - raise ValueError(str(ex)) + + # enable load balancing + model_load_balancing_service.enable_model_load_balancing( + tenant_id=tenant_id, + provider=provider, + model=args['model'], + model_type=args['model_type'] + ) + else: + # disable load balancing + model_load_balancing_service.disable_model_load_balancing( + tenant_id=tenant_id, + provider=provider, + model=args['model'], + model_type=args['model_type'] + ) + + if args.get('config_from', '') != 'predefined-model': + model_provider_service = ModelProviderService() + + try: + model_provider_service.save_model_credentials( + tenant_id=tenant_id, + provider=provider, + model=args['model'], + model_type=args['model_type'], + credentials=args['credentials'] + ) + except CredentialsValidateFailedError as ex: + raise ValueError(str(ex)) return {'result': 'success'}, 200 @@ -170,11 +206,73 @@ class ModelProviderModelCredentialApi(Resource): model=args['model'] ) + model_load_balancing_service = ModelLoadBalancingService() + is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs( + tenant_id=tenant_id, + provider=provider, + model=args['model'], + model_type=args['model_type'] + ) + return { - "credentials": credentials + "credentials": credentials, + "load_balancing": { + "enabled": is_load_balancing_enabled, + "configs": load_balancing_configs + } } +class ModelProviderModelEnableApi(Resource): + + @setup_required + @login_required + @account_initialization_required + def patch(self, provider: str): + tenant_id = current_user.current_tenant_id + + parser = reqparse.RequestParser() + parser.add_argument('model', type=str, required=True, nullable=False, location='json') + parser.add_argument('model_type', type=str, required=True, nullable=False, + choices=[mt.value for mt in ModelType], location='json') + args = parser.parse_args() + + model_provider_service = ModelProviderService() + model_provider_service.enable_model( + tenant_id=tenant_id, + provider=provider, + model=args['model'], + model_type=args['model_type'] + ) + + return {'result': 'success'} + + +class ModelProviderModelDisableApi(Resource): + + @setup_required + @login_required + @account_initialization_required + def patch(self, provider: str): + tenant_id = current_user.current_tenant_id + + parser = reqparse.RequestParser() + parser.add_argument('model', type=str, required=True, nullable=False, location='json') + parser.add_argument('model_type', type=str, required=True, nullable=False, + choices=[mt.value for mt in ModelType], location='json') + args = parser.parse_args() + + model_provider_service = ModelProviderService() + model_provider_service.disable_model( + tenant_id=tenant_id, + provider=provider, + model=args['model'], + model_type=args['model_type'] + ) + + return {'result': 'success'} + + class ModelProviderModelValidateApi(Resource): @setup_required @@ -259,6 +357,10 @@ class ModelProviderAvailableModelApi(Resource): api.add_resource(ModelProviderModelApi, '/workspaces/current/model-providers//models') +api.add_resource(ModelProviderModelEnableApi, '/workspaces/current/model-providers//models/enable', + endpoint='model-provider-model-enable') +api.add_resource(ModelProviderModelDisableApi, '/workspaces/current/model-providers//models/disable', + endpoint='model-provider-model-disable') api.add_resource(ModelProviderModelCredentialApi, '/workspaces/current/model-providers//models/credentials') api.add_resource(ModelProviderModelValidateApi, diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index f1f426b27e..545463c8bd 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -1,6 +1,6 @@ import time from collections.abc import Generator -from typing import Optional, Union, cast +from typing import Optional, Union from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -16,11 +16,11 @@ from core.app.features.hosting_moderation.hosting_moderation import HostingModer from core.external_data_tool.external_data_fetch import ExternalDataFetch from core.file.file_obj import FileVar from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.errors.invoke import InvokeBadRequestError -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.moderation.input_moderation import InputModeration from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig @@ -45,8 +45,11 @@ class AppRunner: :param query: query :return: """ - model_type_instance = model_config.provider_model_bundle.model_type_instance - model_type_instance = cast(LargeLanguageModel, model_type_instance) + # Invoke model + model_instance = ModelInstance( + provider_model_bundle=model_config.provider_model_bundle, + model=model_config.model + ) model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) @@ -73,9 +76,7 @@ class AppRunner: query=query ) - prompt_tokens = model_type_instance.get_num_tokens( - model_config.model, - model_config.credentials, + prompt_tokens = model_instance.get_llm_num_tokens( prompt_messages ) @@ -89,8 +90,10 @@ class AppRunner: def recalc_llm_max_tokens(self, model_config: ModelConfigWithCredentialsEntity, prompt_messages: list[PromptMessage]): # recalc max_tokens if sum(prompt_token + max_tokens) over model token limit - model_type_instance = model_config.provider_model_bundle.model_type_instance - model_type_instance = cast(LargeLanguageModel, model_type_instance) + model_instance = ModelInstance( + provider_model_bundle=model_config.provider_model_bundle, + model=model_config.model + ) model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) @@ -107,9 +110,7 @@ class AppRunner: if max_tokens is None: max_tokens = 0 - prompt_tokens = model_type_instance.get_num_tokens( - model_config.model, - model_config.credentials, + prompt_tokens = model_instance.get_llm_num_tokens( prompt_messages ) diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index a7dbb4754c..f71470edb2 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -37,6 +37,7 @@ from core.app.entities.task_entities import ( ) from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.app.task_pipeline.message_cycle_manage import MessageCycleManage +from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -317,29 +318,30 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan """ model_config = self._model_config model = model_config.model - model_type_instance = model_config.provider_model_bundle.model_type_instance - model_type_instance = cast(LargeLanguageModel, model_type_instance) + + model_instance = ModelInstance( + provider_model_bundle=model_config.provider_model_bundle, + model=model_config.model + ) # calculate num tokens prompt_tokens = 0 if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY: - prompt_tokens = model_type_instance.get_num_tokens( - model, - model_config.credentials, + prompt_tokens = model_instance.get_llm_num_tokens( self._task_state.llm_result.prompt_messages ) completion_tokens = 0 if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL: - completion_tokens = model_type_instance.get_num_tokens( - model, - model_config.credentials, + completion_tokens = model_instance.get_llm_num_tokens( [self._task_state.llm_result.message] ) credentials = model_config.credentials # transform usage + model_type_instance = model_config.provider_model_bundle.model_type_instance + model_type_instance = cast(LargeLanguageModel, model_type_instance) self._task_state.llm_result.usage = model_type_instance._calc_response_usage( model, credentials, diff --git a/api/core/application_manager.py b/api/core/application_manager.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/core/entities/model_entities.py b/api/core/entities/model_entities.py index 05719e5b8d..9a797c1c95 100644 --- a/api/core/entities/model_entities.py +++ b/api/core/entities/model_entities.py @@ -16,6 +16,7 @@ class ModelStatus(Enum): NO_CONFIGURE = "no-configure" QUOTA_EXCEEDED = "quota-exceeded" NO_PERMISSION = "no-permission" + DISABLED = "disabled" class SimpleModelProviderEntity(BaseModel): @@ -43,12 +44,19 @@ class SimpleModelProviderEntity(BaseModel): ) -class ModelWithProviderEntity(ProviderModel): +class ProviderModelWithStatusEntity(ProviderModel): + """ + Model class for model response. + """ + status: ModelStatus + load_balancing_enabled: bool = False + + +class ModelWithProviderEntity(ProviderModelWithStatusEntity): """ Model with provider entity. """ provider: SimpleModelProviderEntity - status: ModelStatus class DefaultModelProviderEntity(BaseModel): diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 303034693d..ec1b2d0d48 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -1,6 +1,7 @@ import datetime import json import logging +from collections import defaultdict from collections.abc import Iterator from json import JSONDecodeError from typing import Optional @@ -8,7 +9,12 @@ from typing import Optional from pydantic import BaseModel from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity -from core.entities.provider_entities import CustomConfiguration, SystemConfiguration, SystemConfigurationStatus +from core.entities.provider_entities import ( + CustomConfiguration, + ModelSettings, + SystemConfiguration, + SystemConfigurationStatus, +) from core.helper import encrypter from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType from core.model_runtime.entities.model_entities import FetchFrom, ModelType @@ -22,7 +28,14 @@ from core.model_runtime.model_providers import model_provider_factory from core.model_runtime.model_providers.__base.ai_model import AIModel from core.model_runtime.model_providers.__base.model_provider import ModelProvider from extensions.ext_database import db -from models.provider import Provider, ProviderModel, ProviderType, TenantPreferredModelProvider +from models.provider import ( + LoadBalancingModelConfig, + Provider, + ProviderModel, + ProviderModelSetting, + ProviderType, + TenantPreferredModelProvider, +) logger = logging.getLogger(__name__) @@ -39,6 +52,7 @@ class ProviderConfiguration(BaseModel): using_provider_type: ProviderType system_configuration: SystemConfiguration custom_configuration: CustomConfiguration + model_settings: list[ModelSettings] def __init__(self, **data): super().__init__(**data) @@ -62,6 +76,14 @@ class ProviderConfiguration(BaseModel): :param model: model name :return: """ + if self.model_settings: + # check if model is disabled by admin + for model_setting in self.model_settings: + if (model_setting.model_type == model_type + and model_setting.model == model): + if not model_setting.enabled: + raise ValueError(f'Model {model} is disabled.') + if self.using_provider_type == ProviderType.SYSTEM: restrict_models = [] for quota_configuration in self.system_configuration.quota_configurations: @@ -80,15 +102,17 @@ class ProviderConfiguration(BaseModel): return copy_credentials else: + credentials = None if self.custom_configuration.models: for model_configuration in self.custom_configuration.models: if model_configuration.model_type == model_type and model_configuration.model == model: - return model_configuration.credentials + credentials = model_configuration.credentials + break if self.custom_configuration.provider: - return self.custom_configuration.provider.credentials - else: - return None + credentials = self.custom_configuration.provider.credentials + + return credentials def get_system_configuration_status(self) -> SystemConfigurationStatus: """ @@ -130,7 +154,7 @@ class ProviderConfiguration(BaseModel): return credentials # Obfuscate credentials - return self._obfuscated_credentials( + return self.obfuscated_credentials( credentials=credentials, credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas if self.provider.provider_credential_schema else [] @@ -151,7 +175,7 @@ class ProviderConfiguration(BaseModel): ).first() # Get provider credential secret variables - provider_credential_secret_variables = self._extract_secret_variables( + provider_credential_secret_variables = self.extract_secret_variables( self.provider.provider_credential_schema.credential_form_schemas if self.provider.provider_credential_schema else [] ) @@ -274,7 +298,7 @@ class ProviderConfiguration(BaseModel): return credentials # Obfuscate credentials - return self._obfuscated_credentials( + return self.obfuscated_credentials( credentials=credentials, credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas if self.provider.model_credential_schema else [] @@ -302,7 +326,7 @@ class ProviderConfiguration(BaseModel): ).first() # Get provider credential secret variables - provider_credential_secret_variables = self._extract_secret_variables( + provider_credential_secret_variables = self.extract_secret_variables( self.provider.model_credential_schema.credential_form_schemas if self.provider.model_credential_schema else [] ) @@ -402,6 +426,160 @@ class ProviderConfiguration(BaseModel): provider_model_credentials_cache.delete() + def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting: + """ + Enable model. + :param model_type: model type + :param model: model name + :return: + """ + model_setting = db.session.query(ProviderModelSetting) \ + .filter( + ProviderModelSetting.tenant_id == self.tenant_id, + ProviderModelSetting.provider_name == self.provider.provider, + ProviderModelSetting.model_type == model_type.to_origin_model_type(), + ProviderModelSetting.model_name == model + ).first() + + if model_setting: + model_setting.enabled = True + model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + db.session.commit() + else: + model_setting = ProviderModelSetting( + tenant_id=self.tenant_id, + provider_name=self.provider.provider, + model_type=model_type.to_origin_model_type(), + model_name=model, + enabled=True + ) + db.session.add(model_setting) + db.session.commit() + + return model_setting + + def disable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting: + """ + Disable model. + :param model_type: model type + :param model: model name + :return: + """ + model_setting = db.session.query(ProviderModelSetting) \ + .filter( + ProviderModelSetting.tenant_id == self.tenant_id, + ProviderModelSetting.provider_name == self.provider.provider, + ProviderModelSetting.model_type == model_type.to_origin_model_type(), + ProviderModelSetting.model_name == model + ).first() + + if model_setting: + model_setting.enabled = False + model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + db.session.commit() + else: + model_setting = ProviderModelSetting( + tenant_id=self.tenant_id, + provider_name=self.provider.provider, + model_type=model_type.to_origin_model_type(), + model_name=model, + enabled=False + ) + db.session.add(model_setting) + db.session.commit() + + return model_setting + + def get_provider_model_setting(self, model_type: ModelType, model: str) -> Optional[ProviderModelSetting]: + """ + Get provider model setting. + :param model_type: model type + :param model: model name + :return: + """ + return db.session.query(ProviderModelSetting) \ + .filter( + ProviderModelSetting.tenant_id == self.tenant_id, + ProviderModelSetting.provider_name == self.provider.provider, + ProviderModelSetting.model_type == model_type.to_origin_model_type(), + ProviderModelSetting.model_name == model + ).first() + + def enable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting: + """ + Enable model load balancing. + :param model_type: model type + :param model: model name + :return: + """ + load_balancing_config_count = db.session.query(LoadBalancingModelConfig) \ + .filter( + LoadBalancingModelConfig.tenant_id == self.tenant_id, + LoadBalancingModelConfig.provider_name == self.provider.provider, + LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_name == model + ).count() + + if load_balancing_config_count <= 1: + raise ValueError('Model load balancing configuration must be more than 1.') + + model_setting = db.session.query(ProviderModelSetting) \ + .filter( + ProviderModelSetting.tenant_id == self.tenant_id, + ProviderModelSetting.provider_name == self.provider.provider, + ProviderModelSetting.model_type == model_type.to_origin_model_type(), + ProviderModelSetting.model_name == model + ).first() + + if model_setting: + model_setting.load_balancing_enabled = True + model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + db.session.commit() + else: + model_setting = ProviderModelSetting( + tenant_id=self.tenant_id, + provider_name=self.provider.provider, + model_type=model_type.to_origin_model_type(), + model_name=model, + load_balancing_enabled=True + ) + db.session.add(model_setting) + db.session.commit() + + return model_setting + + def disable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting: + """ + Disable model load balancing. + :param model_type: model type + :param model: model name + :return: + """ + model_setting = db.session.query(ProviderModelSetting) \ + .filter( + ProviderModelSetting.tenant_id == self.tenant_id, + ProviderModelSetting.provider_name == self.provider.provider, + ProviderModelSetting.model_type == model_type.to_origin_model_type(), + ProviderModelSetting.model_name == model + ).first() + + if model_setting: + model_setting.load_balancing_enabled = False + model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + db.session.commit() + else: + model_setting = ProviderModelSetting( + tenant_id=self.tenant_id, + provider_name=self.provider.provider, + model_type=model_type.to_origin_model_type(), + model_name=model, + load_balancing_enabled=False + ) + db.session.add(model_setting) + db.session.commit() + + return model_setting + def get_provider_instance(self) -> ModelProvider: """ Get provider instance. @@ -453,7 +631,7 @@ class ProviderConfiguration(BaseModel): db.session.commit() - def _extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]: + def extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]: """ Extract secret input form variables. @@ -467,7 +645,7 @@ class ProviderConfiguration(BaseModel): return secret_input_form_variables - def _obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict: + def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict: """ Obfuscated credentials. @@ -476,7 +654,7 @@ class ProviderConfiguration(BaseModel): :return: """ # Get provider credential secret variables - credential_secret_variables = self._extract_secret_variables( + credential_secret_variables = self.extract_secret_variables( credential_form_schemas ) @@ -522,15 +700,22 @@ class ProviderConfiguration(BaseModel): else: model_types = provider_instance.get_provider_schema().supported_model_types + # Group model settings by model type and model + model_setting_map = defaultdict(dict) + for model_setting in self.model_settings: + model_setting_map[model_setting.model_type][model_setting.model] = model_setting + if self.using_provider_type == ProviderType.SYSTEM: provider_models = self._get_system_provider_models( model_types=model_types, - provider_instance=provider_instance + provider_instance=provider_instance, + model_setting_map=model_setting_map ) else: provider_models = self._get_custom_provider_models( model_types=model_types, - provider_instance=provider_instance + provider_instance=provider_instance, + model_setting_map=model_setting_map ) if only_active: @@ -541,18 +726,27 @@ class ProviderConfiguration(BaseModel): def _get_system_provider_models(self, model_types: list[ModelType], - provider_instance: ModelProvider) -> list[ModelWithProviderEntity]: + provider_instance: ModelProvider, + model_setting_map: dict[ModelType, dict[str, ModelSettings]]) \ + -> list[ModelWithProviderEntity]: """ Get system provider models. :param model_types: model types :param provider_instance: provider instance + :param model_setting_map: model setting map :return: """ provider_models = [] for model_type in model_types: - provider_models.extend( - [ + for m in provider_instance.models(model_type): + status = ModelStatus.ACTIVE + if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]: + model_setting = model_setting_map[m.model_type][m.model] + if model_setting.enabled is False: + status = ModelStatus.DISABLED + + provider_models.append( ModelWithProviderEntity( model=m.model, label=m.label, @@ -562,11 +756,9 @@ class ProviderConfiguration(BaseModel): model_properties=m.model_properties, deprecated=m.deprecated, provider=SimpleModelProviderEntity(self.provider), - status=ModelStatus.ACTIVE + status=status ) - for m in provider_instance.models(model_type) - ] - ) + ) if self.provider.provider not in original_provider_configurate_methods: original_provider_configurate_methods[self.provider.provider] = [] @@ -586,7 +778,8 @@ class ProviderConfiguration(BaseModel): break if should_use_custom_model: - if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]: + if original_provider_configurate_methods[self.provider.provider] == [ + ConfigurateMethod.CUSTOMIZABLE_MODEL]: # only customizable model for restrict_model in restrict_models: copy_credentials = self.system_configuration.credentials.copy() @@ -611,6 +804,13 @@ class ProviderConfiguration(BaseModel): if custom_model_schema.model_type not in model_types: continue + status = ModelStatus.ACTIVE + if (custom_model_schema.model_type in model_setting_map + and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]): + model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model] + if model_setting.enabled is False: + status = ModelStatus.DISABLED + provider_models.append( ModelWithProviderEntity( model=custom_model_schema.model, @@ -621,7 +821,7 @@ class ProviderConfiguration(BaseModel): model_properties=custom_model_schema.model_properties, deprecated=custom_model_schema.deprecated, provider=SimpleModelProviderEntity(self.provider), - status=ModelStatus.ACTIVE + status=status ) ) @@ -632,16 +832,20 @@ class ProviderConfiguration(BaseModel): m.status = ModelStatus.NO_PERMISSION elif not quota_configuration.is_valid: m.status = ModelStatus.QUOTA_EXCEEDED + return provider_models def _get_custom_provider_models(self, model_types: list[ModelType], - provider_instance: ModelProvider) -> list[ModelWithProviderEntity]: + provider_instance: ModelProvider, + model_setting_map: dict[ModelType, dict[str, ModelSettings]]) \ + -> list[ModelWithProviderEntity]: """ Get custom provider models. :param model_types: model types :param provider_instance: provider instance + :param model_setting_map: model setting map :return: """ provider_models = [] @@ -656,6 +860,16 @@ class ProviderConfiguration(BaseModel): models = provider_instance.models(model_type) for m in models: + status = ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE + load_balancing_enabled = False + if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]: + model_setting = model_setting_map[m.model_type][m.model] + if model_setting.enabled is False: + status = ModelStatus.DISABLED + + if len(model_setting.load_balancing_configs) > 1: + load_balancing_enabled = True + provider_models.append( ModelWithProviderEntity( model=m.model, @@ -666,7 +880,8 @@ class ProviderConfiguration(BaseModel): model_properties=m.model_properties, deprecated=m.deprecated, provider=SimpleModelProviderEntity(self.provider), - status=ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE + status=status, + load_balancing_enabled=load_balancing_enabled ) ) @@ -690,6 +905,17 @@ class ProviderConfiguration(BaseModel): if not custom_model_schema: continue + status = ModelStatus.ACTIVE + load_balancing_enabled = False + if (custom_model_schema.model_type in model_setting_map + and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]): + model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model] + if model_setting.enabled is False: + status = ModelStatus.DISABLED + + if len(model_setting.load_balancing_configs) > 1: + load_balancing_enabled = True + provider_models.append( ModelWithProviderEntity( model=custom_model_schema.model, @@ -700,7 +926,8 @@ class ProviderConfiguration(BaseModel): model_properties=custom_model_schema.model_properties, deprecated=custom_model_schema.deprecated, provider=SimpleModelProviderEntity(self.provider), - status=ModelStatus.ACTIVE + status=status, + load_balancing_enabled=load_balancing_enabled ) ) diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py index 114dfaf911..1eaa6ea02c 100644 --- a/api/core/entities/provider_entities.py +++ b/api/core/entities/provider_entities.py @@ -72,3 +72,22 @@ class CustomConfiguration(BaseModel): """ provider: Optional[CustomProviderConfiguration] = None models: list[CustomModelConfiguration] = [] + + +class ModelLoadBalancingConfiguration(BaseModel): + """ + Class for model load balancing configuration. + """ + id: str + name: str + credentials: dict + + +class ModelSettings(BaseModel): + """ + Model class for model settings. + """ + model: str + model_type: ModelType + enabled: bool = True + load_balancing_configs: list[ModelLoadBalancingConfiguration] = [] diff --git a/api/core/extension/extensible.py b/api/core/extension/extensible.py index d2ec555d6c..3a37c6492e 100644 --- a/api/core/extension/extensible.py +++ b/api/core/extension/extensible.py @@ -7,7 +7,7 @@ from typing import Any, Optional from pydantic import BaseModel -from core.utils.position_helper import sort_to_dict_by_position_map +from core.helper.position_helper import sort_to_dict_by_position_map class ExtensionModule(enum.Enum): diff --git a/api/core/helper/model_provider_cache.py b/api/core/helper/model_provider_cache.py index 81e589f65b..29cb4acc7d 100644 --- a/api/core/helper/model_provider_cache.py +++ b/api/core/helper/model_provider_cache.py @@ -9,6 +9,7 @@ from extensions.ext_redis import redis_client class ProviderCredentialsCacheType(Enum): PROVIDER = "provider" MODEL = "provider_model" + LOAD_BALANCING_MODEL = "load_balancing_provider_model" class ProviderCredentialsCache: diff --git a/api/core/utils/module_import_helper.py b/api/core/helper/module_import_helper.py similarity index 100% rename from api/core/utils/module_import_helper.py rename to api/core/helper/module_import_helper.py diff --git a/api/core/utils/position_helper.py b/api/core/helper/position_helper.py similarity index 100% rename from api/core/utils/position_helper.py rename to api/core/helper/position_helper.py diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index d4c2bc5ad5..7fa1d7d4b9 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -286,11 +286,7 @@ class IndexingRunner: if len(preview_texts) < 5: preview_texts.append(document.page_content) if indexing_technique == 'high_quality' or embedding_model_instance: - embedding_model_type_instance = embedding_model_instance.model_type_instance - embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance) - tokens += embedding_model_type_instance.get_num_tokens( - model=embedding_model_instance.model, - credentials=embedding_model_instance.credentials, + tokens += embedding_model_instance.get_text_embedding_num_tokens( texts=[self.filter_string(document.page_content)] ) @@ -658,10 +654,6 @@ class IndexingRunner: tokens = 0 chunk_size = 10 - embedding_model_type_instance = None - if embedding_model_instance: - embedding_model_type_instance = embedding_model_instance.model_type_instance - embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance) # create keyword index create_keyword_thread = threading.Thread(target=self._process_keyword_index, args=(current_app._get_current_object(), @@ -674,8 +666,7 @@ class IndexingRunner: chunk_documents = documents[i:i + chunk_size] futures.append(executor.submit(self._process_chunk, current_app._get_current_object(), index_processor, chunk_documents, dataset, - dataset_document, embedding_model_instance, - embedding_model_type_instance)) + dataset_document, embedding_model_instance)) for future in futures: tokens += future.result() @@ -716,7 +707,7 @@ class IndexingRunner: db.session.commit() def _process_chunk(self, flask_app, index_processor, chunk_documents, dataset, dataset_document, - embedding_model_instance, embedding_model_type_instance): + embedding_model_instance): with flask_app.app_context(): # check document is paused self._check_document_paused_status(dataset_document.id) @@ -724,9 +715,7 @@ class IndexingRunner: tokens = 0 if dataset.indexing_technique == 'high_quality' or embedding_model_type_instance: tokens += sum( - embedding_model_type_instance.get_num_tokens( - embedding_model_instance.model, - embedding_model_instance.credentials, + embedding_model_instance.get_text_embedding_num_tokens( [document.page_content] ) for document in chunk_documents diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index cd0b2508d4..6b53104c70 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -9,8 +9,6 @@ from core.model_runtime.entities.message_entities import ( TextPromptMessageContent, UserPromptMessage, ) -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.model_providers import model_provider_factory from extensions.ext_database import db from models.model import AppMode, Conversation, Message @@ -78,12 +76,7 @@ class TokenBufferMemory: return [] # prune the chat message if it exceeds the max token limit - provider_instance = model_provider_factory.get_provider_instance(self.model_instance.provider) - model_type_instance = provider_instance.get_model_instance(ModelType.LLM) - - curr_message_tokens = model_type_instance.get_num_tokens( - self.model_instance.model, - self.model_instance.credentials, + curr_message_tokens = self.model_instance.get_llm_num_tokens( prompt_messages ) @@ -91,9 +84,7 @@ class TokenBufferMemory: pruned_memory = [] while curr_message_tokens > max_token_limit and prompt_messages: pruned_memory.append(prompt_messages.pop(0)) - curr_message_tokens = model_type_instance.get_num_tokens( - self.model_instance.model, - self.model_instance.credentials, + curr_message_tokens = self.model_instance.get_llm_num_tokens( prompt_messages ) diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 8c06339927..8da8442e60 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -1,7 +1,10 @@ +import logging +import os from collections.abc import Generator from typing import IO, Optional, Union, cast -from core.entities.provider_configuration import ProviderModelBundle +from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle +from core.entities.provider_entities import ModelLoadBalancingConfiguration from core.errors.error import ProviderTokenNotInitError from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.entities.llm_entities import LLMResult @@ -9,6 +12,7 @@ from core.model_runtime.entities.message_entities import PromptMessage, PromptMe from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.rerank_entities import RerankResult from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.moderation_model import ModerationModel from core.model_runtime.model_providers.__base.rerank_model import RerankModel @@ -16,6 +20,10 @@ from core.model_runtime.model_providers.__base.speech2text_model import Speech2T from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.__base.tts_model import TTSModel from core.provider_manager import ProviderManager +from extensions.ext_redis import redis_client +from models.provider import ProviderType + +logger = logging.getLogger(__name__) class ModelInstance: @@ -29,6 +37,12 @@ class ModelInstance: self.provider = provider_model_bundle.configuration.provider.provider self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model) self.model_type_instance = self.provider_model_bundle.model_type_instance + self.load_balancing_manager = self._get_load_balancing_manager( + configuration=provider_model_bundle.configuration, + model_type=provider_model_bundle.model_type_instance.model_type, + model=model, + credentials=self.credentials + ) def _fetch_credentials_from_bundle(self, provider_model_bundle: ProviderModelBundle, model: str) -> dict: """ @@ -37,8 +51,10 @@ class ModelInstance: :param model: model name :return: """ - credentials = provider_model_bundle.configuration.get_current_credentials( - model_type=provider_model_bundle.model_type_instance.model_type, + configuration = provider_model_bundle.configuration + model_type = provider_model_bundle.model_type_instance.model_type + credentials = configuration.get_current_credentials( + model_type=model_type, model=model ) @@ -47,6 +63,43 @@ class ModelInstance: return credentials + def _get_load_balancing_manager(self, configuration: ProviderConfiguration, + model_type: ModelType, + model: str, + credentials: dict) -> Optional["LBModelManager"]: + """ + Get load balancing model credentials + :param configuration: provider configuration + :param model_type: model type + :param model: model name + :param credentials: model credentials + :return: + """ + if configuration.model_settings and configuration.using_provider_type == ProviderType.CUSTOM: + current_model_setting = None + # check if model is disabled by admin + for model_setting in configuration.model_settings: + if (model_setting.model_type == model_type + and model_setting.model == model): + current_model_setting = model_setting + break + + # check if load balancing is enabled + if current_model_setting and current_model_setting.load_balancing_configs: + # use load balancing proxy to choose credentials + lb_model_manager = LBModelManager( + tenant_id=configuration.tenant_id, + provider=configuration.provider.provider, + model_type=model_type, + model=model, + load_balancing_configs=current_model_setting.load_balancing_configs, + managed_credentials=credentials if configuration.custom_configuration.provider else None + ) + + return lb_model_manager + + return None + def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None, tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \ @@ -67,7 +120,8 @@ class ModelInstance: raise Exception("Model type instance is not LargeLanguageModel") self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance) - return self.model_type_instance.invoke( + return self._round_robin_invoke( + function=self.model_type_instance.invoke, model=self.model, credentials=self.credentials, prompt_messages=prompt_messages, @@ -79,6 +133,27 @@ class ModelInstance: callbacks=callbacks ) + def get_llm_num_tokens(self, prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None) -> int: + """ + Get number of tokens for llm + + :param prompt_messages: prompt messages + :param tools: tools for tool calling + :return: + """ + if not isinstance(self.model_type_instance, LargeLanguageModel): + raise Exception("Model type instance is not LargeLanguageModel") + + self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance) + return self._round_robin_invoke( + function=self.model_type_instance.get_num_tokens, + model=self.model, + credentials=self.credentials, + prompt_messages=prompt_messages, + tools=tools + ) + def invoke_text_embedding(self, texts: list[str], user: Optional[str] = None) \ -> TextEmbeddingResult: """ @@ -92,13 +167,32 @@ class ModelInstance: raise Exception("Model type instance is not TextEmbeddingModel") self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance) - return self.model_type_instance.invoke( + return self._round_robin_invoke( + function=self.model_type_instance.invoke, model=self.model, credentials=self.credentials, texts=texts, user=user ) + def get_text_embedding_num_tokens(self, texts: list[str]) -> int: + """ + Get number of tokens for text embedding + + :param texts: texts to embed + :return: + """ + if not isinstance(self.model_type_instance, TextEmbeddingModel): + raise Exception("Model type instance is not TextEmbeddingModel") + + self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance) + return self._round_robin_invoke( + function=self.model_type_instance.get_num_tokens, + model=self.model, + credentials=self.credentials, + texts=texts + ) + def invoke_rerank(self, query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, user: Optional[str] = None) \ @@ -117,7 +211,8 @@ class ModelInstance: raise Exception("Model type instance is not RerankModel") self.model_type_instance = cast(RerankModel, self.model_type_instance) - return self.model_type_instance.invoke( + return self._round_robin_invoke( + function=self.model_type_instance.invoke, model=self.model, credentials=self.credentials, query=query, @@ -140,7 +235,8 @@ class ModelInstance: raise Exception("Model type instance is not ModerationModel") self.model_type_instance = cast(ModerationModel, self.model_type_instance) - return self.model_type_instance.invoke( + return self._round_robin_invoke( + function=self.model_type_instance.invoke, model=self.model, credentials=self.credentials, text=text, @@ -160,7 +256,8 @@ class ModelInstance: raise Exception("Model type instance is not Speech2TextModel") self.model_type_instance = cast(Speech2TextModel, self.model_type_instance) - return self.model_type_instance.invoke( + return self._round_robin_invoke( + function=self.model_type_instance.invoke, model=self.model, credentials=self.credentials, file=file, @@ -183,7 +280,8 @@ class ModelInstance: raise Exception("Model type instance is not TTSModel") self.model_type_instance = cast(TTSModel, self.model_type_instance) - return self.model_type_instance.invoke( + return self._round_robin_invoke( + function=self.model_type_instance.invoke, model=self.model, credentials=self.credentials, content_text=content_text, @@ -193,6 +291,43 @@ class ModelInstance: streaming=streaming ) + def _round_robin_invoke(self, function: callable, *args, **kwargs): + """ + Round-robin invoke + :param function: function to invoke + :param args: function args + :param kwargs: function kwargs + :return: + """ + if not self.load_balancing_manager: + return function(*args, **kwargs) + + last_exception = None + while True: + lb_config = self.load_balancing_manager.fetch_next() + if not lb_config: + if not last_exception: + raise ProviderTokenNotInitError("Model credentials is not initialized.") + else: + raise last_exception + + try: + if 'credentials' in kwargs: + del kwargs['credentials'] + return function(*args, **kwargs, credentials=lb_config.credentials) + except InvokeRateLimitError as e: + # expire in 60 seconds + self.load_balancing_manager.cooldown(lb_config, expire=60) + last_exception = e + continue + except (InvokeAuthorizationError, InvokeConnectionError) as e: + # expire in 10 seconds + self.load_balancing_manager.cooldown(lb_config, expire=10) + last_exception = e + continue + except Exception as e: + raise e + def get_tts_voices(self, language: str) -> list: """ Invoke large language tts model voices @@ -226,6 +361,7 @@ class ModelManager: """ if not provider: return self.get_default_model_instance(tenant_id, model_type) + provider_model_bundle = self._provider_manager.get_provider_model_bundle( tenant_id=tenant_id, provider=provider, @@ -255,3 +391,141 @@ class ModelManager: model_type=model_type, model=default_model_entity.model ) + + +class LBModelManager: + def __init__(self, tenant_id: str, + provider: str, + model_type: ModelType, + model: str, + load_balancing_configs: list[ModelLoadBalancingConfiguration], + managed_credentials: Optional[dict] = None) -> None: + """ + Load balancing model manager + :param load_balancing_configs: all load balancing configurations + :param managed_credentials: credentials if load balancing configuration name is __inherit__ + """ + self._tenant_id = tenant_id + self._provider = provider + self._model_type = model_type + self._model = model + self._load_balancing_configs = load_balancing_configs + + for load_balancing_config in self._load_balancing_configs: + if load_balancing_config.name == "__inherit__": + if not managed_credentials: + # remove __inherit__ if managed credentials is not provided + self._load_balancing_configs.remove(load_balancing_config) + else: + load_balancing_config.credentials = managed_credentials + + def fetch_next(self) -> Optional[ModelLoadBalancingConfiguration]: + """ + Get next model load balancing config + Strategy: Round Robin + :return: + """ + cache_key = "model_lb_index:{}:{}:{}:{}".format( + self._tenant_id, + self._provider, + self._model_type.value, + self._model + ) + + cooldown_load_balancing_configs = [] + max_index = len(self._load_balancing_configs) + + while True: + current_index = redis_client.incr(cache_key) + if current_index >= 10000000: + current_index = 1 + redis_client.set(cache_key, current_index) + + redis_client.expire(cache_key, 3600) + if current_index > max_index: + current_index = current_index % max_index + + real_index = current_index - 1 + if real_index > max_index: + real_index = 0 + + config = self._load_balancing_configs[real_index] + + if self.in_cooldown(config): + cooldown_load_balancing_configs.append(config) + if len(cooldown_load_balancing_configs) >= len(self._load_balancing_configs): + # all configs are in cooldown + return None + + continue + + if bool(os.environ.get("DEBUG", 'False').lower() == 'true'): + logger.info(f"Model LB\nid: {config.id}\nname:{config.name}\n" + f"tenant_id: {self._tenant_id}\nprovider: {self._provider}\n" + f"model_type: {self._model_type.value}\nmodel: {self._model}") + + return config + + return None + + def cooldown(self, config: ModelLoadBalancingConfiguration, expire: int = 60) -> None: + """ + Cooldown model load balancing config + :param config: model load balancing config + :param expire: cooldown time + :return: + """ + cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format( + self._tenant_id, + self._provider, + self._model_type.value, + self._model, + config.id + ) + + redis_client.setex(cooldown_cache_key, expire, 'true') + + def in_cooldown(self, config: ModelLoadBalancingConfiguration) -> bool: + """ + Check if model load balancing config is in cooldown + :param config: model load balancing config + :return: + """ + cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format( + self._tenant_id, + self._provider, + self._model_type.value, + self._model, + config.id + ) + + return redis_client.exists(cooldown_cache_key) + + @classmethod + def get_config_in_cooldown_and_ttl(cls, tenant_id: str, + provider: str, + model_type: ModelType, + model: str, + config_id: str) -> tuple[bool, int]: + """ + Get model load balancing config is in cooldown and ttl + :param tenant_id: workspace id + :param provider: provider name + :param model_type: model type + :param model: model name + :param config_id: model load balancing config id + :return: + """ + cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format( + tenant_id, + provider, + model_type.value, + model, + config_id + ) + + ttl = redis_client.ttl(cooldown_cache_key) + if ttl == -2: + return False, 0 + + return True, ttl diff --git a/api/core/model_runtime/model_providers/__base/ai_model.py b/api/core/model_runtime/model_providers/__base/ai_model.py index cd243ca223..919e72554c 100644 --- a/api/core/model_runtime/model_providers/__base/ai_model.py +++ b/api/core/model_runtime/model_providers/__base/ai_model.py @@ -3,6 +3,7 @@ import os from abc import ABC, abstractmethod from typing import Optional +from core.helper.position_helper import get_position_map, sort_by_position_map from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE from core.model_runtime.entities.model_entities import ( @@ -17,7 +18,6 @@ from core.model_runtime.entities.model_entities import ( from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer from core.tools.utils.yaml_utils import load_yaml_file -from core.utils.position_helper import get_position_map, sort_by_position_map class AIModel(ABC): diff --git a/api/core/model_runtime/model_providers/__base/model_provider.py b/api/core/model_runtime/model_providers/__base/model_provider.py index 9ab78b7610..a893d023c0 100644 --- a/api/core/model_runtime/model_providers/__base/model_provider.py +++ b/api/core/model_runtime/model_providers/__base/model_provider.py @@ -1,11 +1,11 @@ import os from abc import ABC, abstractmethod +from core.helper.module_import_helper import get_subclasses_from_module, import_module_from_source from core.model_runtime.entities.model_entities import AIModelEntity, ModelType from core.model_runtime.entities.provider_entities import ProviderEntity from core.model_runtime.model_providers.__base.ai_model import AIModel from core.tools.utils.yaml_utils import load_yaml_file -from core.utils.module_import_helper import get_subclasses_from_module, import_module_from_source class ModelProvider(ABC): diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/core/model_runtime/model_providers/model_provider_factory.py index 44a1cf2e84..26c4199d16 100644 --- a/api/core/model_runtime/model_providers/model_provider_factory.py +++ b/api/core/model_runtime/model_providers/model_provider_factory.py @@ -4,13 +4,13 @@ from typing import Optional from pydantic import BaseModel +from core.helper.module_import_helper import load_single_subclass_from_source +from core.helper.position_helper import get_position_map, sort_to_dict_by_position_map from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity from core.model_runtime.model_providers.__base.model_provider import ModelProvider from core.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator from core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator -from core.utils.module_import_helper import load_single_subclass_from_source -from core.utils.position_helper import get_position_map, sort_to_dict_by_position_map logger = logging.getLogger(__name__) diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index 9bf2ae090f..d8e2d2f76d 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -1,10 +1,10 @@ -from typing import Optional, cast +from typing import Optional from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_manager import ModelInstance from core.model_runtime.entities.message_entities import PromptMessage from core.model_runtime.entities.model_entities import ModelPropertyKey -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.prompt.entities.advanced_prompt_entities import MemoryConfig @@ -25,12 +25,12 @@ class PromptTransform: model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) if model_context_tokens: - model_type_instance = model_config.provider_model_bundle.model_type_instance - model_type_instance = cast(LargeLanguageModel, model_type_instance) + model_instance = ModelInstance( + provider_model_bundle=model_config.provider_model_bundle, + model=model_config.model + ) - curr_message_tokens = model_type_instance.get_num_tokens( - model_config.model, - model_config.credentials, + curr_message_tokens = model_instance.get_llm_num_tokens( prompt_messages ) diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 0281ddad0a..5f00958ed3 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -11,6 +11,8 @@ from core.entities.provider_entities import ( CustomConfiguration, CustomModelConfiguration, CustomProviderConfiguration, + ModelLoadBalancingConfiguration, + ModelSettings, QuotaConfiguration, SystemConfiguration, ) @@ -26,13 +28,16 @@ from core.model_runtime.model_providers import model_provider_factory from extensions import ext_hosting_provider from extensions.ext_database import db from models.provider import ( + LoadBalancingModelConfig, Provider, ProviderModel, + ProviderModelSetting, ProviderQuotaType, ProviderType, TenantDefaultModel, TenantPreferredModelProvider, ) +from services.feature_service import FeatureService class ProviderManager: @@ -98,6 +103,13 @@ class ProviderManager: # Get All preferred provider types of the workspace provider_name_to_preferred_model_provider_records_dict = self._get_all_preferred_model_providers(tenant_id) + # Get All provider model settings + provider_name_to_provider_model_settings_dict = self._get_all_provider_model_settings(tenant_id) + + # Get All load balancing configs + provider_name_to_provider_load_balancing_model_configs_dict \ + = self._get_all_provider_load_balancing_configs(tenant_id) + provider_configurations = ProviderConfigurations( tenant_id=tenant_id ) @@ -147,13 +159,28 @@ class ProviderManager: if system_configuration.enabled and has_valid_quota: using_provider_type = ProviderType.SYSTEM + # Get provider load balancing configs + provider_model_settings = provider_name_to_provider_model_settings_dict.get(provider_name) + + # Get provider load balancing configs + provider_load_balancing_configs \ + = provider_name_to_provider_load_balancing_model_configs_dict.get(provider_name) + + # Convert to model settings + model_settings = self._to_model_settings( + provider_entity=provider_entity, + provider_model_settings=provider_model_settings, + load_balancing_model_configs=provider_load_balancing_configs + ) + provider_configuration = ProviderConfiguration( tenant_id=tenant_id, provider=provider_entity, preferred_provider_type=preferred_provider_type, using_provider_type=using_provider_type, system_configuration=system_configuration, - custom_configuration=custom_configuration + custom_configuration=custom_configuration, + model_settings=model_settings ) provider_configurations[provider_name] = provider_configuration @@ -338,7 +365,7 @@ class ProviderManager: """ Get All preferred provider types of the workspace. - :param tenant_id: + :param tenant_id: workspace id :return: """ preferred_provider_types = db.session.query(TenantPreferredModelProvider) \ @@ -353,6 +380,48 @@ class ProviderManager: return provider_name_to_preferred_provider_type_records_dict + def _get_all_provider_model_settings(self, tenant_id: str) -> dict[str, list[ProviderModelSetting]]: + """ + Get All provider model settings of the workspace. + + :param tenant_id: workspace id + :return: + """ + provider_model_settings = db.session.query(ProviderModelSetting) \ + .filter( + ProviderModelSetting.tenant_id == tenant_id + ).all() + + provider_name_to_provider_model_settings_dict = defaultdict(list) + for provider_model_setting in provider_model_settings: + (provider_name_to_provider_model_settings_dict[provider_model_setting.provider_name] + .append(provider_model_setting)) + + return provider_name_to_provider_model_settings_dict + + def _get_all_provider_load_balancing_configs(self, tenant_id: str) -> dict[str, list[LoadBalancingModelConfig]]: + """ + Get All provider load balancing configs of the workspace. + + :param tenant_id: workspace id + :return: + """ + model_load_balancing_enabled = FeatureService.get_features(tenant_id).model_load_balancing_enabled + if not model_load_balancing_enabled: + return dict() + + provider_load_balancing_configs = db.session.query(LoadBalancingModelConfig) \ + .filter( + LoadBalancingModelConfig.tenant_id == tenant_id + ).all() + + provider_name_to_provider_load_balancing_model_configs_dict = defaultdict(list) + for provider_load_balancing_config in provider_load_balancing_configs: + (provider_name_to_provider_load_balancing_model_configs_dict[provider_load_balancing_config.provider_name] + .append(provider_load_balancing_config)) + + return provider_name_to_provider_load_balancing_model_configs_dict + def _init_trial_provider_records(self, tenant_id: str, provider_name_to_provider_records_dict: dict[str, list]) -> dict[str, list]: """ @@ -726,3 +795,97 @@ class ProviderManager: secret_input_form_variables.append(credential_form_schema.variable) return secret_input_form_variables + + def _to_model_settings(self, provider_entity: ProviderEntity, + provider_model_settings: Optional[list[ProviderModelSetting]] = None, + load_balancing_model_configs: Optional[list[LoadBalancingModelConfig]] = None) \ + -> list[ModelSettings]: + """ + Convert to model settings. + + :param provider_model_settings: provider model settings include enabled, load balancing enabled + :param load_balancing_model_configs: load balancing model configs + :return: + """ + # Get provider model credential secret variables + model_credential_secret_variables = self._extract_secret_variables( + provider_entity.model_credential_schema.credential_form_schemas + if provider_entity.model_credential_schema else [] + ) + + model_settings = [] + if not provider_model_settings: + return model_settings + + for provider_model_setting in provider_model_settings: + load_balancing_configs = [] + if provider_model_setting.load_balancing_enabled and load_balancing_model_configs: + for load_balancing_model_config in load_balancing_model_configs: + if (load_balancing_model_config.model_name == provider_model_setting.model_name + and load_balancing_model_config.model_type == provider_model_setting.model_type): + if not load_balancing_model_config.enabled: + continue + + if not load_balancing_model_config.encrypted_config: + if load_balancing_model_config.name == "__inherit__": + load_balancing_configs.append(ModelLoadBalancingConfiguration( + id=load_balancing_model_config.id, + name=load_balancing_model_config.name, + credentials={} + )) + continue + + provider_model_credentials_cache = ProviderCredentialsCache( + tenant_id=load_balancing_model_config.tenant_id, + identity_id=load_balancing_model_config.id, + cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL + ) + + # Get cached provider model credentials + cached_provider_model_credentials = provider_model_credentials_cache.get() + + if not cached_provider_model_credentials: + try: + provider_model_credentials = json.loads(load_balancing_model_config.encrypted_config) + except JSONDecodeError: + continue + + # Get decoding rsa key and cipher for decrypting credentials + if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None: + self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding( + load_balancing_model_config.tenant_id) + + for variable in model_credential_secret_variables: + if variable in provider_model_credentials: + try: + provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding( + provider_model_credentials.get(variable), + self.decoding_rsa_key, + self.decoding_cipher_rsa + ) + except ValueError: + pass + + # cache provider model credentials + provider_model_credentials_cache.set( + credentials=provider_model_credentials + ) + else: + provider_model_credentials = cached_provider_model_credentials + + load_balancing_configs.append(ModelLoadBalancingConfiguration( + id=load_balancing_model_config.id, + name=load_balancing_model_config.name, + credentials=provider_model_credentials + )) + + model_settings.append( + ModelSettings( + model=provider_model_setting.model_name, + model_type=ModelType.value_of(provider_model_setting.model_type), + enabled=provider_model_setting.enabled, + load_balancing_configs=load_balancing_configs if len(load_balancing_configs) > 1 else [] + ) + ) + + return model_settings diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py index 95d8b9371d..96a15be742 100644 --- a/api/core/rag/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -1,11 +1,10 @@ from collections.abc import Sequence -from typing import Any, Optional, cast +from typing import Any, Optional from sqlalchemy import func from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.rag.models.document import Document from extensions.ext_database import db from models.dataset import Dataset, DocumentSegment @@ -95,11 +94,7 @@ class DatasetDocumentStore: # calc embedding use tokens if embedding_model: - model_type_instance = embedding_model.model_type_instance - model_type_instance = cast(TextEmbeddingModel, model_type_instance) - tokens = model_type_instance.get_num_tokens( - model=embedding_model.model, - credentials=embedding_model.credentials, + tokens = embedding_model.get_text_embedding_num_tokens( texts=[doc.page_content] ) else: diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py index fe6a89ebda..fd714edf5e 100644 --- a/api/core/rag/splitter/fixed_text_splitter.py +++ b/api/core/rag/splitter/fixed_text_splitter.py @@ -1,10 +1,9 @@ """Functionality for splitting text.""" from __future__ import annotations -from typing import Any, Optional, cast +from typing import Any, Optional from core.model_manager import ModelInstance -from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer from core.rag.splitter.text_splitter import ( TS, @@ -35,11 +34,7 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): return 0 if embedding_model_instance: - embedding_model_type_instance = embedding_model_instance.model_type_instance - embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance) - return embedding_model_type_instance.get_num_tokens( - model=embedding_model_instance.model, - credentials=embedding_model_instance.credentials, + return embedding_model_instance.get_text_embedding_num_tokens( texts=[text] ) else: diff --git a/api/core/tools/provider/builtin/_positions.py b/api/core/tools/provider/builtin/_positions.py index 851736dc7a..ae806eaff4 100644 --- a/api/core/tools/provider/builtin/_positions.py +++ b/api/core/tools/provider/builtin/_positions.py @@ -1,7 +1,7 @@ import os.path +from core.helper.position_helper import get_position_map, sort_by_position_map from core.tools.entities.api_entities import UserToolProvider -from core.utils.position_helper import get_position_map, sort_by_position_map class BuiltinToolProviderSort: diff --git a/api/core/tools/provider/builtin_tool_provider.py b/api/core/tools/provider/builtin_tool_provider.py index a7aa62b1ba..d076cb384f 100644 --- a/api/core/tools/provider/builtin_tool_provider.py +++ b/api/core/tools/provider/builtin_tool_provider.py @@ -2,6 +2,7 @@ from abc import abstractmethod from os import listdir, path from typing import Any +from core.helper.module_import_helper import load_single_subclass_from_source from core.tools.entities.tool_entities import ToolParameter, ToolProviderCredentials, ToolProviderType from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict from core.tools.errors import ( @@ -14,7 +15,6 @@ from core.tools.tool.builtin_tool import BuiltinTool from core.tools.tool.tool import Tool from core.tools.utils.tool_parameter_converter import ToolParameterConverter from core.tools.utils.yaml_utils import load_yaml_file -from core.utils.module_import_helper import load_single_subclass_from_source class BuiltinToolProviderController(ToolProviderController): @@ -82,7 +82,7 @@ class BuiltinToolProviderController(ToolProviderController): return {} return self.credentials_schema.copy() - + def get_tools(self) -> list[Tool]: """ returns a list of tools that the provider can provide @@ -127,7 +127,7 @@ class BuiltinToolProviderController(ToolProviderController): :return: type of the provider """ return ToolProviderType.BUILT_IN - + @property def tool_labels(self) -> list[str]: """ @@ -137,7 +137,7 @@ class BuiltinToolProviderController(ToolProviderController): """ label_enums = self._get_tool_labels() return [default_tool_label_dict[label].name for label in label_enums] - + def _get_tool_labels(self) -> list[ToolLabelEnum]: """ returns the labels of the provider diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 9def1f4740..a0ca9f692a 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -10,6 +10,7 @@ from flask import current_app from core.agent.entities import AgentToolEntity from core.app.entities.app_invoke_entities import InvokeFrom +from core.helper.module_import_helper import load_single_subclass_from_source from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.entities.api_entities import UserToolProvider, UserToolProviderTypeLiteral from core.tools.entities.common_entities import I18nObject @@ -31,7 +32,6 @@ from core.tools.utils.configuration import ( ToolParameterConfigurationManager, ) from core.tools.utils.tool_parameter_converter import ToolParameterConverter -from core.utils.module_import_helper import load_single_subclass_from_source from core.workflow.nodes.tool.entities import ToolEntity from extensions.ext_database import db from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider @@ -102,10 +102,10 @@ class ToolManager: raise ToolProviderNotFoundError(f'provider type {provider_type} not found') @classmethod - def get_tool_runtime(cls, provider_type: str, + def get_tool_runtime(cls, provider_type: str, provider_id: str, - tool_name: str, - tenant_id: str, + tool_name: str, + tenant_id: str, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT) \ -> Union[BuiltinTool, ApiTool]: @@ -222,7 +222,7 @@ class ToolManager: get the agent tool runtime """ tool_entity = cls.get_tool_runtime( - provider_type=agent_tool.provider_type, + provider_type=agent_tool.provider_type, provider_id=agent_tool.provider_id, tool_name=agent_tool.tool_name, tenant_id=tenant_id, @@ -235,7 +235,7 @@ class ToolManager: # check file types if parameter.type == ToolParameter.ToolParameterType.FILE: raise ValueError(f"file type parameter {parameter.name} not supported in agent") - + if parameter.form == ToolParameter.ToolParameterForm.FORM: # save tool parameter to tool entity memory value = cls._init_runtime_parameter(parameter, agent_tool.tool_parameters) @@ -403,7 +403,7 @@ class ToolManager: # get builtin providers builtin_providers = cls.list_builtin_providers() - + # get db builtin providers db_builtin_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider). \ filter(BuiltinToolProvider.tenant_id == tenant_id).all() @@ -428,7 +428,7 @@ class ToolManager: if 'api' in filters: db_api_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider). \ filter(ApiToolProvider.tenant_id == tenant_id).all() - + api_provider_controllers = [{ 'provider': provider, 'controller': ToolTransformService.api_provider_to_controller(provider) @@ -450,7 +450,7 @@ class ToolManager: # get workflow providers workflow_providers: list[WorkflowToolProvider] = db.session.query(WorkflowToolProvider). \ filter(WorkflowToolProvider.tenant_id == tenant_id).all() - + workflow_provider_controllers = [] for provider in workflow_providers: try: @@ -460,7 +460,7 @@ class ToolManager: except Exception as e: # app has been deleted pass - + labels = ToolLabelManager.get_tools_labels(workflow_provider_controllers) for provider_controller in workflow_provider_controllers: diff --git a/api/core/tools/utils/model_invocation_utils.py b/api/core/tools/utils/model_invocation_utils.py index 6526df6aa5..9e8ef47823 100644 --- a/api/core/tools/utils/model_invocation_utils.py +++ b/api/core/tools/utils/model_invocation_utils.py @@ -73,10 +73,8 @@ class ModelInvocationUtils: if not model_instance: raise InvokeModelError('Model not found') - llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) - # get tokens - tokens = llm_model.get_num_tokens(model_instance.model, model_instance.credentials, prompt_messages) + tokens = model_instance.get_llm_num_tokens(prompt_messages) return tokens @@ -108,13 +106,8 @@ class ModelInvocationUtils: tenant_id=tenant_id, model_type=ModelType.LLM, ) - llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) - - # get model credentials - model_credentials = model_instance.credentials - # get prompt tokens - prompt_tokens = llm_model.get_num_tokens(model_instance.model, model_credentials, prompt_messages) + prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages) model_parameters = { 'temperature': 0.8, @@ -144,9 +137,7 @@ class ModelInvocationUtils: db.session.commit() try: - response: LLMResult = llm_model.invoke( - model=model_instance.model, - credentials=model_credentials, + response: LLMResult = model_instance.invoke_llm( prompt_messages=prompt_messages, model_parameters=model_parameters, tools=[], stop=[], stream=False, user=user_id, callbacks=[] @@ -176,4 +167,4 @@ class ModelInvocationUtils: db.session.commit() - return response \ No newline at end of file + return response diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 1f59242e98..06c8e41959 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -4,9 +4,9 @@ from typing import Optional, Union, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_manager import ModelInstance from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole from core.model_runtime.entities.model_entities import ModelPropertyKey -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate @@ -200,12 +200,12 @@ class QuestionClassifierNode(LLMNode): model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) if model_context_tokens: - model_type_instance = model_config.provider_model_bundle.model_type_instance - model_type_instance = cast(LargeLanguageModel, model_type_instance) + model_instance = ModelInstance( + provider_model_bundle=model_config.provider_model_bundle, + model=model_config.model + ) - curr_message_tokens = model_type_instance.get_num_tokens( - model_config.model, - model_config.credentials, + curr_message_tokens = model_instance.get_llm_num_tokens( prompt_messages ) diff --git a/api/migrations/versions/4e99a8df00ff_add_load_balancing.py b/api/migrations/versions/4e99a8df00ff_add_load_balancing.py new file mode 100644 index 0000000000..67d7b9fbf5 --- /dev/null +++ b/api/migrations/versions/4e99a8df00ff_add_load_balancing.py @@ -0,0 +1,126 @@ +"""add load balancing + +Revision ID: 4e99a8df00ff +Revises: 47cc7df8c4f3 +Create Date: 2024-05-10 12:08:09.812736 + +""" +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = '4e99a8df00ff' +down_revision = '64a70a7aab8b' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('load_balancing_model_configs', + sa.Column('id', models.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.StringUUID(), nullable=False), + sa.Column('provider_name', sa.String(length=255), nullable=False), + sa.Column('model_name', sa.String(length=255), nullable=False), + sa.Column('model_type', sa.String(length=40), nullable=False), + sa.Column('name', sa.String(length=255), nullable=False), + sa.Column('encrypted_config', sa.Text(), nullable=True), + sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='load_balancing_model_config_pkey') + ) + with op.batch_alter_table('load_balancing_model_configs', schema=None) as batch_op: + batch_op.create_index('load_balancing_model_config_tenant_provider_model_idx', ['tenant_id', 'provider_name', 'model_type'], unique=False) + + op.create_table('provider_model_settings', + sa.Column('id', models.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.StringUUID(), nullable=False), + sa.Column('provider_name', sa.String(length=255), nullable=False), + sa.Column('model_name', sa.String(length=255), nullable=False), + sa.Column('model_type', sa.String(length=40), nullable=False), + sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False), + sa.Column('load_balancing_enabled', sa.Boolean(), server_default=sa.text('false'), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='provider_model_setting_pkey') + ) + with op.batch_alter_table('provider_model_settings', schema=None) as batch_op: + batch_op.create_index('provider_model_setting_tenant_provider_model_idx', ['tenant_id', 'provider_name', 'model_type'], unique=False) + + with op.batch_alter_table('provider_models', schema=None) as batch_op: + batch_op.alter_column('provider_name', + existing_type=sa.VARCHAR(length=40), + type_=sa.String(length=255), + existing_nullable=False) + + with op.batch_alter_table('provider_orders', schema=None) as batch_op: + batch_op.alter_column('provider_name', + existing_type=sa.VARCHAR(length=40), + type_=sa.String(length=255), + existing_nullable=False) + + with op.batch_alter_table('providers', schema=None) as batch_op: + batch_op.alter_column('provider_name', + existing_type=sa.VARCHAR(length=40), + type_=sa.String(length=255), + existing_nullable=False) + + with op.batch_alter_table('tenant_default_models', schema=None) as batch_op: + batch_op.alter_column('provider_name', + existing_type=sa.VARCHAR(length=40), + type_=sa.String(length=255), + existing_nullable=False) + + with op.batch_alter_table('tenant_preferred_model_providers', schema=None) as batch_op: + batch_op.alter_column('provider_name', + existing_type=sa.VARCHAR(length=40), + type_=sa.String(length=255), + existing_nullable=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tenant_preferred_model_providers', schema=None) as batch_op: + batch_op.alter_column('provider_name', + existing_type=sa.String(length=255), + type_=sa.VARCHAR(length=40), + existing_nullable=False) + + with op.batch_alter_table('tenant_default_models', schema=None) as batch_op: + batch_op.alter_column('provider_name', + existing_type=sa.String(length=255), + type_=sa.VARCHAR(length=40), + existing_nullable=False) + + with op.batch_alter_table('providers', schema=None) as batch_op: + batch_op.alter_column('provider_name', + existing_type=sa.String(length=255), + type_=sa.VARCHAR(length=40), + existing_nullable=False) + + with op.batch_alter_table('provider_orders', schema=None) as batch_op: + batch_op.alter_column('provider_name', + existing_type=sa.String(length=255), + type_=sa.VARCHAR(length=40), + existing_nullable=False) + + with op.batch_alter_table('provider_models', schema=None) as batch_op: + batch_op.alter_column('provider_name', + existing_type=sa.String(length=255), + type_=sa.VARCHAR(length=40), + existing_nullable=False) + + with op.batch_alter_table('provider_model_settings', schema=None) as batch_op: + batch_op.drop_index('provider_model_setting_tenant_provider_model_idx') + + op.drop_table('provider_model_settings') + with op.batch_alter_table('load_balancing_model_configs', schema=None) as batch_op: + batch_op.drop_index('load_balancing_model_config_tenant_provider_model_idx') + + op.drop_table('load_balancing_model_configs') + # ### end Alembic commands ### diff --git a/api/models/provider.py b/api/models/provider.py index eb6ec4beb4..4c14c33f09 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -47,7 +47,7 @@ class Provider(db.Model): id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) tenant_id = db.Column(StringUUID, nullable=False) - provider_name = db.Column(db.String(40), nullable=False) + provider_name = db.Column(db.String(255), nullable=False) provider_type = db.Column(db.String(40), nullable=False, server_default=db.text("'custom'::character varying")) encrypted_config = db.Column(db.Text, nullable=True) is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) @@ -94,7 +94,7 @@ class ProviderModel(db.Model): id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) tenant_id = db.Column(StringUUID, nullable=False) - provider_name = db.Column(db.String(40), nullable=False) + provider_name = db.Column(db.String(255), nullable=False) model_name = db.Column(db.String(255), nullable=False) model_type = db.Column(db.String(40), nullable=False) encrypted_config = db.Column(db.Text, nullable=True) @@ -112,7 +112,7 @@ class TenantDefaultModel(db.Model): id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) tenant_id = db.Column(StringUUID, nullable=False) - provider_name = db.Column(db.String(40), nullable=False) + provider_name = db.Column(db.String(255), nullable=False) model_name = db.Column(db.String(255), nullable=False) model_type = db.Column(db.String(40), nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) @@ -128,7 +128,7 @@ class TenantPreferredModelProvider(db.Model): id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) tenant_id = db.Column(StringUUID, nullable=False) - provider_name = db.Column(db.String(40), nullable=False) + provider_name = db.Column(db.String(255), nullable=False) preferred_provider_type = db.Column(db.String(40), nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) @@ -143,7 +143,7 @@ class ProviderOrder(db.Model): id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) tenant_id = db.Column(StringUUID, nullable=False) - provider_name = db.Column(db.String(40), nullable=False) + provider_name = db.Column(db.String(255), nullable=False) account_id = db.Column(StringUUID, nullable=False) payment_product_id = db.Column(db.String(191), nullable=False) payment_id = db.Column(db.String(191)) @@ -157,3 +157,46 @@ class ProviderOrder(db.Model): refunded_at = db.Column(db.DateTime) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + + +class ProviderModelSetting(db.Model): + """ + Provider model settings for record the model enabled status and load balancing status. + """ + __tablename__ = 'provider_model_settings' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='provider_model_setting_pkey'), + db.Index('provider_model_setting_tenant_provider_model_idx', 'tenant_id', 'provider_name', 'model_type'), + ) + + id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(StringUUID, nullable=False) + provider_name = db.Column(db.String(255), nullable=False) + model_name = db.Column(db.String(255), nullable=False) + model_type = db.Column(db.String(40), nullable=False) + enabled = db.Column(db.Boolean, nullable=False, server_default=db.text('true')) + load_balancing_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + + +class LoadBalancingModelConfig(db.Model): + """ + Configurations for load balancing models. + """ + __tablename__ = 'load_balancing_model_configs' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='load_balancing_model_config_pkey'), + db.Index('load_balancing_model_config_tenant_provider_model_idx', 'tenant_id', 'provider_name', 'model_type'), + ) + + id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(StringUUID, nullable=False) + provider_name = db.Column(db.String(255), nullable=False) + model_name = db.Column(db.String(255), nullable=False) + model_type = db.Column(db.String(40), nullable=False) + name = db.Column(db.String(255), nullable=False) + encrypted_config = db.Column(db.Text, nullable=True) + enabled = db.Column(db.Boolean, nullable=False, server_default=db.text('true')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 267db740f8..06d3e9ec40 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -4,7 +4,7 @@ import logging import random import time import uuid -from typing import Optional, cast +from typing import Optional from flask import current_app from flask_login import current_user @@ -13,7 +13,6 @@ from sqlalchemy import func from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.models.document import Document as RAGDocument from events.dataset_event import dataset_was_deleted @@ -1144,10 +1143,7 @@ class SegmentService: model=dataset.embedding_model ) # calc embedding use tokens - model_type_instance = cast(TextEmbeddingModel, embedding_model.model_type_instance) - tokens = model_type_instance.get_num_tokens( - model=embedding_model.model, - credentials=embedding_model.credentials, + tokens = embedding_model.get_text_embedding_num_tokens( texts=[content] ) lock_name = 'add_segment_lock_document_id_{}'.format(document.id) @@ -1215,10 +1211,7 @@ class SegmentService: tokens = 0 if dataset.indexing_technique == 'high_quality' and embedding_model: # calc embedding use tokens - model_type_instance = cast(TextEmbeddingModel, embedding_model.model_type_instance) - tokens = model_type_instance.get_num_tokens( - model=embedding_model.model, - credentials=embedding_model.credentials, + tokens = embedding_model.get_text_embedding_num_tokens( texts=[content] ) segment_document = DocumentSegment( @@ -1321,10 +1314,7 @@ class SegmentService: ) # calc embedding use tokens - model_type_instance = cast(TextEmbeddingModel, embedding_model.model_type_instance) - tokens = model_type_instance.get_num_tokens( - model=embedding_model.model, - credentials=embedding_model.credentials, + tokens = embedding_model.get_text_embedding_num_tokens( texts=[content] ) segment.content = content diff --git a/api/services/entities/model_provider_entities.py b/api/services/entities/model_provider_entities.py index 6cdd5090ae..77bb5e08c3 100644 --- a/api/services/entities/model_provider_entities.py +++ b/api/services/entities/model_provider_entities.py @@ -4,10 +4,10 @@ from typing import Optional from flask import current_app from pydantic import BaseModel -from core.entities.model_entities import ModelStatus, ModelWithProviderEntity +from core.entities.model_entities import ModelWithProviderEntity, ProviderModelWithStatusEntity from core.entities.provider_entities import QuotaConfiguration from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.model_entities import ModelType, ProviderModel +from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.provider_entities import ( ConfigurateMethod, ModelCredentialSchema, @@ -79,13 +79,6 @@ class ProviderResponse(BaseModel): ) -class ModelResponse(ProviderModel): - """ - Model class for model response. - """ - status: ModelStatus - - class ProviderWithModelsResponse(BaseModel): """ Model class for provider with models response. @@ -95,7 +88,7 @@ class ProviderWithModelsResponse(BaseModel): icon_small: Optional[I18nObject] = None icon_large: Optional[I18nObject] = None status: CustomConfigurationStatus - models: list[ModelResponse] + models: list[ProviderModelWithStatusEntity] def __init__(self, **data) -> None: super().__init__(**data) diff --git a/api/services/feature_service.py b/api/services/feature_service.py index 29842d68b7..36cbc3902b 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -29,6 +29,7 @@ class FeatureModel(BaseModel): documents_upload_quota: LimitationModel = LimitationModel(size=0, limit=50) docs_processing: str = 'standard' can_replace_logo: bool = False + model_load_balancing_enabled: bool = False class SystemFeatureModel(BaseModel): @@ -63,6 +64,7 @@ class FeatureService: @classmethod def _fulfill_params_from_env(cls, features: FeatureModel): features.can_replace_logo = current_app.config['CAN_REPLACE_LOGO'] + features.model_load_balancing_enabled = current_app.config['MODEL_LB_ENABLED'] @classmethod def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str): @@ -72,23 +74,34 @@ class FeatureService: features.billing.subscription.plan = billing_info['subscription']['plan'] features.billing.subscription.interval = billing_info['subscription']['interval'] - features.members.size = billing_info['members']['size'] - features.members.limit = billing_info['members']['limit'] + if 'members' in billing_info: + features.members.size = billing_info['members']['size'] + features.members.limit = billing_info['members']['limit'] - features.apps.size = billing_info['apps']['size'] - features.apps.limit = billing_info['apps']['limit'] + if 'apps' in billing_info: + features.apps.size = billing_info['apps']['size'] + features.apps.limit = billing_info['apps']['limit'] - features.vector_space.size = billing_info['vector_space']['size'] - features.vector_space.limit = billing_info['vector_space']['limit'] + if 'vector_space' in billing_info: + features.vector_space.size = billing_info['vector_space']['size'] + features.vector_space.limit = billing_info['vector_space']['limit'] - features.documents_upload_quota.size = billing_info['documents_upload_quota']['size'] - features.documents_upload_quota.limit = billing_info['documents_upload_quota']['limit'] + if 'documents_upload_quota' in billing_info: + features.documents_upload_quota.size = billing_info['documents_upload_quota']['size'] + features.documents_upload_quota.limit = billing_info['documents_upload_quota']['limit'] - features.annotation_quota_limit.size = billing_info['annotation_quota_limit']['size'] - features.annotation_quota_limit.limit = billing_info['annotation_quota_limit']['limit'] + if 'annotation_quota_limit' in billing_info: + features.annotation_quota_limit.size = billing_info['annotation_quota_limit']['size'] + features.annotation_quota_limit.limit = billing_info['annotation_quota_limit']['limit'] - features.docs_processing = billing_info['docs_processing'] - features.can_replace_logo = billing_info['can_replace_logo'] + if 'docs_processing' in billing_info: + features.docs_processing = billing_info['docs_processing'] + + if 'can_replace_logo' in billing_info: + features.can_replace_logo = billing_info['can_replace_logo'] + + if 'model_load_balancing_enabled' in billing_info: + features.model_load_balancing_enabled = billing_info['model_load_balancing_enabled'] @classmethod def _fulfill_params_from_enterprise(cls, features): diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py new file mode 100644 index 0000000000..c684c2862b --- /dev/null +++ b/api/services/model_load_balancing_service.py @@ -0,0 +1,565 @@ +import datetime +import json +import logging +from json import JSONDecodeError +from typing import Optional + +from core.entities.provider_configuration import ProviderConfiguration +from core.helper import encrypter +from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType +from core.model_manager import LBModelManager +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.entities.provider_entities import ( + ModelCredentialSchema, + ProviderCredentialSchema, +) +from core.model_runtime.model_providers import model_provider_factory +from core.provider_manager import ProviderManager +from extensions.ext_database import db +from models.provider import LoadBalancingModelConfig + +logger = logging.getLogger(__name__) + + +class ModelLoadBalancingService: + + def __init__(self) -> None: + self.provider_manager = ProviderManager() + + def enable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str) -> None: + """ + enable model load balancing. + + :param tenant_id: workspace id + :param provider: provider name + :param model: model name + :param model_type: model type + :return: + """ + # Get all provider configurations of the current workspace + provider_configurations = self.provider_manager.get_configurations(tenant_id) + + # Get provider configuration + provider_configuration = provider_configurations.get(provider) + if not provider_configuration: + raise ValueError(f"Provider {provider} does not exist.") + + # Enable model load balancing + provider_configuration.enable_model_load_balancing( + model=model, + model_type=ModelType.value_of(model_type) + ) + + def disable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str) -> None: + """ + disable model load balancing. + + :param tenant_id: workspace id + :param provider: provider name + :param model: model name + :param model_type: model type + :return: + """ + # Get all provider configurations of the current workspace + provider_configurations = self.provider_manager.get_configurations(tenant_id) + + # Get provider configuration + provider_configuration = provider_configurations.get(provider) + if not provider_configuration: + raise ValueError(f"Provider {provider} does not exist.") + + # disable model load balancing + provider_configuration.disable_model_load_balancing( + model=model, + model_type=ModelType.value_of(model_type) + ) + + def get_load_balancing_configs(self, tenant_id: str, provider: str, model: str, model_type: str) \ + -> tuple[bool, list[dict]]: + """ + Get load balancing configurations. + :param tenant_id: workspace id + :param provider: provider name + :param model: model name + :param model_type: model type + :return: + """ + # Get all provider configurations of the current workspace + provider_configurations = self.provider_manager.get_configurations(tenant_id) + + # Get provider configuration + provider_configuration = provider_configurations.get(provider) + if not provider_configuration: + raise ValueError(f"Provider {provider} does not exist.") + + # Convert model type to ModelType + model_type = ModelType.value_of(model_type) + + # Get provider model setting + provider_model_setting = provider_configuration.get_provider_model_setting( + model_type=model_type, + model=model, + ) + + is_load_balancing_enabled = False + if provider_model_setting and provider_model_setting.load_balancing_enabled: + is_load_balancing_enabled = True + + # Get load balancing configurations + load_balancing_configs = db.session.query(LoadBalancingModelConfig) \ + .filter( + LoadBalancingModelConfig.tenant_id == tenant_id, + LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, + LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_name == model + ).order_by(LoadBalancingModelConfig.created_at).all() + + if provider_configuration.custom_configuration.provider: + # check if the inherit configuration exists, + # inherit is represented for the provider or model custom credentials + inherit_config_exists = False + for load_balancing_config in load_balancing_configs: + if load_balancing_config.name == '__inherit__': + inherit_config_exists = True + break + + if not inherit_config_exists: + # Initialize the inherit configuration + inherit_config = self._init_inherit_config(tenant_id, provider, model, model_type) + + # prepend the inherit configuration + load_balancing_configs.insert(0, inherit_config) + else: + # move the inherit configuration to the first + for i, load_balancing_config in enumerate(load_balancing_configs): + if load_balancing_config.name == '__inherit__': + inherit_config = load_balancing_configs.pop(i) + load_balancing_configs.insert(0, inherit_config) + + # Get credential form schemas from model credential schema or provider credential schema + credential_schemas = self._get_credential_schema(provider_configuration) + + # Get decoding rsa key and cipher for decrypting credentials + decoding_rsa_key, decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id) + + # fetch status and ttl for each config + datas = [] + for load_balancing_config in load_balancing_configs: + in_cooldown, ttl = LBModelManager.get_config_in_cooldown_and_ttl( + tenant_id=tenant_id, + provider=provider, + model=model, + model_type=model_type, + config_id=load_balancing_config.id + ) + + try: + if load_balancing_config.encrypted_config: + credentials = json.loads(load_balancing_config.encrypted_config) + else: + credentials = {} + except JSONDecodeError: + credentials = {} + + # Get provider credential secret variables + credential_secret_variables = provider_configuration.extract_secret_variables( + credential_schemas.credential_form_schemas + ) + + # decrypt credentials + for variable in credential_secret_variables: + if variable in credentials: + try: + credentials[variable] = encrypter.decrypt_token_with_decoding( + credentials.get(variable), + decoding_rsa_key, + decoding_cipher_rsa + ) + except ValueError: + pass + + # Obfuscate credentials + credentials = provider_configuration.obfuscated_credentials( + credentials=credentials, + credential_form_schemas=credential_schemas.credential_form_schemas + ) + + datas.append({ + 'id': load_balancing_config.id, + 'name': load_balancing_config.name, + 'credentials': credentials, + 'enabled': load_balancing_config.enabled, + 'in_cooldown': in_cooldown, + 'ttl': ttl + }) + + return is_load_balancing_enabled, datas + + def get_load_balancing_config(self, tenant_id: str, provider: str, model: str, model_type: str, config_id: str) \ + -> Optional[dict]: + """ + Get load balancing configuration. + :param tenant_id: workspace id + :param provider: provider name + :param model: model name + :param model_type: model type + :param config_id: load balancing config id + :return: + """ + # Get all provider configurations of the current workspace + provider_configurations = self.provider_manager.get_configurations(tenant_id) + + # Get provider configuration + provider_configuration = provider_configurations.get(provider) + if not provider_configuration: + raise ValueError(f"Provider {provider} does not exist.") + + # Convert model type to ModelType + model_type = ModelType.value_of(model_type) + + # Get load balancing configurations + load_balancing_model_config = db.session.query(LoadBalancingModelConfig) \ + .filter( + LoadBalancingModelConfig.tenant_id == tenant_id, + LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, + LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_name == model, + LoadBalancingModelConfig.id == config_id + ).first() + + if not load_balancing_model_config: + return None + + try: + if load_balancing_model_config.encrypted_config: + credentials = json.loads(load_balancing_model_config.encrypted_config) + else: + credentials = {} + except JSONDecodeError: + credentials = {} + + # Get credential form schemas from model credential schema or provider credential schema + credential_schemas = self._get_credential_schema(provider_configuration) + + # Obfuscate credentials + credentials = provider_configuration.obfuscated_credentials( + credentials=credentials, + credential_form_schemas=credential_schemas.credential_form_schemas + ) + + return { + 'id': load_balancing_model_config.id, + 'name': load_balancing_model_config.name, + 'credentials': credentials, + 'enabled': load_balancing_model_config.enabled + } + + def _init_inherit_config(self, tenant_id: str, provider: str, model: str, model_type: ModelType) \ + -> LoadBalancingModelConfig: + """ + Initialize the inherit configuration. + :param tenant_id: workspace id + :param provider: provider name + :param model: model name + :param model_type: model type + :return: + """ + # Initialize the inherit configuration + inherit_config = LoadBalancingModelConfig( + tenant_id=tenant_id, + provider_name=provider, + model_type=model_type.to_origin_model_type(), + model_name=model, + name='__inherit__' + ) + db.session.add(inherit_config) + db.session.commit() + + return inherit_config + + def update_load_balancing_configs(self, tenant_id: str, + provider: str, + model: str, + model_type: str, + configs: list[dict]) -> None: + """ + Update load balancing configurations. + :param tenant_id: workspace id + :param provider: provider name + :param model: model name + :param model_type: model type + :param configs: load balancing configs + :return: + """ + # Get all provider configurations of the current workspace + provider_configurations = self.provider_manager.get_configurations(tenant_id) + + # Get provider configuration + provider_configuration = provider_configurations.get(provider) + if not provider_configuration: + raise ValueError(f"Provider {provider} does not exist.") + + # Convert model type to ModelType + model_type = ModelType.value_of(model_type) + + if not isinstance(configs, list): + raise ValueError('Invalid load balancing configs') + + current_load_balancing_configs = db.session.query(LoadBalancingModelConfig) \ + .filter( + LoadBalancingModelConfig.tenant_id == tenant_id, + LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, + LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_name == model + ).all() + + # id as key, config as value + current_load_balancing_configs_dict = {config.id: config for config in current_load_balancing_configs} + updated_config_ids = set() + + for config in configs: + if not isinstance(config, dict): + raise ValueError('Invalid load balancing config') + + config_id = config.get('id') + name = config.get('name') + credentials = config.get('credentials') + enabled = config.get('enabled') + + if not name: + raise ValueError('Invalid load balancing config name') + + if enabled is None: + raise ValueError('Invalid load balancing config enabled') + + # is config exists + if config_id: + config_id = str(config_id) + + if config_id not in current_load_balancing_configs_dict: + raise ValueError('Invalid load balancing config id: {}'.format(config_id)) + + updated_config_ids.add(config_id) + + load_balancing_config = current_load_balancing_configs_dict[config_id] + + # check duplicate name + for current_load_balancing_config in current_load_balancing_configs: + if current_load_balancing_config.id != config_id and current_load_balancing_config.name == name: + raise ValueError('Load balancing config name {} already exists'.format(name)) + + if credentials: + if not isinstance(credentials, dict): + raise ValueError('Invalid load balancing config credentials') + + # validate custom provider config + credentials = self._custom_credentials_validate( + tenant_id=tenant_id, + provider_configuration=provider_configuration, + model_type=model_type, + model=model, + credentials=credentials, + load_balancing_model_config=load_balancing_config, + validate=False + ) + + # update load balancing config + load_balancing_config.encrypted_config = json.dumps(credentials) + + load_balancing_config.name = name + load_balancing_config.enabled = enabled + load_balancing_config.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + db.session.commit() + + self._clear_credentials_cache(tenant_id, config_id) + else: + # create load balancing config + if name == '__inherit__': + raise ValueError('Invalid load balancing config name') + + # check duplicate name + for current_load_balancing_config in current_load_balancing_configs: + if current_load_balancing_config.name == name: + raise ValueError('Load balancing config name {} already exists'.format(name)) + + if not credentials: + raise ValueError('Invalid load balancing config credentials') + + if not isinstance(credentials, dict): + raise ValueError('Invalid load balancing config credentials') + + # validate custom provider config + credentials = self._custom_credentials_validate( + tenant_id=tenant_id, + provider_configuration=provider_configuration, + model_type=model_type, + model=model, + credentials=credentials, + validate=False + ) + + # create load balancing config + load_balancing_model_config = LoadBalancingModelConfig( + tenant_id=tenant_id, + provider_name=provider_configuration.provider.provider, + model_type=model_type.to_origin_model_type(), + model_name=model, + name=name, + encrypted_config=json.dumps(credentials) + ) + + db.session.add(load_balancing_model_config) + db.session.commit() + + # get deleted config ids + deleted_config_ids = set(current_load_balancing_configs_dict.keys()) - updated_config_ids + for config_id in deleted_config_ids: + db.session.delete(current_load_balancing_configs_dict[config_id]) + db.session.commit() + + self._clear_credentials_cache(tenant_id, config_id) + + def validate_load_balancing_credentials(self, tenant_id: str, + provider: str, + model: str, + model_type: str, + credentials: dict, + config_id: Optional[str] = None) -> None: + """ + Validate load balancing credentials. + :param tenant_id: workspace id + :param provider: provider name + :param model_type: model type + :param model: model name + :param credentials: credentials + :param config_id: load balancing config id + :return: + """ + # Get all provider configurations of the current workspace + provider_configurations = self.provider_manager.get_configurations(tenant_id) + + # Get provider configuration + provider_configuration = provider_configurations.get(provider) + if not provider_configuration: + raise ValueError(f"Provider {provider} does not exist.") + + # Convert model type to ModelType + model_type = ModelType.value_of(model_type) + + load_balancing_model_config = None + if config_id: + # Get load balancing config + load_balancing_model_config = db.session.query(LoadBalancingModelConfig) \ + .filter( + LoadBalancingModelConfig.tenant_id == tenant_id, + LoadBalancingModelConfig.provider_name == provider, + LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_name == model, + LoadBalancingModelConfig.id == config_id + ).first() + + if not load_balancing_model_config: + raise ValueError(f"Load balancing config {config_id} does not exist.") + + # Validate custom provider config + self._custom_credentials_validate( + tenant_id=tenant_id, + provider_configuration=provider_configuration, + model_type=model_type, + model=model, + credentials=credentials, + load_balancing_model_config=load_balancing_model_config + ) + + def _custom_credentials_validate(self, tenant_id: str, + provider_configuration: ProviderConfiguration, + model_type: ModelType, + model: str, + credentials: dict, + load_balancing_model_config: Optional[LoadBalancingModelConfig] = None, + validate: bool = True) -> dict: + """ + Validate custom credentials. + :param tenant_id: workspace id + :param provider_configuration: provider configuration + :param model_type: model type + :param model: model name + :param credentials: credentials + :param load_balancing_model_config: load balancing model config + :param validate: validate credentials + :return: + """ + # Get credential form schemas from model credential schema or provider credential schema + credential_schemas = self._get_credential_schema(provider_configuration) + + # Get provider credential secret variables + provider_credential_secret_variables = provider_configuration.extract_secret_variables( + credential_schemas.credential_form_schemas + ) + + if load_balancing_model_config: + try: + # fix origin data + if load_balancing_model_config.encrypted_config: + original_credentials = json.loads(load_balancing_model_config.encrypted_config) + else: + original_credentials = {} + except JSONDecodeError: + original_credentials = {} + + # encrypt credentials + for key, value in credentials.items(): + if key in provider_credential_secret_variables: + # if send [__HIDDEN__] in secret input, it will be same as original value + if value == '[__HIDDEN__]' and key in original_credentials: + credentials[key] = encrypter.decrypt_token(tenant_id, original_credentials[key]) + + if validate: + if isinstance(credential_schemas, ModelCredentialSchema): + credentials = model_provider_factory.model_credentials_validate( + provider=provider_configuration.provider.provider, + model_type=model_type, + model=model, + credentials=credentials + ) + else: + credentials = model_provider_factory.provider_credentials_validate( + provider=provider_configuration.provider.provider, + credentials=credentials + ) + + for key, value in credentials.items(): + if key in provider_credential_secret_variables: + credentials[key] = encrypter.encrypt_token(tenant_id, value) + + return credentials + + def _get_credential_schema(self, provider_configuration: ProviderConfiguration) \ + -> ModelCredentialSchema | ProviderCredentialSchema: + """ + Get form schemas. + :param provider_configuration: provider configuration + :return: + """ + # Get credential form schemas from model credential schema or provider credential schema + if provider_configuration.provider.model_credential_schema: + credential_schema = provider_configuration.provider.model_credential_schema + else: + credential_schema = provider_configuration.provider.provider_credential_schema + + return credential_schema + + def _clear_credentials_cache(self, tenant_id: str, config_id: str) -> None: + """ + Clear credentials cache. + :param tenant_id: workspace id + :param config_id: load balancing config id + :return: + """ + provider_model_credentials_cache = ProviderCredentialsCache( + tenant_id=tenant_id, + identity_id=config_id, + cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL + ) + + provider_model_credentials_cache.delete() diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index 5a4342ae03..385af685f9 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -6,7 +6,7 @@ from typing import Optional, cast import requests from flask import current_app -from core.entities.model_entities import ModelStatus +from core.entities.model_entities import ModelStatus, ProviderModelWithStatusEntity from core.model_runtime.entities.model_entities import ModelType, ParameterRule from core.model_runtime.model_providers import model_provider_factory from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel @@ -16,7 +16,6 @@ from services.entities.model_provider_entities import ( CustomConfigurationResponse, CustomConfigurationStatus, DefaultModelResponse, - ModelResponse, ModelWithProviderEntityResponse, ProviderResponse, ProviderWithModelsResponse, @@ -303,6 +302,9 @@ class ModelProviderService: if model.deprecated: continue + if model.status != ModelStatus.ACTIVE: + continue + provider_models[model.provider.provider].append(model) # convert to ProviderWithModelsResponse list @@ -313,24 +315,22 @@ class ModelProviderService: first_model = models[0] - has_active_models = any([model.status == ModelStatus.ACTIVE for model in models]) - providers_with_models.append( ProviderWithModelsResponse( provider=provider, label=first_model.provider.label, icon_small=first_model.provider.icon_small, icon_large=first_model.provider.icon_large, - status=CustomConfigurationStatus.ACTIVE - if has_active_models else CustomConfigurationStatus.NO_CONFIGURE, - models=[ModelResponse( + status=CustomConfigurationStatus.ACTIVE, + models=[ProviderModelWithStatusEntity( model=model.model, label=model.label, model_type=model.model_type, features=model.features, fetch_from=model.fetch_from, model_properties=model.model_properties, - status=model.status + status=model.status, + load_balancing_enabled=model.load_balancing_enabled ) for model in models] ) ) @@ -486,6 +486,54 @@ class ModelProviderService: # Switch preferred provider type provider_configuration.switch_preferred_provider_type(preferred_provider_type_enum) + def enable_model(self, tenant_id: str, provider: str, model: str, model_type: str) -> None: + """ + enable model. + + :param tenant_id: workspace id + :param provider: provider name + :param model: model name + :param model_type: model type + :return: + """ + # Get all provider configurations of the current workspace + provider_configurations = self.provider_manager.get_configurations(tenant_id) + + # Get provider configuration + provider_configuration = provider_configurations.get(provider) + if not provider_configuration: + raise ValueError(f"Provider {provider} does not exist.") + + # Enable model + provider_configuration.enable_model( + model=model, + model_type=ModelType.value_of(model_type) + ) + + def disable_model(self, tenant_id: str, provider: str, model: str, model_type: str) -> None: + """ + disable model. + + :param tenant_id: workspace id + :param provider: provider name + :param model: model name + :param model_type: model type + :return: + """ + # Get all provider configurations of the current workspace + provider_configurations = self.provider_manager.get_configurations(tenant_id) + + # Get provider configuration + provider_configuration = provider_configurations.get(provider) + if not provider_configuration: + raise ValueError(f"Provider {provider} does not exist.") + + # Enable model + provider_configuration.disable_model( + model=model, + model_type=ModelType.value_of(model_type) + ) + def free_quota_submit(self, tenant_id: str, provider: str): api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY") api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL") diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 456ab0dcb0..6235ecf0a3 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -68,7 +68,7 @@ class WorkflowService: account: Account) -> Workflow: """ Sync draft workflow - @throws WorkflowHashNotEqualError + :raises WorkflowHashNotEqualError """ # fetch draft workflow by app_model workflow = self.get_draft_workflow(app_model=app_model) diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index d6dc970477..67cc03bdeb 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -2,7 +2,6 @@ import datetime import logging import time import uuid -from typing import cast import click from celery import shared_task @@ -11,7 +10,6 @@ from sqlalchemy import func from core.indexing_runner import IndexingRunner from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from extensions.ext_database import db from extensions.ext_redis import redis_client from libs import helper @@ -59,16 +57,12 @@ def batch_create_segment_to_index_task(job_id: str, content: list, dataset_id: s model=dataset.embedding_model ) - model_type_instance = embedding_model.model_type_instance - model_type_instance = cast(TextEmbeddingModel, model_type_instance) for segment in content: content = segment['content'] doc_id = str(uuid.uuid4()) segment_hash = helper.generate_text_hash(content) # calc embedding use tokens - tokens = model_type_instance.get_num_tokens( - model=embedding_model.model, - credentials=embedding_model.credentials, + tokens = embedding_model.get_text_embedding_num_tokens( texts=[content] ) if embedding_model else 0 max_position = db.session.query(func.max(DocumentSegment.position)).filter( diff --git a/api/tests/integration_tests/utils/test_module_import_helper.py b/api/tests/integration_tests/utils/test_module_import_helper.py index 39ac41b648..256c9a911f 100644 --- a/api/tests/integration_tests/utils/test_module_import_helper.py +++ b/api/tests/integration_tests/utils/test_module_import_helper.py @@ -1,6 +1,6 @@ import os -from core.utils.module_import_helper import import_module_from_source, load_single_subclass_from_source +from core.helper.module_import_helper import import_module_from_source, load_single_subclass_from_source from tests.integration_tests.utils.parent_class import ParentClass diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index 394a3dcbd7..a150be3c00 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -92,7 +92,8 @@ def test_execute_llm(setup_openai_mock): provider=CustomProviderConfiguration( credentials=credentials ) - ) + ), + model_settings=[] ), provider_instance=provider_instance, model_type_instance=model_type_instance @@ -206,10 +207,11 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock): provider=CustomProviderConfiguration( credentials=credentials ) - ) + ), + model_settings=[] ), provider_instance=provider_instance, - model_type_instance=model_type_instance + model_type_instance=model_type_instance, ) model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model='gpt-3.5-turbo') diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index 342f371eea..056c78441d 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -42,7 +42,8 @@ def get_mocked_fetch_model_config( provider=CustomProviderConfiguration( credentials=credentials ) - ) + ), + model_settings=[] ), provider_instance=provider_instance, model_type_instance=model_type_instance diff --git a/api/tests/unit_tests/core/prompt/test_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_prompt_transform.py index 40f5be8af9..2bcc6f4292 100644 --- a/api/tests/unit_tests/core/prompt/test_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_prompt_transform.py @@ -1,9 +1,10 @@ from unittest.mock import MagicMock from core.app.app_config.entities import ModelConfigEntity -from core.entities.provider_configuration import ProviderModelBundle +from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.model_runtime.entities.message_entities import UserPromptMessage from core.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey, ParameterRule +from core.model_runtime.entities.provider_entities import ProviderEntity from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.prompt.prompt_transform import PromptTransform @@ -22,8 +23,16 @@ def test__calculate_rest_token(): large_language_model_mock = MagicMock(spec=LargeLanguageModel) large_language_model_mock.get_num_tokens.return_value = 6 + provider_mock = MagicMock(spec=ProviderEntity) + provider_mock.provider = 'openai' + + provider_configuration_mock = MagicMock(spec=ProviderConfiguration) + provider_configuration_mock.provider = provider_mock + provider_configuration_mock.model_settings = None + provider_model_bundle_mock = MagicMock(spec=ProviderModelBundle) provider_model_bundle_mock.model_type_instance = large_language_model_mock + provider_model_bundle_mock.configuration = provider_configuration_mock model_config_mock = MagicMock(spec=ModelConfigEntity) model_config_mock.model = 'gpt-4' diff --git a/api/tests/unit_tests/core/test_model_manager.py b/api/tests/unit_tests/core/test_model_manager.py new file mode 100644 index 0000000000..3024a54a4d --- /dev/null +++ b/api/tests/unit_tests/core/test_model_manager.py @@ -0,0 +1,77 @@ +from unittest.mock import MagicMock + +import pytest + +from core.entities.provider_entities import ModelLoadBalancingConfiguration +from core.model_manager import LBModelManager +from core.model_runtime.entities.model_entities import ModelType + + +@pytest.fixture +def lb_model_manager(): + load_balancing_configs = [ + ModelLoadBalancingConfiguration( + id='id1', + name='__inherit__', + credentials={} + ), + ModelLoadBalancingConfiguration( + id='id2', + name='first', + credentials={"openai_api_key": "fake_key"} + ), + ModelLoadBalancingConfiguration( + id='id3', + name='second', + credentials={"openai_api_key": "fake_key"} + ) + ] + + lb_model_manager = LBModelManager( + tenant_id='tenant_id', + provider='openai', + model_type=ModelType.LLM, + model='gpt-4', + load_balancing_configs=load_balancing_configs, + managed_credentials={"openai_api_key": "fake_key"} + ) + + lb_model_manager.cooldown = MagicMock(return_value=None) + + def is_cooldown(config: ModelLoadBalancingConfiguration): + if config.id == 'id1': + return True + + return False + + lb_model_manager.in_cooldown = MagicMock(side_effect=is_cooldown) + + return lb_model_manager + + +def test_lb_model_manager_fetch_next(mocker, lb_model_manager): + assert len(lb_model_manager._load_balancing_configs) == 3 + + config1 = lb_model_manager._load_balancing_configs[0] + config2 = lb_model_manager._load_balancing_configs[1] + config3 = lb_model_manager._load_balancing_configs[2] + + assert lb_model_manager.in_cooldown(config1) is True + assert lb_model_manager.in_cooldown(config2) is False + assert lb_model_manager.in_cooldown(config3) is False + + start_index = 0 + def incr(key): + nonlocal start_index + start_index += 1 + return start_index + + mocker.patch('redis.Redis.incr', side_effect=incr) + mocker.patch('redis.Redis.set', return_value=None) + mocker.patch('redis.Redis.expire', return_value=None) + + config = lb_model_manager.fetch_next() + assert config == config2 + + config = lb_model_manager.fetch_next() + assert config == config3 diff --git a/api/tests/unit_tests/core/test_provider_manager.py b/api/tests/unit_tests/core/test_provider_manager.py new file mode 100644 index 0000000000..072b6f100f --- /dev/null +++ b/api/tests/unit_tests/core/test_provider_manager.py @@ -0,0 +1,183 @@ +from core.entities.provider_entities import ModelSettings +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.model_providers import model_provider_factory +from core.provider_manager import ProviderManager +from models.provider import LoadBalancingModelConfig, ProviderModelSetting + + +def test__to_model_settings(mocker): + # Get all provider entities + provider_entities = model_provider_factory.get_providers() + + provider_entity = None + for provider in provider_entities: + if provider.provider == 'openai': + provider_entity = provider + + # Mocking the inputs + provider_model_settings = [ProviderModelSetting( + id='id', + tenant_id='tenant_id', + provider_name='openai', + model_name='gpt-4', + model_type='text-generation', + enabled=True, + load_balancing_enabled=True + )] + load_balancing_model_configs = [ + LoadBalancingModelConfig( + id='id1', + tenant_id='tenant_id', + provider_name='openai', + model_name='gpt-4', + model_type='text-generation', + name='__inherit__', + encrypted_config=None, + enabled=True + ), + LoadBalancingModelConfig( + id='id2', + tenant_id='tenant_id', + provider_name='openai', + model_name='gpt-4', + model_type='text-generation', + name='first', + encrypted_config='{"openai_api_key": "fake_key"}', + enabled=True + ) + ] + + mocker.patch('core.helper.model_provider_cache.ProviderCredentialsCache.get', return_value={"openai_api_key": "fake_key"}) + + provider_manager = ProviderManager() + + # Running the method + result = provider_manager._to_model_settings( + provider_entity, + provider_model_settings, + load_balancing_model_configs + ) + + # Asserting that the result is as expected + assert len(result) == 1 + assert isinstance(result[0], ModelSettings) + assert result[0].model == 'gpt-4' + assert result[0].model_type == ModelType.LLM + assert result[0].enabled is True + assert len(result[0].load_balancing_configs) == 2 + assert result[0].load_balancing_configs[0].name == '__inherit__' + assert result[0].load_balancing_configs[1].name == 'first' + + +def test__to_model_settings_only_one_lb(mocker): + # Get all provider entities + provider_entities = model_provider_factory.get_providers() + + provider_entity = None + for provider in provider_entities: + if provider.provider == 'openai': + provider_entity = provider + + # Mocking the inputs + provider_model_settings = [ProviderModelSetting( + id='id', + tenant_id='tenant_id', + provider_name='openai', + model_name='gpt-4', + model_type='text-generation', + enabled=True, + load_balancing_enabled=True + )] + load_balancing_model_configs = [ + LoadBalancingModelConfig( + id='id1', + tenant_id='tenant_id', + provider_name='openai', + model_name='gpt-4', + model_type='text-generation', + name='__inherit__', + encrypted_config=None, + enabled=True + ) + ] + + mocker.patch('core.helper.model_provider_cache.ProviderCredentialsCache.get', return_value={"openai_api_key": "fake_key"}) + + provider_manager = ProviderManager() + + # Running the method + result = provider_manager._to_model_settings( + provider_entity, + provider_model_settings, + load_balancing_model_configs + ) + + # Asserting that the result is as expected + assert len(result) == 1 + assert isinstance(result[0], ModelSettings) + assert result[0].model == 'gpt-4' + assert result[0].model_type == ModelType.LLM + assert result[0].enabled is True + assert len(result[0].load_balancing_configs) == 0 + + +def test__to_model_settings_lb_disabled(mocker): + # Get all provider entities + provider_entities = model_provider_factory.get_providers() + + provider_entity = None + for provider in provider_entities: + if provider.provider == 'openai': + provider_entity = provider + + # Mocking the inputs + provider_model_settings = [ProviderModelSetting( + id='id', + tenant_id='tenant_id', + provider_name='openai', + model_name='gpt-4', + model_type='text-generation', + enabled=True, + load_balancing_enabled=False + )] + load_balancing_model_configs = [ + LoadBalancingModelConfig( + id='id1', + tenant_id='tenant_id', + provider_name='openai', + model_name='gpt-4', + model_type='text-generation', + name='__inherit__', + encrypted_config=None, + enabled=True + ), + LoadBalancingModelConfig( + id='id2', + tenant_id='tenant_id', + provider_name='openai', + model_name='gpt-4', + model_type='text-generation', + name='first', + encrypted_config='{"openai_api_key": "fake_key"}', + enabled=True + ) + ] + + mocker.patch('core.helper.model_provider_cache.ProviderCredentialsCache.get', return_value={"openai_api_key": "fake_key"}) + + provider_manager = ProviderManager() + + # Running the method + result = provider_manager._to_model_settings( + provider_entity, + provider_model_settings, + load_balancing_model_configs + ) + + # Asserting that the result is as expected + assert len(result) == 1 + assert isinstance(result[0], ModelSettings) + assert result[0].model == 'gpt-4' + assert result[0].model_type == ModelType.LLM + assert result[0].enabled is True + assert len(result[0].load_balancing_configs) == 0 diff --git a/api/tests/unit_tests/utils/position_helper/test_position_helper.py b/api/tests/unit_tests/utils/position_helper/test_position_helper.py index b7442d0d93..c389461454 100644 --- a/api/tests/unit_tests/utils/position_helper/test_position_helper.py +++ b/api/tests/unit_tests/utils/position_helper/test_position_helper.py @@ -2,7 +2,7 @@ from textwrap import dedent import pytest -from core.utils.position_helper import get_position_map +from core.helper.position_helper import get_position_map @pytest.fixture