dify/api/core/plugin/entities/request.py

207 lines
4.9 KiB
Python
Raw Normal View History

2024-08-29 20:17:17 +08:00
from typing import Any, Literal, Optional
2024-07-29 22:08:14 +08:00
from pydantic import BaseModel, ConfigDict, Field, field_validator
2024-07-29 22:08:14 +08:00
2024-08-30 14:23:14 +08:00
from core.entities.provider_entities import BasicProviderConfig
2024-07-29 22:08:14 +08:00
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageRole,
PromptMessageTool,
SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import ModelType
2024-09-24 18:03:48 +08:00
from core.workflow.nodes.parameter_extractor.entities import (
ModelConfig as ParameterExtractorModelConfig,
)
from core.workflow.nodes.parameter_extractor.entities import (
ParameterConfig,
)
2024-09-24 20:15:13 +08:00
from core.workflow.nodes.question_classifier.entities import (
ClassConfig,
)
from core.workflow.nodes.question_classifier.entities import (
ModelConfig as QuestionClassifierModelConfig,
)
2024-07-29 16:40:04 +08:00
class RequestInvokeTool(BaseModel):
"""
Request to invoke a tool
"""
2024-10-10 18:09:06 +08:00
tool_type: Literal["builtin", "workflow", "api"]
provider: str
tool: str
tool_parameters: dict
2024-07-29 22:08:14 +08:00
class BaseRequestInvokeModel(BaseModel):
provider: str
model: str
model_type: ModelType
model_config = ConfigDict(protected_namespaces=())
2024-07-29 22:08:14 +08:00
class RequestInvokeLLM(BaseRequestInvokeModel):
2024-07-29 16:40:04 +08:00
"""
Request to invoke LLM
"""
2024-09-14 02:47:01 +08:00
2024-07-29 22:08:14 +08:00
model_type: ModelType = ModelType.LLM
mode: str
completion_params: dict[str, Any] = Field(default_factory=dict)
2024-08-30 14:23:14 +08:00
prompt_messages: list[PromptMessage] = Field(default_factory=list)
2024-07-29 22:08:14 +08:00
tools: Optional[list[PromptMessageTool]] = Field(default_factory=list)
stop: Optional[list[str]] = Field(default_factory=list)
stream: Optional[bool] = False
model_config = ConfigDict(protected_namespaces=())
2024-09-14 02:47:01 +08:00
@field_validator("prompt_messages", mode="before")
2024-09-19 18:02:24 +08:00
@classmethod
def convert_prompt_messages(cls, v):
2024-07-29 22:08:14 +08:00
if not isinstance(v, list):
2024-09-14 02:47:01 +08:00
raise ValueError("prompt_messages must be a list")
2024-07-29 22:08:14 +08:00
for i in range(len(v)):
2024-09-14 02:47:01 +08:00
if v[i]["role"] == PromptMessageRole.USER.value:
2024-07-29 22:08:14 +08:00
v[i] = UserPromptMessage(**v[i])
2024-09-14 02:47:01 +08:00
elif v[i]["role"] == PromptMessageRole.ASSISTANT.value:
2024-07-29 22:08:14 +08:00
v[i] = AssistantPromptMessage(**v[i])
2024-09-14 02:47:01 +08:00
elif v[i]["role"] == PromptMessageRole.SYSTEM.value:
2024-07-29 22:08:14 +08:00
v[i] = SystemPromptMessage(**v[i])
2024-09-14 02:47:01 +08:00
elif v[i]["role"] == PromptMessageRole.TOOL.value:
2024-07-29 22:08:14 +08:00
v[i] = ToolPromptMessage(**v[i])
else:
v[i] = PromptMessage(**v[i])
return v
2024-09-26 15:38:22 +08:00
class RequestInvokeTextEmbedding(BaseRequestInvokeModel):
2024-07-29 16:40:04 +08:00
"""
Request to invoke text embedding
"""
2024-09-26 15:38:22 +08:00
model_type: ModelType = ModelType.TEXT_EMBEDDING
texts: list[str]
2024-07-29 22:08:14 +08:00
2024-09-26 15:38:22 +08:00
class RequestInvokeRerank(BaseRequestInvokeModel):
2024-07-29 16:40:04 +08:00
"""
Request to invoke rerank
"""
2024-09-26 15:38:22 +08:00
model_type: ModelType = ModelType.RERANK
query: str
docs: list[str]
score_threshold: float
top_n: int
2024-07-29 22:08:14 +08:00
2024-09-26 15:38:22 +08:00
class RequestInvokeTTS(BaseRequestInvokeModel):
2024-07-29 16:40:04 +08:00
"""
Request to invoke TTS
"""
2024-09-26 15:38:22 +08:00
model_type: ModelType = ModelType.TTS
content_text: str
voice: str
2024-07-29 22:08:14 +08:00
2024-09-26 15:38:22 +08:00
class RequestInvokeSpeech2Text(BaseRequestInvokeModel):
2024-07-29 16:40:04 +08:00
"""
Request to invoke speech2text
"""
2024-09-26 15:38:22 +08:00
model_type: ModelType = ModelType.SPEECH2TEXT
file: bytes
2024-07-29 22:08:14 +08:00
2024-09-26 15:38:22 +08:00
@field_validator("file", mode="before")
@classmethod
def convert_file(cls, v):
# hex string to bytes
if isinstance(v, str):
return bytes.fromhex(v)
else:
raise ValueError("file must be a hex string")
class RequestInvokeModeration(BaseRequestInvokeModel):
2024-07-29 16:40:04 +08:00
"""
Request to invoke moderation
"""
2024-09-26 15:38:22 +08:00
model_type: ModelType = ModelType.MODERATION
text: str
2024-07-29 22:08:14 +08:00
2024-09-24 18:03:48 +08:00
class RequestInvokeParameterExtractorNode(BaseModel):
2024-07-29 16:40:04 +08:00
"""
2024-09-24 18:03:48 +08:00
Request to invoke parameter extractor node
2024-07-29 22:08:14 +08:00
"""
2024-08-29 20:17:17 +08:00
2024-09-24 18:03:48 +08:00
parameters: list[ParameterConfig]
model: ParameterExtractorModelConfig
instruction: str
query: str
class RequestInvokeQuestionClassifierNode(BaseModel):
"""
Request to invoke question classifier node
"""
query: str
model: QuestionClassifierModelConfig
classes: list[ClassConfig]
instruction: str
2024-09-14 02:47:01 +08:00
2024-08-29 20:17:17 +08:00
class RequestInvokeApp(BaseModel):
"""
Request to invoke app
"""
2024-09-14 02:47:01 +08:00
2024-08-29 20:17:17 +08:00
app_id: str
inputs: dict[str, Any]
query: Optional[str] = None
response_mode: Literal["blocking", "streaming"]
conversation_id: Optional[str] = None
user: Optional[str] = None
files: list[dict] = Field(default_factory=list)
2024-08-30 14:23:14 +08:00
2024-09-14 02:47:01 +08:00
2024-08-30 14:23:14 +08:00
class RequestInvokeEncrypt(BaseModel):
"""
Request to encryption
"""
2024-09-14 02:47:01 +08:00
opt: Literal["encrypt", "decrypt", "clear"]
2024-08-30 23:29:04 +08:00
namespace: Literal["endpoint"]
2024-08-30 21:25:58 +08:00
identity: str
2024-08-30 14:23:14 +08:00
data: dict = Field(default_factory=dict)
2024-09-30 17:39:13 +08:00
config: list[BasicProviderConfig] = Field(default_factory=list)
2024-10-17 19:44:30 +08:00
class RequestInvokeSummary(BaseModel):
"""
Request to summary
"""
text: str
instruction: str
class RequestRequestUploadFile(BaseModel):
"""
Request to upload file
"""
filename: str
mimetype: str