From 5ba875f96bd3ce1e493842cb53caadaeb02f212e Mon Sep 17 00:00:00 2001 From: alexcodelf Date: Mon, 6 Jan 2025 20:07:09 +0800 Subject: [PATCH] fix: gpustack llm and text_embedding model url path wrong after edited --- .../model_providers/gpustack/llm/llm.py | 16 ++++++++++------ .../gpustack/text_embedding/text_embedding.py | 17 +++++++++-------- 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/api/core/model_runtime/model_providers/gpustack/llm/llm.py b/api/core/model_runtime/model_providers/gpustack/llm/llm.py index ce6780b6a7..429c761837 100644 --- a/api/core/model_runtime/model_providers/gpustack/llm/llm.py +++ b/api/core/model_runtime/model_providers/gpustack/llm/llm.py @@ -1,7 +1,5 @@ from collections.abc import Generator -from yarl import URL - from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.entities.message_entities import ( PromptMessage, @@ -24,9 +22,10 @@ class GPUStackLanguageModel(OAIAPICompatLargeLanguageModel): stream: bool = True, user: str | None = None, ) -> LLMResult | Generator: + compatible_credentials = self._get_compatible_credentials(credentials) return super()._invoke( model, - credentials, + compatible_credentials, prompt_messages, model_parameters, tools, @@ -36,10 +35,15 @@ class GPUStackLanguageModel(OAIAPICompatLargeLanguageModel): ) def validate_credentials(self, model: str, credentials: dict) -> None: - self._add_custom_parameters(credentials) - super().validate_credentials(model, credentials) + compatible_credentials = self._get_compatible_credentials(credentials) + super().validate_credentials(model, compatible_credentials) + + def _get_compatible_credentials(self, credentials: dict) -> dict: + credentials = credentials.copy() + base_url = credentials["endpoint_url"].rstrip("/").removesuffix("/v1-openai") + credentials["endpoint_url"] = f"{base_url}/v1-openai" + return credentials @staticmethod def _add_custom_parameters(credentials: dict) -> None: - credentials["endpoint_url"] = str(URL(credentials["endpoint_url"]) / "v1-openai") credentials["mode"] = "chat" diff --git a/api/core/model_runtime/model_providers/gpustack/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/gpustack/text_embedding/text_embedding.py index eb324491a2..35b499e51a 100644 --- a/api/core/model_runtime/model_providers/gpustack/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/gpustack/text_embedding/text_embedding.py @@ -1,7 +1,5 @@ from typing import Optional -from yarl import URL - from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.text_embedding_entities import ( TextEmbeddingResult, @@ -24,12 +22,15 @@ class GPUStackTextEmbeddingModel(OAICompatEmbeddingModel): user: Optional[str] = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, ) -> TextEmbeddingResult: - return super()._invoke(model, credentials, texts, user, input_type) + compatible_credentials = self._get_compatible_credentials(credentials) + return super()._invoke(model, compatible_credentials, texts, user, input_type) def validate_credentials(self, model: str, credentials: dict) -> None: - self._add_custom_parameters(credentials) - super().validate_credentials(model, credentials) + compatible_credentials = self._get_compatible_credentials(credentials) + super().validate_credentials(model, compatible_credentials) - @staticmethod - def _add_custom_parameters(credentials: dict) -> None: - credentials["endpoint_url"] = str(URL(credentials["endpoint_url"]) / "v1-openai") + def _get_compatible_credentials(self, credentials: dict) -> dict: + credentials = credentials.copy() + base_url = credentials["endpoint_url"].rstrip("/").removesuffix("/v1-openai") + credentials["endpoint_url"] = f"{base_url}/v1-openai" + return credentials