fix: gpustack llm and text_embedding model url path wrong after edited

This commit is contained in:
alexcodelf 2025-01-06 20:07:09 +08:00 committed by crazywoola
parent 409cc7d9b0
commit 5ba875f96b
2 changed files with 19 additions and 14 deletions

View File

@ -1,7 +1,5 @@
from collections.abc import Generator
from yarl import URL
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import (
PromptMessage,
@ -24,9 +22,10 @@ class GPUStackLanguageModel(OAIAPICompatLargeLanguageModel):
stream: bool = True,
user: str | None = None,
) -> LLMResult | Generator:
compatible_credentials = self._get_compatible_credentials(credentials)
return super()._invoke(
model,
credentials,
compatible_credentials,
prompt_messages,
model_parameters,
tools,
@ -36,10 +35,15 @@ class GPUStackLanguageModel(OAIAPICompatLargeLanguageModel):
)
def validate_credentials(self, model: str, credentials: dict) -> None:
self._add_custom_parameters(credentials)
super().validate_credentials(model, credentials)
compatible_credentials = self._get_compatible_credentials(credentials)
super().validate_credentials(model, compatible_credentials)
def _get_compatible_credentials(self, credentials: dict) -> dict:
credentials = credentials.copy()
base_url = credentials["endpoint_url"].rstrip("/").removesuffix("/v1-openai")
credentials["endpoint_url"] = f"{base_url}/v1-openai"
return credentials
@staticmethod
def _add_custom_parameters(credentials: dict) -> None:
credentials["endpoint_url"] = str(URL(credentials["endpoint_url"]) / "v1-openai")
credentials["mode"] = "chat"

View File

@ -1,7 +1,5 @@
from typing import Optional
from yarl import URL
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.text_embedding_entities import (
TextEmbeddingResult,
@ -24,12 +22,15 @@ class GPUStackTextEmbeddingModel(OAICompatEmbeddingModel):
user: Optional[str] = None,
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
) -> TextEmbeddingResult:
return super()._invoke(model, credentials, texts, user, input_type)
compatible_credentials = self._get_compatible_credentials(credentials)
return super()._invoke(model, compatible_credentials, texts, user, input_type)
def validate_credentials(self, model: str, credentials: dict) -> None:
self._add_custom_parameters(credentials)
super().validate_credentials(model, credentials)
compatible_credentials = self._get_compatible_credentials(credentials)
super().validate_credentials(model, compatible_credentials)
@staticmethod
def _add_custom_parameters(credentials: dict) -> None:
credentials["endpoint_url"] = str(URL(credentials["endpoint_url"]) / "v1-openai")
def _get_compatible_credentials(self, credentials: dict) -> dict:
credentials = credentials.copy()
base_url = credentials["endpoint_url"].rstrip("/").removesuffix("/v1-openai")
credentials["endpoint_url"] = f"{base_url}/v1-openai"
return credentials