Support TTS and Speech2Text for Model Provider GPUStack (#12381)
This commit is contained in:
parent
409cc7d9b0
commit
2bb521b135
@ -9,6 +9,8 @@ supported_model_types:
|
||||
- llm
|
||||
- text-embedding
|
||||
- rerank
|
||||
- speech2text
|
||||
- tts
|
||||
configurate_methods:
|
||||
- customizable-model
|
||||
model_credential_schema:
|
||||
@ -118,3 +120,19 @@ model_credential_schema:
|
||||
label:
|
||||
en_US: Not Support
|
||||
zh_Hans: 不支持
|
||||
- variable: voices
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: tts
|
||||
label:
|
||||
en_US: Available Voices (comma-separated)
|
||||
zh_Hans: 可用声音(用英文逗号分隔)
|
||||
type: text-input
|
||||
required: false
|
||||
default: "Chinese Female"
|
||||
placeholder:
|
||||
en_US: "Chinese Female, Chinese Male, Japanese Male, Cantonese Female, English Female, English Male, Korean Female"
|
||||
zh_Hans: "Chinese Female, Chinese Male, Japanese Male, Cantonese Female, English Female, English Male, Korean Female"
|
||||
help:
|
||||
en_US: "List voice names separated by commas. First voice will be used as default."
|
||||
zh_Hans: "用英文逗号分隔的声音列表。第一个声音将作为默认值。"
|
||||
|
@ -1,7 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from yarl import URL
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
PromptMessage,
|
||||
@ -24,9 +22,10 @@ class GPUStackLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
) -> LLMResult | Generator:
|
||||
compatible_credentials = self._get_compatible_credentials(credentials)
|
||||
return super()._invoke(
|
||||
model,
|
||||
credentials,
|
||||
compatible_credentials,
|
||||
prompt_messages,
|
||||
model_parameters,
|
||||
tools,
|
||||
@ -36,10 +35,15 @@ class GPUStackLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
self._add_custom_parameters(credentials)
|
||||
super().validate_credentials(model, credentials)
|
||||
compatible_credentials = self._get_compatible_credentials(credentials)
|
||||
super().validate_credentials(model, compatible_credentials)
|
||||
|
||||
def _get_compatible_credentials(self, credentials: dict) -> dict:
|
||||
credentials = credentials.copy()
|
||||
base_url = credentials["endpoint_url"].rstrip("/").removesuffix("/v1-openai")
|
||||
credentials["endpoint_url"] = f"{base_url}/v1-openai"
|
||||
return credentials
|
||||
|
||||
@staticmethod
|
||||
def _add_custom_parameters(credentials: dict) -> None:
|
||||
credentials["endpoint_url"] = str(URL(credentials["endpoint_url"]) / "v1-openai")
|
||||
credentials["mode"] = "chat"
|
||||
|
@ -0,0 +1,43 @@
|
||||
from typing import IO, Optional
|
||||
|
||||
from core.model_runtime.model_providers.openai_api_compatible.speech2text.speech2text import OAICompatSpeech2TextModel
|
||||
|
||||
|
||||
class GPUStackSpeech2TextModel(OAICompatSpeech2TextModel):
|
||||
"""
|
||||
Model class for GPUStack Speech to text model.
|
||||
"""
|
||||
|
||||
def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str:
|
||||
"""
|
||||
Invoke speech2text model
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param file: audio file
|
||||
:param user: unique user id
|
||||
:return: text for given audio file
|
||||
"""
|
||||
compatible_credentials = self._get_compatible_credentials(credentials)
|
||||
return super()._invoke(model, compatible_credentials, file)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
"""
|
||||
compatible_credentials = self._get_compatible_credentials(credentials)
|
||||
super().validate_credentials(model, compatible_credentials)
|
||||
|
||||
def _get_compatible_credentials(self, credentials: dict) -> dict:
|
||||
"""
|
||||
Get compatible credentials
|
||||
|
||||
:param credentials: model credentials
|
||||
:return: compatible credentials
|
||||
"""
|
||||
compatible_credentials = credentials.copy()
|
||||
base_url = credentials["endpoint_url"].rstrip("/").removesuffix("/v1-openai")
|
||||
compatible_credentials["endpoint_url"] = f"{base_url}/v1-openai"
|
||||
return compatible_credentials
|
@ -1,7 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
from yarl import URL
|
||||
|
||||
from core.entities.embedding_type import EmbeddingInputType
|
||||
from core.model_runtime.entities.text_embedding_entities import (
|
||||
TextEmbeddingResult,
|
||||
@ -24,12 +22,15 @@ class GPUStackTextEmbeddingModel(OAICompatEmbeddingModel):
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
return super()._invoke(model, credentials, texts, user, input_type)
|
||||
compatible_credentials = self._get_compatible_credentials(credentials)
|
||||
return super()._invoke(model, compatible_credentials, texts, user, input_type)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
self._add_custom_parameters(credentials)
|
||||
super().validate_credentials(model, credentials)
|
||||
compatible_credentials = self._get_compatible_credentials(credentials)
|
||||
super().validate_credentials(model, compatible_credentials)
|
||||
|
||||
@staticmethod
|
||||
def _add_custom_parameters(credentials: dict) -> None:
|
||||
credentials["endpoint_url"] = str(URL(credentials["endpoint_url"]) / "v1-openai")
|
||||
def _get_compatible_credentials(self, credentials: dict) -> dict:
|
||||
credentials = credentials.copy()
|
||||
base_url = credentials["endpoint_url"].rstrip("/").removesuffix("/v1-openai")
|
||||
credentials["endpoint_url"] = f"{base_url}/v1-openai"
|
||||
return credentials
|
||||
|
57
api/core/model_runtime/model_providers/gpustack/tts/tts.py
Normal file
57
api/core/model_runtime/model_providers/gpustack/tts/tts.py
Normal file
@ -0,0 +1,57 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.model_runtime.model_providers.openai_api_compatible.tts.tts import OAICompatText2SpeechModel
|
||||
|
||||
|
||||
class GPUStackText2SpeechModel(OAICompatText2SpeechModel):
|
||||
"""
|
||||
Model class for GPUStack Text to Speech model.
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None
|
||||
) -> Any:
|
||||
"""
|
||||
Invoke text2speech model
|
||||
|
||||
:param model: model name
|
||||
:param tenant_id: user tenant id
|
||||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
:param voice: model timbre
|
||||
:param user: unique user id
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
compatible_credentials = self._get_compatible_credentials(credentials)
|
||||
return super()._invoke(
|
||||
model=model,
|
||||
tenant_id=tenant_id,
|
||||
credentials=compatible_credentials,
|
||||
content_text=content_text,
|
||||
voice=voice,
|
||||
user=user,
|
||||
)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param user: unique user id
|
||||
"""
|
||||
compatible_credentials = self._get_compatible_credentials(credentials)
|
||||
super().validate_credentials(model, compatible_credentials)
|
||||
|
||||
def _get_compatible_credentials(self, credentials: dict) -> dict:
|
||||
"""
|
||||
Get compatible credentials
|
||||
|
||||
:param credentials: model credentials
|
||||
:return: compatible credentials
|
||||
"""
|
||||
compatible_credentials = credentials.copy()
|
||||
base_url = credentials["endpoint_url"].rstrip("/").removesuffix("/v1-openai")
|
||||
compatible_credentials["endpoint_url"] = f"{base_url}/v1-openai"
|
||||
|
||||
return compatible_credentials
|
@ -0,0 +1,55 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.gpustack.speech2text.speech2text import GPUStackSpeech2TextModel
|
||||
|
||||
|
||||
def test_validate_credentials():
|
||||
model = GPUStackSpeech2TextModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model="faster-whisper-medium",
|
||||
credentials={
|
||||
"endpoint_url": "invalid_url",
|
||||
"api_key": "invalid_api_key",
|
||||
},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model="faster-whisper-medium",
|
||||
credentials={
|
||||
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
|
||||
"api_key": os.environ.get("GPUSTACK_API_KEY"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = GPUStackSpeech2TextModel()
|
||||
|
||||
# Get the directory of the current file
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# Get assets directory
|
||||
assets_dir = os.path.join(os.path.dirname(current_dir), "assets")
|
||||
|
||||
# Construct the path to the audio file
|
||||
audio_file_path = os.path.join(assets_dir, "audio.mp3")
|
||||
|
||||
file = Path(audio_file_path).read_bytes()
|
||||
|
||||
result = model.invoke(
|
||||
model="faster-whisper-medium",
|
||||
credentials={
|
||||
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
|
||||
"api_key": os.environ.get("GPUSTACK_API_KEY"),
|
||||
},
|
||||
file=file,
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert result == "1, 2, 3, 4, 5, 6, 7, 8, 9, 10"
|
@ -0,0 +1,24 @@
|
||||
import os
|
||||
|
||||
from core.model_runtime.model_providers.gpustack.tts.tts import GPUStackText2SpeechModel
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = GPUStackText2SpeechModel()
|
||||
|
||||
result = model.invoke(
|
||||
model="cosyvoice-300m-sft",
|
||||
tenant_id="test",
|
||||
credentials={
|
||||
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
|
||||
"api_key": os.environ.get("GPUSTACK_API_KEY"),
|
||||
},
|
||||
content_text="Hello world",
|
||||
voice="Chinese Female",
|
||||
)
|
||||
|
||||
content = b""
|
||||
for chunk in result:
|
||||
content += chunk
|
||||
|
||||
assert content != b""
|
Loading…
Reference in New Issue
Block a user