diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 6ed065d925..e032b0fa4a 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -7,6 +7,7 @@ from json import JSONDecodeError from typing import Optional from pydantic import BaseModel, ConfigDict +from sqlalchemy import or_ from constants import HIDDEN_VALUE from core.entities import DEFAULT_PLUGIN_ID @@ -28,6 +29,7 @@ from core.model_runtime.entities.provider_entities import ( ) from core.model_runtime.model_providers.__base.ai_model import AIModel from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from core.plugin.entities.plugin import ModelProviderID from extensions.ext_database import db from models.provider import ( LoadBalancingModelConfig, @@ -190,8 +192,11 @@ class ProviderConfiguration(BaseModel): db.session.query(Provider) .filter( Provider.tenant_id == self.tenant_id, - Provider.provider_name == self.provider.provider, Provider.provider_type == ProviderType.CUSTOM.value, + or_( + Provider.provider_name == ModelProviderID(self.provider.provider).plugin_name, + Provider.provider_name == self.provider.provider, + ), ) .first() ) @@ -279,7 +284,10 @@ class ProviderConfiguration(BaseModel): db.session.query(Provider) .filter( Provider.tenant_id == self.tenant_id, - Provider.provider_name == self.provider.provider, + or_( + Provider.provider_name == ModelProviderID(self.provider.provider).plugin_name, + Provider.provider_name == self.provider.provider, + ), Provider.provider_type == ProviderType.CUSTOM.value, ) .first()