Compare commits

...

5 Commits

Author SHA1 Message Date
-LAN-
3a02438131
feat: enhance handling of prompt message content and add error for unsupported types
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-03-18 12:51:43 +08:00
-LAN-
f31fda62c9
Merge branch 'main' into feat/support-image-generate-for-gemini 2025-03-17 16:40:52 +08:00
-LAN-
4ce2819263
Merge branch 'main' into feat/support-image-generate-for-gemini 2025-03-17 16:36:25 +08:00
-LAN-
e945afb3cd
feat: enhance prompt message validation and add content type mapping
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-03-17 16:34:33 +08:00
-LAN-
c43f388586
fix: update datetime usage to use UTC consistently across workflow and task modules
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-03-17 15:16:04 +08:00
6 changed files with 110 additions and 41 deletions

View File

@ -101,7 +101,7 @@ class ModelInstance:
@overload
def invoke_llm(
self,
prompt_messages: list[PromptMessage],
prompt_messages: Sequence[PromptMessage],
model_parameters: Optional[dict] = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: Optional[list[str]] = None,

View File

@ -1,5 +1,5 @@
from abc import ABC
from collections.abc import Sequence
from collections.abc import Mapping, Sequence
from enum import Enum, StrEnum
from typing import Optional
@ -119,6 +119,15 @@ class DocumentPromptMessageContent(MultiModalPromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.DOCUMENT
CONTENT_TYPE_MAPPING: Mapping[PromptMessageContentType, type[PromptMessageContent]] = {
PromptMessageContentType.TEXT: TextPromptMessageContent,
PromptMessageContentType.IMAGE: ImagePromptMessageContent,
PromptMessageContentType.AUDIO: AudioPromptMessageContent,
PromptMessageContentType.VIDEO: VideoPromptMessageContent,
PromptMessageContentType.DOCUMENT: DocumentPromptMessageContent,
}
class PromptMessage(ABC, BaseModel):
"""
Model class for prompt message.
@ -136,6 +145,23 @@ class PromptMessage(ABC, BaseModel):
"""
return not self.content
@field_validator("content", mode="before")
@classmethod
def validate_content(cls, v):
if isinstance(v, list):
prompts = []
for prompt in v:
if isinstance(prompt, PromptMessageContent):
if not isinstance(prompt, TextPromptMessageContent | MultiModalPromptMessageContent):
prompt = CONTENT_TYPE_MAPPING[prompt.type].model_validate(prompt.model_dump())
elif isinstance(prompt, dict):
prompt = CONTENT_TYPE_MAPPING[prompt["type"]].model_validate(prompt)
else:
raise ValueError(f"invalid prompt message {prompt}")
prompts.append(prompt)
return prompts
return v
class UserPromptMessage(PromptMessage):
"""

View File

@ -12,7 +12,9 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk,
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageContent,
PromptMessageTool,
TextPromptMessageContent,
)
from core.model_runtime.entities.model_entities import (
ModelType,
@ -211,7 +213,7 @@ class LargeLanguageModel(AIModel):
def _invoke_result_generator(
self,
model: str,
result: Generator,
result: Generator[LLMResultChunk, None, None],
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
@ -228,7 +230,7 @@ class LargeLanguageModel(AIModel):
:return: result generator
"""
callbacks = callbacks or []
assistant_message = AssistantPromptMessage(content="")
message_content: list[PromptMessageContent] = []
usage = None
system_fingerprint = None
real_model = model
@ -250,7 +252,10 @@ class LargeLanguageModel(AIModel):
callbacks=callbacks,
)
assistant_message.content += chunk.delta.message.content
if isinstance(chunk.delta.message.content, list):
message_content.extend(chunk.delta.message.content)
elif isinstance(chunk.delta.message.content, str):
message_content.append(TextPromptMessageContent(data=chunk.delta.message.content))
real_model = chunk.model
if chunk.delta.usage:
usage = chunk.delta.usage
@ -260,6 +265,7 @@ class LargeLanguageModel(AIModel):
except Exception as e:
raise self._transform_invoke_error(e)
assistant_message = AssistantPromptMessage(content=message_content)
self._trigger_after_invoke_callbacks(
model=model,
result=LLMResult(

View File

@ -171,7 +171,7 @@ class BasePluginManager:
line_data = None
try:
line_data = json.loads(line)
rep = PluginDaemonBasicResponse[type](**line_data) # type: ignore
rep = PluginDaemonBasicResponse[type](**line_data)
except Exception:
# TODO modify this when line_data has code and message
if line_data and "error" in line_data:

View File

@ -38,3 +38,8 @@ class MemoryRolePrefixRequiredError(LLMNodeError):
class FileTypeNotSupportError(LLMNodeError):
def __init__(self, *, type_name: str):
super().__init__(f"{type_name} type is not supported by this model")
class UnsupportedPromptContentTypeError(LLMNodeError):
def __init__(self, *, type_name: str) -> None:
super().__init__(f"Prompt content type {type_name} is not supported.")

View File

@ -19,7 +19,7 @@ from core.model_runtime.entities import (
PromptMessageContentType,
TextPromptMessageContent,
)
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMUsage
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessageContent,
@ -78,6 +78,7 @@ from .exc import (
ModelNotExistError,
NoPromptFoundError,
TemplateTypeNotSupportError,
UnsupportedPromptContentTypeError,
VariableNotFoundError,
)
@ -246,56 +247,62 @@ class LLMNode(BaseNode[LLMNodeData]):
return self._handle_invoke_result(invoke_result=invoke_result)
def _handle_invoke_result(self, invoke_result: LLMResult | Generator) -> Generator[NodeEvent, None, None]:
def _handle_invoke_result(
self, invoke_result: LLMResult | Generator[LLMResultChunk, None, None]
) -> Generator[NodeEvent, None, None]:
# For blocking mode
if isinstance(invoke_result, LLMResult):
content = invoke_result.message.content
if content is None:
message_text = ""
elif isinstance(content, str):
message_text = content
elif isinstance(content, list):
# Assuming the list contains PromptMessageContent objects with a "data" attribute
message_text = "".join(
item.data if hasattr(item, "data") and isinstance(item.data, str) else str(item) for item in content
)
else:
message_text = str(content)
yield ModelInvokeCompletedEvent(
text=message_text,
usage=invoke_result.usage,
finish_reason=None,
)
event = self._handle_blocking_result(invoke_result=invoke_result)
yield event
return
model = None
# For streaming mode
model = ""
prompt_messages: list[PromptMessage] = []
full_text = ""
usage = None
usage = LLMUsage.empty_usage()
finish_reason = None
for result in invoke_result:
text = result.delta.message.content
full_text += text
contents = result.delta.message.content
if contents is None:
continue
yield RunStreamChunkEvent(chunk_content=text, from_variable_selector=[self.node_id, "text"])
if isinstance(contents, str):
yield RunStreamChunkEvent(chunk_content=contents, from_variable_selector=[self.node_id, "text"])
full_text += contents
elif isinstance(contents, list):
for content in contents:
if isinstance(content, TextPromptMessageContent):
text_chunk = content.data
elif isinstance(content, ImagePromptMessageContent):
text_chunk = self._image_to_markdown(content)
else:
raise UnsupportedPromptContentTypeError(type_name=str(type(content)))
yield RunStreamChunkEvent(chunk_content=text_chunk, from_variable_selector=[self.node_id])
full_text += text_chunk
if not model:
# Update the whole metadata
if not model and result.model:
model = result.model
if not prompt_messages:
if len(prompt_messages) == 0:
prompt_messages = result.prompt_messages
if not usage and result.delta.usage:
if usage.prompt_tokens == 0 and result.delta.usage:
usage = result.delta.usage
if not finish_reason and result.delta.finish_reason:
if finish_reason is None and result.delta.finish_reason:
finish_reason = result.delta.finish_reason
if not usage:
usage = LLMUsage.empty_usage()
yield ModelInvokeCompletedEvent(text=full_text, usage=usage, finish_reason=finish_reason)
def _image_to_markdown(self, content: ImagePromptMessageContent, /):
if content.url:
text_chunk = f"![]({content.url})"
elif content.base64_data:
# insert b64 image into markdown text
text_chunk = f"![]({content.data})"
else:
raise ValueError("Image content must have either a URL or base64 data")
return text_chunk
def _transform_chat_messages(
self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, /
) -> Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
@ -926,6 +933,31 @@ class LLMNode(BaseNode[LLMNodeData]):
return prompt_messages
def _handle_blocking_result(self, *, invoke_result: LLMResult) -> ModelInvokeCompletedEvent:
contents = invoke_result.message.content
if contents is None:
message_text = ""
elif isinstance(contents, str):
message_text = contents
elif isinstance(contents, list):
# TODO: support multi modal content
message_text = ""
for item in contents:
if isinstance(item, TextPromptMessageContent):
message_text += item.data
elif isinstance(item, ImagePromptMessageContent):
message_text += self._image_to_markdown(item)
else:
message_text += str(item)
else:
message_text = str(contents)
return ModelInvokeCompletedEvent(
text=message_text,
usage=invoke_result.usage,
finish_reason=None,
)
def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole):
match role: