feat: enhance handling of prompt message content and add error for unsupported types

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN- 2025-03-18 12:51:43 +08:00
parent f31fda62c9
commit 3a02438131
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF
3 changed files with 81 additions and 38 deletions

View File

@ -12,7 +12,9 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk,
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageContent,
PromptMessageTool,
TextPromptMessageContent,
)
from core.model_runtime.entities.model_entities import (
ModelType,
@ -211,7 +213,7 @@ class LargeLanguageModel(AIModel):
def _invoke_result_generator(
self,
model: str,
result: Generator,
result: Generator[LLMResultChunk, None, None],
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
@ -228,7 +230,7 @@ class LargeLanguageModel(AIModel):
:return: result generator
"""
callbacks = callbacks or []
assistant_message = AssistantPromptMessage(content="")
message_content: list[PromptMessageContent] = []
usage = None
system_fingerprint = None
real_model = model
@ -250,7 +252,10 @@ class LargeLanguageModel(AIModel):
callbacks=callbacks,
)
assistant_message.content += chunk.delta.message.content
if isinstance(chunk.delta.message.content, list):
message_content.extend(chunk.delta.message.content)
elif isinstance(chunk.delta.message.content, str):
message_content.append(TextPromptMessageContent(data=chunk.delta.message.content))
real_model = chunk.model
if chunk.delta.usage:
usage = chunk.delta.usage
@ -260,6 +265,7 @@ class LargeLanguageModel(AIModel):
except Exception as e:
raise self._transform_invoke_error(e)
assistant_message = AssistantPromptMessage(content=message_content)
self._trigger_after_invoke_callbacks(
model=model,
result=LLMResult(

View File

@ -38,3 +38,8 @@ class MemoryRolePrefixRequiredError(LLMNodeError):
class FileTypeNotSupportError(LLMNodeError):
def __init__(self, *, type_name: str):
super().__init__(f"{type_name} type is not supported by this model")
class UnsupportedPromptContentTypeError(LLMNodeError):
def __init__(self, *, type_name: str) -> None:
super().__init__(f"Prompt content type {type_name} is not supported.")

View File

@ -19,7 +19,7 @@ from core.model_runtime.entities import (
PromptMessageContentType,
TextPromptMessageContent,
)
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMUsage
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessageContent,
@ -78,6 +78,7 @@ from .exc import (
ModelNotExistError,
NoPromptFoundError,
TemplateTypeNotSupportError,
UnsupportedPromptContentTypeError,
VariableNotFoundError,
)
@ -246,56 +247,62 @@ class LLMNode(BaseNode[LLMNodeData]):
return self._handle_invoke_result(invoke_result=invoke_result)
def _handle_invoke_result(self, invoke_result: LLMResult | Generator) -> Generator[NodeEvent, None, None]:
def _handle_invoke_result(
self, invoke_result: LLMResult | Generator[LLMResultChunk, None, None]
) -> Generator[NodeEvent, None, None]:
# For blocking mode
if isinstance(invoke_result, LLMResult):
content = invoke_result.message.content
if content is None:
message_text = ""
elif isinstance(content, str):
message_text = content
elif isinstance(content, list):
# Assuming the list contains PromptMessageContent objects with a "data" attribute
message_text = "".join(
item.data if hasattr(item, "data") and isinstance(item.data, str) else str(item) for item in content
)
else:
message_text = str(content)
yield ModelInvokeCompletedEvent(
text=message_text,
usage=invoke_result.usage,
finish_reason=None,
)
event = self._handle_blocking_result(invoke_result=invoke_result)
yield event
return
model = None
# For streaming mode
model = ""
prompt_messages: list[PromptMessage] = []
full_text = ""
usage = None
usage = LLMUsage.empty_usage()
finish_reason = None
for result in invoke_result:
text = result.delta.message.content
full_text += text
contents = result.delta.message.content
if contents is None:
continue
yield RunStreamChunkEvent(chunk_content=text, from_variable_selector=[self.node_id, "text"])
if isinstance(contents, str):
yield RunStreamChunkEvent(chunk_content=contents, from_variable_selector=[self.node_id, "text"])
full_text += contents
elif isinstance(contents, list):
for content in contents:
if isinstance(content, TextPromptMessageContent):
text_chunk = content.data
elif isinstance(content, ImagePromptMessageContent):
text_chunk = self._image_to_markdown(content)
else:
raise UnsupportedPromptContentTypeError(type_name=str(type(content)))
yield RunStreamChunkEvent(chunk_content=text_chunk, from_variable_selector=[self.node_id])
full_text += text_chunk
if not model:
# Update the whole metadata
if not model and result.model:
model = result.model
if not prompt_messages:
if len(prompt_messages) == 0:
prompt_messages = result.prompt_messages
if not usage and result.delta.usage:
if usage.prompt_tokens == 0 and result.delta.usage:
usage = result.delta.usage
if not finish_reason and result.delta.finish_reason:
if finish_reason is None and result.delta.finish_reason:
finish_reason = result.delta.finish_reason
if not usage:
usage = LLMUsage.empty_usage()
yield ModelInvokeCompletedEvent(text=full_text, usage=usage, finish_reason=finish_reason)
def _image_to_markdown(self, content: ImagePromptMessageContent, /):
if content.url:
text_chunk = f"![]({content.url})"
elif content.base64_data:
# insert b64 image into markdown text
text_chunk = f"![]({content.data})"
else:
raise ValueError("Image content must have either a URL or base64 data")
return text_chunk
def _transform_chat_messages(
self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, /
) -> Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
@ -926,6 +933,31 @@ class LLMNode(BaseNode[LLMNodeData]):
return prompt_messages
def _handle_blocking_result(self, *, invoke_result: LLMResult) -> ModelInvokeCompletedEvent:
contents = invoke_result.message.content
if contents is None:
message_text = ""
elif isinstance(contents, str):
message_text = contents
elif isinstance(contents, list):
# TODO: support multi modal content
message_text = ""
for item in contents:
if isinstance(item, TextPromptMessageContent):
message_text += item.data
elif isinstance(item, ImagePromptMessageContent):
message_text += self._image_to_markdown(item)
else:
message_text += str(item)
else:
message_text = str(contents)
return ModelInvokeCompletedEvent(
text=message_text,
usage=invoke_result.usage,
finish_reason=None,
)
def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole):
match role: