2024-08-30 14:23:14 +08:00
|
|
|
from collections.abc import Mapping
|
2024-08-29 20:17:17 +08:00
|
|
|
from typing import Any, Literal, Optional
|
2024-07-29 22:08:14 +08:00
|
|
|
|
|
|
|
from pydantic import BaseModel, Field, field_validator
|
|
|
|
|
2024-08-30 14:23:14 +08:00
|
|
|
from core.entities.provider_entities import BasicProviderConfig
|
2024-07-29 22:08:14 +08:00
|
|
|
from core.model_runtime.entities.message_entities import (
|
|
|
|
AssistantPromptMessage,
|
|
|
|
PromptMessage,
|
|
|
|
PromptMessageRole,
|
|
|
|
PromptMessageTool,
|
|
|
|
SystemPromptMessage,
|
|
|
|
ToolPromptMessage,
|
|
|
|
UserPromptMessage,
|
|
|
|
)
|
|
|
|
from core.model_runtime.entities.model_entities import ModelType
|
2024-07-29 16:40:04 +08:00
|
|
|
|
|
|
|
|
|
|
|
class RequestInvokeTool(BaseModel):
|
|
|
|
"""
|
|
|
|
Request to invoke a tool
|
|
|
|
"""
|
|
|
|
|
2024-07-29 22:08:14 +08:00
|
|
|
|
|
|
|
class BaseRequestInvokeModel(BaseModel):
|
|
|
|
provider: str
|
|
|
|
model: str
|
|
|
|
model_type: ModelType
|
|
|
|
|
|
|
|
|
|
|
|
class RequestInvokeLLM(BaseRequestInvokeModel):
|
2024-07-29 16:40:04 +08:00
|
|
|
"""
|
|
|
|
Request to invoke LLM
|
|
|
|
"""
|
2024-07-29 22:08:14 +08:00
|
|
|
model_type: ModelType = ModelType.LLM
|
|
|
|
mode: str
|
|
|
|
model_parameters: dict[str, Any] = Field(default_factory=dict)
|
2024-08-30 14:23:14 +08:00
|
|
|
prompt_messages: list[PromptMessage] = Field(default_factory=list)
|
2024-07-29 22:08:14 +08:00
|
|
|
tools: Optional[list[PromptMessageTool]] = Field(default_factory=list)
|
|
|
|
stop: Optional[list[str]] = Field(default_factory=list)
|
|
|
|
stream: Optional[bool] = False
|
|
|
|
|
|
|
|
@field_validator('prompt_messages', mode='before')
|
|
|
|
def convert_prompt_messages(cls, v):
|
|
|
|
if not isinstance(v, list):
|
|
|
|
raise ValueError('prompt_messages must be a list')
|
|
|
|
|
|
|
|
for i in range(len(v)):
|
|
|
|
if v[i]['role'] == PromptMessageRole.USER.value:
|
|
|
|
v[i] = UserPromptMessage(**v[i])
|
|
|
|
elif v[i]['role'] == PromptMessageRole.ASSISTANT.value:
|
|
|
|
v[i] = AssistantPromptMessage(**v[i])
|
|
|
|
elif v[i]['role'] == PromptMessageRole.SYSTEM.value:
|
|
|
|
v[i] = SystemPromptMessage(**v[i])
|
|
|
|
elif v[i]['role'] == PromptMessageRole.TOOL.value:
|
|
|
|
v[i] = ToolPromptMessage(**v[i])
|
|
|
|
else:
|
|
|
|
v[i] = PromptMessage(**v[i])
|
|
|
|
|
|
|
|
return v
|
|
|
|
|
|
|
|
|
2024-07-29 16:40:04 +08:00
|
|
|
class RequestInvokeTextEmbedding(BaseModel):
|
|
|
|
"""
|
|
|
|
Request to invoke text embedding
|
|
|
|
"""
|
|
|
|
|
2024-07-29 22:08:14 +08:00
|
|
|
|
2024-07-29 16:40:04 +08:00
|
|
|
class RequestInvokeRerank(BaseModel):
|
|
|
|
"""
|
|
|
|
Request to invoke rerank
|
|
|
|
"""
|
|
|
|
|
2024-07-29 22:08:14 +08:00
|
|
|
|
2024-07-29 16:40:04 +08:00
|
|
|
class RequestInvokeTTS(BaseModel):
|
|
|
|
"""
|
|
|
|
Request to invoke TTS
|
|
|
|
"""
|
|
|
|
|
2024-07-29 22:08:14 +08:00
|
|
|
|
2024-07-29 16:40:04 +08:00
|
|
|
class RequestInvokeSpeech2Text(BaseModel):
|
|
|
|
"""
|
|
|
|
Request to invoke speech2text
|
|
|
|
"""
|
|
|
|
|
2024-07-29 22:08:14 +08:00
|
|
|
|
2024-07-29 16:40:04 +08:00
|
|
|
class RequestInvokeModeration(BaseModel):
|
|
|
|
"""
|
|
|
|
Request to invoke moderation
|
|
|
|
"""
|
|
|
|
|
2024-07-29 22:08:14 +08:00
|
|
|
|
2024-07-29 16:40:04 +08:00
|
|
|
class RequestInvokeNode(BaseModel):
|
|
|
|
"""
|
|
|
|
Request to invoke node
|
2024-07-29 22:08:14 +08:00
|
|
|
"""
|
2024-08-29 20:17:17 +08:00
|
|
|
|
|
|
|
class RequestInvokeApp(BaseModel):
|
|
|
|
"""
|
|
|
|
Request to invoke app
|
|
|
|
"""
|
|
|
|
app_id: str
|
|
|
|
inputs: dict[str, Any]
|
|
|
|
query: Optional[str] = None
|
|
|
|
response_mode: Literal["blocking", "streaming"]
|
|
|
|
conversation_id: Optional[str] = None
|
|
|
|
user: Optional[str] = None
|
|
|
|
files: list[dict] = Field(default_factory=list)
|
2024-08-30 14:23:14 +08:00
|
|
|
|
|
|
|
class RequestInvokeEncrypt(BaseModel):
|
|
|
|
"""
|
|
|
|
Request to encryption
|
|
|
|
"""
|
|
|
|
opt: Literal["encrypt", "decrypt"]
|
2024-08-30 21:25:58 +08:00
|
|
|
type: Literal["endpoint"]
|
|
|
|
identity: str
|
2024-08-30 14:23:14 +08:00
|
|
|
data: dict = Field(default_factory=dict)
|
2024-08-30 21:25:58 +08:00
|
|
|
config: Mapping[str, BasicProviderConfig] = Field(default_factory=Mapping)
|