Add TTS to OpenAI_API_Compatible (#11071)
This commit is contained in:
parent
044e7b63c2
commit
aa135a3780
@ -14,7 +14,7 @@ from core.model_runtime.model_providers.azure_openai._constant import TTS_BASE_M
|
|||||||
|
|
||||||
class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
||||||
"""
|
"""
|
||||||
Model class for OpenAI Speech to text model.
|
Model class for OpenAI text2speech model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _invoke(
|
def _invoke(
|
||||||
|
@ -10,7 +10,7 @@ from core.model_runtime.model_providers.gitee_ai._common import _CommonGiteeAI
|
|||||||
|
|
||||||
class GiteeAIText2SpeechModel(_CommonGiteeAI, TTSModel):
|
class GiteeAIText2SpeechModel(_CommonGiteeAI, TTSModel):
|
||||||
"""
|
"""
|
||||||
Model class for OpenAI Speech to text model.
|
Model class for OpenAI text2speech model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _invoke(
|
def _invoke(
|
||||||
|
@ -11,7 +11,7 @@ from core.model_runtime.model_providers.openai._common import _CommonOpenAI
|
|||||||
|
|
||||||
class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
|
class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
|
||||||
"""
|
"""
|
||||||
Model class for OpenAI Speech to text model.
|
Model class for OpenAI text2speech model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _invoke(
|
def _invoke(
|
||||||
|
@ -9,6 +9,7 @@ supported_model_types:
|
|||||||
- text-embedding
|
- text-embedding
|
||||||
- speech2text
|
- speech2text
|
||||||
- rerank
|
- rerank
|
||||||
|
- tts
|
||||||
configurate_methods:
|
configurate_methods:
|
||||||
- customizable-model
|
- customizable-model
|
||||||
model_credential_schema:
|
model_credential_schema:
|
||||||
@ -67,7 +68,7 @@ model_credential_schema:
|
|||||||
- variable: __model_type
|
- variable: __model_type
|
||||||
value: llm
|
value: llm
|
||||||
type: text-input
|
type: text-input
|
||||||
default: '4096'
|
default: "4096"
|
||||||
placeholder:
|
placeholder:
|
||||||
zh_Hans: 在此输入您的模型上下文长度
|
zh_Hans: 在此输入您的模型上下文长度
|
||||||
en_US: Enter your Model context size
|
en_US: Enter your Model context size
|
||||||
@ -80,7 +81,7 @@ model_credential_schema:
|
|||||||
- variable: __model_type
|
- variable: __model_type
|
||||||
value: text-embedding
|
value: text-embedding
|
||||||
type: text-input
|
type: text-input
|
||||||
default: '4096'
|
default: "4096"
|
||||||
placeholder:
|
placeholder:
|
||||||
zh_Hans: 在此输入您的模型上下文长度
|
zh_Hans: 在此输入您的模型上下文长度
|
||||||
en_US: Enter your Model context size
|
en_US: Enter your Model context size
|
||||||
@ -93,7 +94,7 @@ model_credential_schema:
|
|||||||
- variable: __model_type
|
- variable: __model_type
|
||||||
value: rerank
|
value: rerank
|
||||||
type: text-input
|
type: text-input
|
||||||
default: '4096'
|
default: "4096"
|
||||||
placeholder:
|
placeholder:
|
||||||
zh_Hans: 在此输入您的模型上下文长度
|
zh_Hans: 在此输入您的模型上下文长度
|
||||||
en_US: Enter your Model context size
|
en_US: Enter your Model context size
|
||||||
@ -104,7 +105,7 @@ model_credential_schema:
|
|||||||
show_on:
|
show_on:
|
||||||
- variable: __model_type
|
- variable: __model_type
|
||||||
value: llm
|
value: llm
|
||||||
default: '4096'
|
default: "4096"
|
||||||
type: text-input
|
type: text-input
|
||||||
- variable: function_calling_type
|
- variable: function_calling_type
|
||||||
show_on:
|
show_on:
|
||||||
@ -174,3 +175,19 @@ model_credential_schema:
|
|||||||
value: llm
|
value: llm
|
||||||
default: '\n\n'
|
default: '\n\n'
|
||||||
type: text-input
|
type: text-input
|
||||||
|
- variable: voices
|
||||||
|
show_on:
|
||||||
|
- variable: __model_type
|
||||||
|
value: tts
|
||||||
|
label:
|
||||||
|
en_US: Available Voices (comma-separated)
|
||||||
|
zh_Hans: 可用声音(用英文逗号分隔)
|
||||||
|
type: text-input
|
||||||
|
required: false
|
||||||
|
default: "alloy"
|
||||||
|
placeholder:
|
||||||
|
en_US: "alloy,echo,fable,onyx,nova,shimmer"
|
||||||
|
zh_Hans: "alloy,echo,fable,onyx,nova,shimmer"
|
||||||
|
help:
|
||||||
|
en_US: "List voice names separated by commas. First voice will be used as default."
|
||||||
|
zh_Hans: "用英文逗号分隔的声音列表。第一个声音将作为默认值。"
|
||||||
|
@ -0,0 +1,145 @@
|
|||||||
|
from collections.abc import Iterable
|
||||||
|
from typing import Optional
|
||||||
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from core.model_runtime.entities.common_entities import I18nObject
|
||||||
|
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType
|
||||||
|
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||||
|
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat
|
||||||
|
|
||||||
|
|
||||||
|
class OAICompatText2SpeechModel(_CommonOaiApiCompat, TTSModel):
|
||||||
|
"""
|
||||||
|
Model class for OpenAI-compatible text2speech model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _invoke(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
tenant_id: str,
|
||||||
|
credentials: dict,
|
||||||
|
content_text: str,
|
||||||
|
voice: str,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> Iterable[bytes]:
|
||||||
|
"""
|
||||||
|
Invoke TTS model
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param tenant_id: user tenant id
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param content_text: text content to be translated
|
||||||
|
:param voice: model voice/speaker
|
||||||
|
:param user: unique user id
|
||||||
|
:return: audio data as bytes iterator
|
||||||
|
"""
|
||||||
|
# Set up headers with authentication if provided
|
||||||
|
headers = {}
|
||||||
|
if api_key := credentials.get("api_key"):
|
||||||
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
|
|
||||||
|
# Construct endpoint URL
|
||||||
|
endpoint_url = credentials.get("endpoint_url")
|
||||||
|
if not endpoint_url.endswith("/"):
|
||||||
|
endpoint_url += "/"
|
||||||
|
endpoint_url = urljoin(endpoint_url, "audio/speech")
|
||||||
|
|
||||||
|
# Get audio format from model properties
|
||||||
|
audio_format = self._get_model_audio_type(model, credentials)
|
||||||
|
|
||||||
|
# Split text into chunks if needed based on word limit
|
||||||
|
word_limit = self._get_model_word_limit(model, credentials)
|
||||||
|
sentences = self._split_text_into_sentences(content_text, word_limit)
|
||||||
|
|
||||||
|
for sentence in sentences:
|
||||||
|
# Prepare request payload
|
||||||
|
payload = {"model": model, "input": sentence, "voice": voice, "response_format": audio_format}
|
||||||
|
|
||||||
|
# Make POST request
|
||||||
|
response = requests.post(endpoint_url, headers=headers, json=payload, stream=True)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise InvokeBadRequestError(response.text)
|
||||||
|
|
||||||
|
# Stream the audio data
|
||||||
|
for chunk in response.iter_content(chunk_size=4096):
|
||||||
|
if chunk:
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
|
"""
|
||||||
|
Validate model credentials
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Get default voice for validation
|
||||||
|
voice = self._get_model_default_voice(model, credentials)
|
||||||
|
|
||||||
|
# Test with a simple text
|
||||||
|
next(
|
||||||
|
self._invoke(
|
||||||
|
model=model, tenant_id="validate", credentials=credentials, content_text="Test.", voice=voice
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as ex:
|
||||||
|
raise CredentialsValidateFailedError(str(ex))
|
||||||
|
|
||||||
|
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||||
|
"""
|
||||||
|
Get customizable model schema
|
||||||
|
"""
|
||||||
|
# Parse voices from comma-separated string
|
||||||
|
voice_names = credentials.get("voices", "alloy").strip().split(",")
|
||||||
|
voices = []
|
||||||
|
|
||||||
|
for voice in voice_names:
|
||||||
|
voice = voice.strip()
|
||||||
|
if not voice:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Use en-US for all voices
|
||||||
|
voices.append(
|
||||||
|
{
|
||||||
|
"name": voice,
|
||||||
|
"mode": voice,
|
||||||
|
"language": "en-US",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# If no voices provided or all voices were empty strings, use 'alloy' as default
|
||||||
|
if not voices:
|
||||||
|
voices = [{"name": "Alloy", "mode": "alloy", "language": "en-US"}]
|
||||||
|
|
||||||
|
return AIModelEntity(
|
||||||
|
model=model,
|
||||||
|
label=I18nObject(en_US=model),
|
||||||
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
|
model_type=ModelType.TTS,
|
||||||
|
model_properties={
|
||||||
|
ModelPropertyKey.AUDIO_TYPE: credentials.get("audio_type", "mp3"),
|
||||||
|
ModelPropertyKey.WORD_LIMIT: int(credentials.get("word_limit", 4096)),
|
||||||
|
ModelPropertyKey.DEFAULT_VOICE: voices[0]["mode"],
|
||||||
|
ModelPropertyKey.VOICES: voices,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list:
|
||||||
|
"""
|
||||||
|
Override base get_tts_model_voices to handle customizable voices
|
||||||
|
"""
|
||||||
|
model_schema = self.get_customizable_model_schema(model, credentials)
|
||||||
|
|
||||||
|
if not model_schema or ModelPropertyKey.VOICES not in model_schema.model_properties:
|
||||||
|
raise ValueError("this model does not support voice")
|
||||||
|
|
||||||
|
voices = model_schema.model_properties[ModelPropertyKey.VOICES]
|
||||||
|
|
||||||
|
# Always return all voices regardless of language
|
||||||
|
return [{"name": d["name"], "value": d["mode"]} for d in voices]
|
Loading…
Reference in New Issue
Block a user