feat: backwards invoke model

This commit is contained in:
Yeuoly 2024-09-26 15:38:22 +08:00
parent 4c28034224
commit 0ad9dbea63
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61
4 changed files with 169 additions and 10 deletions

View File

@ -47,7 +47,11 @@ class PluginInvokeTextEmbeddingApi(Resource):
@get_tenant @get_tenant
@plugin_data(payload_type=RequestInvokeTextEmbedding) @plugin_data(payload_type=RequestInvokeTextEmbedding)
def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeTextEmbedding): def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeTextEmbedding):
pass return PluginModelBackwardsInvocation.invoke_text_embedding(
user_id=user_id,
tenant=tenant_model,
payload=payload,
)
class PluginInvokeRerankApi(Resource): class PluginInvokeRerankApi(Resource):

View File

@ -310,7 +310,9 @@ class ModelInstance:
user=user, user=user,
) )
def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) -> str: def invoke_tts(
self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None
) -> Generator[bytes, None, None]:
""" """
Invoke large language tts model Invoke large language tts model

View File

@ -1,9 +1,18 @@
import tempfile
from binascii import hexlify, unhexlify
from collections.abc import Generator from collections.abc import Generator
from core.model_manager import ModelManager from core.model_manager import ModelManager
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
from core.plugin.entities.request import RequestInvokeLLM from core.plugin.entities.request import (
RequestInvokeLLM,
RequestInvokeModeration,
RequestInvokeRerank,
RequestInvokeSpeech2Text,
RequestInvokeTextEmbedding,
RequestInvokeTTS,
)
from core.workflow.nodes.llm.llm_node import LLMNode from core.workflow.nodes.llm.llm_node import LLMNode
from models.account import Tenant from models.account import Tenant
@ -48,5 +57,121 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
if response.usage: if response.usage:
LLMNode.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage) LLMNode.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
return response return response
@classmethod
def invoke_text_embedding(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTextEmbedding):
"""
invoke text embedding
"""
model_instance = ModelManager().get_model_instance(
tenant_id=tenant.id,
provider=payload.provider,
model_type=payload.model_type,
model=payload.model,
)
# invoke model
response = model_instance.invoke_text_embedding(
texts=payload.texts,
user=user_id,
)
return response
@classmethod
def invoke_rerank(cls, user_id: str, tenant: Tenant, payload: RequestInvokeRerank):
"""
invoke rerank
"""
model_instance = ModelManager().get_model_instance(
tenant_id=tenant.id,
provider=payload.provider,
model_type=payload.model_type,
model=payload.model,
)
# invoke model
response = model_instance.invoke_rerank(
query=payload.query,
docs=payload.docs,
score_threshold=payload.score_threshold,
top_n=payload.top_n,
user=user_id,
)
return response
@classmethod
def invoke_tts(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTTS):
"""
invoke tts
"""
model_instance = ModelManager().get_model_instance(
tenant_id=tenant.id,
provider=payload.provider,
model_type=payload.model_type,
model=payload.model,
)
# invoke model
response = model_instance.invoke_tts(
content_text=payload.content_text,
tenant_id=tenant.id,
voice=payload.voice,
user=user_id,
)
def handle() -> Generator[dict, None, None]:
for chunk in response:
yield {"result": hexlify(chunk).decode("utf-8")}
return handle()
@classmethod
def invoke_speech2text(cls, user_id: str, tenant: Tenant, payload: RequestInvokeSpeech2Text):
"""
invoke speech2text
"""
model_instance = ModelManager().get_model_instance(
tenant_id=tenant.id,
provider=payload.provider,
model_type=payload.model_type,
model=payload.model,
)
# invoke model
with tempfile.NamedTemporaryFile(suffix=".mp3", mode="wb", delete=True) as temp:
temp.write(unhexlify(payload.file))
temp.flush()
temp.seek(0)
response = model_instance.invoke_speech2text(
file=temp,
user=user_id,
)
return {
"result": response,
}
@classmethod
def invoke_moderation(cls, user_id: str, tenant: Tenant, payload: RequestInvokeModeration):
"""
invoke moderation
"""
model_instance = ModelManager().get_model_instance(
tenant_id=tenant.id,
provider=payload.provider,
model_type=payload.model_type,
model=payload.model,
)
# invoke model
response = model_instance.invoke_moderation(
text=payload.text,
user=user_id,
)
return {
"result": response,
}

View File

@ -74,35 +74,63 @@ class RequestInvokeLLM(BaseRequestInvokeModel):
return v return v
class RequestInvokeTextEmbedding(BaseModel): class RequestInvokeTextEmbedding(BaseRequestInvokeModel):
""" """
Request to invoke text embedding Request to invoke text embedding
""" """
model_type: ModelType = ModelType.TEXT_EMBEDDING
texts: list[str]
class RequestInvokeRerank(BaseModel):
class RequestInvokeRerank(BaseRequestInvokeModel):
""" """
Request to invoke rerank Request to invoke rerank
""" """
model_type: ModelType = ModelType.RERANK
query: str
docs: list[str]
score_threshold: float
top_n: int
class RequestInvokeTTS(BaseModel):
class RequestInvokeTTS(BaseRequestInvokeModel):
""" """
Request to invoke TTS Request to invoke TTS
""" """
model_type: ModelType = ModelType.TTS
content_text: str
voice: str
class RequestInvokeSpeech2Text(BaseModel):
class RequestInvokeSpeech2Text(BaseRequestInvokeModel):
""" """
Request to invoke speech2text Request to invoke speech2text
""" """
model_type: ModelType = ModelType.SPEECH2TEXT
file: bytes
class RequestInvokeModeration(BaseModel): @field_validator("file", mode="before")
@classmethod
def convert_file(cls, v):
# hex string to bytes
if isinstance(v, str):
return bytes.fromhex(v)
else:
raise ValueError("file must be a hex string")
class RequestInvokeModeration(BaseRequestInvokeModel):
""" """
Request to invoke moderation Request to invoke moderation
""" """
model_type: ModelType = ModelType.MODERATION
text: str
class RequestInvokeParameterExtractorNode(BaseModel): class RequestInvokeParameterExtractorNode(BaseModel):
""" """