dify/api/core/plugin/entities/request.py

125 lines
3.0 KiB
Python
Raw Normal View History

2024-08-30 14:23:14 +08:00
from collections.abc import Mapping
2024-08-29 20:17:17 +08:00
from typing import Any, Literal, Optional
2024-07-29 22:08:14 +08:00
from pydantic import BaseModel, Field, field_validator
2024-08-30 14:23:14 +08:00
from core.entities.provider_entities import BasicProviderConfig
2024-07-29 22:08:14 +08:00
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageRole,
PromptMessageTool,
SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import ModelType
2024-07-29 16:40:04 +08:00
class RequestInvokeTool(BaseModel):
"""
Request to invoke a tool
"""
2024-07-29 22:08:14 +08:00
class BaseRequestInvokeModel(BaseModel):
provider: str
model: str
model_type: ModelType
class RequestInvokeLLM(BaseRequestInvokeModel):
2024-07-29 16:40:04 +08:00
"""
Request to invoke LLM
"""
2024-09-14 02:47:01 +08:00
2024-07-29 22:08:14 +08:00
model_type: ModelType = ModelType.LLM
mode: str
model_parameters: dict[str, Any] = Field(default_factory=dict)
2024-08-30 14:23:14 +08:00
prompt_messages: list[PromptMessage] = Field(default_factory=list)
2024-07-29 22:08:14 +08:00
tools: Optional[list[PromptMessageTool]] = Field(default_factory=list)
stop: Optional[list[str]] = Field(default_factory=list)
stream: Optional[bool] = False
2024-09-14 02:47:01 +08:00
@field_validator("prompt_messages", mode="before")
2024-09-19 18:02:24 +08:00
@classmethod
def convert_prompt_messages(cls, v):
2024-07-29 22:08:14 +08:00
if not isinstance(v, list):
2024-09-14 02:47:01 +08:00
raise ValueError("prompt_messages must be a list")
2024-07-29 22:08:14 +08:00
for i in range(len(v)):
2024-09-14 02:47:01 +08:00
if v[i]["role"] == PromptMessageRole.USER.value:
2024-07-29 22:08:14 +08:00
v[i] = UserPromptMessage(**v[i])
2024-09-14 02:47:01 +08:00
elif v[i]["role"] == PromptMessageRole.ASSISTANT.value:
2024-07-29 22:08:14 +08:00
v[i] = AssistantPromptMessage(**v[i])
2024-09-14 02:47:01 +08:00
elif v[i]["role"] == PromptMessageRole.SYSTEM.value:
2024-07-29 22:08:14 +08:00
v[i] = SystemPromptMessage(**v[i])
2024-09-14 02:47:01 +08:00
elif v[i]["role"] == PromptMessageRole.TOOL.value:
2024-07-29 22:08:14 +08:00
v[i] = ToolPromptMessage(**v[i])
else:
v[i] = PromptMessage(**v[i])
return v
2024-07-29 16:40:04 +08:00
class RequestInvokeTextEmbedding(BaseModel):
"""
Request to invoke text embedding
"""
2024-07-29 22:08:14 +08:00
2024-07-29 16:40:04 +08:00
class RequestInvokeRerank(BaseModel):
"""
Request to invoke rerank
"""
2024-07-29 22:08:14 +08:00
2024-07-29 16:40:04 +08:00
class RequestInvokeTTS(BaseModel):
"""
Request to invoke TTS
"""
2024-07-29 22:08:14 +08:00
2024-07-29 16:40:04 +08:00
class RequestInvokeSpeech2Text(BaseModel):
"""
Request to invoke speech2text
"""
2024-07-29 22:08:14 +08:00
2024-07-29 16:40:04 +08:00
class RequestInvokeModeration(BaseModel):
"""
Request to invoke moderation
"""
2024-07-29 22:08:14 +08:00
2024-07-29 16:40:04 +08:00
class RequestInvokeNode(BaseModel):
"""
Request to invoke node
2024-07-29 22:08:14 +08:00
"""
2024-08-29 20:17:17 +08:00
2024-09-14 02:47:01 +08:00
2024-08-29 20:17:17 +08:00
class RequestInvokeApp(BaseModel):
"""
Request to invoke app
"""
2024-09-14 02:47:01 +08:00
2024-08-29 20:17:17 +08:00
app_id: str
inputs: dict[str, Any]
query: Optional[str] = None
response_mode: Literal["blocking", "streaming"]
conversation_id: Optional[str] = None
user: Optional[str] = None
files: list[dict] = Field(default_factory=list)
2024-08-30 14:23:14 +08:00
2024-09-14 02:47:01 +08:00
2024-08-30 14:23:14 +08:00
class RequestInvokeEncrypt(BaseModel):
"""
Request to encryption
"""
2024-09-14 02:47:01 +08:00
2024-08-30 14:23:14 +08:00
opt: Literal["encrypt", "decrypt"]
2024-08-30 23:29:04 +08:00
namespace: Literal["endpoint"]
2024-08-30 21:25:58 +08:00
identity: str
2024-08-30 14:23:14 +08:00
data: dict = Field(default_factory=dict)
2024-09-14 02:47:01 +08:00
config: Mapping[str, BasicProviderConfig] = Field(default_factory=Mapping)