Add TTS to OpenAI_API_Compatible (#11071)
This commit is contained in:
parent
044e7b63c2
commit
aa135a3780
@ -14,7 +14,7 @@ from core.model_runtime.model_providers.azure_openai._constant import TTS_BASE_M
|
||||
|
||||
class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
||||
"""
|
||||
Model class for OpenAI Speech to text model.
|
||||
Model class for OpenAI text2speech model.
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
|
@ -10,7 +10,7 @@ from core.model_runtime.model_providers.gitee_ai._common import _CommonGiteeAI
|
||||
|
||||
class GiteeAIText2SpeechModel(_CommonGiteeAI, TTSModel):
|
||||
"""
|
||||
Model class for OpenAI Speech to text model.
|
||||
Model class for OpenAI text2speech model.
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
|
@ -11,7 +11,7 @@ from core.model_runtime.model_providers.openai._common import _CommonOpenAI
|
||||
|
||||
class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
|
||||
"""
|
||||
Model class for OpenAI Speech to text model.
|
||||
Model class for OpenAI text2speech model.
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
|
@ -9,6 +9,7 @@ supported_model_types:
|
||||
- text-embedding
|
||||
- speech2text
|
||||
- rerank
|
||||
- tts
|
||||
configurate_methods:
|
||||
- customizable-model
|
||||
model_credential_schema:
|
||||
@ -67,7 +68,7 @@ model_credential_schema:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
type: text-input
|
||||
default: '4096'
|
||||
default: "4096"
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的模型上下文长度
|
||||
en_US: Enter your Model context size
|
||||
@ -80,7 +81,7 @@ model_credential_schema:
|
||||
- variable: __model_type
|
||||
value: text-embedding
|
||||
type: text-input
|
||||
default: '4096'
|
||||
default: "4096"
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的模型上下文长度
|
||||
en_US: Enter your Model context size
|
||||
@ -93,7 +94,7 @@ model_credential_schema:
|
||||
- variable: __model_type
|
||||
value: rerank
|
||||
type: text-input
|
||||
default: '4096'
|
||||
default: "4096"
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的模型上下文长度
|
||||
en_US: Enter your Model context size
|
||||
@ -104,7 +105,7 @@ model_credential_schema:
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
default: '4096'
|
||||
default: "4096"
|
||||
type: text-input
|
||||
- variable: function_calling_type
|
||||
show_on:
|
||||
@ -174,3 +175,19 @@ model_credential_schema:
|
||||
value: llm
|
||||
default: '\n\n'
|
||||
type: text-input
|
||||
- variable: voices
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: tts
|
||||
label:
|
||||
en_US: Available Voices (comma-separated)
|
||||
zh_Hans: 可用声音(用英文逗号分隔)
|
||||
type: text-input
|
||||
required: false
|
||||
default: "alloy"
|
||||
placeholder:
|
||||
en_US: "alloy,echo,fable,onyx,nova,shimmer"
|
||||
zh_Hans: "alloy,echo,fable,onyx,nova,shimmer"
|
||||
help:
|
||||
en_US: "List voice names separated by commas. First voice will be used as default."
|
||||
zh_Hans: "用英文逗号分隔的声音列表。第一个声音将作为默认值。"
|
||||
|
@ -0,0 +1,145 @@
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType
|
||||
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat
|
||||
|
||||
|
||||
class OAICompatText2SpeechModel(_CommonOaiApiCompat, TTSModel):
|
||||
"""
|
||||
Model class for OpenAI-compatible text2speech model.
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
tenant_id: str,
|
||||
credentials: dict,
|
||||
content_text: str,
|
||||
voice: str,
|
||||
user: Optional[str] = None,
|
||||
) -> Iterable[bytes]:
|
||||
"""
|
||||
Invoke TTS model
|
||||
|
||||
:param model: model name
|
||||
:param tenant_id: user tenant id
|
||||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
:param voice: model voice/speaker
|
||||
:param user: unique user id
|
||||
:return: audio data as bytes iterator
|
||||
"""
|
||||
# Set up headers with authentication if provided
|
||||
headers = {}
|
||||
if api_key := credentials.get("api_key"):
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
# Construct endpoint URL
|
||||
endpoint_url = credentials.get("endpoint_url")
|
||||
if not endpoint_url.endswith("/"):
|
||||
endpoint_url += "/"
|
||||
endpoint_url = urljoin(endpoint_url, "audio/speech")
|
||||
|
||||
# Get audio format from model properties
|
||||
audio_format = self._get_model_audio_type(model, credentials)
|
||||
|
||||
# Split text into chunks if needed based on word limit
|
||||
word_limit = self._get_model_word_limit(model, credentials)
|
||||
sentences = self._split_text_into_sentences(content_text, word_limit)
|
||||
|
||||
for sentence in sentences:
|
||||
# Prepare request payload
|
||||
payload = {"model": model, "input": sentence, "voice": voice, "response_format": audio_format}
|
||||
|
||||
# Make POST request
|
||||
response = requests.post(endpoint_url, headers=headers, json=payload, stream=True)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise InvokeBadRequestError(response.text)
|
||||
|
||||
# Stream the audio data
|
||||
for chunk in response.iter_content(chunk_size=4096):
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
# Get default voice for validation
|
||||
voice = self._get_model_default_voice(model, credentials)
|
||||
|
||||
# Test with a simple text
|
||||
next(
|
||||
self._invoke(
|
||||
model=model, tenant_id="validate", credentials=credentials, content_text="Test.", voice=voice
|
||||
)
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
Get customizable model schema
|
||||
"""
|
||||
# Parse voices from comma-separated string
|
||||
voice_names = credentials.get("voices", "alloy").strip().split(",")
|
||||
voices = []
|
||||
|
||||
for voice in voice_names:
|
||||
voice = voice.strip()
|
||||
if not voice:
|
||||
continue
|
||||
|
||||
# Use en-US for all voices
|
||||
voices.append(
|
||||
{
|
||||
"name": voice,
|
||||
"mode": voice,
|
||||
"language": "en-US",
|
||||
}
|
||||
)
|
||||
|
||||
# If no voices provided or all voices were empty strings, use 'alloy' as default
|
||||
if not voices:
|
||||
voices = [{"name": "Alloy", "mode": "alloy", "language": "en-US"}]
|
||||
|
||||
return AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.TTS,
|
||||
model_properties={
|
||||
ModelPropertyKey.AUDIO_TYPE: credentials.get("audio_type", "mp3"),
|
||||
ModelPropertyKey.WORD_LIMIT: int(credentials.get("word_limit", 4096)),
|
||||
ModelPropertyKey.DEFAULT_VOICE: voices[0]["mode"],
|
||||
ModelPropertyKey.VOICES: voices,
|
||||
},
|
||||
)
|
||||
|
||||
def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list:
|
||||
"""
|
||||
Override base get_tts_model_voices to handle customizable voices
|
||||
"""
|
||||
model_schema = self.get_customizable_model_schema(model, credentials)
|
||||
|
||||
if not model_schema or ModelPropertyKey.VOICES not in model_schema.model_properties:
|
||||
raise ValueError("this model does not support voice")
|
||||
|
||||
voices = model_schema.model_properties[ModelPropertyKey.VOICES]
|
||||
|
||||
# Always return all voices regardless of language
|
||||
return [{"name": d["name"], "value": d["mode"]} for d in voices]
|
Loading…
Reference in New Issue
Block a user