From 0ad9dbea63ac532e9194ee403c51a20218e7c6da Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Thu, 26 Sep 2024 15:38:22 +0800 Subject: [PATCH] feat: backwards invoke model --- api/controllers/inner_api/plugin/plugin.py | 6 +- api/core/model_manager.py | 4 +- api/core/plugin/backwards_invocation/model.py | 131 +++++++++++++++++- api/core/plugin/entities/request.py | 38 ++++- 4 files changed, 169 insertions(+), 10 deletions(-) diff --git a/api/controllers/inner_api/plugin/plugin.py b/api/controllers/inner_api/plugin/plugin.py index ae35332689..7dde4f0148 100644 --- a/api/controllers/inner_api/plugin/plugin.py +++ b/api/controllers/inner_api/plugin/plugin.py @@ -47,7 +47,11 @@ class PluginInvokeTextEmbeddingApi(Resource): @get_tenant @plugin_data(payload_type=RequestInvokeTextEmbedding) def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeTextEmbedding): - pass + return PluginModelBackwardsInvocation.invoke_text_embedding( + user_id=user_id, + tenant=tenant_model, + payload=payload, + ) class PluginInvokeRerankApi(Resource): diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 482ca2d4b9..1a4a03e277 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -310,7 +310,9 @@ class ModelInstance: user=user, ) - def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) -> str: + def invoke_tts( + self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None + ) -> Generator[bytes, None, None]: """ Invoke large language tts model diff --git a/api/core/plugin/backwards_invocation/model.py b/api/core/plugin/backwards_invocation/model.py index 7904fd6234..b3ecced55c 100644 --- a/api/core/plugin/backwards_invocation/model.py +++ b/api/core/plugin/backwards_invocation/model.py @@ -1,9 +1,18 @@ +import tempfile +from binascii import hexlify, unhexlify from collections.abc import Generator from core.model_manager import ModelManager from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from core.plugin.backwards_invocation.base import BaseBackwardsInvocation -from core.plugin.entities.request import RequestInvokeLLM +from core.plugin.entities.request import ( + RequestInvokeLLM, + RequestInvokeModeration, + RequestInvokeRerank, + RequestInvokeSpeech2Text, + RequestInvokeTextEmbedding, + RequestInvokeTTS, +) from core.workflow.nodes.llm.llm_node import LLMNode from models.account import Tenant @@ -48,5 +57,121 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): if response.usage: LLMNode.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage) return response - - \ No newline at end of file + + @classmethod + def invoke_text_embedding(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTextEmbedding): + """ + invoke text embedding + """ + model_instance = ModelManager().get_model_instance( + tenant_id=tenant.id, + provider=payload.provider, + model_type=payload.model_type, + model=payload.model, + ) + + # invoke model + response = model_instance.invoke_text_embedding( + texts=payload.texts, + user=user_id, + ) + + return response + + @classmethod + def invoke_rerank(cls, user_id: str, tenant: Tenant, payload: RequestInvokeRerank): + """ + invoke rerank + """ + model_instance = ModelManager().get_model_instance( + tenant_id=tenant.id, + provider=payload.provider, + model_type=payload.model_type, + model=payload.model, + ) + + # invoke model + response = model_instance.invoke_rerank( + query=payload.query, + docs=payload.docs, + score_threshold=payload.score_threshold, + top_n=payload.top_n, + user=user_id, + ) + + return response + + @classmethod + def invoke_tts(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTTS): + """ + invoke tts + """ + model_instance = ModelManager().get_model_instance( + tenant_id=tenant.id, + provider=payload.provider, + model_type=payload.model_type, + model=payload.model, + ) + + # invoke model + response = model_instance.invoke_tts( + content_text=payload.content_text, + tenant_id=tenant.id, + voice=payload.voice, + user=user_id, + ) + + def handle() -> Generator[dict, None, None]: + for chunk in response: + yield {"result": hexlify(chunk).decode("utf-8")} + + return handle() + + @classmethod + def invoke_speech2text(cls, user_id: str, tenant: Tenant, payload: RequestInvokeSpeech2Text): + """ + invoke speech2text + """ + model_instance = ModelManager().get_model_instance( + tenant_id=tenant.id, + provider=payload.provider, + model_type=payload.model_type, + model=payload.model, + ) + + # invoke model + with tempfile.NamedTemporaryFile(suffix=".mp3", mode="wb", delete=True) as temp: + temp.write(unhexlify(payload.file)) + temp.flush() + temp.seek(0) + + response = model_instance.invoke_speech2text( + file=temp, + user=user_id, + ) + + return { + "result": response, + } + + @classmethod + def invoke_moderation(cls, user_id: str, tenant: Tenant, payload: RequestInvokeModeration): + """ + invoke moderation + """ + model_instance = ModelManager().get_model_instance( + tenant_id=tenant.id, + provider=payload.provider, + model_type=payload.model_type, + model=payload.model, + ) + + # invoke model + response = model_instance.invoke_moderation( + text=payload.text, + user=user_id, + ) + + return { + "result": response, + } diff --git a/api/core/plugin/entities/request.py b/api/core/plugin/entities/request.py index 00ac53ca72..a27f8751a6 100644 --- a/api/core/plugin/entities/request.py +++ b/api/core/plugin/entities/request.py @@ -74,35 +74,63 @@ class RequestInvokeLLM(BaseRequestInvokeModel): return v -class RequestInvokeTextEmbedding(BaseModel): +class RequestInvokeTextEmbedding(BaseRequestInvokeModel): """ Request to invoke text embedding """ + model_type: ModelType = ModelType.TEXT_EMBEDDING + texts: list[str] -class RequestInvokeRerank(BaseModel): + +class RequestInvokeRerank(BaseRequestInvokeModel): """ Request to invoke rerank """ + model_type: ModelType = ModelType.RERANK + query: str + docs: list[str] + score_threshold: float + top_n: int -class RequestInvokeTTS(BaseModel): + +class RequestInvokeTTS(BaseRequestInvokeModel): """ Request to invoke TTS """ + model_type: ModelType = ModelType.TTS + content_text: str + voice: str -class RequestInvokeSpeech2Text(BaseModel): + +class RequestInvokeSpeech2Text(BaseRequestInvokeModel): """ Request to invoke speech2text """ + model_type: ModelType = ModelType.SPEECH2TEXT + file: bytes -class RequestInvokeModeration(BaseModel): + @field_validator("file", mode="before") + @classmethod + def convert_file(cls, v): + # hex string to bytes + if isinstance(v, str): + return bytes.fromhex(v) + else: + raise ValueError("file must be a hex string") + + +class RequestInvokeModeration(BaseRequestInvokeModel): """ Request to invoke moderation """ + model_type: ModelType = ModelType.MODERATION + text: str + class RequestInvokeParameterExtractorNode(BaseModel): """