feat: backwards invoke model
This commit is contained in:
parent
4c28034224
commit
0ad9dbea63
@ -47,7 +47,11 @@ class PluginInvokeTextEmbeddingApi(Resource):
|
||||
@get_tenant
|
||||
@plugin_data(payload_type=RequestInvokeTextEmbedding)
|
||||
def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeTextEmbedding):
|
||||
pass
|
||||
return PluginModelBackwardsInvocation.invoke_text_embedding(
|
||||
user_id=user_id,
|
||||
tenant=tenant_model,
|
||||
payload=payload,
|
||||
)
|
||||
|
||||
|
||||
class PluginInvokeRerankApi(Resource):
|
||||
|
@ -310,7 +310,9 @@ class ModelInstance:
|
||||
user=user,
|
||||
)
|
||||
|
||||
def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) -> str:
|
||||
def invoke_tts(
|
||||
self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None
|
||||
) -> Generator[bytes, None, None]:
|
||||
"""
|
||||
Invoke large language tts model
|
||||
|
||||
|
@ -1,9 +1,18 @@
|
||||
import tempfile
|
||||
from binascii import hexlify, unhexlify
|
||||
from collections.abc import Generator
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
|
||||
from core.plugin.entities.request import RequestInvokeLLM
|
||||
from core.plugin.entities.request import (
|
||||
RequestInvokeLLM,
|
||||
RequestInvokeModeration,
|
||||
RequestInvokeRerank,
|
||||
RequestInvokeSpeech2Text,
|
||||
RequestInvokeTextEmbedding,
|
||||
RequestInvokeTTS,
|
||||
)
|
||||
from core.workflow.nodes.llm.llm_node import LLMNode
|
||||
from models.account import Tenant
|
||||
|
||||
@ -49,4 +58,120 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
LLMNode.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def invoke_text_embedding(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTextEmbedding):
|
||||
"""
|
||||
invoke text embedding
|
||||
"""
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
tenant_id=tenant.id,
|
||||
provider=payload.provider,
|
||||
model_type=payload.model_type,
|
||||
model=payload.model,
|
||||
)
|
||||
|
||||
# invoke model
|
||||
response = model_instance.invoke_text_embedding(
|
||||
texts=payload.texts,
|
||||
user=user_id,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def invoke_rerank(cls, user_id: str, tenant: Tenant, payload: RequestInvokeRerank):
|
||||
"""
|
||||
invoke rerank
|
||||
"""
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
tenant_id=tenant.id,
|
||||
provider=payload.provider,
|
||||
model_type=payload.model_type,
|
||||
model=payload.model,
|
||||
)
|
||||
|
||||
# invoke model
|
||||
response = model_instance.invoke_rerank(
|
||||
query=payload.query,
|
||||
docs=payload.docs,
|
||||
score_threshold=payload.score_threshold,
|
||||
top_n=payload.top_n,
|
||||
user=user_id,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def invoke_tts(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTTS):
|
||||
"""
|
||||
invoke tts
|
||||
"""
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
tenant_id=tenant.id,
|
||||
provider=payload.provider,
|
||||
model_type=payload.model_type,
|
||||
model=payload.model,
|
||||
)
|
||||
|
||||
# invoke model
|
||||
response = model_instance.invoke_tts(
|
||||
content_text=payload.content_text,
|
||||
tenant_id=tenant.id,
|
||||
voice=payload.voice,
|
||||
user=user_id,
|
||||
)
|
||||
|
||||
def handle() -> Generator[dict, None, None]:
|
||||
for chunk in response:
|
||||
yield {"result": hexlify(chunk).decode("utf-8")}
|
||||
|
||||
return handle()
|
||||
|
||||
@classmethod
|
||||
def invoke_speech2text(cls, user_id: str, tenant: Tenant, payload: RequestInvokeSpeech2Text):
|
||||
"""
|
||||
invoke speech2text
|
||||
"""
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
tenant_id=tenant.id,
|
||||
provider=payload.provider,
|
||||
model_type=payload.model_type,
|
||||
model=payload.model,
|
||||
)
|
||||
|
||||
# invoke model
|
||||
with tempfile.NamedTemporaryFile(suffix=".mp3", mode="wb", delete=True) as temp:
|
||||
temp.write(unhexlify(payload.file))
|
||||
temp.flush()
|
||||
temp.seek(0)
|
||||
|
||||
response = model_instance.invoke_speech2text(
|
||||
file=temp,
|
||||
user=user_id,
|
||||
)
|
||||
|
||||
return {
|
||||
"result": response,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def invoke_moderation(cls, user_id: str, tenant: Tenant, payload: RequestInvokeModeration):
|
||||
"""
|
||||
invoke moderation
|
||||
"""
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
tenant_id=tenant.id,
|
||||
provider=payload.provider,
|
||||
model_type=payload.model_type,
|
||||
model=payload.model,
|
||||
)
|
||||
|
||||
# invoke model
|
||||
response = model_instance.invoke_moderation(
|
||||
text=payload.text,
|
||||
user=user_id,
|
||||
)
|
||||
|
||||
return {
|
||||
"result": response,
|
||||
}
|
||||
|
@ -74,35 +74,63 @@ class RequestInvokeLLM(BaseRequestInvokeModel):
|
||||
return v
|
||||
|
||||
|
||||
class RequestInvokeTextEmbedding(BaseModel):
|
||||
class RequestInvokeTextEmbedding(BaseRequestInvokeModel):
|
||||
"""
|
||||
Request to invoke text embedding
|
||||
"""
|
||||
|
||||
model_type: ModelType = ModelType.TEXT_EMBEDDING
|
||||
texts: list[str]
|
||||
|
||||
class RequestInvokeRerank(BaseModel):
|
||||
|
||||
class RequestInvokeRerank(BaseRequestInvokeModel):
|
||||
"""
|
||||
Request to invoke rerank
|
||||
"""
|
||||
|
||||
model_type: ModelType = ModelType.RERANK
|
||||
query: str
|
||||
docs: list[str]
|
||||
score_threshold: float
|
||||
top_n: int
|
||||
|
||||
class RequestInvokeTTS(BaseModel):
|
||||
|
||||
class RequestInvokeTTS(BaseRequestInvokeModel):
|
||||
"""
|
||||
Request to invoke TTS
|
||||
"""
|
||||
|
||||
model_type: ModelType = ModelType.TTS
|
||||
content_text: str
|
||||
voice: str
|
||||
|
||||
class RequestInvokeSpeech2Text(BaseModel):
|
||||
|
||||
class RequestInvokeSpeech2Text(BaseRequestInvokeModel):
|
||||
"""
|
||||
Request to invoke speech2text
|
||||
"""
|
||||
|
||||
model_type: ModelType = ModelType.SPEECH2TEXT
|
||||
file: bytes
|
||||
|
||||
class RequestInvokeModeration(BaseModel):
|
||||
@field_validator("file", mode="before")
|
||||
@classmethod
|
||||
def convert_file(cls, v):
|
||||
# hex string to bytes
|
||||
if isinstance(v, str):
|
||||
return bytes.fromhex(v)
|
||||
else:
|
||||
raise ValueError("file must be a hex string")
|
||||
|
||||
|
||||
class RequestInvokeModeration(BaseRequestInvokeModel):
|
||||
"""
|
||||
Request to invoke moderation
|
||||
"""
|
||||
|
||||
model_type: ModelType = ModelType.MODERATION
|
||||
text: str
|
||||
|
||||
|
||||
class RequestInvokeParameterExtractorNode(BaseModel):
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user