refactor: optimize provider configuration queries with provider name … (#15491)
This commit is contained in:
parent
b730f243dc
commit
a6bc642721
@ -7,7 +7,6 @@ from json import JSONDecodeError
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from sqlalchemy import or_
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
|
||||
@ -180,37 +179,35 @@ class ProviderConfiguration(BaseModel):
|
||||
else [],
|
||||
)
|
||||
|
||||
def _get_custom_provider_credentials(self) -> Provider | None:
|
||||
"""
|
||||
Get custom provider credentials.
|
||||
"""
|
||||
# get provider
|
||||
model_provider_id = ModelProviderID(self.provider.provider)
|
||||
provider_names = [self.provider.provider]
|
||||
if model_provider_id.is_langgenius():
|
||||
provider_names.append(model_provider_id.provider_name)
|
||||
|
||||
provider_record = (
|
||||
db.session.query(Provider)
|
||||
.filter(
|
||||
Provider.tenant_id == self.tenant_id,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value,
|
||||
Provider.provider_name.in_(provider_names),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
return provider_record
|
||||
|
||||
def custom_credentials_validate(self, credentials: dict) -> tuple[Provider | None, dict]:
|
||||
"""
|
||||
Validate custom credentials.
|
||||
:param credentials: provider credentials
|
||||
:return:
|
||||
"""
|
||||
# get provider
|
||||
model_provider_id = ModelProviderID(self.provider.provider)
|
||||
if model_provider_id.is_langgenius():
|
||||
provider_record = (
|
||||
db.session.query(Provider)
|
||||
.filter(
|
||||
Provider.tenant_id == self.tenant_id,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value,
|
||||
or_(
|
||||
Provider.provider_name == model_provider_id.provider_name,
|
||||
Provider.provider_name == self.provider.provider,
|
||||
),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
else:
|
||||
provider_record = (
|
||||
db.session.query(Provider)
|
||||
.filter(
|
||||
Provider.tenant_id == self.tenant_id,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value,
|
||||
Provider.provider_name == self.provider.provider,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
provider_record = self._get_custom_provider_credentials()
|
||||
|
||||
# Get provider credential secret variables
|
||||
provider_credential_secret_variables = self.extract_secret_variables(
|
||||
@ -291,18 +288,7 @@ class ProviderConfiguration(BaseModel):
|
||||
:return:
|
||||
"""
|
||||
# get provider
|
||||
provider_record = (
|
||||
db.session.query(Provider)
|
||||
.filter(
|
||||
Provider.tenant_id == self.tenant_id,
|
||||
or_(
|
||||
Provider.provider_name == ModelProviderID(self.provider.provider).plugin_name,
|
||||
Provider.provider_name == self.provider.provider,
|
||||
),
|
||||
Provider.provider_type == ProviderType.CUSTOM.value,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
provider_record = self._get_custom_provider_credentials()
|
||||
|
||||
# delete provider
|
||||
if provider_record:
|
||||
@ -349,6 +335,33 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
return None
|
||||
|
||||
def _get_custom_model_credentials(
|
||||
self,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
) -> ProviderModel | None:
|
||||
"""
|
||||
Get custom model credentials.
|
||||
"""
|
||||
# get provider model
|
||||
model_provider_id = ModelProviderID(self.provider.provider)
|
||||
provider_names = [self.provider.provider]
|
||||
if model_provider_id.is_langgenius():
|
||||
provider_names.append(model_provider_id.provider_name)
|
||||
|
||||
provider_model_record = (
|
||||
db.session.query(ProviderModel)
|
||||
.filter(
|
||||
ProviderModel.tenant_id == self.tenant_id,
|
||||
ProviderModel.provider_name.in_(provider_names),
|
||||
ProviderModel.model_name == model,
|
||||
ProviderModel.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
return provider_model_record
|
||||
|
||||
def custom_model_credentials_validate(
|
||||
self, model_type: ModelType, model: str, credentials: dict
|
||||
) -> tuple[ProviderModel | None, dict]:
|
||||
@ -361,16 +374,7 @@ class ProviderConfiguration(BaseModel):
|
||||
:return:
|
||||
"""
|
||||
# get provider model
|
||||
provider_model_record = (
|
||||
db.session.query(ProviderModel)
|
||||
.filter(
|
||||
ProviderModel.tenant_id == self.tenant_id,
|
||||
ProviderModel.provider_name == self.provider.provider,
|
||||
ProviderModel.model_name == model,
|
||||
ProviderModel.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
provider_model_record = self._get_custom_model_credentials(model_type, model)
|
||||
|
||||
# Get provider credential secret variables
|
||||
provider_credential_secret_variables = self.extract_secret_variables(
|
||||
@ -451,16 +455,7 @@ class ProviderConfiguration(BaseModel):
|
||||
:return:
|
||||
"""
|
||||
# get provider model
|
||||
provider_model_record = (
|
||||
db.session.query(ProviderModel)
|
||||
.filter(
|
||||
ProviderModel.tenant_id == self.tenant_id,
|
||||
ProviderModel.provider_name == self.provider.provider,
|
||||
ProviderModel.model_name == model,
|
||||
ProviderModel.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
provider_model_record = self._get_custom_model_credentials(model_type, model)
|
||||
|
||||
# delete provider model
|
||||
if provider_model_record:
|
||||
@ -475,6 +470,26 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
provider_model_credentials_cache.delete()
|
||||
|
||||
def _get_provider_model_setting(self, model_type: ModelType, model: str) -> ProviderModelSetting | None:
|
||||
"""
|
||||
Get provider model setting.
|
||||
"""
|
||||
model_provider_id = ModelProviderID(self.provider.provider)
|
||||
provider_names = [self.provider.provider]
|
||||
if model_provider_id.is_langgenius():
|
||||
provider_names.append(model_provider_id.provider_name)
|
||||
|
||||
return (
|
||||
db.session.query(ProviderModelSetting)
|
||||
.filter(
|
||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
||||
ProviderModelSetting.provider_name.in_(provider_names),
|
||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelSetting.model_name == model,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
|
||||
"""
|
||||
Enable model.
|
||||
@ -482,16 +497,7 @@ class ProviderConfiguration(BaseModel):
|
||||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
model_setting = (
|
||||
db.session.query(ProviderModelSetting)
|
||||
.filter(
|
||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
||||
ProviderModelSetting.provider_name == self.provider.provider,
|
||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelSetting.model_name == model,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
model_setting = self._get_provider_model_setting(model_type, model)
|
||||
|
||||
if model_setting:
|
||||
model_setting.enabled = True
|
||||
@ -516,16 +522,7 @@ class ProviderConfiguration(BaseModel):
|
||||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
model_setting = (
|
||||
db.session.query(ProviderModelSetting)
|
||||
.filter(
|
||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
||||
ProviderModelSetting.provider_name == self.provider.provider,
|
||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelSetting.model_name == model,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
model_setting = self._get_provider_model_setting(model_type, model)
|
||||
|
||||
if model_setting:
|
||||
model_setting.enabled = False
|
||||
@ -550,13 +547,24 @@ class ProviderConfiguration(BaseModel):
|
||||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
return self._get_provider_model_setting(model_type, model)
|
||||
|
||||
def _get_load_balancing_config(self, model_type: ModelType, model: str) -> Optional[LoadBalancingModelConfig]:
|
||||
"""
|
||||
Get load balancing config.
|
||||
"""
|
||||
model_provider_id = ModelProviderID(self.provider.provider)
|
||||
provider_names = [self.provider.provider]
|
||||
if model_provider_id.is_langgenius():
|
||||
provider_names.append(model_provider_id.provider_name)
|
||||
|
||||
return (
|
||||
db.session.query(ProviderModelSetting)
|
||||
db.session.query(LoadBalancingModelConfig)
|
||||
.filter(
|
||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
||||
ProviderModelSetting.provider_name == self.provider.provider,
|
||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelSetting.model_name == model,
|
||||
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
||||
LoadBalancingModelConfig.provider_name.in_(provider_names),
|
||||
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_name == model,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
@ -568,11 +576,16 @@ class ProviderConfiguration(BaseModel):
|
||||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
model_provider_id = ModelProviderID(self.provider.provider)
|
||||
provider_names = [self.provider.provider]
|
||||
if model_provider_id.is_langgenius():
|
||||
provider_names.append(model_provider_id.provider_name)
|
||||
|
||||
load_balancing_config_count = (
|
||||
db.session.query(LoadBalancingModelConfig)
|
||||
.filter(
|
||||
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == self.provider.provider,
|
||||
LoadBalancingModelConfig.provider_name.in_(provider_names),
|
||||
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_name == model,
|
||||
)
|
||||
@ -582,16 +595,7 @@ class ProviderConfiguration(BaseModel):
|
||||
if load_balancing_config_count <= 1:
|
||||
raise ValueError("Model load balancing configuration must be more than 1.")
|
||||
|
||||
model_setting = (
|
||||
db.session.query(ProviderModelSetting)
|
||||
.filter(
|
||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
||||
ProviderModelSetting.provider_name == self.provider.provider,
|
||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelSetting.model_name == model,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
model_setting = self._get_provider_model_setting(model_type, model)
|
||||
|
||||
if model_setting:
|
||||
model_setting.load_balancing_enabled = True
|
||||
@ -616,11 +620,16 @@ class ProviderConfiguration(BaseModel):
|
||||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
model_provider_id = ModelProviderID(self.provider.provider)
|
||||
provider_names = [self.provider.provider]
|
||||
if model_provider_id.is_langgenius():
|
||||
provider_names.append(model_provider_id.provider_name)
|
||||
|
||||
model_setting = (
|
||||
db.session.query(ProviderModelSetting)
|
||||
.filter(
|
||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
||||
ProviderModelSetting.provider_name == self.provider.provider,
|
||||
ProviderModelSetting.provider_name.in_(provider_names),
|
||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelSetting.model_name == model,
|
||||
)
|
||||
@ -677,11 +686,16 @@ class ProviderConfiguration(BaseModel):
|
||||
return
|
||||
|
||||
# get preferred provider
|
||||
model_provider_id = ModelProviderID(self.provider.provider)
|
||||
provider_names = [self.provider.provider]
|
||||
if model_provider_id.is_langgenius():
|
||||
provider_names.append(model_provider_id.provider_name)
|
||||
|
||||
preferred_model_provider = (
|
||||
db.session.query(TenantPreferredModelProvider)
|
||||
.filter(
|
||||
TenantPreferredModelProvider.tenant_id == self.tenant_id,
|
||||
TenantPreferredModelProvider.provider_name == self.provider.provider,
|
||||
TenantPreferredModelProvider.provider_name.in_(provider_names),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user