dify/api/core/plugin/backwards_invocation/model.py

178 lines
5.3 KiB
Python
Raw Normal View History

2024-09-26 15:38:22 +08:00
import tempfile
from binascii import hexlify, unhexlify
2024-07-29 22:08:14 +08:00
from collections.abc import Generator
from core.model_manager import ModelManager
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
2024-08-29 20:17:17 +08:00
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
2024-09-26 15:38:22 +08:00
from core.plugin.entities.request import (
RequestInvokeLLM,
RequestInvokeModeration,
RequestInvokeRerank,
RequestInvokeSpeech2Text,
RequestInvokeTextEmbedding,
RequestInvokeTTS,
)
2024-07-29 22:08:14 +08:00
from core.workflow.nodes.llm.llm_node import LLMNode
from models.account import Tenant
2024-08-29 20:17:17 +08:00
class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
2024-07-29 22:08:14 +08:00
@classmethod
def invoke_llm(
cls, user_id: str, tenant: Tenant, payload: RequestInvokeLLM
) -> Generator[LLMResultChunk, None, None] | LLMResult:
"""
invoke llm
"""
model_instance = ModelManager().get_model_instance(
tenant_id=tenant.id,
provider=payload.provider,
model_type=payload.model_type,
model=payload.model,
)
# invoke model
response = model_instance.invoke_llm(
prompt_messages=payload.prompt_messages,
model_parameters=payload.model_parameters,
tools=payload.tools,
stop=payload.stop,
stream=payload.stream or True,
user=user_id,
)
if isinstance(response, Generator):
def handle() -> Generator[LLMResultChunk, None, None]:
for chunk in response:
if chunk.delta.usage:
LLMNode.deduct_llm_quota(
tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
)
yield chunk
return handle()
else:
if response.usage:
LLMNode.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
return response
2024-09-26 15:38:22 +08:00
@classmethod
def invoke_text_embedding(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTextEmbedding):
"""
invoke text embedding
"""
model_instance = ModelManager().get_model_instance(
tenant_id=tenant.id,
provider=payload.provider,
model_type=payload.model_type,
model=payload.model,
)
# invoke model
response = model_instance.invoke_text_embedding(
texts=payload.texts,
user=user_id,
)
return response
@classmethod
def invoke_rerank(cls, user_id: str, tenant: Tenant, payload: RequestInvokeRerank):
"""
invoke rerank
"""
model_instance = ModelManager().get_model_instance(
tenant_id=tenant.id,
provider=payload.provider,
model_type=payload.model_type,
model=payload.model,
)
# invoke model
response = model_instance.invoke_rerank(
query=payload.query,
docs=payload.docs,
score_threshold=payload.score_threshold,
top_n=payload.top_n,
user=user_id,
)
return response
@classmethod
def invoke_tts(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTTS):
"""
invoke tts
"""
model_instance = ModelManager().get_model_instance(
tenant_id=tenant.id,
provider=payload.provider,
model_type=payload.model_type,
model=payload.model,
)
# invoke model
response = model_instance.invoke_tts(
content_text=payload.content_text,
tenant_id=tenant.id,
voice=payload.voice,
user=user_id,
)
def handle() -> Generator[dict, None, None]:
for chunk in response:
yield {"result": hexlify(chunk).decode("utf-8")}
return handle()
@classmethod
def invoke_speech2text(cls, user_id: str, tenant: Tenant, payload: RequestInvokeSpeech2Text):
"""
invoke speech2text
"""
model_instance = ModelManager().get_model_instance(
tenant_id=tenant.id,
provider=payload.provider,
model_type=payload.model_type,
model=payload.model,
)
# invoke model
with tempfile.NamedTemporaryFile(suffix=".mp3", mode="wb", delete=True) as temp:
temp.write(unhexlify(payload.file))
temp.flush()
temp.seek(0)
response = model_instance.invoke_speech2text(
file=temp,
user=user_id,
)
return {
"result": response,
}
2024-09-26 15:38:22 +08:00
@classmethod
def invoke_moderation(cls, user_id: str, tenant: Tenant, payload: RequestInvokeModeration):
"""
invoke moderation
"""
model_instance = ModelManager().get_model_instance(
tenant_id=tenant.id,
provider=payload.provider,
model_type=payload.model_type,
model=payload.model,
)
# invoke model
response = model_instance.invoke_moderation(
text=payload.text,
user=user_id,
)
return {
"result": response,
}