feat: backend model load balancing support (#4927)

This commit is contained in:
takatost 2024-06-05 00:13:04 +08:00 committed by GitHub
parent 52ec152dd3
commit d1dbbc1e33
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
47 changed files with 2191 additions and 256 deletions

View File

@ -70,6 +70,7 @@ DEFAULTS = {
'INVITE_EXPIRY_HOURS': 72, 'INVITE_EXPIRY_HOURS': 72,
'BILLING_ENABLED': 'False', 'BILLING_ENABLED': 'False',
'CAN_REPLACE_LOGO': 'False', 'CAN_REPLACE_LOGO': 'False',
'MODEL_LB_ENABLED': 'False',
'ETL_TYPE': 'dify', 'ETL_TYPE': 'dify',
'KEYWORD_STORE': 'jieba', 'KEYWORD_STORE': 'jieba',
'BATCH_UPLOAD_LIMIT': 20, 'BATCH_UPLOAD_LIMIT': 20,
@ -123,6 +124,7 @@ class Config:
self.LOG_FILE = get_env('LOG_FILE') self.LOG_FILE = get_env('LOG_FILE')
self.LOG_FORMAT = get_env('LOG_FORMAT') self.LOG_FORMAT = get_env('LOG_FORMAT')
self.LOG_DATEFORMAT = get_env('LOG_DATEFORMAT') self.LOG_DATEFORMAT = get_env('LOG_DATEFORMAT')
self.API_COMPRESSION_ENABLED = get_bool_env('API_COMPRESSION_ENABLED')
# The backend URL prefix of the console API. # The backend URL prefix of the console API.
# used to concatenate the login authorization callback or notion integration callback. # used to concatenate the login authorization callback or notion integration callback.
@ -210,27 +212,41 @@ class Config:
if self.CELERY_BACKEND == 'database' else self.CELERY_BROKER_URL if self.CELERY_BACKEND == 'database' else self.CELERY_BROKER_URL
self.BROKER_USE_SSL = self.CELERY_BROKER_URL.startswith('rediss://') self.BROKER_USE_SSL = self.CELERY_BROKER_URL.startswith('rediss://')
# ------------------------
# Code Execution Sandbox Configurations.
# ------------------------
self.CODE_EXECUTION_ENDPOINT = get_env('CODE_EXECUTION_ENDPOINT')
self.CODE_EXECUTION_API_KEY = get_env('CODE_EXECUTION_API_KEY')
# ------------------------ # ------------------------
# File Storage Configurations. # File Storage Configurations.
# ------------------------ # ------------------------
self.STORAGE_TYPE = get_env('STORAGE_TYPE') self.STORAGE_TYPE = get_env('STORAGE_TYPE')
self.STORAGE_LOCAL_PATH = get_env('STORAGE_LOCAL_PATH') self.STORAGE_LOCAL_PATH = get_env('STORAGE_LOCAL_PATH')
# S3 Storage settings
self.S3_ENDPOINT = get_env('S3_ENDPOINT') self.S3_ENDPOINT = get_env('S3_ENDPOINT')
self.S3_BUCKET_NAME = get_env('S3_BUCKET_NAME') self.S3_BUCKET_NAME = get_env('S3_BUCKET_NAME')
self.S3_ACCESS_KEY = get_env('S3_ACCESS_KEY') self.S3_ACCESS_KEY = get_env('S3_ACCESS_KEY')
self.S3_SECRET_KEY = get_env('S3_SECRET_KEY') self.S3_SECRET_KEY = get_env('S3_SECRET_KEY')
self.S3_REGION = get_env('S3_REGION') self.S3_REGION = get_env('S3_REGION')
self.S3_ADDRESS_STYLE = get_env('S3_ADDRESS_STYLE') self.S3_ADDRESS_STYLE = get_env('S3_ADDRESS_STYLE')
# Azure Blob Storage settings
self.AZURE_BLOB_ACCOUNT_NAME = get_env('AZURE_BLOB_ACCOUNT_NAME') self.AZURE_BLOB_ACCOUNT_NAME = get_env('AZURE_BLOB_ACCOUNT_NAME')
self.AZURE_BLOB_ACCOUNT_KEY = get_env('AZURE_BLOB_ACCOUNT_KEY') self.AZURE_BLOB_ACCOUNT_KEY = get_env('AZURE_BLOB_ACCOUNT_KEY')
self.AZURE_BLOB_CONTAINER_NAME = get_env('AZURE_BLOB_CONTAINER_NAME') self.AZURE_BLOB_CONTAINER_NAME = get_env('AZURE_BLOB_CONTAINER_NAME')
self.AZURE_BLOB_ACCOUNT_URL = get_env('AZURE_BLOB_ACCOUNT_URL') self.AZURE_BLOB_ACCOUNT_URL = get_env('AZURE_BLOB_ACCOUNT_URL')
# Aliyun Storage settings
self.ALIYUN_OSS_BUCKET_NAME = get_env('ALIYUN_OSS_BUCKET_NAME') self.ALIYUN_OSS_BUCKET_NAME = get_env('ALIYUN_OSS_BUCKET_NAME')
self.ALIYUN_OSS_ACCESS_KEY = get_env('ALIYUN_OSS_ACCESS_KEY') self.ALIYUN_OSS_ACCESS_KEY = get_env('ALIYUN_OSS_ACCESS_KEY')
self.ALIYUN_OSS_SECRET_KEY = get_env('ALIYUN_OSS_SECRET_KEY') self.ALIYUN_OSS_SECRET_KEY = get_env('ALIYUN_OSS_SECRET_KEY')
self.ALIYUN_OSS_ENDPOINT = get_env('ALIYUN_OSS_ENDPOINT') self.ALIYUN_OSS_ENDPOINT = get_env('ALIYUN_OSS_ENDPOINT')
self.ALIYUN_OSS_REGION = get_env('ALIYUN_OSS_REGION') self.ALIYUN_OSS_REGION = get_env('ALIYUN_OSS_REGION')
self.ALIYUN_OSS_AUTH_VERSION = get_env('ALIYUN_OSS_AUTH_VERSION') self.ALIYUN_OSS_AUTH_VERSION = get_env('ALIYUN_OSS_AUTH_VERSION')
# Google Cloud Storage settings
self.GOOGLE_STORAGE_BUCKET_NAME = get_env('GOOGLE_STORAGE_BUCKET_NAME') self.GOOGLE_STORAGE_BUCKET_NAME = get_env('GOOGLE_STORAGE_BUCKET_NAME')
self.GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64 = get_env('GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64') self.GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64 = get_env('GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64')
@ -240,6 +256,7 @@ class Config:
# ------------------------ # ------------------------
self.VECTOR_STORE = get_env('VECTOR_STORE') self.VECTOR_STORE = get_env('VECTOR_STORE')
self.KEYWORD_STORE = get_env('KEYWORD_STORE') self.KEYWORD_STORE = get_env('KEYWORD_STORE')
# qdrant settings # qdrant settings
self.QDRANT_URL = get_env('QDRANT_URL') self.QDRANT_URL = get_env('QDRANT_URL')
self.QDRANT_API_KEY = get_env('QDRANT_API_KEY') self.QDRANT_API_KEY = get_env('QDRANT_API_KEY')
@ -323,6 +340,19 @@ class Config:
self.UPLOAD_FILE_SIZE_LIMIT = int(get_env('UPLOAD_FILE_SIZE_LIMIT')) self.UPLOAD_FILE_SIZE_LIMIT = int(get_env('UPLOAD_FILE_SIZE_LIMIT'))
self.UPLOAD_FILE_BATCH_LIMIT = int(get_env('UPLOAD_FILE_BATCH_LIMIT')) self.UPLOAD_FILE_BATCH_LIMIT = int(get_env('UPLOAD_FILE_BATCH_LIMIT'))
self.UPLOAD_IMAGE_FILE_SIZE_LIMIT = int(get_env('UPLOAD_IMAGE_FILE_SIZE_LIMIT')) self.UPLOAD_IMAGE_FILE_SIZE_LIMIT = int(get_env('UPLOAD_IMAGE_FILE_SIZE_LIMIT'))
self.BATCH_UPLOAD_LIMIT = get_env('BATCH_UPLOAD_LIMIT')
# RAG ETL Configurations.
self.ETL_TYPE = get_env('ETL_TYPE')
self.UNSTRUCTURED_API_URL = get_env('UNSTRUCTURED_API_URL')
self.UNSTRUCTURED_API_KEY = get_env('UNSTRUCTURED_API_KEY')
self.KEYWORD_DATA_SOURCE_TYPE = get_env('KEYWORD_DATA_SOURCE_TYPE')
# Indexing Configurations.
self.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH = get_env('INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH')
# Tool Configurations.
self.TOOL_ICON_CACHE_MAX_AGE = get_env('TOOL_ICON_CACHE_MAX_AGE')
self.WORKFLOW_MAX_EXECUTION_STEPS = int(get_env('WORKFLOW_MAX_EXECUTION_STEPS')) self.WORKFLOW_MAX_EXECUTION_STEPS = int(get_env('WORKFLOW_MAX_EXECUTION_STEPS'))
self.WORKFLOW_MAX_EXECUTION_TIME = int(get_env('WORKFLOW_MAX_EXECUTION_TIME')) self.WORKFLOW_MAX_EXECUTION_TIME = int(get_env('WORKFLOW_MAX_EXECUTION_TIME'))
@ -378,24 +408,15 @@ class Config:
self.HOSTED_FETCH_APP_TEMPLATES_MODE = get_env('HOSTED_FETCH_APP_TEMPLATES_MODE') self.HOSTED_FETCH_APP_TEMPLATES_MODE = get_env('HOSTED_FETCH_APP_TEMPLATES_MODE')
self.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN = get_env('HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN') self.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN = get_env('HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN')
self.ETL_TYPE = get_env('ETL_TYPE') # Model Load Balancing Configurations.
self.UNSTRUCTURED_API_URL = get_env('UNSTRUCTURED_API_URL') self.MODEL_LB_ENABLED = get_bool_env('MODEL_LB_ENABLED')
self.UNSTRUCTURED_API_KEY = get_env('UNSTRUCTURED_API_KEY')
# Platform Billing Configurations.
self.BILLING_ENABLED = get_bool_env('BILLING_ENABLED') self.BILLING_ENABLED = get_bool_env('BILLING_ENABLED')
self.CAN_REPLACE_LOGO = get_bool_env('CAN_REPLACE_LOGO')
self.BATCH_UPLOAD_LIMIT = get_env('BATCH_UPLOAD_LIMIT') # ------------------------
# Enterprise feature Configurations.
self.CODE_EXECUTION_ENDPOINT = get_env('CODE_EXECUTION_ENDPOINT') # **Before using, please contact business@dify.ai by email to inquire about licensing matters.**
self.CODE_EXECUTION_API_KEY = get_env('CODE_EXECUTION_API_KEY') # ------------------------
self.API_COMPRESSION_ENABLED = get_bool_env('API_COMPRESSION_ENABLED')
self.TOOL_ICON_CACHE_MAX_AGE = get_env('TOOL_ICON_CACHE_MAX_AGE')
self.KEYWORD_DATA_SOURCE_TYPE = get_env('KEYWORD_DATA_SOURCE_TYPE')
self.ENTERPRISE_ENABLED = get_bool_env('ENTERPRISE_ENABLED') self.ENTERPRISE_ENABLED = get_bool_env('ENTERPRISE_ENABLED')
self.CAN_REPLACE_LOGO = get_bool_env('CAN_REPLACE_LOGO')
# ------------------------
# Indexing Configurations.
# ------------------------
self.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH = get_env('INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH')

View File

@ -54,4 +54,4 @@ from .explore import (
from .tag import tags from .tag import tags
# Import workspace controllers # Import workspace controllers
from .workspace import account, members, model_providers, models, tool_providers, workspace from .workspace import account, load_balancing_config, members, model_providers, models, tool_providers, workspace

View File

@ -1,14 +1,19 @@
from flask_login import current_user from flask_login import current_user
from flask_restful import Resource from flask_restful import Resource
from libs.login import login_required
from services.feature_service import FeatureService from services.feature_service import FeatureService
from . import api from . import api
from .wraps import cloud_utm_record from .setup import setup_required
from .wraps import account_initialization_required, cloud_utm_record
class FeatureApi(Resource): class FeatureApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_utm_record @cloud_utm_record
def get(self): def get(self):
return FeatureService.get_features(current_user.current_tenant_id).dict() return FeatureService.get_features(current_user.current_tenant_id).dict()

View File

@ -17,13 +17,19 @@ class VersionApi(Resource):
args = parser.parse_args() args = parser.parse_args()
check_update_url = current_app.config['CHECK_UPDATE_URL'] check_update_url = current_app.config['CHECK_UPDATE_URL']
if not check_update_url: result = {
return { 'version': current_app.config['CURRENT_VERSION'],
'version': '0.0.0',
'release_date': '', 'release_date': '',
'release_notes': '', 'release_notes': '',
'can_auto_update': False 'can_auto_update': False,
'features': {
'can_replace_logo': current_app.config['CAN_REPLACE_LOGO'],
'model_load_balancing_enabled': current_app.config['MODEL_LB_ENABLED']
} }
}
if not check_update_url:
return result
try: try:
response = requests.get(check_update_url, { response = requests.get(check_update_url, {
@ -31,20 +37,15 @@ class VersionApi(Resource):
}) })
except Exception as error: except Exception as error:
logging.warning("Check update version error: {}.".format(str(error))) logging.warning("Check update version error: {}.".format(str(error)))
return { result['version'] = args.get('current_version')
'version': args.get('current_version'), return result
'release_date': '',
'release_notes': '',
'can_auto_update': False
}
content = json.loads(response.content) content = json.loads(response.content)
return { result['version'] = content['version']
'version': content['version'], result['release_date'] = content['releaseDate']
'release_date': content['releaseDate'], result['release_notes'] = content['releaseNotes']
'release_notes': content['releaseNotes'], result['can_auto_update'] = content['canAutoUpdate']
'can_auto_update': content['canAutoUpdate'] return result
}
api.add_resource(VersionApi, '/version') api.add_resource(VersionApi, '/version')

View File

@ -0,0 +1,106 @@
from flask_restful import Resource, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from libs.login import current_user, login_required
from models.account import TenantAccountRole
from services.model_load_balancing_service import ModelLoadBalancingService
class LoadBalancingCredentialsValidateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider: str):
if not TenantAccountRole.is_privileged_role(current_user.current_tenant.current_role):
raise Forbidden()
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument('model', type=str, required=True, nullable=False, location='json')
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=[mt.value for mt in ModelType], location='json')
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
args = parser.parse_args()
# validate model load balancing credentials
model_load_balancing_service = ModelLoadBalancingService()
result = True
error = None
try:
model_load_balancing_service.validate_load_balancing_credentials(
tenant_id=tenant_id,
provider=provider,
model=args['model'],
model_type=args['model_type'],
credentials=args['credentials']
)
except CredentialsValidateFailedError as ex:
result = False
error = str(ex)
response = {'result': 'success' if result else 'error'}
if not result:
response['error'] = error
return response
class LoadBalancingConfigCredentialsValidateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider: str, config_id: str):
if not TenantAccountRole.is_privileged_role(current_user.current_tenant.current_role):
raise Forbidden()
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument('model', type=str, required=True, nullable=False, location='json')
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=[mt.value for mt in ModelType], location='json')
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
args = parser.parse_args()
# validate model load balancing config credentials
model_load_balancing_service = ModelLoadBalancingService()
result = True
error = None
try:
model_load_balancing_service.validate_load_balancing_credentials(
tenant_id=tenant_id,
provider=provider,
model=args['model'],
model_type=args['model_type'],
credentials=args['credentials'],
config_id=config_id,
)
except CredentialsValidateFailedError as ex:
result = False
error = str(ex)
response = {'result': 'success' if result else 'error'}
if not result:
response['error'] = error
return response
# Load Balancing Config
api.add_resource(LoadBalancingCredentialsValidateApi,
'/workspaces/current/model-providers/<string:provider>/models/load-balancing-configs/credentials-validate')
api.add_resource(LoadBalancingConfigCredentialsValidateApi,
'/workspaces/current/model-providers/<string:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate')

View File

@ -12,6 +12,7 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from libs.login import login_required from libs.login import login_required
from models.account import TenantAccountRole from models.account import TenantAccountRole
from services.model_load_balancing_service import ModelLoadBalancingService
from services.model_provider_service import ModelProviderService from services.model_provider_service import ModelProviderService
@ -104,9 +105,44 @@ class ModelProviderModelApi(Resource):
parser.add_argument('model', type=str, required=True, nullable=False, location='json') parser.add_argument('model', type=str, required=True, nullable=False, location='json')
parser.add_argument('model_type', type=str, required=True, nullable=False, parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=[mt.value for mt in ModelType], location='json') choices=[mt.value for mt in ModelType], location='json')
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') parser.add_argument('credentials', type=dict, required=False, nullable=True, location='json')
parser.add_argument('load_balancing', type=dict, required=False, nullable=True, location='json')
parser.add_argument('config_from', type=str, required=False, nullable=True, location='json')
args = parser.parse_args() args = parser.parse_args()
model_load_balancing_service = ModelLoadBalancingService()
if ('load_balancing' in args and args['load_balancing'] and
'enabled' in args['load_balancing'] and args['load_balancing']['enabled']):
if 'configs' not in args['load_balancing']:
raise ValueError('invalid load balancing configs')
# save load balancing configs
model_load_balancing_service.update_load_balancing_configs(
tenant_id=tenant_id,
provider=provider,
model=args['model'],
model_type=args['model_type'],
configs=args['load_balancing']['configs']
)
# enable load balancing
model_load_balancing_service.enable_model_load_balancing(
tenant_id=tenant_id,
provider=provider,
model=args['model'],
model_type=args['model_type']
)
else:
# disable load balancing
model_load_balancing_service.disable_model_load_balancing(
tenant_id=tenant_id,
provider=provider,
model=args['model'],
model_type=args['model_type']
)
if args.get('config_from', '') != 'predefined-model':
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
try: try:
@ -170,9 +206,71 @@ class ModelProviderModelCredentialApi(Resource):
model=args['model'] model=args['model']
) )
model_load_balancing_service = ModelLoadBalancingService()
is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs(
tenant_id=tenant_id,
provider=provider,
model=args['model'],
model_type=args['model_type']
)
return { return {
"credentials": credentials "credentials": credentials,
"load_balancing": {
"enabled": is_load_balancing_enabled,
"configs": load_balancing_configs
} }
}
class ModelProviderModelEnableApi(Resource):
@setup_required
@login_required
@account_initialization_required
def patch(self, provider: str):
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument('model', type=str, required=True, nullable=False, location='json')
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=[mt.value for mt in ModelType], location='json')
args = parser.parse_args()
model_provider_service = ModelProviderService()
model_provider_service.enable_model(
tenant_id=tenant_id,
provider=provider,
model=args['model'],
model_type=args['model_type']
)
return {'result': 'success'}
class ModelProviderModelDisableApi(Resource):
@setup_required
@login_required
@account_initialization_required
def patch(self, provider: str):
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument('model', type=str, required=True, nullable=False, location='json')
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=[mt.value for mt in ModelType], location='json')
args = parser.parse_args()
model_provider_service = ModelProviderService()
model_provider_service.disable_model(
tenant_id=tenant_id,
provider=provider,
model=args['model'],
model_type=args['model_type']
)
return {'result': 'success'}
class ModelProviderModelValidateApi(Resource): class ModelProviderModelValidateApi(Resource):
@ -259,6 +357,10 @@ class ModelProviderAvailableModelApi(Resource):
api.add_resource(ModelProviderModelApi, '/workspaces/current/model-providers/<string:provider>/models') api.add_resource(ModelProviderModelApi, '/workspaces/current/model-providers/<string:provider>/models')
api.add_resource(ModelProviderModelEnableApi, '/workspaces/current/model-providers/<string:provider>/models/enable',
endpoint='model-provider-model-enable')
api.add_resource(ModelProviderModelDisableApi, '/workspaces/current/model-providers/<string:provider>/models/disable',
endpoint='model-provider-model-disable')
api.add_resource(ModelProviderModelCredentialApi, api.add_resource(ModelProviderModelCredentialApi,
'/workspaces/current/model-providers/<string:provider>/models/credentials') '/workspaces/current/model-providers/<string:provider>/models/credentials')
api.add_resource(ModelProviderModelValidateApi, api.add_resource(ModelProviderModelValidateApi,

View File

@ -1,6 +1,6 @@
import time import time
from collections.abc import Generator from collections.abc import Generator
from typing import Optional, Union, cast from typing import Optional, Union
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@ -16,11 +16,11 @@ from core.app.features.hosting_moderation.hosting_moderation import HostingModer
from core.external_data_tool.external_data_fetch import ExternalDataFetch from core.external_data_tool.external_data_fetch import ExternalDataFetch
from core.file.file_obj import FileVar from core.file.file_obj import FileVar
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage
from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.errors.invoke import InvokeBadRequestError from core.model_runtime.errors.invoke import InvokeBadRequestError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.moderation.input_moderation import InputModeration from core.moderation.input_moderation import InputModeration
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
@ -45,8 +45,11 @@ class AppRunner:
:param query: query :param query: query
:return: :return:
""" """
model_type_instance = model_config.provider_model_bundle.model_type_instance # Invoke model
model_type_instance = cast(LargeLanguageModel, model_type_instance) model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle,
model=model_config.model
)
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
@ -73,9 +76,7 @@ class AppRunner:
query=query query=query
) )
prompt_tokens = model_type_instance.get_num_tokens( prompt_tokens = model_instance.get_llm_num_tokens(
model_config.model,
model_config.credentials,
prompt_messages prompt_messages
) )
@ -89,8 +90,10 @@ class AppRunner:
def recalc_llm_max_tokens(self, model_config: ModelConfigWithCredentialsEntity, def recalc_llm_max_tokens(self, model_config: ModelConfigWithCredentialsEntity,
prompt_messages: list[PromptMessage]): prompt_messages: list[PromptMessage]):
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit # recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
model_type_instance = model_config.provider_model_bundle.model_type_instance model_instance = ModelInstance(
model_type_instance = cast(LargeLanguageModel, model_type_instance) provider_model_bundle=model_config.provider_model_bundle,
model=model_config.model
)
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
@ -107,9 +110,7 @@ class AppRunner:
if max_tokens is None: if max_tokens is None:
max_tokens = 0 max_tokens = 0
prompt_tokens = model_type_instance.get_num_tokens( prompt_tokens = model_instance.get_llm_num_tokens(
model_config.model,
model_config.credentials,
prompt_messages prompt_messages
) )

View File

@ -37,6 +37,7 @@ from core.app.entities.task_entities import (
) )
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.message_cycle_manage import MessageCycleManage from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
AssistantPromptMessage, AssistantPromptMessage,
@ -317,29 +318,30 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
""" """
model_config = self._model_config model_config = self._model_config
model = model_config.model model = model_config.model
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance) model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle,
model=model_config.model
)
# calculate num tokens # calculate num tokens
prompt_tokens = 0 prompt_tokens = 0
if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY: if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY:
prompt_tokens = model_type_instance.get_num_tokens( prompt_tokens = model_instance.get_llm_num_tokens(
model,
model_config.credentials,
self._task_state.llm_result.prompt_messages self._task_state.llm_result.prompt_messages
) )
completion_tokens = 0 completion_tokens = 0
if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL: if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL:
completion_tokens = model_type_instance.get_num_tokens( completion_tokens = model_instance.get_llm_num_tokens(
model,
model_config.credentials,
[self._task_state.llm_result.message] [self._task_state.llm_result.message]
) )
credentials = model_config.credentials credentials = model_config.credentials
# transform usage # transform usage
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
self._task_state.llm_result.usage = model_type_instance._calc_response_usage( self._task_state.llm_result.usage = model_type_instance._calc_response_usage(
model, model,
credentials, credentials,

View File

@ -16,6 +16,7 @@ class ModelStatus(Enum):
NO_CONFIGURE = "no-configure" NO_CONFIGURE = "no-configure"
QUOTA_EXCEEDED = "quota-exceeded" QUOTA_EXCEEDED = "quota-exceeded"
NO_PERMISSION = "no-permission" NO_PERMISSION = "no-permission"
DISABLED = "disabled"
class SimpleModelProviderEntity(BaseModel): class SimpleModelProviderEntity(BaseModel):
@ -43,12 +44,19 @@ class SimpleModelProviderEntity(BaseModel):
) )
class ModelWithProviderEntity(ProviderModel): class ProviderModelWithStatusEntity(ProviderModel):
"""
Model class for model response.
"""
status: ModelStatus
load_balancing_enabled: bool = False
class ModelWithProviderEntity(ProviderModelWithStatusEntity):
""" """
Model with provider entity. Model with provider entity.
""" """
provider: SimpleModelProviderEntity provider: SimpleModelProviderEntity
status: ModelStatus
class DefaultModelProviderEntity(BaseModel): class DefaultModelProviderEntity(BaseModel):

View File

@ -1,6 +1,7 @@
import datetime import datetime
import json import json
import logging import logging
from collections import defaultdict
from collections.abc import Iterator from collections.abc import Iterator
from json import JSONDecodeError from json import JSONDecodeError
from typing import Optional from typing import Optional
@ -8,7 +9,12 @@ from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
from core.entities.provider_entities import CustomConfiguration, SystemConfiguration, SystemConfigurationStatus from core.entities.provider_entities import (
CustomConfiguration,
ModelSettings,
SystemConfiguration,
SystemConfigurationStatus,
)
from core.helper import encrypter from core.helper import encrypter
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
from core.model_runtime.entities.model_entities import FetchFrom, ModelType from core.model_runtime.entities.model_entities import FetchFrom, ModelType
@ -22,7 +28,14 @@ from core.model_runtime.model_providers import model_provider_factory
from core.model_runtime.model_providers.__base.ai_model import AIModel from core.model_runtime.model_providers.__base.ai_model import AIModel
from core.model_runtime.model_providers.__base.model_provider import ModelProvider from core.model_runtime.model_providers.__base.model_provider import ModelProvider
from extensions.ext_database import db from extensions.ext_database import db
from models.provider import Provider, ProviderModel, ProviderType, TenantPreferredModelProvider from models.provider import (
LoadBalancingModelConfig,
Provider,
ProviderModel,
ProviderModelSetting,
ProviderType,
TenantPreferredModelProvider,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -39,6 +52,7 @@ class ProviderConfiguration(BaseModel):
using_provider_type: ProviderType using_provider_type: ProviderType
system_configuration: SystemConfiguration system_configuration: SystemConfiguration
custom_configuration: CustomConfiguration custom_configuration: CustomConfiguration
model_settings: list[ModelSettings]
def __init__(self, **data): def __init__(self, **data):
super().__init__(**data) super().__init__(**data)
@ -62,6 +76,14 @@ class ProviderConfiguration(BaseModel):
:param model: model name :param model: model name
:return: :return:
""" """
if self.model_settings:
# check if model is disabled by admin
for model_setting in self.model_settings:
if (model_setting.model_type == model_type
and model_setting.model == model):
if not model_setting.enabled:
raise ValueError(f'Model {model} is disabled.')
if self.using_provider_type == ProviderType.SYSTEM: if self.using_provider_type == ProviderType.SYSTEM:
restrict_models = [] restrict_models = []
for quota_configuration in self.system_configuration.quota_configurations: for quota_configuration in self.system_configuration.quota_configurations:
@ -80,15 +102,17 @@ class ProviderConfiguration(BaseModel):
return copy_credentials return copy_credentials
else: else:
credentials = None
if self.custom_configuration.models: if self.custom_configuration.models:
for model_configuration in self.custom_configuration.models: for model_configuration in self.custom_configuration.models:
if model_configuration.model_type == model_type and model_configuration.model == model: if model_configuration.model_type == model_type and model_configuration.model == model:
return model_configuration.credentials credentials = model_configuration.credentials
break
if self.custom_configuration.provider: if self.custom_configuration.provider:
return self.custom_configuration.provider.credentials credentials = self.custom_configuration.provider.credentials
else:
return None return credentials
def get_system_configuration_status(self) -> SystemConfigurationStatus: def get_system_configuration_status(self) -> SystemConfigurationStatus:
""" """
@ -130,7 +154,7 @@ class ProviderConfiguration(BaseModel):
return credentials return credentials
# Obfuscate credentials # Obfuscate credentials
return self._obfuscated_credentials( return self.obfuscated_credentials(
credentials=credentials, credentials=credentials,
credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas
if self.provider.provider_credential_schema else [] if self.provider.provider_credential_schema else []
@ -151,7 +175,7 @@ class ProviderConfiguration(BaseModel):
).first() ).first()
# Get provider credential secret variables # Get provider credential secret variables
provider_credential_secret_variables = self._extract_secret_variables( provider_credential_secret_variables = self.extract_secret_variables(
self.provider.provider_credential_schema.credential_form_schemas self.provider.provider_credential_schema.credential_form_schemas
if self.provider.provider_credential_schema else [] if self.provider.provider_credential_schema else []
) )
@ -274,7 +298,7 @@ class ProviderConfiguration(BaseModel):
return credentials return credentials
# Obfuscate credentials # Obfuscate credentials
return self._obfuscated_credentials( return self.obfuscated_credentials(
credentials=credentials, credentials=credentials,
credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
if self.provider.model_credential_schema else [] if self.provider.model_credential_schema else []
@ -302,7 +326,7 @@ class ProviderConfiguration(BaseModel):
).first() ).first()
# Get provider credential secret variables # Get provider credential secret variables
provider_credential_secret_variables = self._extract_secret_variables( provider_credential_secret_variables = self.extract_secret_variables(
self.provider.model_credential_schema.credential_form_schemas self.provider.model_credential_schema.credential_form_schemas
if self.provider.model_credential_schema else [] if self.provider.model_credential_schema else []
) )
@ -402,6 +426,160 @@ class ProviderConfiguration(BaseModel):
provider_model_credentials_cache.delete() provider_model_credentials_cache.delete()
def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
"""
Enable model.
:param model_type: model type
:param model: model name
:return:
"""
model_setting = db.session.query(ProviderModelSetting) \
.filter(
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name == self.provider.provider,
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model
).first()
if model_setting:
model_setting.enabled = True
model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.commit()
else:
model_setting = ProviderModelSetting(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_type=model_type.to_origin_model_type(),
model_name=model,
enabled=True
)
db.session.add(model_setting)
db.session.commit()
return model_setting
def disable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
"""
Disable model.
:param model_type: model type
:param model: model name
:return:
"""
model_setting = db.session.query(ProviderModelSetting) \
.filter(
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name == self.provider.provider,
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model
).first()
if model_setting:
model_setting.enabled = False
model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.commit()
else:
model_setting = ProviderModelSetting(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_type=model_type.to_origin_model_type(),
model_name=model,
enabled=False
)
db.session.add(model_setting)
db.session.commit()
return model_setting
def get_provider_model_setting(self, model_type: ModelType, model: str) -> Optional[ProviderModelSetting]:
"""
Get provider model setting.
:param model_type: model type
:param model: model name
:return:
"""
return db.session.query(ProviderModelSetting) \
.filter(
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name == self.provider.provider,
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model
).first()
def enable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting:
"""
Enable model load balancing.
:param model_type: model type
:param model: model name
:return:
"""
load_balancing_config_count = db.session.query(LoadBalancingModelConfig) \
.filter(
LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name == self.provider.provider,
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model
).count()
if load_balancing_config_count <= 1:
raise ValueError('Model load balancing configuration must be more than 1.')
model_setting = db.session.query(ProviderModelSetting) \
.filter(
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name == self.provider.provider,
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model
).first()
if model_setting:
model_setting.load_balancing_enabled = True
model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.commit()
else:
model_setting = ProviderModelSetting(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_type=model_type.to_origin_model_type(),
model_name=model,
load_balancing_enabled=True
)
db.session.add(model_setting)
db.session.commit()
return model_setting
def disable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting:
"""
Disable model load balancing.
:param model_type: model type
:param model: model name
:return:
"""
model_setting = db.session.query(ProviderModelSetting) \
.filter(
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name == self.provider.provider,
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model
).first()
if model_setting:
model_setting.load_balancing_enabled = False
model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.commit()
else:
model_setting = ProviderModelSetting(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_type=model_type.to_origin_model_type(),
model_name=model,
load_balancing_enabled=False
)
db.session.add(model_setting)
db.session.commit()
return model_setting
def get_provider_instance(self) -> ModelProvider: def get_provider_instance(self) -> ModelProvider:
""" """
Get provider instance. Get provider instance.
@ -453,7 +631,7 @@ class ProviderConfiguration(BaseModel):
db.session.commit() db.session.commit()
def _extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]: def extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]:
""" """
Extract secret input form variables. Extract secret input form variables.
@ -467,7 +645,7 @@ class ProviderConfiguration(BaseModel):
return secret_input_form_variables return secret_input_form_variables
def _obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict: def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict:
""" """
Obfuscated credentials. Obfuscated credentials.
@ -476,7 +654,7 @@ class ProviderConfiguration(BaseModel):
:return: :return:
""" """
# Get provider credential secret variables # Get provider credential secret variables
credential_secret_variables = self._extract_secret_variables( credential_secret_variables = self.extract_secret_variables(
credential_form_schemas credential_form_schemas
) )
@ -522,15 +700,22 @@ class ProviderConfiguration(BaseModel):
else: else:
model_types = provider_instance.get_provider_schema().supported_model_types model_types = provider_instance.get_provider_schema().supported_model_types
# Group model settings by model type and model
model_setting_map = defaultdict(dict)
for model_setting in self.model_settings:
model_setting_map[model_setting.model_type][model_setting.model] = model_setting
if self.using_provider_type == ProviderType.SYSTEM: if self.using_provider_type == ProviderType.SYSTEM:
provider_models = self._get_system_provider_models( provider_models = self._get_system_provider_models(
model_types=model_types, model_types=model_types,
provider_instance=provider_instance provider_instance=provider_instance,
model_setting_map=model_setting_map
) )
else: else:
provider_models = self._get_custom_provider_models( provider_models = self._get_custom_provider_models(
model_types=model_types, model_types=model_types,
provider_instance=provider_instance provider_instance=provider_instance,
model_setting_map=model_setting_map
) )
if only_active: if only_active:
@ -541,18 +726,27 @@ class ProviderConfiguration(BaseModel):
def _get_system_provider_models(self, def _get_system_provider_models(self,
model_types: list[ModelType], model_types: list[ModelType],
provider_instance: ModelProvider) -> list[ModelWithProviderEntity]: provider_instance: ModelProvider,
model_setting_map: dict[ModelType, dict[str, ModelSettings]]) \
-> list[ModelWithProviderEntity]:
""" """
Get system provider models. Get system provider models.
:param model_types: model types :param model_types: model types
:param provider_instance: provider instance :param provider_instance: provider instance
:param model_setting_map: model setting map
:return: :return:
""" """
provider_models = [] provider_models = []
for model_type in model_types: for model_type in model_types:
provider_models.extend( for m in provider_instance.models(model_type):
[ status = ModelStatus.ACTIVE
if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
model_setting = model_setting_map[m.model_type][m.model]
if model_setting.enabled is False:
status = ModelStatus.DISABLED
provider_models.append(
ModelWithProviderEntity( ModelWithProviderEntity(
model=m.model, model=m.model,
label=m.label, label=m.label,
@ -562,10 +756,8 @@ class ProviderConfiguration(BaseModel):
model_properties=m.model_properties, model_properties=m.model_properties,
deprecated=m.deprecated, deprecated=m.deprecated,
provider=SimpleModelProviderEntity(self.provider), provider=SimpleModelProviderEntity(self.provider),
status=ModelStatus.ACTIVE status=status
) )
for m in provider_instance.models(model_type)
]
) )
if self.provider.provider not in original_provider_configurate_methods: if self.provider.provider not in original_provider_configurate_methods:
@ -586,7 +778,8 @@ class ProviderConfiguration(BaseModel):
break break
if should_use_custom_model: if should_use_custom_model:
if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]: if original_provider_configurate_methods[self.provider.provider] == [
ConfigurateMethod.CUSTOMIZABLE_MODEL]:
# only customizable model # only customizable model
for restrict_model in restrict_models: for restrict_model in restrict_models:
copy_credentials = self.system_configuration.credentials.copy() copy_credentials = self.system_configuration.credentials.copy()
@ -611,6 +804,13 @@ class ProviderConfiguration(BaseModel):
if custom_model_schema.model_type not in model_types: if custom_model_schema.model_type not in model_types:
continue continue
status = ModelStatus.ACTIVE
if (custom_model_schema.model_type in model_setting_map
and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]):
model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
if model_setting.enabled is False:
status = ModelStatus.DISABLED
provider_models.append( provider_models.append(
ModelWithProviderEntity( ModelWithProviderEntity(
model=custom_model_schema.model, model=custom_model_schema.model,
@ -621,7 +821,7 @@ class ProviderConfiguration(BaseModel):
model_properties=custom_model_schema.model_properties, model_properties=custom_model_schema.model_properties,
deprecated=custom_model_schema.deprecated, deprecated=custom_model_schema.deprecated,
provider=SimpleModelProviderEntity(self.provider), provider=SimpleModelProviderEntity(self.provider),
status=ModelStatus.ACTIVE status=status
) )
) )
@ -632,16 +832,20 @@ class ProviderConfiguration(BaseModel):
m.status = ModelStatus.NO_PERMISSION m.status = ModelStatus.NO_PERMISSION
elif not quota_configuration.is_valid: elif not quota_configuration.is_valid:
m.status = ModelStatus.QUOTA_EXCEEDED m.status = ModelStatus.QUOTA_EXCEEDED
return provider_models return provider_models
def _get_custom_provider_models(self, def _get_custom_provider_models(self,
model_types: list[ModelType], model_types: list[ModelType],
provider_instance: ModelProvider) -> list[ModelWithProviderEntity]: provider_instance: ModelProvider,
model_setting_map: dict[ModelType, dict[str, ModelSettings]]) \
-> list[ModelWithProviderEntity]:
""" """
Get custom provider models. Get custom provider models.
:param model_types: model types :param model_types: model types
:param provider_instance: provider instance :param provider_instance: provider instance
:param model_setting_map: model setting map
:return: :return:
""" """
provider_models = [] provider_models = []
@ -656,6 +860,16 @@ class ProviderConfiguration(BaseModel):
models = provider_instance.models(model_type) models = provider_instance.models(model_type)
for m in models: for m in models:
status = ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
load_balancing_enabled = False
if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
model_setting = model_setting_map[m.model_type][m.model]
if model_setting.enabled is False:
status = ModelStatus.DISABLED
if len(model_setting.load_balancing_configs) > 1:
load_balancing_enabled = True
provider_models.append( provider_models.append(
ModelWithProviderEntity( ModelWithProviderEntity(
model=m.model, model=m.model,
@ -666,7 +880,8 @@ class ProviderConfiguration(BaseModel):
model_properties=m.model_properties, model_properties=m.model_properties,
deprecated=m.deprecated, deprecated=m.deprecated,
provider=SimpleModelProviderEntity(self.provider), provider=SimpleModelProviderEntity(self.provider),
status=ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE status=status,
load_balancing_enabled=load_balancing_enabled
) )
) )
@ -690,6 +905,17 @@ class ProviderConfiguration(BaseModel):
if not custom_model_schema: if not custom_model_schema:
continue continue
status = ModelStatus.ACTIVE
load_balancing_enabled = False
if (custom_model_schema.model_type in model_setting_map
and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]):
model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
if model_setting.enabled is False:
status = ModelStatus.DISABLED
if len(model_setting.load_balancing_configs) > 1:
load_balancing_enabled = True
provider_models.append( provider_models.append(
ModelWithProviderEntity( ModelWithProviderEntity(
model=custom_model_schema.model, model=custom_model_schema.model,
@ -700,7 +926,8 @@ class ProviderConfiguration(BaseModel):
model_properties=custom_model_schema.model_properties, model_properties=custom_model_schema.model_properties,
deprecated=custom_model_schema.deprecated, deprecated=custom_model_schema.deprecated,
provider=SimpleModelProviderEntity(self.provider), provider=SimpleModelProviderEntity(self.provider),
status=ModelStatus.ACTIVE status=status,
load_balancing_enabled=load_balancing_enabled
) )
) )

View File

@ -72,3 +72,22 @@ class CustomConfiguration(BaseModel):
""" """
provider: Optional[CustomProviderConfiguration] = None provider: Optional[CustomProviderConfiguration] = None
models: list[CustomModelConfiguration] = [] models: list[CustomModelConfiguration] = []
class ModelLoadBalancingConfiguration(BaseModel):
"""
Class for model load balancing configuration.
"""
id: str
name: str
credentials: dict
class ModelSettings(BaseModel):
"""
Model class for model settings.
"""
model: str
model_type: ModelType
enabled: bool = True
load_balancing_configs: list[ModelLoadBalancingConfiguration] = []

View File

@ -7,7 +7,7 @@ from typing import Any, Optional
from pydantic import BaseModel from pydantic import BaseModel
from core.utils.position_helper import sort_to_dict_by_position_map from core.helper.position_helper import sort_to_dict_by_position_map
class ExtensionModule(enum.Enum): class ExtensionModule(enum.Enum):

View File

@ -9,6 +9,7 @@ from extensions.ext_redis import redis_client
class ProviderCredentialsCacheType(Enum): class ProviderCredentialsCacheType(Enum):
PROVIDER = "provider" PROVIDER = "provider"
MODEL = "provider_model" MODEL = "provider_model"
LOAD_BALANCING_MODEL = "load_balancing_provider_model"
class ProviderCredentialsCache: class ProviderCredentialsCache:

View File

@ -286,11 +286,7 @@ class IndexingRunner:
if len(preview_texts) < 5: if len(preview_texts) < 5:
preview_texts.append(document.page_content) preview_texts.append(document.page_content)
if indexing_technique == 'high_quality' or embedding_model_instance: if indexing_technique == 'high_quality' or embedding_model_instance:
embedding_model_type_instance = embedding_model_instance.model_type_instance tokens += embedding_model_instance.get_text_embedding_num_tokens(
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
tokens += embedding_model_type_instance.get_num_tokens(
model=embedding_model_instance.model,
credentials=embedding_model_instance.credentials,
texts=[self.filter_string(document.page_content)] texts=[self.filter_string(document.page_content)]
) )
@ -658,10 +654,6 @@ class IndexingRunner:
tokens = 0 tokens = 0
chunk_size = 10 chunk_size = 10
embedding_model_type_instance = None
if embedding_model_instance:
embedding_model_type_instance = embedding_model_instance.model_type_instance
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
# create keyword index # create keyword index
create_keyword_thread = threading.Thread(target=self._process_keyword_index, create_keyword_thread = threading.Thread(target=self._process_keyword_index,
args=(current_app._get_current_object(), args=(current_app._get_current_object(),
@ -674,8 +666,7 @@ class IndexingRunner:
chunk_documents = documents[i:i + chunk_size] chunk_documents = documents[i:i + chunk_size]
futures.append(executor.submit(self._process_chunk, current_app._get_current_object(), index_processor, futures.append(executor.submit(self._process_chunk, current_app._get_current_object(), index_processor,
chunk_documents, dataset, chunk_documents, dataset,
dataset_document, embedding_model_instance, dataset_document, embedding_model_instance))
embedding_model_type_instance))
for future in futures: for future in futures:
tokens += future.result() tokens += future.result()
@ -716,7 +707,7 @@ class IndexingRunner:
db.session.commit() db.session.commit()
def _process_chunk(self, flask_app, index_processor, chunk_documents, dataset, dataset_document, def _process_chunk(self, flask_app, index_processor, chunk_documents, dataset, dataset_document,
embedding_model_instance, embedding_model_type_instance): embedding_model_instance):
with flask_app.app_context(): with flask_app.app_context():
# check document is paused # check document is paused
self._check_document_paused_status(dataset_document.id) self._check_document_paused_status(dataset_document.id)
@ -724,9 +715,7 @@ class IndexingRunner:
tokens = 0 tokens = 0
if dataset.indexing_technique == 'high_quality' or embedding_model_type_instance: if dataset.indexing_technique == 'high_quality' or embedding_model_type_instance:
tokens += sum( tokens += sum(
embedding_model_type_instance.get_num_tokens( embedding_model_instance.get_text_embedding_num_tokens(
embedding_model_instance.model,
embedding_model_instance.credentials,
[document.page_content] [document.page_content]
) )
for document in chunk_documents for document in chunk_documents

View File

@ -9,8 +9,6 @@ from core.model_runtime.entities.message_entities import (
TextPromptMessageContent, TextPromptMessageContent,
UserPromptMessage, UserPromptMessage,
) )
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers import model_provider_factory
from extensions.ext_database import db from extensions.ext_database import db
from models.model import AppMode, Conversation, Message from models.model import AppMode, Conversation, Message
@ -78,12 +76,7 @@ class TokenBufferMemory:
return [] return []
# prune the chat message if it exceeds the max token limit # prune the chat message if it exceeds the max token limit
provider_instance = model_provider_factory.get_provider_instance(self.model_instance.provider) curr_message_tokens = self.model_instance.get_llm_num_tokens(
model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
curr_message_tokens = model_type_instance.get_num_tokens(
self.model_instance.model,
self.model_instance.credentials,
prompt_messages prompt_messages
) )
@ -91,9 +84,7 @@ class TokenBufferMemory:
pruned_memory = [] pruned_memory = []
while curr_message_tokens > max_token_limit and prompt_messages: while curr_message_tokens > max_token_limit and prompt_messages:
pruned_memory.append(prompt_messages.pop(0)) pruned_memory.append(prompt_messages.pop(0))
curr_message_tokens = model_type_instance.get_num_tokens( curr_message_tokens = self.model_instance.get_llm_num_tokens(
self.model_instance.model,
self.model_instance.credentials,
prompt_messages prompt_messages
) )

View File

@ -1,7 +1,10 @@
import logging
import os
from collections.abc import Generator from collections.abc import Generator
from typing import IO, Optional, Union, cast from typing import IO, Optional, Union, cast
from core.entities.provider_configuration import ProviderModelBundle from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
from core.entities.provider_entities import ModelLoadBalancingConfiguration
from core.errors.error import ProviderTokenNotInitError from core.errors.error import ProviderTokenNotInitError
from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.entities.llm_entities import LLMResult
@ -9,6 +12,7 @@ from core.model_runtime.entities.message_entities import PromptMessage, PromptMe
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.rerank_entities import RerankResult from core.model_runtime.entities.rerank_entities import RerankResult
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.__base.moderation_model import ModerationModel from core.model_runtime.model_providers.__base.moderation_model import ModerationModel
from core.model_runtime.model_providers.__base.rerank_model import RerankModel from core.model_runtime.model_providers.__base.rerank_model import RerankModel
@ -16,6 +20,10 @@ from core.model_runtime.model_providers.__base.speech2text_model import Speech2T
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.model_providers.__base.tts_model import TTSModel from core.model_runtime.model_providers.__base.tts_model import TTSModel
from core.provider_manager import ProviderManager from core.provider_manager import ProviderManager
from extensions.ext_redis import redis_client
from models.provider import ProviderType
logger = logging.getLogger(__name__)
class ModelInstance: class ModelInstance:
@ -29,6 +37,12 @@ class ModelInstance:
self.provider = provider_model_bundle.configuration.provider.provider self.provider = provider_model_bundle.configuration.provider.provider
self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model) self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
self.model_type_instance = self.provider_model_bundle.model_type_instance self.model_type_instance = self.provider_model_bundle.model_type_instance
self.load_balancing_manager = self._get_load_balancing_manager(
configuration=provider_model_bundle.configuration,
model_type=provider_model_bundle.model_type_instance.model_type,
model=model,
credentials=self.credentials
)
def _fetch_credentials_from_bundle(self, provider_model_bundle: ProviderModelBundle, model: str) -> dict: def _fetch_credentials_from_bundle(self, provider_model_bundle: ProviderModelBundle, model: str) -> dict:
""" """
@ -37,8 +51,10 @@ class ModelInstance:
:param model: model name :param model: model name
:return: :return:
""" """
credentials = provider_model_bundle.configuration.get_current_credentials( configuration = provider_model_bundle.configuration
model_type=provider_model_bundle.model_type_instance.model_type, model_type = provider_model_bundle.model_type_instance.model_type
credentials = configuration.get_current_credentials(
model_type=model_type,
model=model model=model
) )
@ -47,6 +63,43 @@ class ModelInstance:
return credentials return credentials
def _get_load_balancing_manager(self, configuration: ProviderConfiguration,
model_type: ModelType,
model: str,
credentials: dict) -> Optional["LBModelManager"]:
"""
Get load balancing model credentials
:param configuration: provider configuration
:param model_type: model type
:param model: model name
:param credentials: model credentials
:return:
"""
if configuration.model_settings and configuration.using_provider_type == ProviderType.CUSTOM:
current_model_setting = None
# check if model is disabled by admin
for model_setting in configuration.model_settings:
if (model_setting.model_type == model_type
and model_setting.model == model):
current_model_setting = model_setting
break
# check if load balancing is enabled
if current_model_setting and current_model_setting.load_balancing_configs:
# use load balancing proxy to choose credentials
lb_model_manager = LBModelManager(
tenant_id=configuration.tenant_id,
provider=configuration.provider.provider,
model_type=model_type,
model=model,
load_balancing_configs=current_model_setting.load_balancing_configs,
managed_credentials=credentials if configuration.custom_configuration.provider else None
)
return lb_model_manager
return None
def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None, def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \ stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \
@ -67,7 +120,8 @@ class ModelInstance:
raise Exception("Model type instance is not LargeLanguageModel") raise Exception("Model type instance is not LargeLanguageModel")
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance) self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
return self.model_type_instance.invoke( return self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model, model=self.model,
credentials=self.credentials, credentials=self.credentials,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
@ -79,6 +133,27 @@ class ModelInstance:
callbacks=callbacks callbacks=callbacks
) )
def get_llm_num_tokens(self, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
"""
Get number of tokens for llm
:param prompt_messages: prompt messages
:param tools: tools for tool calling
:return:
"""
if not isinstance(self.model_type_instance, LargeLanguageModel):
raise Exception("Model type instance is not LargeLanguageModel")
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
return self._round_robin_invoke(
function=self.model_type_instance.get_num_tokens,
model=self.model,
credentials=self.credentials,
prompt_messages=prompt_messages,
tools=tools
)
def invoke_text_embedding(self, texts: list[str], user: Optional[str] = None) \ def invoke_text_embedding(self, texts: list[str], user: Optional[str] = None) \
-> TextEmbeddingResult: -> TextEmbeddingResult:
""" """
@ -92,13 +167,32 @@ class ModelInstance:
raise Exception("Model type instance is not TextEmbeddingModel") raise Exception("Model type instance is not TextEmbeddingModel")
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance) self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
return self.model_type_instance.invoke( return self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model, model=self.model,
credentials=self.credentials, credentials=self.credentials,
texts=texts, texts=texts,
user=user user=user
) )
def get_text_embedding_num_tokens(self, texts: list[str]) -> int:
"""
Get number of tokens for text embedding
:param texts: texts to embed
:return:
"""
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
return self._round_robin_invoke(
function=self.model_type_instance.get_num_tokens,
model=self.model,
credentials=self.credentials,
texts=texts
)
def invoke_rerank(self, query: str, docs: list[str], score_threshold: Optional[float] = None, def invoke_rerank(self, query: str, docs: list[str], score_threshold: Optional[float] = None,
top_n: Optional[int] = None, top_n: Optional[int] = None,
user: Optional[str] = None) \ user: Optional[str] = None) \
@ -117,7 +211,8 @@ class ModelInstance:
raise Exception("Model type instance is not RerankModel") raise Exception("Model type instance is not RerankModel")
self.model_type_instance = cast(RerankModel, self.model_type_instance) self.model_type_instance = cast(RerankModel, self.model_type_instance)
return self.model_type_instance.invoke( return self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model, model=self.model,
credentials=self.credentials, credentials=self.credentials,
query=query, query=query,
@ -140,7 +235,8 @@ class ModelInstance:
raise Exception("Model type instance is not ModerationModel") raise Exception("Model type instance is not ModerationModel")
self.model_type_instance = cast(ModerationModel, self.model_type_instance) self.model_type_instance = cast(ModerationModel, self.model_type_instance)
return self.model_type_instance.invoke( return self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model, model=self.model,
credentials=self.credentials, credentials=self.credentials,
text=text, text=text,
@ -160,7 +256,8 @@ class ModelInstance:
raise Exception("Model type instance is not Speech2TextModel") raise Exception("Model type instance is not Speech2TextModel")
self.model_type_instance = cast(Speech2TextModel, self.model_type_instance) self.model_type_instance = cast(Speech2TextModel, self.model_type_instance)
return self.model_type_instance.invoke( return self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model, model=self.model,
credentials=self.credentials, credentials=self.credentials,
file=file, file=file,
@ -183,7 +280,8 @@ class ModelInstance:
raise Exception("Model type instance is not TTSModel") raise Exception("Model type instance is not TTSModel")
self.model_type_instance = cast(TTSModel, self.model_type_instance) self.model_type_instance = cast(TTSModel, self.model_type_instance)
return self.model_type_instance.invoke( return self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model, model=self.model,
credentials=self.credentials, credentials=self.credentials,
content_text=content_text, content_text=content_text,
@ -193,6 +291,43 @@ class ModelInstance:
streaming=streaming streaming=streaming
) )
def _round_robin_invoke(self, function: callable, *args, **kwargs):
"""
Round-robin invoke
:param function: function to invoke
:param args: function args
:param kwargs: function kwargs
:return:
"""
if not self.load_balancing_manager:
return function(*args, **kwargs)
last_exception = None
while True:
lb_config = self.load_balancing_manager.fetch_next()
if not lb_config:
if not last_exception:
raise ProviderTokenNotInitError("Model credentials is not initialized.")
else:
raise last_exception
try:
if 'credentials' in kwargs:
del kwargs['credentials']
return function(*args, **kwargs, credentials=lb_config.credentials)
except InvokeRateLimitError as e:
# expire in 60 seconds
self.load_balancing_manager.cooldown(lb_config, expire=60)
last_exception = e
continue
except (InvokeAuthorizationError, InvokeConnectionError) as e:
# expire in 10 seconds
self.load_balancing_manager.cooldown(lb_config, expire=10)
last_exception = e
continue
except Exception as e:
raise e
def get_tts_voices(self, language: str) -> list: def get_tts_voices(self, language: str) -> list:
""" """
Invoke large language tts model voices Invoke large language tts model voices
@ -226,6 +361,7 @@ class ModelManager:
""" """
if not provider: if not provider:
return self.get_default_model_instance(tenant_id, model_type) return self.get_default_model_instance(tenant_id, model_type)
provider_model_bundle = self._provider_manager.get_provider_model_bundle( provider_model_bundle = self._provider_manager.get_provider_model_bundle(
tenant_id=tenant_id, tenant_id=tenant_id,
provider=provider, provider=provider,
@ -255,3 +391,141 @@ class ModelManager:
model_type=model_type, model_type=model_type,
model=default_model_entity.model model=default_model_entity.model
) )
class LBModelManager:
def __init__(self, tenant_id: str,
provider: str,
model_type: ModelType,
model: str,
load_balancing_configs: list[ModelLoadBalancingConfiguration],
managed_credentials: Optional[dict] = None) -> None:
"""
Load balancing model manager
:param load_balancing_configs: all load balancing configurations
:param managed_credentials: credentials if load balancing configuration name is __inherit__
"""
self._tenant_id = tenant_id
self._provider = provider
self._model_type = model_type
self._model = model
self._load_balancing_configs = load_balancing_configs
for load_balancing_config in self._load_balancing_configs:
if load_balancing_config.name == "__inherit__":
if not managed_credentials:
# remove __inherit__ if managed credentials is not provided
self._load_balancing_configs.remove(load_balancing_config)
else:
load_balancing_config.credentials = managed_credentials
def fetch_next(self) -> Optional[ModelLoadBalancingConfiguration]:
"""
Get next model load balancing config
Strategy: Round Robin
:return:
"""
cache_key = "model_lb_index:{}:{}:{}:{}".format(
self._tenant_id,
self._provider,
self._model_type.value,
self._model
)
cooldown_load_balancing_configs = []
max_index = len(self._load_balancing_configs)
while True:
current_index = redis_client.incr(cache_key)
if current_index >= 10000000:
current_index = 1
redis_client.set(cache_key, current_index)
redis_client.expire(cache_key, 3600)
if current_index > max_index:
current_index = current_index % max_index
real_index = current_index - 1
if real_index > max_index:
real_index = 0
config = self._load_balancing_configs[real_index]
if self.in_cooldown(config):
cooldown_load_balancing_configs.append(config)
if len(cooldown_load_balancing_configs) >= len(self._load_balancing_configs):
# all configs are in cooldown
return None
continue
if bool(os.environ.get("DEBUG", 'False').lower() == 'true'):
logger.info(f"Model LB\nid: {config.id}\nname:{config.name}\n"
f"tenant_id: {self._tenant_id}\nprovider: {self._provider}\n"
f"model_type: {self._model_type.value}\nmodel: {self._model}")
return config
return None
def cooldown(self, config: ModelLoadBalancingConfiguration, expire: int = 60) -> None:
"""
Cooldown model load balancing config
:param config: model load balancing config
:param expire: cooldown time
:return:
"""
cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format(
self._tenant_id,
self._provider,
self._model_type.value,
self._model,
config.id
)
redis_client.setex(cooldown_cache_key, expire, 'true')
def in_cooldown(self, config: ModelLoadBalancingConfiguration) -> bool:
"""
Check if model load balancing config is in cooldown
:param config: model load balancing config
:return:
"""
cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format(
self._tenant_id,
self._provider,
self._model_type.value,
self._model,
config.id
)
return redis_client.exists(cooldown_cache_key)
@classmethod
def get_config_in_cooldown_and_ttl(cls, tenant_id: str,
provider: str,
model_type: ModelType,
model: str,
config_id: str) -> tuple[bool, int]:
"""
Get model load balancing config is in cooldown and ttl
:param tenant_id: workspace id
:param provider: provider name
:param model_type: model type
:param model: model name
:param config_id: model load balancing config id
:return:
"""
cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format(
tenant_id,
provider,
model_type.value,
model,
config_id
)
ttl = redis_client.ttl(cooldown_cache_key)
if ttl == -2:
return False, 0
return True, ttl

View File

@ -3,6 +3,7 @@ import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional from typing import Optional
from core.helper.position_helper import get_position_map, sort_by_position_map
from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
from core.model_runtime.entities.model_entities import ( from core.model_runtime.entities.model_entities import (
@ -17,7 +18,6 @@ from core.model_runtime.entities.model_entities import (
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
from core.tools.utils.yaml_utils import load_yaml_file from core.tools.utils.yaml_utils import load_yaml_file
from core.utils.position_helper import get_position_map, sort_by_position_map
class AIModel(ABC): class AIModel(ABC):

View File

@ -1,11 +1,11 @@
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from core.helper.module_import_helper import get_subclasses_from_module, import_module_from_source
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
from core.model_runtime.entities.provider_entities import ProviderEntity from core.model_runtime.entities.provider_entities import ProviderEntity
from core.model_runtime.model_providers.__base.ai_model import AIModel from core.model_runtime.model_providers.__base.ai_model import AIModel
from core.tools.utils.yaml_utils import load_yaml_file from core.tools.utils.yaml_utils import load_yaml_file
from core.utils.module_import_helper import get_subclasses_from_module, import_module_from_source
class ModelProvider(ABC): class ModelProvider(ABC):

View File

@ -4,13 +4,13 @@ from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel
from core.helper.module_import_helper import load_single_subclass_from_source
from core.helper.position_helper import get_position_map, sort_to_dict_by_position_map
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity
from core.model_runtime.model_providers.__base.model_provider import ModelProvider from core.model_runtime.model_providers.__base.model_provider import ModelProvider
from core.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator from core.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator
from core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator from core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator
from core.utils.module_import_helper import load_single_subclass_from_source
from core.utils.position_helper import get_position_map, sort_to_dict_by_position_map
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -1,10 +1,10 @@
from typing import Optional, cast from typing import Optional
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import PromptMessage from core.model_runtime.entities.message_entities import PromptMessage
from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.prompt.entities.advanced_prompt_entities import MemoryConfig
@ -25,12 +25,12 @@ class PromptTransform:
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
if model_context_tokens: if model_context_tokens:
model_type_instance = model_config.provider_model_bundle.model_type_instance model_instance = ModelInstance(
model_type_instance = cast(LargeLanguageModel, model_type_instance) provider_model_bundle=model_config.provider_model_bundle,
model=model_config.model
)
curr_message_tokens = model_type_instance.get_num_tokens( curr_message_tokens = model_instance.get_llm_num_tokens(
model_config.model,
model_config.credentials,
prompt_messages prompt_messages
) )

View File

@ -11,6 +11,8 @@ from core.entities.provider_entities import (
CustomConfiguration, CustomConfiguration,
CustomModelConfiguration, CustomModelConfiguration,
CustomProviderConfiguration, CustomProviderConfiguration,
ModelLoadBalancingConfiguration,
ModelSettings,
QuotaConfiguration, QuotaConfiguration,
SystemConfiguration, SystemConfiguration,
) )
@ -26,13 +28,16 @@ from core.model_runtime.model_providers import model_provider_factory
from extensions import ext_hosting_provider from extensions import ext_hosting_provider
from extensions.ext_database import db from extensions.ext_database import db
from models.provider import ( from models.provider import (
LoadBalancingModelConfig,
Provider, Provider,
ProviderModel, ProviderModel,
ProviderModelSetting,
ProviderQuotaType, ProviderQuotaType,
ProviderType, ProviderType,
TenantDefaultModel, TenantDefaultModel,
TenantPreferredModelProvider, TenantPreferredModelProvider,
) )
from services.feature_service import FeatureService
class ProviderManager: class ProviderManager:
@ -98,6 +103,13 @@ class ProviderManager:
# Get All preferred provider types of the workspace # Get All preferred provider types of the workspace
provider_name_to_preferred_model_provider_records_dict = self._get_all_preferred_model_providers(tenant_id) provider_name_to_preferred_model_provider_records_dict = self._get_all_preferred_model_providers(tenant_id)
# Get All provider model settings
provider_name_to_provider_model_settings_dict = self._get_all_provider_model_settings(tenant_id)
# Get All load balancing configs
provider_name_to_provider_load_balancing_model_configs_dict \
= self._get_all_provider_load_balancing_configs(tenant_id)
provider_configurations = ProviderConfigurations( provider_configurations = ProviderConfigurations(
tenant_id=tenant_id tenant_id=tenant_id
) )
@ -147,13 +159,28 @@ class ProviderManager:
if system_configuration.enabled and has_valid_quota: if system_configuration.enabled and has_valid_quota:
using_provider_type = ProviderType.SYSTEM using_provider_type = ProviderType.SYSTEM
# Get provider load balancing configs
provider_model_settings = provider_name_to_provider_model_settings_dict.get(provider_name)
# Get provider load balancing configs
provider_load_balancing_configs \
= provider_name_to_provider_load_balancing_model_configs_dict.get(provider_name)
# Convert to model settings
model_settings = self._to_model_settings(
provider_entity=provider_entity,
provider_model_settings=provider_model_settings,
load_balancing_model_configs=provider_load_balancing_configs
)
provider_configuration = ProviderConfiguration( provider_configuration = ProviderConfiguration(
tenant_id=tenant_id, tenant_id=tenant_id,
provider=provider_entity, provider=provider_entity,
preferred_provider_type=preferred_provider_type, preferred_provider_type=preferred_provider_type,
using_provider_type=using_provider_type, using_provider_type=using_provider_type,
system_configuration=system_configuration, system_configuration=system_configuration,
custom_configuration=custom_configuration custom_configuration=custom_configuration,
model_settings=model_settings
) )
provider_configurations[provider_name] = provider_configuration provider_configurations[provider_name] = provider_configuration
@ -338,7 +365,7 @@ class ProviderManager:
""" """
Get All preferred provider types of the workspace. Get All preferred provider types of the workspace.
:param tenant_id: :param tenant_id: workspace id
:return: :return:
""" """
preferred_provider_types = db.session.query(TenantPreferredModelProvider) \ preferred_provider_types = db.session.query(TenantPreferredModelProvider) \
@ -353,6 +380,48 @@ class ProviderManager:
return provider_name_to_preferred_provider_type_records_dict return provider_name_to_preferred_provider_type_records_dict
def _get_all_provider_model_settings(self, tenant_id: str) -> dict[str, list[ProviderModelSetting]]:
"""
Get All provider model settings of the workspace.
:param tenant_id: workspace id
:return:
"""
provider_model_settings = db.session.query(ProviderModelSetting) \
.filter(
ProviderModelSetting.tenant_id == tenant_id
).all()
provider_name_to_provider_model_settings_dict = defaultdict(list)
for provider_model_setting in provider_model_settings:
(provider_name_to_provider_model_settings_dict[provider_model_setting.provider_name]
.append(provider_model_setting))
return provider_name_to_provider_model_settings_dict
def _get_all_provider_load_balancing_configs(self, tenant_id: str) -> dict[str, list[LoadBalancingModelConfig]]:
"""
Get All provider load balancing configs of the workspace.
:param tenant_id: workspace id
:return:
"""
model_load_balancing_enabled = FeatureService.get_features(tenant_id).model_load_balancing_enabled
if not model_load_balancing_enabled:
return dict()
provider_load_balancing_configs = db.session.query(LoadBalancingModelConfig) \
.filter(
LoadBalancingModelConfig.tenant_id == tenant_id
).all()
provider_name_to_provider_load_balancing_model_configs_dict = defaultdict(list)
for provider_load_balancing_config in provider_load_balancing_configs:
(provider_name_to_provider_load_balancing_model_configs_dict[provider_load_balancing_config.provider_name]
.append(provider_load_balancing_config))
return provider_name_to_provider_load_balancing_model_configs_dict
def _init_trial_provider_records(self, tenant_id: str, def _init_trial_provider_records(self, tenant_id: str,
provider_name_to_provider_records_dict: dict[str, list]) -> dict[str, list]: provider_name_to_provider_records_dict: dict[str, list]) -> dict[str, list]:
""" """
@ -726,3 +795,97 @@ class ProviderManager:
secret_input_form_variables.append(credential_form_schema.variable) secret_input_form_variables.append(credential_form_schema.variable)
return secret_input_form_variables return secret_input_form_variables
def _to_model_settings(self, provider_entity: ProviderEntity,
provider_model_settings: Optional[list[ProviderModelSetting]] = None,
load_balancing_model_configs: Optional[list[LoadBalancingModelConfig]] = None) \
-> list[ModelSettings]:
"""
Convert to model settings.
:param provider_model_settings: provider model settings include enabled, load balancing enabled
:param load_balancing_model_configs: load balancing model configs
:return:
"""
# Get provider model credential secret variables
model_credential_secret_variables = self._extract_secret_variables(
provider_entity.model_credential_schema.credential_form_schemas
if provider_entity.model_credential_schema else []
)
model_settings = []
if not provider_model_settings:
return model_settings
for provider_model_setting in provider_model_settings:
load_balancing_configs = []
if provider_model_setting.load_balancing_enabled and load_balancing_model_configs:
for load_balancing_model_config in load_balancing_model_configs:
if (load_balancing_model_config.model_name == provider_model_setting.model_name
and load_balancing_model_config.model_type == provider_model_setting.model_type):
if not load_balancing_model_config.enabled:
continue
if not load_balancing_model_config.encrypted_config:
if load_balancing_model_config.name == "__inherit__":
load_balancing_configs.append(ModelLoadBalancingConfiguration(
id=load_balancing_model_config.id,
name=load_balancing_model_config.name,
credentials={}
))
continue
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=load_balancing_model_config.tenant_id,
identity_id=load_balancing_model_config.id,
cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL
)
# Get cached provider model credentials
cached_provider_model_credentials = provider_model_credentials_cache.get()
if not cached_provider_model_credentials:
try:
provider_model_credentials = json.loads(load_balancing_model_config.encrypted_config)
except JSONDecodeError:
continue
# Get decoding rsa key and cipher for decrypting credentials
if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None:
self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(
load_balancing_model_config.tenant_id)
for variable in model_credential_secret_variables:
if variable in provider_model_credentials:
try:
provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding(
provider_model_credentials.get(variable),
self.decoding_rsa_key,
self.decoding_cipher_rsa
)
except ValueError:
pass
# cache provider model credentials
provider_model_credentials_cache.set(
credentials=provider_model_credentials
)
else:
provider_model_credentials = cached_provider_model_credentials
load_balancing_configs.append(ModelLoadBalancingConfiguration(
id=load_balancing_model_config.id,
name=load_balancing_model_config.name,
credentials=provider_model_credentials
))
model_settings.append(
ModelSettings(
model=provider_model_setting.model_name,
model_type=ModelType.value_of(provider_model_setting.model_type),
enabled=provider_model_setting.enabled,
load_balancing_configs=load_balancing_configs if len(load_balancing_configs) > 1 else []
)
)
return model_settings

View File

@ -1,11 +1,10 @@
from collections.abc import Sequence from collections.abc import Sequence
from typing import Any, Optional, cast from typing import Any, Optional
from sqlalchemy import func from sqlalchemy import func
from core.model_manager import ModelManager from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.rag.models.document import Document from core.rag.models.document import Document
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment from models.dataset import Dataset, DocumentSegment
@ -95,11 +94,7 @@ class DatasetDocumentStore:
# calc embedding use tokens # calc embedding use tokens
if embedding_model: if embedding_model:
model_type_instance = embedding_model.model_type_instance tokens = embedding_model.get_text_embedding_num_tokens(
model_type_instance = cast(TextEmbeddingModel, model_type_instance)
tokens = model_type_instance.get_num_tokens(
model=embedding_model.model,
credentials=embedding_model.credentials,
texts=[doc.page_content] texts=[doc.page_content]
) )
else: else:

View File

@ -1,10 +1,9 @@
"""Functionality for splitting text.""" """Functionality for splitting text."""
from __future__ import annotations from __future__ import annotations
from typing import Any, Optional, cast from typing import Any, Optional
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
from core.rag.splitter.text_splitter import ( from core.rag.splitter.text_splitter import (
TS, TS,
@ -35,11 +34,7 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
return 0 return 0
if embedding_model_instance: if embedding_model_instance:
embedding_model_type_instance = embedding_model_instance.model_type_instance return embedding_model_instance.get_text_embedding_num_tokens(
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
return embedding_model_type_instance.get_num_tokens(
model=embedding_model_instance.model,
credentials=embedding_model_instance.credentials,
texts=[text] texts=[text]
) )
else: else:

View File

@ -1,7 +1,7 @@
import os.path import os.path
from core.helper.position_helper import get_position_map, sort_by_position_map
from core.tools.entities.api_entities import UserToolProvider from core.tools.entities.api_entities import UserToolProvider
from core.utils.position_helper import get_position_map, sort_by_position_map
class BuiltinToolProviderSort: class BuiltinToolProviderSort:

View File

@ -2,6 +2,7 @@ from abc import abstractmethod
from os import listdir, path from os import listdir, path
from typing import Any from typing import Any
from core.helper.module_import_helper import load_single_subclass_from_source
from core.tools.entities.tool_entities import ToolParameter, ToolProviderCredentials, ToolProviderType from core.tools.entities.tool_entities import ToolParameter, ToolProviderCredentials, ToolProviderType
from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict
from core.tools.errors import ( from core.tools.errors import (
@ -14,7 +15,6 @@ from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.tool.tool import Tool from core.tools.tool.tool import Tool
from core.tools.utils.tool_parameter_converter import ToolParameterConverter from core.tools.utils.tool_parameter_converter import ToolParameterConverter
from core.tools.utils.yaml_utils import load_yaml_file from core.tools.utils.yaml_utils import load_yaml_file
from core.utils.module_import_helper import load_single_subclass_from_source
class BuiltinToolProviderController(ToolProviderController): class BuiltinToolProviderController(ToolProviderController):

View File

@ -10,6 +10,7 @@ from flask import current_app
from core.agent.entities import AgentToolEntity from core.agent.entities import AgentToolEntity
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.helper.module_import_helper import load_single_subclass_from_source
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.entities.api_entities import UserToolProvider, UserToolProviderTypeLiteral from core.tools.entities.api_entities import UserToolProvider, UserToolProviderTypeLiteral
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
@ -31,7 +32,6 @@ from core.tools.utils.configuration import (
ToolParameterConfigurationManager, ToolParameterConfigurationManager,
) )
from core.tools.utils.tool_parameter_converter import ToolParameterConverter from core.tools.utils.tool_parameter_converter import ToolParameterConverter
from core.utils.module_import_helper import load_single_subclass_from_source
from core.workflow.nodes.tool.entities import ToolEntity from core.workflow.nodes.tool.entities import ToolEntity
from extensions.ext_database import db from extensions.ext_database import db
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider

View File

@ -73,10 +73,8 @@ class ModelInvocationUtils:
if not model_instance: if not model_instance:
raise InvokeModelError('Model not found') raise InvokeModelError('Model not found')
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
# get tokens # get tokens
tokens = llm_model.get_num_tokens(model_instance.model, model_instance.credentials, prompt_messages) tokens = model_instance.get_llm_num_tokens(prompt_messages)
return tokens return tokens
@ -108,13 +106,8 @@ class ModelInvocationUtils:
tenant_id=tenant_id, model_type=ModelType.LLM, tenant_id=tenant_id, model_type=ModelType.LLM,
) )
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
# get model credentials
model_credentials = model_instance.credentials
# get prompt tokens # get prompt tokens
prompt_tokens = llm_model.get_num_tokens(model_instance.model, model_credentials, prompt_messages) prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
model_parameters = { model_parameters = {
'temperature': 0.8, 'temperature': 0.8,
@ -144,9 +137,7 @@ class ModelInvocationUtils:
db.session.commit() db.session.commit()
try: try:
response: LLMResult = llm_model.invoke( response: LLMResult = model_instance.invoke_llm(
model=model_instance.model,
credentials=model_credentials,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
model_parameters=model_parameters, model_parameters=model_parameters,
tools=[], stop=[], stream=False, user=user_id, callbacks=[] tools=[], stop=[], stream=False, user=user_id, callbacks=[]

View File

@ -4,9 +4,9 @@ from typing import Optional, Union, cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole
from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
@ -200,12 +200,12 @@ class QuestionClassifierNode(LLMNode):
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
if model_context_tokens: if model_context_tokens:
model_type_instance = model_config.provider_model_bundle.model_type_instance model_instance = ModelInstance(
model_type_instance = cast(LargeLanguageModel, model_type_instance) provider_model_bundle=model_config.provider_model_bundle,
model=model_config.model
)
curr_message_tokens = model_type_instance.get_num_tokens( curr_message_tokens = model_instance.get_llm_num_tokens(
model_config.model,
model_config.credentials,
prompt_messages prompt_messages
) )

View File

@ -0,0 +1,126 @@
"""add load balancing
Revision ID: 4e99a8df00ff
Revises: 47cc7df8c4f3
Create Date: 2024-05-10 12:08:09.812736
"""
import sqlalchemy as sa
from alembic import op
import models as models
# revision identifiers, used by Alembic.
revision = '4e99a8df00ff'
down_revision = '64a70a7aab8b'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('load_balancing_model_configs',
sa.Column('id', models.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.StringUUID(), nullable=False),
sa.Column('provider_name', sa.String(length=255), nullable=False),
sa.Column('model_name', sa.String(length=255), nullable=False),
sa.Column('model_type', sa.String(length=40), nullable=False),
sa.Column('name', sa.String(length=255), nullable=False),
sa.Column('encrypted_config', sa.Text(), nullable=True),
sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.PrimaryKeyConstraint('id', name='load_balancing_model_config_pkey')
)
with op.batch_alter_table('load_balancing_model_configs', schema=None) as batch_op:
batch_op.create_index('load_balancing_model_config_tenant_provider_model_idx', ['tenant_id', 'provider_name', 'model_type'], unique=False)
op.create_table('provider_model_settings',
sa.Column('id', models.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.StringUUID(), nullable=False),
sa.Column('provider_name', sa.String(length=255), nullable=False),
sa.Column('model_name', sa.String(length=255), nullable=False),
sa.Column('model_type', sa.String(length=40), nullable=False),
sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False),
sa.Column('load_balancing_enabled', sa.Boolean(), server_default=sa.text('false'), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.PrimaryKeyConstraint('id', name='provider_model_setting_pkey')
)
with op.batch_alter_table('provider_model_settings', schema=None) as batch_op:
batch_op.create_index('provider_model_setting_tenant_provider_model_idx', ['tenant_id', 'provider_name', 'model_type'], unique=False)
with op.batch_alter_table('provider_models', schema=None) as batch_op:
batch_op.alter_column('provider_name',
existing_type=sa.VARCHAR(length=40),
type_=sa.String(length=255),
existing_nullable=False)
with op.batch_alter_table('provider_orders', schema=None) as batch_op:
batch_op.alter_column('provider_name',
existing_type=sa.VARCHAR(length=40),
type_=sa.String(length=255),
existing_nullable=False)
with op.batch_alter_table('providers', schema=None) as batch_op:
batch_op.alter_column('provider_name',
existing_type=sa.VARCHAR(length=40),
type_=sa.String(length=255),
existing_nullable=False)
with op.batch_alter_table('tenant_default_models', schema=None) as batch_op:
batch_op.alter_column('provider_name',
existing_type=sa.VARCHAR(length=40),
type_=sa.String(length=255),
existing_nullable=False)
with op.batch_alter_table('tenant_preferred_model_providers', schema=None) as batch_op:
batch_op.alter_column('provider_name',
existing_type=sa.VARCHAR(length=40),
type_=sa.String(length=255),
existing_nullable=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tenant_preferred_model_providers', schema=None) as batch_op:
batch_op.alter_column('provider_name',
existing_type=sa.String(length=255),
type_=sa.VARCHAR(length=40),
existing_nullable=False)
with op.batch_alter_table('tenant_default_models', schema=None) as batch_op:
batch_op.alter_column('provider_name',
existing_type=sa.String(length=255),
type_=sa.VARCHAR(length=40),
existing_nullable=False)
with op.batch_alter_table('providers', schema=None) as batch_op:
batch_op.alter_column('provider_name',
existing_type=sa.String(length=255),
type_=sa.VARCHAR(length=40),
existing_nullable=False)
with op.batch_alter_table('provider_orders', schema=None) as batch_op:
batch_op.alter_column('provider_name',
existing_type=sa.String(length=255),
type_=sa.VARCHAR(length=40),
existing_nullable=False)
with op.batch_alter_table('provider_models', schema=None) as batch_op:
batch_op.alter_column('provider_name',
existing_type=sa.String(length=255),
type_=sa.VARCHAR(length=40),
existing_nullable=False)
with op.batch_alter_table('provider_model_settings', schema=None) as batch_op:
batch_op.drop_index('provider_model_setting_tenant_provider_model_idx')
op.drop_table('provider_model_settings')
with op.batch_alter_table('load_balancing_model_configs', schema=None) as batch_op:
batch_op.drop_index('load_balancing_model_config_tenant_provider_model_idx')
op.drop_table('load_balancing_model_configs')
# ### end Alembic commands ###

View File

@ -47,7 +47,7 @@ class Provider(db.Model):
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(StringUUID, nullable=False) tenant_id = db.Column(StringUUID, nullable=False)
provider_name = db.Column(db.String(40), nullable=False) provider_name = db.Column(db.String(255), nullable=False)
provider_type = db.Column(db.String(40), nullable=False, server_default=db.text("'custom'::character varying")) provider_type = db.Column(db.String(40), nullable=False, server_default=db.text("'custom'::character varying"))
encrypted_config = db.Column(db.Text, nullable=True) encrypted_config = db.Column(db.Text, nullable=True)
is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
@ -94,7 +94,7 @@ class ProviderModel(db.Model):
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(StringUUID, nullable=False) tenant_id = db.Column(StringUUID, nullable=False)
provider_name = db.Column(db.String(40), nullable=False) provider_name = db.Column(db.String(255), nullable=False)
model_name = db.Column(db.String(255), nullable=False) model_name = db.Column(db.String(255), nullable=False)
model_type = db.Column(db.String(40), nullable=False) model_type = db.Column(db.String(40), nullable=False)
encrypted_config = db.Column(db.Text, nullable=True) encrypted_config = db.Column(db.Text, nullable=True)
@ -112,7 +112,7 @@ class TenantDefaultModel(db.Model):
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(StringUUID, nullable=False) tenant_id = db.Column(StringUUID, nullable=False)
provider_name = db.Column(db.String(40), nullable=False) provider_name = db.Column(db.String(255), nullable=False)
model_name = db.Column(db.String(255), nullable=False) model_name = db.Column(db.String(255), nullable=False)
model_type = db.Column(db.String(40), nullable=False) model_type = db.Column(db.String(40), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
@ -128,7 +128,7 @@ class TenantPreferredModelProvider(db.Model):
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(StringUUID, nullable=False) tenant_id = db.Column(StringUUID, nullable=False)
provider_name = db.Column(db.String(40), nullable=False) provider_name = db.Column(db.String(255), nullable=False)
preferred_provider_type = db.Column(db.String(40), nullable=False) preferred_provider_type = db.Column(db.String(40), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
@ -143,7 +143,7 @@ class ProviderOrder(db.Model):
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(StringUUID, nullable=False) tenant_id = db.Column(StringUUID, nullable=False)
provider_name = db.Column(db.String(40), nullable=False) provider_name = db.Column(db.String(255), nullable=False)
account_id = db.Column(StringUUID, nullable=False) account_id = db.Column(StringUUID, nullable=False)
payment_product_id = db.Column(db.String(191), nullable=False) payment_product_id = db.Column(db.String(191), nullable=False)
payment_id = db.Column(db.String(191)) payment_id = db.Column(db.String(191))
@ -157,3 +157,46 @@ class ProviderOrder(db.Model):
refunded_at = db.Column(db.DateTime) refunded_at = db.Column(db.DateTime)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
class ProviderModelSetting(db.Model):
"""
Provider model settings for record the model enabled status and load balancing status.
"""
__tablename__ = 'provider_model_settings'
__table_args__ = (
db.PrimaryKeyConstraint('id', name='provider_model_setting_pkey'),
db.Index('provider_model_setting_tenant_provider_model_idx', 'tenant_id', 'provider_name', 'model_type'),
)
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(StringUUID, nullable=False)
provider_name = db.Column(db.String(255), nullable=False)
model_name = db.Column(db.String(255), nullable=False)
model_type = db.Column(db.String(40), nullable=False)
enabled = db.Column(db.Boolean, nullable=False, server_default=db.text('true'))
load_balancing_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
class LoadBalancingModelConfig(db.Model):
"""
Configurations for load balancing models.
"""
__tablename__ = 'load_balancing_model_configs'
__table_args__ = (
db.PrimaryKeyConstraint('id', name='load_balancing_model_config_pkey'),
db.Index('load_balancing_model_config_tenant_provider_model_idx', 'tenant_id', 'provider_name', 'model_type'),
)
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(StringUUID, nullable=False)
provider_name = db.Column(db.String(255), nullable=False)
model_name = db.Column(db.String(255), nullable=False)
model_type = db.Column(db.String(40), nullable=False)
name = db.Column(db.String(255), nullable=False)
encrypted_config = db.Column(db.Text, nullable=True)
enabled = db.Column(db.Boolean, nullable=False, server_default=db.text('true'))
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))

View File

@ -4,7 +4,7 @@ import logging
import random import random
import time import time
import uuid import uuid
from typing import Optional, cast from typing import Optional
from flask import current_app from flask import current_app
from flask_login import current_user from flask_login import current_user
@ -13,7 +13,6 @@ from sqlalchemy import func
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_manager import ModelManager from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.models.document import Document as RAGDocument from core.rag.models.document import Document as RAGDocument
from events.dataset_event import dataset_was_deleted from events.dataset_event import dataset_was_deleted
@ -1144,10 +1143,7 @@ class SegmentService:
model=dataset.embedding_model model=dataset.embedding_model
) )
# calc embedding use tokens # calc embedding use tokens
model_type_instance = cast(TextEmbeddingModel, embedding_model.model_type_instance) tokens = embedding_model.get_text_embedding_num_tokens(
tokens = model_type_instance.get_num_tokens(
model=embedding_model.model,
credentials=embedding_model.credentials,
texts=[content] texts=[content]
) )
lock_name = 'add_segment_lock_document_id_{}'.format(document.id) lock_name = 'add_segment_lock_document_id_{}'.format(document.id)
@ -1215,10 +1211,7 @@ class SegmentService:
tokens = 0 tokens = 0
if dataset.indexing_technique == 'high_quality' and embedding_model: if dataset.indexing_technique == 'high_quality' and embedding_model:
# calc embedding use tokens # calc embedding use tokens
model_type_instance = cast(TextEmbeddingModel, embedding_model.model_type_instance) tokens = embedding_model.get_text_embedding_num_tokens(
tokens = model_type_instance.get_num_tokens(
model=embedding_model.model,
credentials=embedding_model.credentials,
texts=[content] texts=[content]
) )
segment_document = DocumentSegment( segment_document = DocumentSegment(
@ -1321,10 +1314,7 @@ class SegmentService:
) )
# calc embedding use tokens # calc embedding use tokens
model_type_instance = cast(TextEmbeddingModel, embedding_model.model_type_instance) tokens = embedding_model.get_text_embedding_num_tokens(
tokens = model_type_instance.get_num_tokens(
model=embedding_model.model,
credentials=embedding_model.credentials,
texts=[content] texts=[content]
) )
segment.content = content segment.content = content

View File

@ -4,10 +4,10 @@ from typing import Optional
from flask import current_app from flask import current_app
from pydantic import BaseModel from pydantic import BaseModel
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity from core.entities.model_entities import ModelWithProviderEntity, ProviderModelWithStatusEntity
from core.entities.provider_entities import QuotaConfiguration from core.entities.provider_entities import QuotaConfiguration
from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import ModelType, ProviderModel from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.provider_entities import ( from core.model_runtime.entities.provider_entities import (
ConfigurateMethod, ConfigurateMethod,
ModelCredentialSchema, ModelCredentialSchema,
@ -79,13 +79,6 @@ class ProviderResponse(BaseModel):
) )
class ModelResponse(ProviderModel):
"""
Model class for model response.
"""
status: ModelStatus
class ProviderWithModelsResponse(BaseModel): class ProviderWithModelsResponse(BaseModel):
""" """
Model class for provider with models response. Model class for provider with models response.
@ -95,7 +88,7 @@ class ProviderWithModelsResponse(BaseModel):
icon_small: Optional[I18nObject] = None icon_small: Optional[I18nObject] = None
icon_large: Optional[I18nObject] = None icon_large: Optional[I18nObject] = None
status: CustomConfigurationStatus status: CustomConfigurationStatus
models: list[ModelResponse] models: list[ProviderModelWithStatusEntity]
def __init__(self, **data) -> None: def __init__(self, **data) -> None:
super().__init__(**data) super().__init__(**data)

View File

@ -29,6 +29,7 @@ class FeatureModel(BaseModel):
documents_upload_quota: LimitationModel = LimitationModel(size=0, limit=50) documents_upload_quota: LimitationModel = LimitationModel(size=0, limit=50)
docs_processing: str = 'standard' docs_processing: str = 'standard'
can_replace_logo: bool = False can_replace_logo: bool = False
model_load_balancing_enabled: bool = False
class SystemFeatureModel(BaseModel): class SystemFeatureModel(BaseModel):
@ -63,6 +64,7 @@ class FeatureService:
@classmethod @classmethod
def _fulfill_params_from_env(cls, features: FeatureModel): def _fulfill_params_from_env(cls, features: FeatureModel):
features.can_replace_logo = current_app.config['CAN_REPLACE_LOGO'] features.can_replace_logo = current_app.config['CAN_REPLACE_LOGO']
features.model_load_balancing_enabled = current_app.config['MODEL_LB_ENABLED']
@classmethod @classmethod
def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str): def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str):
@ -72,24 +74,35 @@ class FeatureService:
features.billing.subscription.plan = billing_info['subscription']['plan'] features.billing.subscription.plan = billing_info['subscription']['plan']
features.billing.subscription.interval = billing_info['subscription']['interval'] features.billing.subscription.interval = billing_info['subscription']['interval']
if 'members' in billing_info:
features.members.size = billing_info['members']['size'] features.members.size = billing_info['members']['size']
features.members.limit = billing_info['members']['limit'] features.members.limit = billing_info['members']['limit']
if 'apps' in billing_info:
features.apps.size = billing_info['apps']['size'] features.apps.size = billing_info['apps']['size']
features.apps.limit = billing_info['apps']['limit'] features.apps.limit = billing_info['apps']['limit']
if 'vector_space' in billing_info:
features.vector_space.size = billing_info['vector_space']['size'] features.vector_space.size = billing_info['vector_space']['size']
features.vector_space.limit = billing_info['vector_space']['limit'] features.vector_space.limit = billing_info['vector_space']['limit']
if 'documents_upload_quota' in billing_info:
features.documents_upload_quota.size = billing_info['documents_upload_quota']['size'] features.documents_upload_quota.size = billing_info['documents_upload_quota']['size']
features.documents_upload_quota.limit = billing_info['documents_upload_quota']['limit'] features.documents_upload_quota.limit = billing_info['documents_upload_quota']['limit']
if 'annotation_quota_limit' in billing_info:
features.annotation_quota_limit.size = billing_info['annotation_quota_limit']['size'] features.annotation_quota_limit.size = billing_info['annotation_quota_limit']['size']
features.annotation_quota_limit.limit = billing_info['annotation_quota_limit']['limit'] features.annotation_quota_limit.limit = billing_info['annotation_quota_limit']['limit']
if 'docs_processing' in billing_info:
features.docs_processing = billing_info['docs_processing'] features.docs_processing = billing_info['docs_processing']
if 'can_replace_logo' in billing_info:
features.can_replace_logo = billing_info['can_replace_logo'] features.can_replace_logo = billing_info['can_replace_logo']
if 'model_load_balancing_enabled' in billing_info:
features.model_load_balancing_enabled = billing_info['model_load_balancing_enabled']
@classmethod @classmethod
def _fulfill_params_from_enterprise(cls, features): def _fulfill_params_from_enterprise(cls, features):
enterprise_info = EnterpriseService.get_info() enterprise_info = EnterpriseService.get_info()

View File

@ -0,0 +1,565 @@
import datetime
import json
import logging
from json import JSONDecodeError
from typing import Optional
from core.entities.provider_configuration import ProviderConfiguration
from core.helper import encrypter
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
from core.model_manager import LBModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.provider_entities import (
ModelCredentialSchema,
ProviderCredentialSchema,
)
from core.model_runtime.model_providers import model_provider_factory
from core.provider_manager import ProviderManager
from extensions.ext_database import db
from models.provider import LoadBalancingModelConfig
logger = logging.getLogger(__name__)
class ModelLoadBalancingService:
def __init__(self) -> None:
self.provider_manager = ProviderManager()
def enable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
"""
enable model load balancing.
:param tenant_id: workspace id
:param provider: provider name
:param model: model name
:param model_type: model type
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
# Enable model load balancing
provider_configuration.enable_model_load_balancing(
model=model,
model_type=ModelType.value_of(model_type)
)
def disable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
"""
disable model load balancing.
:param tenant_id: workspace id
:param provider: provider name
:param model: model name
:param model_type: model type
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
# disable model load balancing
provider_configuration.disable_model_load_balancing(
model=model,
model_type=ModelType.value_of(model_type)
)
def get_load_balancing_configs(self, tenant_id: str, provider: str, model: str, model_type: str) \
-> tuple[bool, list[dict]]:
"""
Get load balancing configurations.
:param tenant_id: workspace id
:param provider: provider name
:param model: model name
:param model_type: model type
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
# Convert model type to ModelType
model_type = ModelType.value_of(model_type)
# Get provider model setting
provider_model_setting = provider_configuration.get_provider_model_setting(
model_type=model_type,
model=model,
)
is_load_balancing_enabled = False
if provider_model_setting and provider_model_setting.load_balancing_enabled:
is_load_balancing_enabled = True
# Get load balancing configurations
load_balancing_configs = db.session.query(LoadBalancingModelConfig) \
.filter(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model
).order_by(LoadBalancingModelConfig.created_at).all()
if provider_configuration.custom_configuration.provider:
# check if the inherit configuration exists,
# inherit is represented for the provider or model custom credentials
inherit_config_exists = False
for load_balancing_config in load_balancing_configs:
if load_balancing_config.name == '__inherit__':
inherit_config_exists = True
break
if not inherit_config_exists:
# Initialize the inherit configuration
inherit_config = self._init_inherit_config(tenant_id, provider, model, model_type)
# prepend the inherit configuration
load_balancing_configs.insert(0, inherit_config)
else:
# move the inherit configuration to the first
for i, load_balancing_config in enumerate(load_balancing_configs):
if load_balancing_config.name == '__inherit__':
inherit_config = load_balancing_configs.pop(i)
load_balancing_configs.insert(0, inherit_config)
# Get credential form schemas from model credential schema or provider credential schema
credential_schemas = self._get_credential_schema(provider_configuration)
# Get decoding rsa key and cipher for decrypting credentials
decoding_rsa_key, decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
# fetch status and ttl for each config
datas = []
for load_balancing_config in load_balancing_configs:
in_cooldown, ttl = LBModelManager.get_config_in_cooldown_and_ttl(
tenant_id=tenant_id,
provider=provider,
model=model,
model_type=model_type,
config_id=load_balancing_config.id
)
try:
if load_balancing_config.encrypted_config:
credentials = json.loads(load_balancing_config.encrypted_config)
else:
credentials = {}
except JSONDecodeError:
credentials = {}
# Get provider credential secret variables
credential_secret_variables = provider_configuration.extract_secret_variables(
credential_schemas.credential_form_schemas
)
# decrypt credentials
for variable in credential_secret_variables:
if variable in credentials:
try:
credentials[variable] = encrypter.decrypt_token_with_decoding(
credentials.get(variable),
decoding_rsa_key,
decoding_cipher_rsa
)
except ValueError:
pass
# Obfuscate credentials
credentials = provider_configuration.obfuscated_credentials(
credentials=credentials,
credential_form_schemas=credential_schemas.credential_form_schemas
)
datas.append({
'id': load_balancing_config.id,
'name': load_balancing_config.name,
'credentials': credentials,
'enabled': load_balancing_config.enabled,
'in_cooldown': in_cooldown,
'ttl': ttl
})
return is_load_balancing_enabled, datas
def get_load_balancing_config(self, tenant_id: str, provider: str, model: str, model_type: str, config_id: str) \
-> Optional[dict]:
"""
Get load balancing configuration.
:param tenant_id: workspace id
:param provider: provider name
:param model: model name
:param model_type: model type
:param config_id: load balancing config id
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
# Convert model type to ModelType
model_type = ModelType.value_of(model_type)
# Get load balancing configurations
load_balancing_model_config = db.session.query(LoadBalancingModelConfig) \
.filter(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model,
LoadBalancingModelConfig.id == config_id
).first()
if not load_balancing_model_config:
return None
try:
if load_balancing_model_config.encrypted_config:
credentials = json.loads(load_balancing_model_config.encrypted_config)
else:
credentials = {}
except JSONDecodeError:
credentials = {}
# Get credential form schemas from model credential schema or provider credential schema
credential_schemas = self._get_credential_schema(provider_configuration)
# Obfuscate credentials
credentials = provider_configuration.obfuscated_credentials(
credentials=credentials,
credential_form_schemas=credential_schemas.credential_form_schemas
)
return {
'id': load_balancing_model_config.id,
'name': load_balancing_model_config.name,
'credentials': credentials,
'enabled': load_balancing_model_config.enabled
}
def _init_inherit_config(self, tenant_id: str, provider: str, model: str, model_type: ModelType) \
-> LoadBalancingModelConfig:
"""
Initialize the inherit configuration.
:param tenant_id: workspace id
:param provider: provider name
:param model: model name
:param model_type: model type
:return:
"""
# Initialize the inherit configuration
inherit_config = LoadBalancingModelConfig(
tenant_id=tenant_id,
provider_name=provider,
model_type=model_type.to_origin_model_type(),
model_name=model,
name='__inherit__'
)
db.session.add(inherit_config)
db.session.commit()
return inherit_config
def update_load_balancing_configs(self, tenant_id: str,
provider: str,
model: str,
model_type: str,
configs: list[dict]) -> None:
"""
Update load balancing configurations.
:param tenant_id: workspace id
:param provider: provider name
:param model: model name
:param model_type: model type
:param configs: load balancing configs
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
# Convert model type to ModelType
model_type = ModelType.value_of(model_type)
if not isinstance(configs, list):
raise ValueError('Invalid load balancing configs')
current_load_balancing_configs = db.session.query(LoadBalancingModelConfig) \
.filter(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model
).all()
# id as key, config as value
current_load_balancing_configs_dict = {config.id: config for config in current_load_balancing_configs}
updated_config_ids = set()
for config in configs:
if not isinstance(config, dict):
raise ValueError('Invalid load balancing config')
config_id = config.get('id')
name = config.get('name')
credentials = config.get('credentials')
enabled = config.get('enabled')
if not name:
raise ValueError('Invalid load balancing config name')
if enabled is None:
raise ValueError('Invalid load balancing config enabled')
# is config exists
if config_id:
config_id = str(config_id)
if config_id not in current_load_balancing_configs_dict:
raise ValueError('Invalid load balancing config id: {}'.format(config_id))
updated_config_ids.add(config_id)
load_balancing_config = current_load_balancing_configs_dict[config_id]
# check duplicate name
for current_load_balancing_config in current_load_balancing_configs:
if current_load_balancing_config.id != config_id and current_load_balancing_config.name == name:
raise ValueError('Load balancing config name {} already exists'.format(name))
if credentials:
if not isinstance(credentials, dict):
raise ValueError('Invalid load balancing config credentials')
# validate custom provider config
credentials = self._custom_credentials_validate(
tenant_id=tenant_id,
provider_configuration=provider_configuration,
model_type=model_type,
model=model,
credentials=credentials,
load_balancing_model_config=load_balancing_config,
validate=False
)
# update load balancing config
load_balancing_config.encrypted_config = json.dumps(credentials)
load_balancing_config.name = name
load_balancing_config.enabled = enabled
load_balancing_config.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.commit()
self._clear_credentials_cache(tenant_id, config_id)
else:
# create load balancing config
if name == '__inherit__':
raise ValueError('Invalid load balancing config name')
# check duplicate name
for current_load_balancing_config in current_load_balancing_configs:
if current_load_balancing_config.name == name:
raise ValueError('Load balancing config name {} already exists'.format(name))
if not credentials:
raise ValueError('Invalid load balancing config credentials')
if not isinstance(credentials, dict):
raise ValueError('Invalid load balancing config credentials')
# validate custom provider config
credentials = self._custom_credentials_validate(
tenant_id=tenant_id,
provider_configuration=provider_configuration,
model_type=model_type,
model=model,
credentials=credentials,
validate=False
)
# create load balancing config
load_balancing_model_config = LoadBalancingModelConfig(
tenant_id=tenant_id,
provider_name=provider_configuration.provider.provider,
model_type=model_type.to_origin_model_type(),
model_name=model,
name=name,
encrypted_config=json.dumps(credentials)
)
db.session.add(load_balancing_model_config)
db.session.commit()
# get deleted config ids
deleted_config_ids = set(current_load_balancing_configs_dict.keys()) - updated_config_ids
for config_id in deleted_config_ids:
db.session.delete(current_load_balancing_configs_dict[config_id])
db.session.commit()
self._clear_credentials_cache(tenant_id, config_id)
def validate_load_balancing_credentials(self, tenant_id: str,
provider: str,
model: str,
model_type: str,
credentials: dict,
config_id: Optional[str] = None) -> None:
"""
Validate load balancing credentials.
:param tenant_id: workspace id
:param provider: provider name
:param model_type: model type
:param model: model name
:param credentials: credentials
:param config_id: load balancing config id
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
# Convert model type to ModelType
model_type = ModelType.value_of(model_type)
load_balancing_model_config = None
if config_id:
# Get load balancing config
load_balancing_model_config = db.session.query(LoadBalancingModelConfig) \
.filter(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider,
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model,
LoadBalancingModelConfig.id == config_id
).first()
if not load_balancing_model_config:
raise ValueError(f"Load balancing config {config_id} does not exist.")
# Validate custom provider config
self._custom_credentials_validate(
tenant_id=tenant_id,
provider_configuration=provider_configuration,
model_type=model_type,
model=model,
credentials=credentials,
load_balancing_model_config=load_balancing_model_config
)
def _custom_credentials_validate(self, tenant_id: str,
provider_configuration: ProviderConfiguration,
model_type: ModelType,
model: str,
credentials: dict,
load_balancing_model_config: Optional[LoadBalancingModelConfig] = None,
validate: bool = True) -> dict:
"""
Validate custom credentials.
:param tenant_id: workspace id
:param provider_configuration: provider configuration
:param model_type: model type
:param model: model name
:param credentials: credentials
:param load_balancing_model_config: load balancing model config
:param validate: validate credentials
:return:
"""
# Get credential form schemas from model credential schema or provider credential schema
credential_schemas = self._get_credential_schema(provider_configuration)
# Get provider credential secret variables
provider_credential_secret_variables = provider_configuration.extract_secret_variables(
credential_schemas.credential_form_schemas
)
if load_balancing_model_config:
try:
# fix origin data
if load_balancing_model_config.encrypted_config:
original_credentials = json.loads(load_balancing_model_config.encrypted_config)
else:
original_credentials = {}
except JSONDecodeError:
original_credentials = {}
# encrypt credentials
for key, value in credentials.items():
if key in provider_credential_secret_variables:
# if send [__HIDDEN__] in secret input, it will be same as original value
if value == '[__HIDDEN__]' and key in original_credentials:
credentials[key] = encrypter.decrypt_token(tenant_id, original_credentials[key])
if validate:
if isinstance(credential_schemas, ModelCredentialSchema):
credentials = model_provider_factory.model_credentials_validate(
provider=provider_configuration.provider.provider,
model_type=model_type,
model=model,
credentials=credentials
)
else:
credentials = model_provider_factory.provider_credentials_validate(
provider=provider_configuration.provider.provider,
credentials=credentials
)
for key, value in credentials.items():
if key in provider_credential_secret_variables:
credentials[key] = encrypter.encrypt_token(tenant_id, value)
return credentials
def _get_credential_schema(self, provider_configuration: ProviderConfiguration) \
-> ModelCredentialSchema | ProviderCredentialSchema:
"""
Get form schemas.
:param provider_configuration: provider configuration
:return:
"""
# Get credential form schemas from model credential schema or provider credential schema
if provider_configuration.provider.model_credential_schema:
credential_schema = provider_configuration.provider.model_credential_schema
else:
credential_schema = provider_configuration.provider.provider_credential_schema
return credential_schema
def _clear_credentials_cache(self, tenant_id: str, config_id: str) -> None:
"""
Clear credentials cache.
:param tenant_id: workspace id
:param config_id: load balancing config id
:return:
"""
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=tenant_id,
identity_id=config_id,
cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL
)
provider_model_credentials_cache.delete()

View File

@ -6,7 +6,7 @@ from typing import Optional, cast
import requests import requests
from flask import current_app from flask import current_app
from core.entities.model_entities import ModelStatus from core.entities.model_entities import ModelStatus, ProviderModelWithStatusEntity
from core.model_runtime.entities.model_entities import ModelType, ParameterRule from core.model_runtime.entities.model_entities import ModelType, ParameterRule
from core.model_runtime.model_providers import model_provider_factory from core.model_runtime.model_providers import model_provider_factory
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
@ -16,7 +16,6 @@ from services.entities.model_provider_entities import (
CustomConfigurationResponse, CustomConfigurationResponse,
CustomConfigurationStatus, CustomConfigurationStatus,
DefaultModelResponse, DefaultModelResponse,
ModelResponse,
ModelWithProviderEntityResponse, ModelWithProviderEntityResponse,
ProviderResponse, ProviderResponse,
ProviderWithModelsResponse, ProviderWithModelsResponse,
@ -303,6 +302,9 @@ class ModelProviderService:
if model.deprecated: if model.deprecated:
continue continue
if model.status != ModelStatus.ACTIVE:
continue
provider_models[model.provider.provider].append(model) provider_models[model.provider.provider].append(model)
# convert to ProviderWithModelsResponse list # convert to ProviderWithModelsResponse list
@ -313,24 +315,22 @@ class ModelProviderService:
first_model = models[0] first_model = models[0]
has_active_models = any([model.status == ModelStatus.ACTIVE for model in models])
providers_with_models.append( providers_with_models.append(
ProviderWithModelsResponse( ProviderWithModelsResponse(
provider=provider, provider=provider,
label=first_model.provider.label, label=first_model.provider.label,
icon_small=first_model.provider.icon_small, icon_small=first_model.provider.icon_small,
icon_large=first_model.provider.icon_large, icon_large=first_model.provider.icon_large,
status=CustomConfigurationStatus.ACTIVE status=CustomConfigurationStatus.ACTIVE,
if has_active_models else CustomConfigurationStatus.NO_CONFIGURE, models=[ProviderModelWithStatusEntity(
models=[ModelResponse(
model=model.model, model=model.model,
label=model.label, label=model.label,
model_type=model.model_type, model_type=model.model_type,
features=model.features, features=model.features,
fetch_from=model.fetch_from, fetch_from=model.fetch_from,
model_properties=model.model_properties, model_properties=model.model_properties,
status=model.status status=model.status,
load_balancing_enabled=model.load_balancing_enabled
) for model in models] ) for model in models]
) )
) )
@ -486,6 +486,54 @@ class ModelProviderService:
# Switch preferred provider type # Switch preferred provider type
provider_configuration.switch_preferred_provider_type(preferred_provider_type_enum) provider_configuration.switch_preferred_provider_type(preferred_provider_type_enum)
def enable_model(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
"""
enable model.
:param tenant_id: workspace id
:param provider: provider name
:param model: model name
:param model_type: model type
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
# Enable model
provider_configuration.enable_model(
model=model,
model_type=ModelType.value_of(model_type)
)
def disable_model(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
"""
disable model.
:param tenant_id: workspace id
:param provider: provider name
:param model: model name
:param model_type: model type
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
# Enable model
provider_configuration.disable_model(
model=model,
model_type=ModelType.value_of(model_type)
)
def free_quota_submit(self, tenant_id: str, provider: str): def free_quota_submit(self, tenant_id: str, provider: str):
api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY") api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY")
api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL") api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL")

View File

@ -68,7 +68,7 @@ class WorkflowService:
account: Account) -> Workflow: account: Account) -> Workflow:
""" """
Sync draft workflow Sync draft workflow
@throws WorkflowHashNotEqualError :raises WorkflowHashNotEqualError
""" """
# fetch draft workflow by app_model # fetch draft workflow by app_model
workflow = self.get_draft_workflow(app_model=app_model) workflow = self.get_draft_workflow(app_model=app_model)

View File

@ -2,7 +2,6 @@ import datetime
import logging import logging
import time import time
import uuid import uuid
from typing import cast
import click import click
from celery import shared_task from celery import shared_task
@ -11,7 +10,6 @@ from sqlalchemy import func
from core.indexing_runner import IndexingRunner from core.indexing_runner import IndexingRunner
from core.model_manager import ModelManager from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from libs import helper from libs import helper
@ -59,16 +57,12 @@ def batch_create_segment_to_index_task(job_id: str, content: list, dataset_id: s
model=dataset.embedding_model model=dataset.embedding_model
) )
model_type_instance = embedding_model.model_type_instance
model_type_instance = cast(TextEmbeddingModel, model_type_instance)
for segment in content: for segment in content:
content = segment['content'] content = segment['content']
doc_id = str(uuid.uuid4()) doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content) segment_hash = helper.generate_text_hash(content)
# calc embedding use tokens # calc embedding use tokens
tokens = model_type_instance.get_num_tokens( tokens = embedding_model.get_text_embedding_num_tokens(
model=embedding_model.model,
credentials=embedding_model.credentials,
texts=[content] texts=[content]
) if embedding_model else 0 ) if embedding_model else 0
max_position = db.session.query(func.max(DocumentSegment.position)).filter( max_position = db.session.query(func.max(DocumentSegment.position)).filter(

View File

@ -1,6 +1,6 @@
import os import os
from core.utils.module_import_helper import import_module_from_source, load_single_subclass_from_source from core.helper.module_import_helper import import_module_from_source, load_single_subclass_from_source
from tests.integration_tests.utils.parent_class import ParentClass from tests.integration_tests.utils.parent_class import ParentClass

View File

@ -92,7 +92,8 @@ def test_execute_llm(setup_openai_mock):
provider=CustomProviderConfiguration( provider=CustomProviderConfiguration(
credentials=credentials credentials=credentials
) )
) ),
model_settings=[]
), ),
provider_instance=provider_instance, provider_instance=provider_instance,
model_type_instance=model_type_instance model_type_instance=model_type_instance
@ -206,10 +207,11 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock):
provider=CustomProviderConfiguration( provider=CustomProviderConfiguration(
credentials=credentials credentials=credentials
) )
) ),
model_settings=[]
), ),
provider_instance=provider_instance, provider_instance=provider_instance,
model_type_instance=model_type_instance model_type_instance=model_type_instance,
) )
model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model='gpt-3.5-turbo') model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model='gpt-3.5-turbo')

View File

@ -42,7 +42,8 @@ def get_mocked_fetch_model_config(
provider=CustomProviderConfiguration( provider=CustomProviderConfiguration(
credentials=credentials credentials=credentials
) )
) ),
model_settings=[]
), ),
provider_instance=provider_instance, provider_instance=provider_instance,
model_type_instance=model_type_instance model_type_instance=model_type_instance

View File

@ -1,9 +1,10 @@
from unittest.mock import MagicMock from unittest.mock import MagicMock
from core.app.app_config.entities import ModelConfigEntity from core.app.app_config.entities import ModelConfigEntity
from core.entities.provider_configuration import ProviderModelBundle from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
from core.model_runtime.entities.message_entities import UserPromptMessage from core.model_runtime.entities.message_entities import UserPromptMessage
from core.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey, ParameterRule from core.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey, ParameterRule
from core.model_runtime.entities.provider_entities import ProviderEntity
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.prompt_transform import PromptTransform from core.prompt.prompt_transform import PromptTransform
@ -22,8 +23,16 @@ def test__calculate_rest_token():
large_language_model_mock = MagicMock(spec=LargeLanguageModel) large_language_model_mock = MagicMock(spec=LargeLanguageModel)
large_language_model_mock.get_num_tokens.return_value = 6 large_language_model_mock.get_num_tokens.return_value = 6
provider_mock = MagicMock(spec=ProviderEntity)
provider_mock.provider = 'openai'
provider_configuration_mock = MagicMock(spec=ProviderConfiguration)
provider_configuration_mock.provider = provider_mock
provider_configuration_mock.model_settings = None
provider_model_bundle_mock = MagicMock(spec=ProviderModelBundle) provider_model_bundle_mock = MagicMock(spec=ProviderModelBundle)
provider_model_bundle_mock.model_type_instance = large_language_model_mock provider_model_bundle_mock.model_type_instance = large_language_model_mock
provider_model_bundle_mock.configuration = provider_configuration_mock
model_config_mock = MagicMock(spec=ModelConfigEntity) model_config_mock = MagicMock(spec=ModelConfigEntity)
model_config_mock.model = 'gpt-4' model_config_mock.model = 'gpt-4'

View File

@ -0,0 +1,77 @@
from unittest.mock import MagicMock
import pytest
from core.entities.provider_entities import ModelLoadBalancingConfiguration
from core.model_manager import LBModelManager
from core.model_runtime.entities.model_entities import ModelType
@pytest.fixture
def lb_model_manager():
load_balancing_configs = [
ModelLoadBalancingConfiguration(
id='id1',
name='__inherit__',
credentials={}
),
ModelLoadBalancingConfiguration(
id='id2',
name='first',
credentials={"openai_api_key": "fake_key"}
),
ModelLoadBalancingConfiguration(
id='id3',
name='second',
credentials={"openai_api_key": "fake_key"}
)
]
lb_model_manager = LBModelManager(
tenant_id='tenant_id',
provider='openai',
model_type=ModelType.LLM,
model='gpt-4',
load_balancing_configs=load_balancing_configs,
managed_credentials={"openai_api_key": "fake_key"}
)
lb_model_manager.cooldown = MagicMock(return_value=None)
def is_cooldown(config: ModelLoadBalancingConfiguration):
if config.id == 'id1':
return True
return False
lb_model_manager.in_cooldown = MagicMock(side_effect=is_cooldown)
return lb_model_manager
def test_lb_model_manager_fetch_next(mocker, lb_model_manager):
assert len(lb_model_manager._load_balancing_configs) == 3
config1 = lb_model_manager._load_balancing_configs[0]
config2 = lb_model_manager._load_balancing_configs[1]
config3 = lb_model_manager._load_balancing_configs[2]
assert lb_model_manager.in_cooldown(config1) is True
assert lb_model_manager.in_cooldown(config2) is False
assert lb_model_manager.in_cooldown(config3) is False
start_index = 0
def incr(key):
nonlocal start_index
start_index += 1
return start_index
mocker.patch('redis.Redis.incr', side_effect=incr)
mocker.patch('redis.Redis.set', return_value=None)
mocker.patch('redis.Redis.expire', return_value=None)
config = lb_model_manager.fetch_next()
assert config == config2
config = lb_model_manager.fetch_next()
assert config == config3

View File

@ -0,0 +1,183 @@
from core.entities.provider_entities import ModelSettings
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers import model_provider_factory
from core.provider_manager import ProviderManager
from models.provider import LoadBalancingModelConfig, ProviderModelSetting
def test__to_model_settings(mocker):
# Get all provider entities
provider_entities = model_provider_factory.get_providers()
provider_entity = None
for provider in provider_entities:
if provider.provider == 'openai':
provider_entity = provider
# Mocking the inputs
provider_model_settings = [ProviderModelSetting(
id='id',
tenant_id='tenant_id',
provider_name='openai',
model_name='gpt-4',
model_type='text-generation',
enabled=True,
load_balancing_enabled=True
)]
load_balancing_model_configs = [
LoadBalancingModelConfig(
id='id1',
tenant_id='tenant_id',
provider_name='openai',
model_name='gpt-4',
model_type='text-generation',
name='__inherit__',
encrypted_config=None,
enabled=True
),
LoadBalancingModelConfig(
id='id2',
tenant_id='tenant_id',
provider_name='openai',
model_name='gpt-4',
model_type='text-generation',
name='first',
encrypted_config='{"openai_api_key": "fake_key"}',
enabled=True
)
]
mocker.patch('core.helper.model_provider_cache.ProviderCredentialsCache.get', return_value={"openai_api_key": "fake_key"})
provider_manager = ProviderManager()
# Running the method
result = provider_manager._to_model_settings(
provider_entity,
provider_model_settings,
load_balancing_model_configs
)
# Asserting that the result is as expected
assert len(result) == 1
assert isinstance(result[0], ModelSettings)
assert result[0].model == 'gpt-4'
assert result[0].model_type == ModelType.LLM
assert result[0].enabled is True
assert len(result[0].load_balancing_configs) == 2
assert result[0].load_balancing_configs[0].name == '__inherit__'
assert result[0].load_balancing_configs[1].name == 'first'
def test__to_model_settings_only_one_lb(mocker):
# Get all provider entities
provider_entities = model_provider_factory.get_providers()
provider_entity = None
for provider in provider_entities:
if provider.provider == 'openai':
provider_entity = provider
# Mocking the inputs
provider_model_settings = [ProviderModelSetting(
id='id',
tenant_id='tenant_id',
provider_name='openai',
model_name='gpt-4',
model_type='text-generation',
enabled=True,
load_balancing_enabled=True
)]
load_balancing_model_configs = [
LoadBalancingModelConfig(
id='id1',
tenant_id='tenant_id',
provider_name='openai',
model_name='gpt-4',
model_type='text-generation',
name='__inherit__',
encrypted_config=None,
enabled=True
)
]
mocker.patch('core.helper.model_provider_cache.ProviderCredentialsCache.get', return_value={"openai_api_key": "fake_key"})
provider_manager = ProviderManager()
# Running the method
result = provider_manager._to_model_settings(
provider_entity,
provider_model_settings,
load_balancing_model_configs
)
# Asserting that the result is as expected
assert len(result) == 1
assert isinstance(result[0], ModelSettings)
assert result[0].model == 'gpt-4'
assert result[0].model_type == ModelType.LLM
assert result[0].enabled is True
assert len(result[0].load_balancing_configs) == 0
def test__to_model_settings_lb_disabled(mocker):
# Get all provider entities
provider_entities = model_provider_factory.get_providers()
provider_entity = None
for provider in provider_entities:
if provider.provider == 'openai':
provider_entity = provider
# Mocking the inputs
provider_model_settings = [ProviderModelSetting(
id='id',
tenant_id='tenant_id',
provider_name='openai',
model_name='gpt-4',
model_type='text-generation',
enabled=True,
load_balancing_enabled=False
)]
load_balancing_model_configs = [
LoadBalancingModelConfig(
id='id1',
tenant_id='tenant_id',
provider_name='openai',
model_name='gpt-4',
model_type='text-generation',
name='__inherit__',
encrypted_config=None,
enabled=True
),
LoadBalancingModelConfig(
id='id2',
tenant_id='tenant_id',
provider_name='openai',
model_name='gpt-4',
model_type='text-generation',
name='first',
encrypted_config='{"openai_api_key": "fake_key"}',
enabled=True
)
]
mocker.patch('core.helper.model_provider_cache.ProviderCredentialsCache.get', return_value={"openai_api_key": "fake_key"})
provider_manager = ProviderManager()
# Running the method
result = provider_manager._to_model_settings(
provider_entity,
provider_model_settings,
load_balancing_model_configs
)
# Asserting that the result is as expected
assert len(result) == 1
assert isinstance(result[0], ModelSettings)
assert result[0].model == 'gpt-4'
assert result[0].model_type == ModelType.LLM
assert result[0].enabled is True
assert len(result[0].load_balancing_configs) == 0

View File

@ -2,7 +2,7 @@ from textwrap import dedent
import pytest import pytest
from core.utils.position_helper import get_position_map from core.helper.position_helper import get_position_map
@pytest.fixture @pytest.fixture