refactor: optimize provider configuration queries with provider name … (#15491)
This commit is contained in:
parent
b730f243dc
commit
a6bc642721
@ -7,7 +7,6 @@ from json import JSONDecodeError
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
from sqlalchemy import or_
|
|
||||||
|
|
||||||
from constants import HIDDEN_VALUE
|
from constants import HIDDEN_VALUE
|
||||||
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
|
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
|
||||||
@ -180,37 +179,35 @@ class ProviderConfiguration(BaseModel):
|
|||||||
else [],
|
else [],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _get_custom_provider_credentials(self) -> Provider | None:
|
||||||
|
"""
|
||||||
|
Get custom provider credentials.
|
||||||
|
"""
|
||||||
|
# get provider
|
||||||
|
model_provider_id = ModelProviderID(self.provider.provider)
|
||||||
|
provider_names = [self.provider.provider]
|
||||||
|
if model_provider_id.is_langgenius():
|
||||||
|
provider_names.append(model_provider_id.provider_name)
|
||||||
|
|
||||||
|
provider_record = (
|
||||||
|
db.session.query(Provider)
|
||||||
|
.filter(
|
||||||
|
Provider.tenant_id == self.tenant_id,
|
||||||
|
Provider.provider_type == ProviderType.CUSTOM.value,
|
||||||
|
Provider.provider_name.in_(provider_names),
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
return provider_record
|
||||||
|
|
||||||
def custom_credentials_validate(self, credentials: dict) -> tuple[Provider | None, dict]:
|
def custom_credentials_validate(self, credentials: dict) -> tuple[Provider | None, dict]:
|
||||||
"""
|
"""
|
||||||
Validate custom credentials.
|
Validate custom credentials.
|
||||||
:param credentials: provider credentials
|
:param credentials: provider credentials
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
# get provider
|
provider_record = self._get_custom_provider_credentials()
|
||||||
model_provider_id = ModelProviderID(self.provider.provider)
|
|
||||||
if model_provider_id.is_langgenius():
|
|
||||||
provider_record = (
|
|
||||||
db.session.query(Provider)
|
|
||||||
.filter(
|
|
||||||
Provider.tenant_id == self.tenant_id,
|
|
||||||
Provider.provider_type == ProviderType.CUSTOM.value,
|
|
||||||
or_(
|
|
||||||
Provider.provider_name == model_provider_id.provider_name,
|
|
||||||
Provider.provider_name == self.provider.provider,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
provider_record = (
|
|
||||||
db.session.query(Provider)
|
|
||||||
.filter(
|
|
||||||
Provider.tenant_id == self.tenant_id,
|
|
||||||
Provider.provider_type == ProviderType.CUSTOM.value,
|
|
||||||
Provider.provider_name == self.provider.provider,
|
|
||||||
)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get provider credential secret variables
|
# Get provider credential secret variables
|
||||||
provider_credential_secret_variables = self.extract_secret_variables(
|
provider_credential_secret_variables = self.extract_secret_variables(
|
||||||
@ -291,18 +288,7 @@ class ProviderConfiguration(BaseModel):
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
# get provider
|
# get provider
|
||||||
provider_record = (
|
provider_record = self._get_custom_provider_credentials()
|
||||||
db.session.query(Provider)
|
|
||||||
.filter(
|
|
||||||
Provider.tenant_id == self.tenant_id,
|
|
||||||
or_(
|
|
||||||
Provider.provider_name == ModelProviderID(self.provider.provider).plugin_name,
|
|
||||||
Provider.provider_name == self.provider.provider,
|
|
||||||
),
|
|
||||||
Provider.provider_type == ProviderType.CUSTOM.value,
|
|
||||||
)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
# delete provider
|
# delete provider
|
||||||
if provider_record:
|
if provider_record:
|
||||||
@ -349,6 +335,33 @@ class ProviderConfiguration(BaseModel):
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def _get_custom_model_credentials(
|
||||||
|
self,
|
||||||
|
model_type: ModelType,
|
||||||
|
model: str,
|
||||||
|
) -> ProviderModel | None:
|
||||||
|
"""
|
||||||
|
Get custom model credentials.
|
||||||
|
"""
|
||||||
|
# get provider model
|
||||||
|
model_provider_id = ModelProviderID(self.provider.provider)
|
||||||
|
provider_names = [self.provider.provider]
|
||||||
|
if model_provider_id.is_langgenius():
|
||||||
|
provider_names.append(model_provider_id.provider_name)
|
||||||
|
|
||||||
|
provider_model_record = (
|
||||||
|
db.session.query(ProviderModel)
|
||||||
|
.filter(
|
||||||
|
ProviderModel.tenant_id == self.tenant_id,
|
||||||
|
ProviderModel.provider_name.in_(provider_names),
|
||||||
|
ProviderModel.model_name == model,
|
||||||
|
ProviderModel.model_type == model_type.to_origin_model_type(),
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
return provider_model_record
|
||||||
|
|
||||||
def custom_model_credentials_validate(
|
def custom_model_credentials_validate(
|
||||||
self, model_type: ModelType, model: str, credentials: dict
|
self, model_type: ModelType, model: str, credentials: dict
|
||||||
) -> tuple[ProviderModel | None, dict]:
|
) -> tuple[ProviderModel | None, dict]:
|
||||||
@ -361,16 +374,7 @@ class ProviderConfiguration(BaseModel):
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
# get provider model
|
# get provider model
|
||||||
provider_model_record = (
|
provider_model_record = self._get_custom_model_credentials(model_type, model)
|
||||||
db.session.query(ProviderModel)
|
|
||||||
.filter(
|
|
||||||
ProviderModel.tenant_id == self.tenant_id,
|
|
||||||
ProviderModel.provider_name == self.provider.provider,
|
|
||||||
ProviderModel.model_name == model,
|
|
||||||
ProviderModel.model_type == model_type.to_origin_model_type(),
|
|
||||||
)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get provider credential secret variables
|
# Get provider credential secret variables
|
||||||
provider_credential_secret_variables = self.extract_secret_variables(
|
provider_credential_secret_variables = self.extract_secret_variables(
|
||||||
@ -451,16 +455,7 @@ class ProviderConfiguration(BaseModel):
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
# get provider model
|
# get provider model
|
||||||
provider_model_record = (
|
provider_model_record = self._get_custom_model_credentials(model_type, model)
|
||||||
db.session.query(ProviderModel)
|
|
||||||
.filter(
|
|
||||||
ProviderModel.tenant_id == self.tenant_id,
|
|
||||||
ProviderModel.provider_name == self.provider.provider,
|
|
||||||
ProviderModel.model_name == model,
|
|
||||||
ProviderModel.model_type == model_type.to_origin_model_type(),
|
|
||||||
)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
# delete provider model
|
# delete provider model
|
||||||
if provider_model_record:
|
if provider_model_record:
|
||||||
@ -475,6 +470,26 @@ class ProviderConfiguration(BaseModel):
|
|||||||
|
|
||||||
provider_model_credentials_cache.delete()
|
provider_model_credentials_cache.delete()
|
||||||
|
|
||||||
|
def _get_provider_model_setting(self, model_type: ModelType, model: str) -> ProviderModelSetting | None:
|
||||||
|
"""
|
||||||
|
Get provider model setting.
|
||||||
|
"""
|
||||||
|
model_provider_id = ModelProviderID(self.provider.provider)
|
||||||
|
provider_names = [self.provider.provider]
|
||||||
|
if model_provider_id.is_langgenius():
|
||||||
|
provider_names.append(model_provider_id.provider_name)
|
||||||
|
|
||||||
|
return (
|
||||||
|
db.session.query(ProviderModelSetting)
|
||||||
|
.filter(
|
||||||
|
ProviderModelSetting.tenant_id == self.tenant_id,
|
||||||
|
ProviderModelSetting.provider_name.in_(provider_names),
|
||||||
|
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
||||||
|
ProviderModelSetting.model_name == model,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
|
def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
|
||||||
"""
|
"""
|
||||||
Enable model.
|
Enable model.
|
||||||
@ -482,16 +497,7 @@ class ProviderConfiguration(BaseModel):
|
|||||||
:param model: model name
|
:param model: model name
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
model_setting = (
|
model_setting = self._get_provider_model_setting(model_type, model)
|
||||||
db.session.query(ProviderModelSetting)
|
|
||||||
.filter(
|
|
||||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
|
||||||
ProviderModelSetting.provider_name == self.provider.provider,
|
|
||||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
|
||||||
ProviderModelSetting.model_name == model,
|
|
||||||
)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
if model_setting:
|
if model_setting:
|
||||||
model_setting.enabled = True
|
model_setting.enabled = True
|
||||||
@ -516,16 +522,7 @@ class ProviderConfiguration(BaseModel):
|
|||||||
:param model: model name
|
:param model: model name
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
model_setting = (
|
model_setting = self._get_provider_model_setting(model_type, model)
|
||||||
db.session.query(ProviderModelSetting)
|
|
||||||
.filter(
|
|
||||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
|
||||||
ProviderModelSetting.provider_name == self.provider.provider,
|
|
||||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
|
||||||
ProviderModelSetting.model_name == model,
|
|
||||||
)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
if model_setting:
|
if model_setting:
|
||||||
model_setting.enabled = False
|
model_setting.enabled = False
|
||||||
@ -550,13 +547,24 @@ class ProviderConfiguration(BaseModel):
|
|||||||
:param model: model name
|
:param model: model name
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
return self._get_provider_model_setting(model_type, model)
|
||||||
|
|
||||||
|
def _get_load_balancing_config(self, model_type: ModelType, model: str) -> Optional[LoadBalancingModelConfig]:
|
||||||
|
"""
|
||||||
|
Get load balancing config.
|
||||||
|
"""
|
||||||
|
model_provider_id = ModelProviderID(self.provider.provider)
|
||||||
|
provider_names = [self.provider.provider]
|
||||||
|
if model_provider_id.is_langgenius():
|
||||||
|
provider_names.append(model_provider_id.provider_name)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
db.session.query(ProviderModelSetting)
|
db.session.query(LoadBalancingModelConfig)
|
||||||
.filter(
|
.filter(
|
||||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
||||||
ProviderModelSetting.provider_name == self.provider.provider,
|
LoadBalancingModelConfig.provider_name.in_(provider_names),
|
||||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
|
||||||
ProviderModelSetting.model_name == model,
|
LoadBalancingModelConfig.model_name == model,
|
||||||
)
|
)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
@ -568,11 +576,16 @@ class ProviderConfiguration(BaseModel):
|
|||||||
:param model: model name
|
:param model: model name
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
model_provider_id = ModelProviderID(self.provider.provider)
|
||||||
|
provider_names = [self.provider.provider]
|
||||||
|
if model_provider_id.is_langgenius():
|
||||||
|
provider_names.append(model_provider_id.provider_name)
|
||||||
|
|
||||||
load_balancing_config_count = (
|
load_balancing_config_count = (
|
||||||
db.session.query(LoadBalancingModelConfig)
|
db.session.query(LoadBalancingModelConfig)
|
||||||
.filter(
|
.filter(
|
||||||
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
||||||
LoadBalancingModelConfig.provider_name == self.provider.provider,
|
LoadBalancingModelConfig.provider_name.in_(provider_names),
|
||||||
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
|
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
|
||||||
LoadBalancingModelConfig.model_name == model,
|
LoadBalancingModelConfig.model_name == model,
|
||||||
)
|
)
|
||||||
@ -582,16 +595,7 @@ class ProviderConfiguration(BaseModel):
|
|||||||
if load_balancing_config_count <= 1:
|
if load_balancing_config_count <= 1:
|
||||||
raise ValueError("Model load balancing configuration must be more than 1.")
|
raise ValueError("Model load balancing configuration must be more than 1.")
|
||||||
|
|
||||||
model_setting = (
|
model_setting = self._get_provider_model_setting(model_type, model)
|
||||||
db.session.query(ProviderModelSetting)
|
|
||||||
.filter(
|
|
||||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
|
||||||
ProviderModelSetting.provider_name == self.provider.provider,
|
|
||||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
|
||||||
ProviderModelSetting.model_name == model,
|
|
||||||
)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
if model_setting:
|
if model_setting:
|
||||||
model_setting.load_balancing_enabled = True
|
model_setting.load_balancing_enabled = True
|
||||||
@ -616,11 +620,16 @@ class ProviderConfiguration(BaseModel):
|
|||||||
:param model: model name
|
:param model: model name
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
model_provider_id = ModelProviderID(self.provider.provider)
|
||||||
|
provider_names = [self.provider.provider]
|
||||||
|
if model_provider_id.is_langgenius():
|
||||||
|
provider_names.append(model_provider_id.provider_name)
|
||||||
|
|
||||||
model_setting = (
|
model_setting = (
|
||||||
db.session.query(ProviderModelSetting)
|
db.session.query(ProviderModelSetting)
|
||||||
.filter(
|
.filter(
|
||||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
ProviderModelSetting.tenant_id == self.tenant_id,
|
||||||
ProviderModelSetting.provider_name == self.provider.provider,
|
ProviderModelSetting.provider_name.in_(provider_names),
|
||||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
||||||
ProviderModelSetting.model_name == model,
|
ProviderModelSetting.model_name == model,
|
||||||
)
|
)
|
||||||
@ -677,11 +686,16 @@ class ProviderConfiguration(BaseModel):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# get preferred provider
|
# get preferred provider
|
||||||
|
model_provider_id = ModelProviderID(self.provider.provider)
|
||||||
|
provider_names = [self.provider.provider]
|
||||||
|
if model_provider_id.is_langgenius():
|
||||||
|
provider_names.append(model_provider_id.provider_name)
|
||||||
|
|
||||||
preferred_model_provider = (
|
preferred_model_provider = (
|
||||||
db.session.query(TenantPreferredModelProvider)
|
db.session.query(TenantPreferredModelProvider)
|
||||||
.filter(
|
.filter(
|
||||||
TenantPreferredModelProvider.tenant_id == self.tenant_id,
|
TenantPreferredModelProvider.tenant_id == self.tenant_id,
|
||||||
TenantPreferredModelProvider.provider_name == self.provider.provider,
|
TenantPreferredModelProvider.provider_name.in_(provider_names),
|
||||||
)
|
)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user