feat(llm_node): support order in text and files (#11837)
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
parent
3599751f93
commit
996a9135f6
@ -1,15 +1,14 @@
|
|||||||
import base64
|
import base64
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.file import file_repository
|
|
||||||
from core.helper import ssrf_proxy
|
from core.helper import ssrf_proxy
|
||||||
from core.model_runtime.entities import (
|
from core.model_runtime.entities import (
|
||||||
AudioPromptMessageContent,
|
AudioPromptMessageContent,
|
||||||
DocumentPromptMessageContent,
|
DocumentPromptMessageContent,
|
||||||
ImagePromptMessageContent,
|
ImagePromptMessageContent,
|
||||||
|
MultiModalPromptMessageContent,
|
||||||
VideoPromptMessageContent,
|
VideoPromptMessageContent,
|
||||||
)
|
)
|
||||||
from extensions.ext_database import db
|
|
||||||
from extensions.ext_storage import storage
|
from extensions.ext_storage import storage
|
||||||
|
|
||||||
from . import helpers
|
from . import helpers
|
||||||
@ -41,7 +40,7 @@ def to_prompt_message_content(
|
|||||||
/,
|
/,
|
||||||
*,
|
*,
|
||||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||||
):
|
) -> MultiModalPromptMessageContent:
|
||||||
if f.extension is None:
|
if f.extension is None:
|
||||||
raise ValueError("Missing file extension")
|
raise ValueError("Missing file extension")
|
||||||
if f.mime_type is None:
|
if f.mime_type is None:
|
||||||
@ -70,16 +69,13 @@ def to_prompt_message_content(
|
|||||||
|
|
||||||
|
|
||||||
def download(f: File, /):
|
def download(f: File, /):
|
||||||
if f.transfer_method == FileTransferMethod.TOOL_FILE:
|
if f.transfer_method in (FileTransferMethod.TOOL_FILE, FileTransferMethod.LOCAL_FILE):
|
||||||
tool_file = file_repository.get_tool_file(session=db.session(), file=f)
|
return _download_file_content(f._storage_key)
|
||||||
return _download_file_content(tool_file.file_key)
|
elif f.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||||
elif f.transfer_method == FileTransferMethod.LOCAL_FILE:
|
|
||||||
upload_file = file_repository.get_upload_file(session=db.session(), file=f)
|
|
||||||
return _download_file_content(upload_file.key)
|
|
||||||
# remote file
|
|
||||||
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
|
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return response.content
|
return response.content
|
||||||
|
raise ValueError(f"unsupported transfer method: {f.transfer_method}")
|
||||||
|
|
||||||
|
|
||||||
def _download_file_content(path: str, /):
|
def _download_file_content(path: str, /):
|
||||||
@ -110,11 +106,9 @@ def _get_encoded_string(f: File, /):
|
|||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
data = response.content
|
data = response.content
|
||||||
case FileTransferMethod.LOCAL_FILE:
|
case FileTransferMethod.LOCAL_FILE:
|
||||||
upload_file = file_repository.get_upload_file(session=db.session(), file=f)
|
data = _download_file_content(f._storage_key)
|
||||||
data = _download_file_content(upload_file.key)
|
|
||||||
case FileTransferMethod.TOOL_FILE:
|
case FileTransferMethod.TOOL_FILE:
|
||||||
tool_file = file_repository.get_tool_file(session=db.session(), file=f)
|
data = _download_file_content(f._storage_key)
|
||||||
data = _download_file_content(tool_file.file_key)
|
|
||||||
|
|
||||||
encoded_string = base64.b64encode(data).decode("utf-8")
|
encoded_string = base64.b64encode(data).decode("utf-8")
|
||||||
return encoded_string
|
return encoded_string
|
||||||
|
@ -1,32 +0,0 @@
|
|||||||
from sqlalchemy import select
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from models import ToolFile, UploadFile
|
|
||||||
|
|
||||||
from .models import File
|
|
||||||
|
|
||||||
|
|
||||||
def get_upload_file(*, session: Session, file: File):
|
|
||||||
if file.related_id is None:
|
|
||||||
raise ValueError("Missing file related_id")
|
|
||||||
stmt = select(UploadFile).filter(
|
|
||||||
UploadFile.id == file.related_id,
|
|
||||||
UploadFile.tenant_id == file.tenant_id,
|
|
||||||
)
|
|
||||||
record = session.scalar(stmt)
|
|
||||||
if not record:
|
|
||||||
raise ValueError(f"upload file {file.related_id} not found")
|
|
||||||
return record
|
|
||||||
|
|
||||||
|
|
||||||
def get_tool_file(*, session: Session, file: File):
|
|
||||||
if file.related_id is None:
|
|
||||||
raise ValueError("Missing file related_id")
|
|
||||||
stmt = select(ToolFile).filter(
|
|
||||||
ToolFile.id == file.related_id,
|
|
||||||
ToolFile.tenant_id == file.tenant_id,
|
|
||||||
)
|
|
||||||
record = session.scalar(stmt)
|
|
||||||
if not record:
|
|
||||||
raise ValueError(f"tool file {file.related_id} not found")
|
|
||||||
return record
|
|
@ -47,6 +47,38 @@ class File(BaseModel):
|
|||||||
mime_type: Optional[str] = None
|
mime_type: Optional[str] = None
|
||||||
size: int = -1
|
size: int = -1
|
||||||
|
|
||||||
|
# Those properties are private, should not be exposed to the outside.
|
||||||
|
_storage_key: str
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
id: Optional[str] = None,
|
||||||
|
tenant_id: str,
|
||||||
|
type: FileType,
|
||||||
|
transfer_method: FileTransferMethod,
|
||||||
|
remote_url: Optional[str] = None,
|
||||||
|
related_id: Optional[str] = None,
|
||||||
|
filename: Optional[str] = None,
|
||||||
|
extension: Optional[str] = None,
|
||||||
|
mime_type: Optional[str] = None,
|
||||||
|
size: int = -1,
|
||||||
|
storage_key: str,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
id=id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
type=type,
|
||||||
|
transfer_method=transfer_method,
|
||||||
|
remote_url=remote_url,
|
||||||
|
related_id=related_id,
|
||||||
|
filename=filename,
|
||||||
|
extension=extension,
|
||||||
|
mime_type=mime_type,
|
||||||
|
size=size,
|
||||||
|
)
|
||||||
|
self._storage_key = storage_key
|
||||||
|
|
||||||
def to_dict(self) -> Mapping[str, str | int | None]:
|
def to_dict(self) -> Mapping[str, str | int | None]:
|
||||||
data = self.model_dump(mode="json")
|
data = self.model_dump(mode="json")
|
||||||
return {
|
return {
|
||||||
|
@ -4,6 +4,7 @@ from .message_entities import (
|
|||||||
AudioPromptMessageContent,
|
AudioPromptMessageContent,
|
||||||
DocumentPromptMessageContent,
|
DocumentPromptMessageContent,
|
||||||
ImagePromptMessageContent,
|
ImagePromptMessageContent,
|
||||||
|
MultiModalPromptMessageContent,
|
||||||
PromptMessage,
|
PromptMessage,
|
||||||
PromptMessageContent,
|
PromptMessageContent,
|
||||||
PromptMessageContentType,
|
PromptMessageContentType,
|
||||||
@ -27,6 +28,7 @@ __all__ = [
|
|||||||
"LLMResultChunkDelta",
|
"LLMResultChunkDelta",
|
||||||
"LLMUsage",
|
"LLMUsage",
|
||||||
"ModelPropertyKey",
|
"ModelPropertyKey",
|
||||||
|
"MultiModalPromptMessageContent",
|
||||||
"PromptMessage",
|
"PromptMessage",
|
||||||
"PromptMessage",
|
"PromptMessage",
|
||||||
"PromptMessageContent",
|
"PromptMessageContent",
|
||||||
|
@ -84,10 +84,10 @@ class MultiModalPromptMessageContent(PromptMessageContent):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
type: PromptMessageContentType
|
type: PromptMessageContentType
|
||||||
format: str = Field(..., description="the format of multi-modal file")
|
format: str = Field(default=..., description="the format of multi-modal file")
|
||||||
base64_data: str = Field("", description="the base64 data of multi-modal file")
|
base64_data: str = Field(default="", description="the base64 data of multi-modal file")
|
||||||
url: str = Field("", description="the url of multi-modal file")
|
url: str = Field(default="", description="the url of multi-modal file")
|
||||||
mime_type: str = Field(..., description="the mime type of multi-modal file")
|
mime_type: str = Field(default=..., description="the mime type of multi-modal file")
|
||||||
|
|
||||||
@computed_field(return_type=str)
|
@computed_field(return_type=str)
|
||||||
@property
|
@property
|
||||||
|
@ -50,6 +50,7 @@ class PromptConfig(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class LLMNodeChatModelMessage(ChatModelMessage):
|
class LLMNodeChatModelMessage(ChatModelMessage):
|
||||||
|
text: str = ""
|
||||||
jinja2_text: Optional[str] = None
|
jinja2_text: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@ -145,8 +145,8 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
query = query_variable.text
|
query = query_variable.text
|
||||||
|
|
||||||
prompt_messages, stop = self._fetch_prompt_messages(
|
prompt_messages, stop = self._fetch_prompt_messages(
|
||||||
user_query=query,
|
sys_query=query,
|
||||||
user_files=files,
|
sys_files=files,
|
||||||
context=context,
|
context=context,
|
||||||
memory=memory,
|
memory=memory,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
@ -545,8 +545,8 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
def _fetch_prompt_messages(
|
def _fetch_prompt_messages(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
user_query: str | None = None,
|
sys_query: str | None = None,
|
||||||
user_files: Sequence["File"],
|
sys_files: Sequence["File"],
|
||||||
context: str | None = None,
|
context: str | None = None,
|
||||||
memory: TokenBufferMemory | None = None,
|
memory: TokenBufferMemory | None = None,
|
||||||
model_config: ModelConfigWithCredentialsEntity,
|
model_config: ModelConfigWithCredentialsEntity,
|
||||||
@ -562,7 +562,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
if isinstance(prompt_template, list):
|
if isinstance(prompt_template, list):
|
||||||
# For chat model
|
# For chat model
|
||||||
prompt_messages.extend(
|
prompt_messages.extend(
|
||||||
_handle_list_messages(
|
self._handle_list_messages(
|
||||||
messages=prompt_template,
|
messages=prompt_template,
|
||||||
context=context,
|
context=context,
|
||||||
jinja2_variables=jinja2_variables,
|
jinja2_variables=jinja2_variables,
|
||||||
@ -581,14 +581,14 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
prompt_messages.extend(memory_messages)
|
prompt_messages.extend(memory_messages)
|
||||||
|
|
||||||
# Add current query to the prompt messages
|
# Add current query to the prompt messages
|
||||||
if user_query:
|
if sys_query:
|
||||||
message = LLMNodeChatModelMessage(
|
message = LLMNodeChatModelMessage(
|
||||||
text=user_query,
|
text=sys_query,
|
||||||
role=PromptMessageRole.USER,
|
role=PromptMessageRole.USER,
|
||||||
edition_type="basic",
|
edition_type="basic",
|
||||||
)
|
)
|
||||||
prompt_messages.extend(
|
prompt_messages.extend(
|
||||||
_handle_list_messages(
|
self._handle_list_messages(
|
||||||
messages=[message],
|
messages=[message],
|
||||||
context="",
|
context="",
|
||||||
jinja2_variables=[],
|
jinja2_variables=[],
|
||||||
@ -635,24 +635,27 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
raise ValueError("Invalid prompt content type")
|
raise ValueError("Invalid prompt content type")
|
||||||
|
|
||||||
# Add current query to the prompt message
|
# Add current query to the prompt message
|
||||||
if user_query:
|
if sys_query:
|
||||||
if prompt_content_type == str:
|
if prompt_content_type == str:
|
||||||
prompt_content = prompt_messages[0].content.replace("#sys.query#", user_query)
|
prompt_content = prompt_messages[0].content.replace("#sys.query#", sys_query)
|
||||||
prompt_messages[0].content = prompt_content
|
prompt_messages[0].content = prompt_content
|
||||||
elif prompt_content_type == list:
|
elif prompt_content_type == list:
|
||||||
for content_item in prompt_content:
|
for content_item in prompt_content:
|
||||||
if content_item.type == PromptMessageContentType.TEXT:
|
if content_item.type == PromptMessageContentType.TEXT:
|
||||||
content_item.data = user_query + "\n" + content_item.data
|
content_item.data = sys_query + "\n" + content_item.data
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid prompt content type")
|
raise ValueError("Invalid prompt content type")
|
||||||
else:
|
else:
|
||||||
raise TemplateTypeNotSupportError(type_name=str(type(prompt_template)))
|
raise TemplateTypeNotSupportError(type_name=str(type(prompt_template)))
|
||||||
|
|
||||||
if vision_enabled and user_files:
|
# The sys_files will be deprecated later
|
||||||
|
if vision_enabled and sys_files:
|
||||||
file_prompts = []
|
file_prompts = []
|
||||||
for file in user_files:
|
for file in sys_files:
|
||||||
file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
|
file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
|
||||||
file_prompts.append(file_prompt)
|
file_prompts.append(file_prompt)
|
||||||
|
# If last prompt is a user prompt, add files into its contents,
|
||||||
|
# otherwise append a new user prompt
|
||||||
if (
|
if (
|
||||||
len(prompt_messages) > 0
|
len(prompt_messages) > 0
|
||||||
and isinstance(prompt_messages[-1], UserPromptMessage)
|
and isinstance(prompt_messages[-1], UserPromptMessage)
|
||||||
@ -662,7 +665,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
else:
|
else:
|
||||||
prompt_messages.append(UserPromptMessage(content=file_prompts))
|
prompt_messages.append(UserPromptMessage(content=file_prompts))
|
||||||
|
|
||||||
# Filter prompt messages
|
# Remove empty messages and filter unsupported content
|
||||||
filtered_prompt_messages = []
|
filtered_prompt_messages = []
|
||||||
for prompt_message in prompt_messages:
|
for prompt_message in prompt_messages:
|
||||||
if isinstance(prompt_message.content, list):
|
if isinstance(prompt_message.content, list):
|
||||||
@ -846,6 +849,58 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def _handle_list_messages(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
messages: Sequence[LLMNodeChatModelMessage],
|
||||||
|
context: Optional[str],
|
||||||
|
jinja2_variables: Sequence[VariableSelector],
|
||||||
|
variable_pool: VariablePool,
|
||||||
|
vision_detail_config: ImagePromptMessageContent.DETAIL,
|
||||||
|
) -> Sequence[PromptMessage]:
|
||||||
|
prompt_messages: list[PromptMessage] = []
|
||||||
|
for message in messages:
|
||||||
|
contents: list[PromptMessageContent] = []
|
||||||
|
if message.edition_type == "jinja2":
|
||||||
|
result_text = _render_jinja2_message(
|
||||||
|
template=message.jinja2_text or "",
|
||||||
|
jinjia2_variables=jinja2_variables,
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
)
|
||||||
|
contents.append(TextPromptMessageContent(data=result_text))
|
||||||
|
else:
|
||||||
|
# Get segment group from basic message
|
||||||
|
if context:
|
||||||
|
template = message.text.replace("{#context#}", context)
|
||||||
|
else:
|
||||||
|
template = message.text
|
||||||
|
segment_group = variable_pool.convert_template(template)
|
||||||
|
|
||||||
|
# Process segments for images
|
||||||
|
for segment in segment_group.value:
|
||||||
|
if isinstance(segment, ArrayFileSegment):
|
||||||
|
for file in segment.value:
|
||||||
|
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
|
||||||
|
file_content = file_manager.to_prompt_message_content(
|
||||||
|
file, image_detail_config=vision_detail_config
|
||||||
|
)
|
||||||
|
contents.append(file_content)
|
||||||
|
elif isinstance(segment, FileSegment):
|
||||||
|
file = segment.value
|
||||||
|
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
|
||||||
|
file_content = file_manager.to_prompt_message_content(
|
||||||
|
file, image_detail_config=vision_detail_config
|
||||||
|
)
|
||||||
|
contents.append(file_content)
|
||||||
|
else:
|
||||||
|
plain_text = segment.markdown.strip()
|
||||||
|
if plain_text:
|
||||||
|
contents.append(TextPromptMessageContent(data=plain_text))
|
||||||
|
prompt_message = _combine_message_content_with_role(contents=contents, role=message.role)
|
||||||
|
prompt_messages.append(prompt_message)
|
||||||
|
|
||||||
|
return prompt_messages
|
||||||
|
|
||||||
|
|
||||||
def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole):
|
def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole):
|
||||||
match role:
|
match role:
|
||||||
@ -880,68 +935,6 @@ def _render_jinja2_message(
|
|||||||
return result_text
|
return result_text
|
||||||
|
|
||||||
|
|
||||||
def _handle_list_messages(
|
|
||||||
*,
|
|
||||||
messages: Sequence[LLMNodeChatModelMessage],
|
|
||||||
context: Optional[str],
|
|
||||||
jinja2_variables: Sequence[VariableSelector],
|
|
||||||
variable_pool: VariablePool,
|
|
||||||
vision_detail_config: ImagePromptMessageContent.DETAIL,
|
|
||||||
) -> Sequence[PromptMessage]:
|
|
||||||
prompt_messages = []
|
|
||||||
for message in messages:
|
|
||||||
if message.edition_type == "jinja2":
|
|
||||||
result_text = _render_jinja2_message(
|
|
||||||
template=message.jinja2_text or "",
|
|
||||||
jinjia2_variables=jinja2_variables,
|
|
||||||
variable_pool=variable_pool,
|
|
||||||
)
|
|
||||||
prompt_message = _combine_message_content_with_role(
|
|
||||||
contents=[TextPromptMessageContent(data=result_text)], role=message.role
|
|
||||||
)
|
|
||||||
prompt_messages.append(prompt_message)
|
|
||||||
else:
|
|
||||||
# Get segment group from basic message
|
|
||||||
if context:
|
|
||||||
template = message.text.replace("{#context#}", context)
|
|
||||||
else:
|
|
||||||
template = message.text
|
|
||||||
segment_group = variable_pool.convert_template(template)
|
|
||||||
|
|
||||||
# Process segments for images
|
|
||||||
file_contents = []
|
|
||||||
for segment in segment_group.value:
|
|
||||||
if isinstance(segment, ArrayFileSegment):
|
|
||||||
for file in segment.value:
|
|
||||||
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
|
|
||||||
file_content = file_manager.to_prompt_message_content(
|
|
||||||
file, image_detail_config=vision_detail_config
|
|
||||||
)
|
|
||||||
file_contents.append(file_content)
|
|
||||||
if isinstance(segment, FileSegment):
|
|
||||||
file = segment.value
|
|
||||||
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
|
|
||||||
file_content = file_manager.to_prompt_message_content(
|
|
||||||
file, image_detail_config=vision_detail_config
|
|
||||||
)
|
|
||||||
file_contents.append(file_content)
|
|
||||||
|
|
||||||
# Create message with text from all segments
|
|
||||||
plain_text = segment_group.text
|
|
||||||
if plain_text:
|
|
||||||
prompt_message = _combine_message_content_with_role(
|
|
||||||
contents=[TextPromptMessageContent(data=plain_text)], role=message.role
|
|
||||||
)
|
|
||||||
prompt_messages.append(prompt_message)
|
|
||||||
|
|
||||||
if file_contents:
|
|
||||||
# Create message with image contents
|
|
||||||
prompt_message = _combine_message_content_with_role(contents=file_contents, role=message.role)
|
|
||||||
prompt_messages.append(prompt_message)
|
|
||||||
|
|
||||||
return prompt_messages
|
|
||||||
|
|
||||||
|
|
||||||
def _calculate_rest_token(
|
def _calculate_rest_token(
|
||||||
*, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
|
*, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
|
||||||
) -> int:
|
) -> int:
|
||||||
|
@ -86,10 +86,10 @@ class QuestionClassifierNode(LLMNode):
|
|||||||
)
|
)
|
||||||
prompt_messages, stop = self._fetch_prompt_messages(
|
prompt_messages, stop = self._fetch_prompt_messages(
|
||||||
prompt_template=prompt_template,
|
prompt_template=prompt_template,
|
||||||
user_query=query,
|
sys_query=query,
|
||||||
memory=memory,
|
memory=memory,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
user_files=files,
|
sys_files=files,
|
||||||
vision_enabled=node_data.vision.enabled,
|
vision_enabled=node_data.vision.enabled,
|
||||||
vision_detail=node_data.vision.configs.detail,
|
vision_detail=node_data.vision.configs.detail,
|
||||||
variable_pool=variable_pool,
|
variable_pool=variable_pool,
|
||||||
|
@ -139,6 +139,7 @@ def _build_from_local_file(
|
|||||||
remote_url=row.source_url,
|
remote_url=row.source_url,
|
||||||
related_id=mapping.get("upload_file_id"),
|
related_id=mapping.get("upload_file_id"),
|
||||||
size=row.size,
|
size=row.size,
|
||||||
|
storage_key=row.key,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -168,6 +169,7 @@ def _build_from_remote_url(
|
|||||||
mime_type=mime_type,
|
mime_type=mime_type,
|
||||||
extension=extension,
|
extension=extension,
|
||||||
size=file_size,
|
size=file_size,
|
||||||
|
storage_key="",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -220,6 +222,7 @@ def _build_from_tool_file(
|
|||||||
extension=extension,
|
extension=extension,
|
||||||
mime_type=tool_file.mimetype,
|
mime_type=tool_file.mimetype,
|
||||||
size=tool_file.size,
|
size=tool_file.size,
|
||||||
|
storage_key=tool_file.file_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -560,13 +560,29 @@ class Conversation(db.Model):
|
|||||||
@property
|
@property
|
||||||
def inputs(self):
|
def inputs(self):
|
||||||
inputs = self._inputs.copy()
|
inputs = self._inputs.copy()
|
||||||
|
|
||||||
|
# Convert file mapping to File object
|
||||||
for key, value in inputs.items():
|
for key, value in inputs.items():
|
||||||
|
# NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now.
|
||||||
|
from factories import file_factory
|
||||||
|
|
||||||
if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
|
if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
|
||||||
inputs[key] = File.model_validate(value)
|
if value["transfer_method"] == FileTransferMethod.TOOL_FILE:
|
||||||
|
value["tool_file_id"] = value["related_id"]
|
||||||
|
elif value["transfer_method"] == FileTransferMethod.LOCAL_FILE:
|
||||||
|
value["upload_file_id"] = value["related_id"]
|
||||||
|
inputs[key] = file_factory.build_from_mapping(mapping=value, tenant_id=value["tenant_id"])
|
||||||
elif isinstance(value, list) and all(
|
elif isinstance(value, list) and all(
|
||||||
isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY for item in value
|
isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY for item in value
|
||||||
):
|
):
|
||||||
inputs[key] = [File.model_validate(item) for item in value]
|
inputs[key] = []
|
||||||
|
for item in value:
|
||||||
|
if item["transfer_method"] == FileTransferMethod.TOOL_FILE:
|
||||||
|
item["tool_file_id"] = item["related_id"]
|
||||||
|
elif item["transfer_method"] == FileTransferMethod.LOCAL_FILE:
|
||||||
|
item["upload_file_id"] = item["related_id"]
|
||||||
|
inputs[key].append(file_factory.build_from_mapping(mapping=item, tenant_id=item["tenant_id"]))
|
||||||
|
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
@inputs.setter
|
@inputs.setter
|
||||||
@ -758,12 +774,25 @@ class Message(db.Model):
|
|||||||
def inputs(self):
|
def inputs(self):
|
||||||
inputs = self._inputs.copy()
|
inputs = self._inputs.copy()
|
||||||
for key, value in inputs.items():
|
for key, value in inputs.items():
|
||||||
|
# NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now.
|
||||||
|
from factories import file_factory
|
||||||
|
|
||||||
if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
|
if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
|
||||||
inputs[key] = File.model_validate(value)
|
if value["transfer_method"] == FileTransferMethod.TOOL_FILE:
|
||||||
|
value["tool_file_id"] = value["related_id"]
|
||||||
|
elif value["transfer_method"] == FileTransferMethod.LOCAL_FILE:
|
||||||
|
value["upload_file_id"] = value["related_id"]
|
||||||
|
inputs[key] = file_factory.build_from_mapping(mapping=value, tenant_id=value["tenant_id"])
|
||||||
elif isinstance(value, list) and all(
|
elif isinstance(value, list) and all(
|
||||||
isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY for item in value
|
isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY for item in value
|
||||||
):
|
):
|
||||||
inputs[key] = [File.model_validate(item) for item in value]
|
inputs[key] = []
|
||||||
|
for item in value:
|
||||||
|
if item["transfer_method"] == FileTransferMethod.TOOL_FILE:
|
||||||
|
item["tool_file_id"] = item["related_id"]
|
||||||
|
elif item["transfer_method"] == FileTransferMethod.LOCAL_FILE:
|
||||||
|
item["upload_file_id"] = item["related_id"]
|
||||||
|
inputs[key].append(file_factory.build_from_mapping(mapping=item, tenant_id=item["tenant_id"]))
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
@inputs.setter
|
@inputs.setter
|
||||||
|
@ -136,6 +136,7 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
|
|||||||
type=FileType.IMAGE,
|
type=FileType.IMAGE,
|
||||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||||
remote_url="https://example.com/image1.jpg",
|
remote_url="https://example.com/image1.jpg",
|
||||||
|
storage_key="",
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -1,34 +1,9 @@
|
|||||||
import json
|
import json
|
||||||
|
|
||||||
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType, FileUploadConfig
|
from core.file import File, FileTransferMethod, FileType, FileUploadConfig
|
||||||
from models.workflow import Workflow
|
from models.workflow import Workflow
|
||||||
|
|
||||||
|
|
||||||
def test_file_loads_and_dumps():
|
|
||||||
file = File(
|
|
||||||
id="file1",
|
|
||||||
tenant_id="tenant1",
|
|
||||||
type=FileType.IMAGE,
|
|
||||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
|
||||||
remote_url="https://example.com/image1.jpg",
|
|
||||||
)
|
|
||||||
|
|
||||||
file_dict = file.model_dump()
|
|
||||||
assert file_dict["dify_model_identity"] == FILE_MODEL_IDENTITY
|
|
||||||
assert file_dict["type"] == file.type.value
|
|
||||||
assert isinstance(file_dict["type"], str)
|
|
||||||
assert file_dict["transfer_method"] == file.transfer_method.value
|
|
||||||
assert isinstance(file_dict["transfer_method"], str)
|
|
||||||
assert "_extra_config" not in file_dict
|
|
||||||
|
|
||||||
file_obj = File.model_validate(file_dict)
|
|
||||||
assert file_obj.id == file.id
|
|
||||||
assert file_obj.tenant_id == file.tenant_id
|
|
||||||
assert file_obj.type == file.type
|
|
||||||
assert file_obj.transfer_method == file.transfer_method
|
|
||||||
assert file_obj.remote_url == file.remote_url
|
|
||||||
|
|
||||||
|
|
||||||
def test_file_to_dict():
|
def test_file_to_dict():
|
||||||
file = File(
|
file = File(
|
||||||
id="file1",
|
id="file1",
|
||||||
@ -36,10 +11,11 @@ def test_file_to_dict():
|
|||||||
type=FileType.IMAGE,
|
type=FileType.IMAGE,
|
||||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||||
remote_url="https://example.com/image1.jpg",
|
remote_url="https://example.com/image1.jpg",
|
||||||
|
storage_key="storage_key",
|
||||||
)
|
)
|
||||||
|
|
||||||
file_dict = file.to_dict()
|
file_dict = file.to_dict()
|
||||||
assert "_extra_config" not in file_dict
|
assert "_storage_key" not in file_dict
|
||||||
assert "url" in file_dict
|
assert "url" in file_dict
|
||||||
|
|
||||||
|
|
||||||
|
@ -51,6 +51,7 @@ def test_http_request_node_binary_file(monkeypatch):
|
|||||||
type=FileType.IMAGE,
|
type=FileType.IMAGE,
|
||||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||||
related_id="1111",
|
related_id="1111",
|
||||||
|
storage_key="",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@ -138,6 +139,7 @@ def test_http_request_node_form_with_file(monkeypatch):
|
|||||||
type=FileType.IMAGE,
|
type=FileType.IMAGE,
|
||||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||||
related_id="1111",
|
related_id="1111",
|
||||||
|
storage_key="",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
@ -21,7 +21,8 @@ from core.model_runtime.entities.message_entities import (
|
|||||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelFeature, ModelType
|
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelFeature, ModelType
|
||||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||||
from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment
|
from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment, StringSegment
|
||||||
|
from core.workflow.entities.variable_entities import VariableSelector
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
|
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
|
||||||
from core.workflow.nodes.answer import AnswerStreamGenerateRoute
|
from core.workflow.nodes.answer import AnswerStreamGenerateRoute
|
||||||
@ -157,6 +158,7 @@ def test_fetch_files_with_file_segment(llm_node):
|
|||||||
filename="test.jpg",
|
filename="test.jpg",
|
||||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||||
related_id="1",
|
related_id="1",
|
||||||
|
storage_key="",
|
||||||
)
|
)
|
||||||
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file)
|
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file)
|
||||||
|
|
||||||
@ -173,6 +175,7 @@ def test_fetch_files_with_array_file_segment(llm_node):
|
|||||||
filename="test1.jpg",
|
filename="test1.jpg",
|
||||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||||
related_id="1",
|
related_id="1",
|
||||||
|
storage_key="",
|
||||||
),
|
),
|
||||||
File(
|
File(
|
||||||
id="2",
|
id="2",
|
||||||
@ -181,6 +184,7 @@ def test_fetch_files_with_array_file_segment(llm_node):
|
|||||||
filename="test2.jpg",
|
filename="test2.jpg",
|
||||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||||
related_id="2",
|
related_id="2",
|
||||||
|
storage_key="",
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files))
|
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files))
|
||||||
@ -224,14 +228,15 @@ def test_fetch_prompt_messages__vison_disabled(faker, llm_node, model_config):
|
|||||||
filename="test1.jpg",
|
filename="test1.jpg",
|
||||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||||
remote_url=fake_remote_url,
|
remote_url=fake_remote_url,
|
||||||
|
storage_key="",
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
fake_query = faker.sentence()
|
fake_query = faker.sentence()
|
||||||
|
|
||||||
prompt_messages, _ = llm_node._fetch_prompt_messages(
|
prompt_messages, _ = llm_node._fetch_prompt_messages(
|
||||||
user_query=fake_query,
|
sys_query=fake_query,
|
||||||
user_files=files,
|
sys_files=files,
|
||||||
context=None,
|
context=None,
|
||||||
memory=None,
|
memory=None,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
@ -283,8 +288,8 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
|
|||||||
test_scenarios = [
|
test_scenarios = [
|
||||||
LLMNodeTestScenario(
|
LLMNodeTestScenario(
|
||||||
description="No files",
|
description="No files",
|
||||||
user_query=fake_query,
|
sys_query=fake_query,
|
||||||
user_files=[],
|
sys_files=[],
|
||||||
features=[],
|
features=[],
|
||||||
vision_enabled=False,
|
vision_enabled=False,
|
||||||
vision_detail=None,
|
vision_detail=None,
|
||||||
@ -318,8 +323,8 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
|
|||||||
),
|
),
|
||||||
LLMNodeTestScenario(
|
LLMNodeTestScenario(
|
||||||
description="User files",
|
description="User files",
|
||||||
user_query=fake_query,
|
sys_query=fake_query,
|
||||||
user_files=[
|
sys_files=[
|
||||||
File(
|
File(
|
||||||
tenant_id="test",
|
tenant_id="test",
|
||||||
type=FileType.IMAGE,
|
type=FileType.IMAGE,
|
||||||
@ -328,6 +333,7 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
|
|||||||
remote_url=fake_remote_url,
|
remote_url=fake_remote_url,
|
||||||
extension=".jpg",
|
extension=".jpg",
|
||||||
mime_type="image/jpg",
|
mime_type="image/jpg",
|
||||||
|
storage_key="",
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
vision_enabled=True,
|
vision_enabled=True,
|
||||||
@ -370,8 +376,8 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
|
|||||||
),
|
),
|
||||||
LLMNodeTestScenario(
|
LLMNodeTestScenario(
|
||||||
description="Prompt template with variable selector of File",
|
description="Prompt template with variable selector of File",
|
||||||
user_query=fake_query,
|
sys_query=fake_query,
|
||||||
user_files=[],
|
sys_files=[],
|
||||||
vision_enabled=False,
|
vision_enabled=False,
|
||||||
vision_detail=fake_vision_detail,
|
vision_detail=fake_vision_detail,
|
||||||
features=[ModelFeature.VISION],
|
features=[ModelFeature.VISION],
|
||||||
@ -403,6 +409,7 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
|
|||||||
remote_url=fake_remote_url,
|
remote_url=fake_remote_url,
|
||||||
extension=".jpg",
|
extension=".jpg",
|
||||||
mime_type="image/jpg",
|
mime_type="image/jpg",
|
||||||
|
storage_key="",
|
||||||
)
|
)
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
@ -417,8 +424,8 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
|
|||||||
|
|
||||||
# Call the method under test
|
# Call the method under test
|
||||||
prompt_messages, _ = llm_node._fetch_prompt_messages(
|
prompt_messages, _ = llm_node._fetch_prompt_messages(
|
||||||
user_query=scenario.user_query,
|
sys_query=scenario.sys_query,
|
||||||
user_files=scenario.user_files,
|
sys_files=scenario.sys_files,
|
||||||
context=fake_context,
|
context=fake_context,
|
||||||
memory=memory,
|
memory=memory,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
@ -435,3 +442,29 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
|
|||||||
assert (
|
assert (
|
||||||
prompt_messages == scenario.expected_messages
|
prompt_messages == scenario.expected_messages
|
||||||
), f"Message content mismatch in scenario: {scenario.description}"
|
), f"Message content mismatch in scenario: {scenario.description}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_list_messages_basic(llm_node):
|
||||||
|
messages = [
|
||||||
|
LLMNodeChatModelMessage(
|
||||||
|
text="Hello, {#context#}",
|
||||||
|
role=PromptMessageRole.USER,
|
||||||
|
edition_type="basic",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
context = "world"
|
||||||
|
jinja2_variables = []
|
||||||
|
variable_pool = llm_node.graph_runtime_state.variable_pool
|
||||||
|
vision_detail_config = ImagePromptMessageContent.DETAIL.HIGH
|
||||||
|
|
||||||
|
result = llm_node._handle_list_messages(
|
||||||
|
messages=messages,
|
||||||
|
context=context,
|
||||||
|
jinja2_variables=jinja2_variables,
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
vision_detail_config=vision_detail_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert isinstance(result[0], UserPromptMessage)
|
||||||
|
assert result[0].content == [TextPromptMessageContent(data="Hello, world")]
|
||||||
|
@ -12,8 +12,8 @@ class LLMNodeTestScenario(BaseModel):
|
|||||||
"""Test scenario for LLM node testing."""
|
"""Test scenario for LLM node testing."""
|
||||||
|
|
||||||
description: str = Field(..., description="Description of the test scenario")
|
description: str = Field(..., description="Description of the test scenario")
|
||||||
user_query: str = Field(..., description="User query input")
|
sys_query: str = Field(..., description="User query input")
|
||||||
user_files: Sequence[File] = Field(default_factory=list, description="List of user files")
|
sys_files: Sequence[File] = Field(default_factory=list, description="List of user files")
|
||||||
vision_enabled: bool = Field(default=False, description="Whether vision is enabled")
|
vision_enabled: bool = Field(default=False, description="Whether vision is enabled")
|
||||||
vision_detail: str | None = Field(None, description="Vision detail level if vision is enabled")
|
vision_detail: str | None = Field(None, description="Vision detail level if vision is enabled")
|
||||||
features: Sequence[ModelFeature] = Field(default_factory=list, description="List of model features")
|
features: Sequence[ModelFeature] = Field(default_factory=list, description="List of model features")
|
||||||
|
@ -248,6 +248,7 @@ def test_array_file_contains_file_name():
|
|||||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||||
related_id="1",
|
related_id="1",
|
||||||
filename="ab",
|
filename="ab",
|
||||||
|
storage_key="",
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -57,6 +57,7 @@ def test_filter_files_by_type(list_operator_node):
|
|||||||
tenant_id="tenant1",
|
tenant_id="tenant1",
|
||||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||||
related_id="related1",
|
related_id="related1",
|
||||||
|
storage_key="",
|
||||||
),
|
),
|
||||||
File(
|
File(
|
||||||
filename="document1.pdf",
|
filename="document1.pdf",
|
||||||
@ -64,6 +65,7 @@ def test_filter_files_by_type(list_operator_node):
|
|||||||
tenant_id="tenant1",
|
tenant_id="tenant1",
|
||||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||||
related_id="related2",
|
related_id="related2",
|
||||||
|
storage_key="",
|
||||||
),
|
),
|
||||||
File(
|
File(
|
||||||
filename="image2.png",
|
filename="image2.png",
|
||||||
@ -71,6 +73,7 @@ def test_filter_files_by_type(list_operator_node):
|
|||||||
tenant_id="tenant1",
|
tenant_id="tenant1",
|
||||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||||
related_id="related3",
|
related_id="related3",
|
||||||
|
storage_key="",
|
||||||
),
|
),
|
||||||
File(
|
File(
|
||||||
filename="audio1.mp3",
|
filename="audio1.mp3",
|
||||||
@ -78,6 +81,7 @@ def test_filter_files_by_type(list_operator_node):
|
|||||||
tenant_id="tenant1",
|
tenant_id="tenant1",
|
||||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||||
related_id="related4",
|
related_id="related4",
|
||||||
|
storage_key="",
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
variable = ArrayFileSegment(value=files)
|
variable = ArrayFileSegment(value=files)
|
||||||
@ -130,6 +134,7 @@ def test_get_file_extract_string_func():
|
|||||||
mime_type="text/plain",
|
mime_type="text/plain",
|
||||||
remote_url="https://example.com/test_file.txt",
|
remote_url="https://example.com/test_file.txt",
|
||||||
related_id="test_related_id",
|
related_id="test_related_id",
|
||||||
|
storage_key="",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Test each case
|
# Test each case
|
||||||
@ -150,6 +155,7 @@ def test_get_file_extract_string_func():
|
|||||||
mime_type=None,
|
mime_type=None,
|
||||||
remote_url=None,
|
remote_url=None,
|
||||||
related_id="test_related_id",
|
related_id="test_related_id",
|
||||||
|
storage_key="",
|
||||||
)
|
)
|
||||||
|
|
||||||
assert _get_file_extract_string_func(key="name")(empty_file) == ""
|
assert _get_file_extract_string_func(key="name")(empty_file) == ""
|
||||||
|
@ -19,6 +19,7 @@ def file():
|
|||||||
related_id="test_related_id",
|
related_id="test_related_id",
|
||||||
remote_url="test_url",
|
remote_url="test_url",
|
||||||
filename="test_file.txt",
|
filename="test_file.txt",
|
||||||
|
storage_key="",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user