chore: the consistency of MultiModalPromptMessageContent (#11721)

This commit is contained in:
非法操作 2024-12-17 15:01:38 +08:00 committed by GitHub
parent 78c3051585
commit c9b4029ce7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 108 additions and 99 deletions

View File

@ -313,8 +313,7 @@ UPLOAD_VIDEO_FILE_SIZE_LIMIT=100
UPLOAD_AUDIO_FILE_SIZE_LIMIT=50 UPLOAD_AUDIO_FILE_SIZE_LIMIT=50
# Model configuration # Model configuration
MULTIMODAL_SEND_IMAGE_FORMAT=base64 MULTIMODAL_SEND_FORMAT=base64
MULTIMODAL_SEND_VIDEO_FORMAT=base64
PROMPT_GENERATION_MAX_TOKENS=512 PROMPT_GENERATION_MAX_TOKENS=512
CODE_GENERATION_MAX_TOKENS=1024 CODE_GENERATION_MAX_TOKENS=1024

View File

@ -665,14 +665,9 @@ class IndexingConfig(BaseSettings):
) )
class VisionFormatConfig(BaseSettings): class MultiModalTransferConfig(BaseSettings):
MULTIMODAL_SEND_IMAGE_FORMAT: Literal["base64", "url"] = Field( MULTIMODAL_SEND_FORMAT: Literal["base64", "url"] = Field(
description="Format for sending images in multimodal contexts ('base64' or 'url'), default is base64", description="Format for sending files in multimodal contexts ('base64' or 'url'), default is base64",
default="base64",
)
MULTIMODAL_SEND_VIDEO_FORMAT: Literal["base64", "url"] = Field(
description="Format for sending videos in multimodal contexts ('base64' or 'url'), default is base64",
default="base64", default="base64",
) )
@ -778,13 +773,13 @@ class FeatureConfig(
FileAccessConfig, FileAccessConfig,
FileUploadConfig, FileUploadConfig,
HttpConfig, HttpConfig,
VisionFormatConfig,
InnerAPIConfig, InnerAPIConfig,
IndexingConfig, IndexingConfig,
LoggingConfig, LoggingConfig,
MailConfig, MailConfig,
ModelLoadBalanceConfig, ModelLoadBalanceConfig,
ModerationConfig, ModerationConfig,
MultiModalTransferConfig,
PositionConfig, PositionConfig,
RagEtlConfig, RagEtlConfig,
SecurityConfig, SecurityConfig,

View File

@ -42,33 +42,31 @@ def to_prompt_message_content(
*, *,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None, image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
): ):
match f.type: if f.extension is None:
case FileType.IMAGE: raise ValueError("Missing file extension")
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW if f.mime_type is None:
if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url": raise ValueError("Missing file mime_type")
data = _to_url(f)
else:
data = _to_base64_data_string(f)
return ImagePromptMessageContent(data=data, detail=image_detail_config, format=f.extension.lstrip(".")) params = {
case FileType.AUDIO: "base64_data": _get_encoded_string(f) if dify_config.MULTIMODAL_SEND_FORMAT == "base64" else "",
data = _to_base64_data_string(f) "url": _to_url(f) if dify_config.MULTIMODAL_SEND_FORMAT == "url" else "",
if f.extension is None: "format": f.extension.removeprefix("."),
raise ValueError("Missing file extension") "mime_type": f.mime_type,
return AudioPromptMessageContent(data=data, format=f.extension.lstrip(".")) }
case FileType.VIDEO: if f.type == FileType.IMAGE:
if dify_config.MULTIMODAL_SEND_VIDEO_FORMAT == "url": params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
data = _to_url(f)
else: prompt_class_map = {
data = _to_base64_data_string(f) FileType.IMAGE: ImagePromptMessageContent,
if f.extension is None: FileType.AUDIO: AudioPromptMessageContent,
raise ValueError("Missing file extension") FileType.VIDEO: VideoPromptMessageContent,
return VideoPromptMessageContent(data=data, format=f.extension.lstrip(".")) FileType.DOCUMENT: DocumentPromptMessageContent,
case FileType.DOCUMENT: }
data = _to_base64_data_string(f)
return DocumentPromptMessageContent(encode_format="base64", data=data, format=f.extension.lstrip(".")) try:
case _: return prompt_class_map[f.type](**params)
raise ValueError(f"file type {f.type} is not supported") except KeyError:
raise ValueError(f"file type {f.type} is not supported")
def download(f: File, /): def download(f: File, /):
@ -122,11 +120,6 @@ def _get_encoded_string(f: File, /):
return encoded_string return encoded_string
def _to_base64_data_string(f: File, /):
encoded_string = _get_encoded_string(f)
return f"data:{f.mime_type};base64,{encoded_string}"
def _to_url(f: File, /): def _to_url(f: File, /):
if f.transfer_method == FileTransferMethod.REMOTE_URL: if f.transfer_method == FileTransferMethod.REMOTE_URL:
if f.remote_url is None: if f.remote_url is None:

View File

@ -1,9 +1,9 @@
from abc import ABC from abc import ABC
from collections.abc import Sequence from collections.abc import Sequence
from enum import Enum, StrEnum from enum import Enum, StrEnum
from typing import Literal, Optional from typing import Optional
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, computed_field, field_validator
class PromptMessageRole(Enum): class PromptMessageRole(Enum):
@ -67,7 +67,6 @@ class PromptMessageContent(BaseModel):
""" """
type: PromptMessageContentType type: PromptMessageContentType
data: str
class TextPromptMessageContent(PromptMessageContent): class TextPromptMessageContent(PromptMessageContent):
@ -76,21 +75,35 @@ class TextPromptMessageContent(PromptMessageContent):
""" """
type: PromptMessageContentType = PromptMessageContentType.TEXT type: PromptMessageContentType = PromptMessageContentType.TEXT
data: str
class VideoPromptMessageContent(PromptMessageContent): class MultiModalPromptMessageContent(PromptMessageContent):
"""
Model class for multi-modal prompt message content.
"""
type: PromptMessageContentType
format: str = Field(..., description="the format of multi-modal file")
base64_data: str = Field("", description="the base64 data of multi-modal file")
url: str = Field("", description="the url of multi-modal file")
mime_type: str = Field(..., description="the mime type of multi-modal file")
@computed_field(return_type=str)
@property
def data(self):
return self.url or f"data:{self.mime_type};base64,{self.base64_data}"
class VideoPromptMessageContent(MultiModalPromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.VIDEO type: PromptMessageContentType = PromptMessageContentType.VIDEO
data: str = Field(..., description="Base64 encoded video data")
format: str = Field(..., description="Video format")
class AudioPromptMessageContent(PromptMessageContent): class AudioPromptMessageContent(MultiModalPromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.AUDIO type: PromptMessageContentType = PromptMessageContentType.AUDIO
data: str = Field(..., description="Base64 encoded audio data")
format: str = Field(..., description="Audio format")
class ImagePromptMessageContent(PromptMessageContent): class ImagePromptMessageContent(MultiModalPromptMessageContent):
""" """
Model class for image prompt message content. Model class for image prompt message content.
""" """
@ -101,14 +114,10 @@ class ImagePromptMessageContent(PromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.IMAGE type: PromptMessageContentType = PromptMessageContentType.IMAGE
detail: DETAIL = DETAIL.LOW detail: DETAIL = DETAIL.LOW
format: str = Field("jpg", description="Image format")
class DocumentPromptMessageContent(PromptMessageContent): class DocumentPromptMessageContent(MultiModalPromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.DOCUMENT type: PromptMessageContentType = PromptMessageContentType.DOCUMENT
encode_format: Literal["base64"]
data: str
format: str = Field(..., description="Document format")
class PromptMessage(ABC, BaseModel): class PromptMessage(ABC, BaseModel):

View File

@ -1,5 +1,4 @@
import base64 import base64
import io
import json import json
from collections.abc import Generator, Sequence from collections.abc import Generator, Sequence
from typing import Optional, Union, cast from typing import Optional, Union, cast
@ -18,7 +17,6 @@ from anthropic.types import (
) )
from anthropic.types.beta.tools import ToolsBetaMessage from anthropic.types.beta.tools import ToolsBetaMessage
from httpx import Timeout from httpx import Timeout
from PIL import Image
from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities import ( from core.model_runtime.entities import (
@ -498,22 +496,19 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
sub_messages.append(sub_message_dict) sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.IMAGE: elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, message_content) message_content = cast(ImagePromptMessageContent, message_content)
if not message_content.data.startswith("data:"): if not message_content.base64_data:
# fetch image data from url # fetch image data from url
try: try:
image_content = requests.get(message_content.data).content image_content = requests.get(message_content.url).content
with Image.open(io.BytesIO(image_content)) as img:
mime_type = f"image/{img.format.lower()}"
base64_data = base64.b64encode(image_content).decode("utf-8") base64_data = base64.b64encode(image_content).decode("utf-8")
except Exception as ex: except Exception as ex:
raise ValueError( raise ValueError(
f"Failed to fetch image data from url {message_content.data}, {ex}" f"Failed to fetch image data from url {message_content.data}, {ex}"
) )
else: else:
data_split = message_content.data.split(";base64,") base64_data = message_content.base64_data
mime_type = data_split[0].replace("data:", "")
base64_data = data_split[1]
mime_type = message_content.mime_type
if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}: if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}:
raise ValueError( raise ValueError(
f"Unsupported image type {mime_type}, " f"Unsupported image type {mime_type}, "
@ -526,19 +521,17 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
} }
sub_messages.append(sub_message_dict) sub_messages.append(sub_message_dict)
elif isinstance(message_content, DocumentPromptMessageContent): elif isinstance(message_content, DocumentPromptMessageContent):
data_split = message_content.data.split(";base64,") if message_content.mime_type != "application/pdf":
mime_type = data_split[0].replace("data:", "")
base64_data = data_split[1]
if mime_type != "application/pdf":
raise ValueError( raise ValueError(
f"Unsupported document type {mime_type}, " "only support application/pdf" f"Unsupported document type {message_content.mime_type}, "
"only support application/pdf"
) )
sub_message_dict = { sub_message_dict = {
"type": "document", "type": "document",
"source": { "source": {
"type": message_content.encode_format, "type": "base64",
"media_type": mime_type, "media_type": message_content.mime_type,
"data": base64_data, "data": message_content.data,
}, },
} }
sub_messages.append(sub_message_dict) sub_messages.append(sub_message_dict)

View File

@ -434,9 +434,9 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
sub_messages.append(sub_message_dict) sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.VIDEO: elif message_content.type == PromptMessageContentType.VIDEO:
message_content = cast(VideoPromptMessageContent, message_content) message_content = cast(VideoPromptMessageContent, message_content)
video_url = message_content.data video_url = message_content.url
if message_content.data.startswith("data:"): if not video_url:
raise InvokeError("not support base64, please set MULTIMODAL_SEND_VIDEO_FORMAT to url") raise InvokeError("not support base64, please set MULTIMODAL_SEND_FORMAT to url")
sub_message_dict = {"video": video_url} sub_message_dict = {"video": video_url}
sub_messages.append(sub_message_dict) sub_messages.append(sub_message_dict)

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch
import pytest import pytest
from configs import dify_config
from core.app.app_config.entities import ModelConfigEntity from core.app.app_config.entities import ModelConfigEntity
from core.file import File, FileTransferMethod, FileType, FileUploadConfig, ImageConfig from core.file import File, FileTransferMethod, FileType, FileUploadConfig, ImageConfig
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
@ -126,6 +127,7 @@ def test__get_chat_model_prompt_messages_no_memory(get_chat_model_args):
def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_args): def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_args):
model_config_mock, _, messages, inputs, context = get_chat_model_args model_config_mock, _, messages, inputs, context = get_chat_model_args
dify_config.MULTIMODAL_SEND_FORMAT = "url"
files = [ files = [
File( File(
@ -140,7 +142,9 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
prompt_transform = AdvancedPromptTransform() prompt_transform = AdvancedPromptTransform()
prompt_transform._calculate_rest_token = MagicMock(return_value=2000) prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
with patch("core.file.file_manager.to_prompt_message_content") as mock_get_encoded_string: with patch("core.file.file_manager.to_prompt_message_content") as mock_get_encoded_string:
mock_get_encoded_string.return_value = ImagePromptMessageContent(data=str(files[0].remote_url)) mock_get_encoded_string.return_value = ImagePromptMessageContent(
url=str(files[0].remote_url), format="jpg", mime_type="image/jpg"
)
prompt_messages = prompt_transform._get_chat_model_prompt_messages( prompt_messages = prompt_transform._get_chat_model_prompt_messages(
prompt_template=messages, prompt_template=messages,
inputs=inputs, inputs=inputs,

View File

@ -18,8 +18,7 @@ from core.model_runtime.entities.message_entities import (
TextPromptMessageContent, TextPromptMessageContent,
UserPromptMessage, UserPromptMessage,
) )
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelFeature, ModelType, ProviderModel from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelFeature, ModelType
from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment
@ -249,8 +248,7 @@ def test_fetch_prompt_messages__vison_disabled(faker, llm_node, model_config):
def test_fetch_prompt_messages__basic(faker, llm_node, model_config): def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
# Setup dify config # Setup dify config
dify_config.MULTIMODAL_SEND_IMAGE_FORMAT = "url" dify_config.MULTIMODAL_SEND_FORMAT = "url"
dify_config.MULTIMODAL_SEND_VIDEO_FORMAT = "url"
# Generate fake values for prompt template # Generate fake values for prompt template
fake_assistant_prompt = faker.sentence() fake_assistant_prompt = faker.sentence()
@ -326,9 +324,10 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
tenant_id="test", tenant_id="test",
type=FileType.IMAGE, type=FileType.IMAGE,
filename="test1.jpg", filename="test1.jpg",
extension=".jpg",
transfer_method=FileTransferMethod.REMOTE_URL, transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=fake_remote_url, remote_url=fake_remote_url,
extension=".jpg",
mime_type="image/jpg",
) )
], ],
vision_enabled=True, vision_enabled=True,
@ -362,7 +361,9 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
UserPromptMessage( UserPromptMessage(
content=[ content=[
TextPromptMessageContent(data=fake_query), TextPromptMessageContent(data=fake_query),
ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail), ImagePromptMessageContent(
url=fake_remote_url, mime_type="image/jpg", format="jpg", detail=fake_vision_detail
),
] ]
), ),
], ],
@ -385,7 +386,9 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
expected_messages=[ expected_messages=[
UserPromptMessage( UserPromptMessage(
content=[ content=[
ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail), ImagePromptMessageContent(
url=fake_remote_url, mime_type="image/jpg", format="jpg", detail=fake_vision_detail
),
] ]
), ),
] ]
@ -396,9 +399,10 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
tenant_id="test", tenant_id="test",
type=FileType.IMAGE, type=FileType.IMAGE,
filename="test1.jpg", filename="test1.jpg",
extension=".jpg",
transfer_method=FileTransferMethod.REMOTE_URL, transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=fake_remote_url, remote_url=fake_remote_url,
extension=".jpg",
mime_type="image/jpg",
) )
}, },
), ),

View File

@ -614,13 +614,12 @@ CODE_GENERATION_MAX_TOKENS=1024
# Multi-modal Configuration # Multi-modal Configuration
# ------------------------------ # ------------------------------
# The format of the image/video sent when the multi-modal model is input, # The format of the image/video/audio/document sent when the multi-modal model is input,
# the default is base64, optional url. # the default is base64, optional url.
# The delay of the call in url mode will be lower than that in base64 mode. # The delay of the call in url mode will be lower than that in base64 mode.
# It is generally recommended to use the more compatible base64 mode. # It is generally recommended to use the more compatible base64 mode.
# If configured as url, you need to configure FILES_URL as an externally accessible address so that the multi-modal model can access the image/video. # If configured as url, you need to configure FILES_URL as an externally accessible address so that the multi-modal model can access the image/video/audio/document.
MULTIMODAL_SEND_IMAGE_FORMAT=base64 MULTIMODAL_SEND_FORMAT=base64
MULTIMODAL_SEND_VIDEO_FORMAT=base64
# Upload image file size limit, default 10M. # Upload image file size limit, default 10M.
UPLOAD_IMAGE_FILE_SIZE_LIMIT=10 UPLOAD_IMAGE_FILE_SIZE_LIMIT=10

View File

@ -225,8 +225,7 @@ x-shared-env: &shared-api-worker-env
UNSTRUCTURED_API_KEY: ${UNSTRUCTURED_API_KEY:-} UNSTRUCTURED_API_KEY: ${UNSTRUCTURED_API_KEY:-}
PROMPT_GENERATION_MAX_TOKENS: ${PROMPT_GENERATION_MAX_TOKENS:-512} PROMPT_GENERATION_MAX_TOKENS: ${PROMPT_GENERATION_MAX_TOKENS:-512}
CODE_GENERATION_MAX_TOKENS: ${CODE_GENERATION_MAX_TOKENS:-1024} CODE_GENERATION_MAX_TOKENS: ${CODE_GENERATION_MAX_TOKENS:-1024}
MULTIMODAL_SEND_IMAGE_FORMAT: ${MULTIMODAL_SEND_IMAGE_FORMAT:-base64} MULTIMODAL_SEND_FORMAT: ${MULTIMODAL_SEND_FORMAT:-base64}
MULTIMODAL_SEND_VIDEO_FORMAT: ${MULTIMODAL_SEND_VIDEO_FORMAT:-base64}
UPLOAD_IMAGE_FILE_SIZE_LIMIT: ${UPLOAD_IMAGE_FILE_SIZE_LIMIT:-10} UPLOAD_IMAGE_FILE_SIZE_LIMIT: ${UPLOAD_IMAGE_FILE_SIZE_LIMIT:-10}
UPLOAD_VIDEO_FILE_SIZE_LIMIT: ${UPLOAD_VIDEO_FILE_SIZE_LIMIT:-100} UPLOAD_VIDEO_FILE_SIZE_LIMIT: ${UPLOAD_VIDEO_FILE_SIZE_LIMIT:-100}
UPLOAD_AUDIO_FILE_SIZE_LIMIT: ${UPLOAD_AUDIO_FILE_SIZE_LIMIT:-50} UPLOAD_AUDIO_FILE_SIZE_LIMIT: ${UPLOAD_AUDIO_FILE_SIZE_LIMIT:-50}