diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 4b8214019c..b3affc91a6 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -7,7 +7,6 @@ from json import JSONDecodeError from typing import Optional from pydantic import BaseModel, ConfigDict, Field -from sqlalchemy import or_ from constants import HIDDEN_VALUE from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity @@ -180,37 +179,35 @@ class ProviderConfiguration(BaseModel): else [], ) + def _get_custom_provider_credentials(self) -> Provider | None: + """ + Get custom provider credentials. + """ + # get provider + model_provider_id = ModelProviderID(self.provider.provider) + provider_names = [self.provider.provider] + if model_provider_id.is_langgenius(): + provider_names.append(model_provider_id.provider_name) + + provider_record = ( + db.session.query(Provider) + .filter( + Provider.tenant_id == self.tenant_id, + Provider.provider_type == ProviderType.CUSTOM.value, + Provider.provider_name.in_(provider_names), + ) + .first() + ) + + return provider_record + def custom_credentials_validate(self, credentials: dict) -> tuple[Provider | None, dict]: """ Validate custom credentials. :param credentials: provider credentials :return: """ - # get provider - model_provider_id = ModelProviderID(self.provider.provider) - if model_provider_id.is_langgenius(): - provider_record = ( - db.session.query(Provider) - .filter( - Provider.tenant_id == self.tenant_id, - Provider.provider_type == ProviderType.CUSTOM.value, - or_( - Provider.provider_name == model_provider_id.provider_name, - Provider.provider_name == self.provider.provider, - ), - ) - .first() - ) - else: - provider_record = ( - db.session.query(Provider) - .filter( - Provider.tenant_id == self.tenant_id, - Provider.provider_type == ProviderType.CUSTOM.value, - Provider.provider_name == self.provider.provider, - ) - .first() - ) + provider_record = self._get_custom_provider_credentials() # Get provider credential secret variables provider_credential_secret_variables = self.extract_secret_variables( @@ -291,18 +288,7 @@ class ProviderConfiguration(BaseModel): :return: """ # get provider - provider_record = ( - db.session.query(Provider) - .filter( - Provider.tenant_id == self.tenant_id, - or_( - Provider.provider_name == ModelProviderID(self.provider.provider).plugin_name, - Provider.provider_name == self.provider.provider, - ), - Provider.provider_type == ProviderType.CUSTOM.value, - ) - .first() - ) + provider_record = self._get_custom_provider_credentials() # delete provider if provider_record: @@ -349,6 +335,33 @@ class ProviderConfiguration(BaseModel): return None + def _get_custom_model_credentials( + self, + model_type: ModelType, + model: str, + ) -> ProviderModel | None: + """ + Get custom model credentials. + """ + # get provider model + model_provider_id = ModelProviderID(self.provider.provider) + provider_names = [self.provider.provider] + if model_provider_id.is_langgenius(): + provider_names.append(model_provider_id.provider_name) + + provider_model_record = ( + db.session.query(ProviderModel) + .filter( + ProviderModel.tenant_id == self.tenant_id, + ProviderModel.provider_name.in_(provider_names), + ProviderModel.model_name == model, + ProviderModel.model_type == model_type.to_origin_model_type(), + ) + .first() + ) + + return provider_model_record + def custom_model_credentials_validate( self, model_type: ModelType, model: str, credentials: dict ) -> tuple[ProviderModel | None, dict]: @@ -361,16 +374,7 @@ class ProviderConfiguration(BaseModel): :return: """ # get provider model - provider_model_record = ( - db.session.query(ProviderModel) - .filter( - ProviderModel.tenant_id == self.tenant_id, - ProviderModel.provider_name == self.provider.provider, - ProviderModel.model_name == model, - ProviderModel.model_type == model_type.to_origin_model_type(), - ) - .first() - ) + provider_model_record = self._get_custom_model_credentials(model_type, model) # Get provider credential secret variables provider_credential_secret_variables = self.extract_secret_variables( @@ -451,16 +455,7 @@ class ProviderConfiguration(BaseModel): :return: """ # get provider model - provider_model_record = ( - db.session.query(ProviderModel) - .filter( - ProviderModel.tenant_id == self.tenant_id, - ProviderModel.provider_name == self.provider.provider, - ProviderModel.model_name == model, - ProviderModel.model_type == model_type.to_origin_model_type(), - ) - .first() - ) + provider_model_record = self._get_custom_model_credentials(model_type, model) # delete provider model if provider_model_record: @@ -475,6 +470,26 @@ class ProviderConfiguration(BaseModel): provider_model_credentials_cache.delete() + def _get_provider_model_setting(self, model_type: ModelType, model: str) -> ProviderModelSetting | None: + """ + Get provider model setting. + """ + model_provider_id = ModelProviderID(self.provider.provider) + provider_names = [self.provider.provider] + if model_provider_id.is_langgenius(): + provider_names.append(model_provider_id.provider_name) + + return ( + db.session.query(ProviderModelSetting) + .filter( + ProviderModelSetting.tenant_id == self.tenant_id, + ProviderModelSetting.provider_name.in_(provider_names), + ProviderModelSetting.model_type == model_type.to_origin_model_type(), + ProviderModelSetting.model_name == model, + ) + .first() + ) + def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting: """ Enable model. @@ -482,16 +497,7 @@ class ProviderConfiguration(BaseModel): :param model: model name :return: """ - model_setting = ( - db.session.query(ProviderModelSetting) - .filter( - ProviderModelSetting.tenant_id == self.tenant_id, - ProviderModelSetting.provider_name == self.provider.provider, - ProviderModelSetting.model_type == model_type.to_origin_model_type(), - ProviderModelSetting.model_name == model, - ) - .first() - ) + model_setting = self._get_provider_model_setting(model_type, model) if model_setting: model_setting.enabled = True @@ -516,16 +522,7 @@ class ProviderConfiguration(BaseModel): :param model: model name :return: """ - model_setting = ( - db.session.query(ProviderModelSetting) - .filter( - ProviderModelSetting.tenant_id == self.tenant_id, - ProviderModelSetting.provider_name == self.provider.provider, - ProviderModelSetting.model_type == model_type.to_origin_model_type(), - ProviderModelSetting.model_name == model, - ) - .first() - ) + model_setting = self._get_provider_model_setting(model_type, model) if model_setting: model_setting.enabled = False @@ -550,13 +547,24 @@ class ProviderConfiguration(BaseModel): :param model: model name :return: """ + return self._get_provider_model_setting(model_type, model) + + def _get_load_balancing_config(self, model_type: ModelType, model: str) -> Optional[LoadBalancingModelConfig]: + """ + Get load balancing config. + """ + model_provider_id = ModelProviderID(self.provider.provider) + provider_names = [self.provider.provider] + if model_provider_id.is_langgenius(): + provider_names.append(model_provider_id.provider_name) + return ( - db.session.query(ProviderModelSetting) + db.session.query(LoadBalancingModelConfig) .filter( - ProviderModelSetting.tenant_id == self.tenant_id, - ProviderModelSetting.provider_name == self.provider.provider, - ProviderModelSetting.model_type == model_type.to_origin_model_type(), - ProviderModelSetting.model_name == model, + LoadBalancingModelConfig.tenant_id == self.tenant_id, + LoadBalancingModelConfig.provider_name.in_(provider_names), + LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_name == model, ) .first() ) @@ -568,11 +576,16 @@ class ProviderConfiguration(BaseModel): :param model: model name :return: """ + model_provider_id = ModelProviderID(self.provider.provider) + provider_names = [self.provider.provider] + if model_provider_id.is_langgenius(): + provider_names.append(model_provider_id.provider_name) + load_balancing_config_count = ( db.session.query(LoadBalancingModelConfig) .filter( LoadBalancingModelConfig.tenant_id == self.tenant_id, - LoadBalancingModelConfig.provider_name == self.provider.provider, + LoadBalancingModelConfig.provider_name.in_(provider_names), LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), LoadBalancingModelConfig.model_name == model, ) @@ -582,16 +595,7 @@ class ProviderConfiguration(BaseModel): if load_balancing_config_count <= 1: raise ValueError("Model load balancing configuration must be more than 1.") - model_setting = ( - db.session.query(ProviderModelSetting) - .filter( - ProviderModelSetting.tenant_id == self.tenant_id, - ProviderModelSetting.provider_name == self.provider.provider, - ProviderModelSetting.model_type == model_type.to_origin_model_type(), - ProviderModelSetting.model_name == model, - ) - .first() - ) + model_setting = self._get_provider_model_setting(model_type, model) if model_setting: model_setting.load_balancing_enabled = True @@ -616,11 +620,16 @@ class ProviderConfiguration(BaseModel): :param model: model name :return: """ + model_provider_id = ModelProviderID(self.provider.provider) + provider_names = [self.provider.provider] + if model_provider_id.is_langgenius(): + provider_names.append(model_provider_id.provider_name) + model_setting = ( db.session.query(ProviderModelSetting) .filter( ProviderModelSetting.tenant_id == self.tenant_id, - ProviderModelSetting.provider_name == self.provider.provider, + ProviderModelSetting.provider_name.in_(provider_names), ProviderModelSetting.model_type == model_type.to_origin_model_type(), ProviderModelSetting.model_name == model, ) @@ -677,11 +686,16 @@ class ProviderConfiguration(BaseModel): return # get preferred provider + model_provider_id = ModelProviderID(self.provider.provider) + provider_names = [self.provider.provider] + if model_provider_id.is_langgenius(): + provider_names.append(model_provider_id.provider_name) + preferred_model_provider = ( db.session.query(TenantPreferredModelProvider) .filter( TenantPreferredModelProvider.tenant_id == self.tenant_id, - TenantPreferredModelProvider.provider_name == self.provider.provider, + TenantPreferredModelProvider.provider_name.in_(provider_names), ) .first() )