feat: enhance handling of prompt message content and add error for unsupported types

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN- 2025-03-18 12:51:43 +08:00
parent f31fda62c9
commit 3a02438131
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF
3 changed files with 81 additions and 38 deletions

View File

@ -12,7 +12,9 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk,
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
AssistantPromptMessage, AssistantPromptMessage,
PromptMessage, PromptMessage,
PromptMessageContent,
PromptMessageTool, PromptMessageTool,
TextPromptMessageContent,
) )
from core.model_runtime.entities.model_entities import ( from core.model_runtime.entities.model_entities import (
ModelType, ModelType,
@ -211,7 +213,7 @@ class LargeLanguageModel(AIModel):
def _invoke_result_generator( def _invoke_result_generator(
self, self,
model: str, model: str,
result: Generator, result: Generator[LLMResultChunk, None, None],
credentials: dict, credentials: dict,
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
@ -228,7 +230,7 @@ class LargeLanguageModel(AIModel):
:return: result generator :return: result generator
""" """
callbacks = callbacks or [] callbacks = callbacks or []
assistant_message = AssistantPromptMessage(content="") message_content: list[PromptMessageContent] = []
usage = None usage = None
system_fingerprint = None system_fingerprint = None
real_model = model real_model = model
@ -250,7 +252,10 @@ class LargeLanguageModel(AIModel):
callbacks=callbacks, callbacks=callbacks,
) )
assistant_message.content += chunk.delta.message.content if isinstance(chunk.delta.message.content, list):
message_content.extend(chunk.delta.message.content)
elif isinstance(chunk.delta.message.content, str):
message_content.append(TextPromptMessageContent(data=chunk.delta.message.content))
real_model = chunk.model real_model = chunk.model
if chunk.delta.usage: if chunk.delta.usage:
usage = chunk.delta.usage usage = chunk.delta.usage
@ -260,6 +265,7 @@ class LargeLanguageModel(AIModel):
except Exception as e: except Exception as e:
raise self._transform_invoke_error(e) raise self._transform_invoke_error(e)
assistant_message = AssistantPromptMessage(content=message_content)
self._trigger_after_invoke_callbacks( self._trigger_after_invoke_callbacks(
model=model, model=model,
result=LLMResult( result=LLMResult(

View File

@ -38,3 +38,8 @@ class MemoryRolePrefixRequiredError(LLMNodeError):
class FileTypeNotSupportError(LLMNodeError): class FileTypeNotSupportError(LLMNodeError):
def __init__(self, *, type_name: str): def __init__(self, *, type_name: str):
super().__init__(f"{type_name} type is not supported by this model") super().__init__(f"{type_name} type is not supported by this model")
class UnsupportedPromptContentTypeError(LLMNodeError):
def __init__(self, *, type_name: str) -> None:
super().__init__(f"Prompt content type {type_name} is not supported.")

View File

@ -19,7 +19,7 @@ from core.model_runtime.entities import (
PromptMessageContentType, PromptMessageContentType,
TextPromptMessageContent, TextPromptMessageContent,
) )
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMUsage
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
AssistantPromptMessage, AssistantPromptMessage,
PromptMessageContent, PromptMessageContent,
@ -78,6 +78,7 @@ from .exc import (
ModelNotExistError, ModelNotExistError,
NoPromptFoundError, NoPromptFoundError,
TemplateTypeNotSupportError, TemplateTypeNotSupportError,
UnsupportedPromptContentTypeError,
VariableNotFoundError, VariableNotFoundError,
) )
@ -246,56 +247,62 @@ class LLMNode(BaseNode[LLMNodeData]):
return self._handle_invoke_result(invoke_result=invoke_result) return self._handle_invoke_result(invoke_result=invoke_result)
def _handle_invoke_result(self, invoke_result: LLMResult | Generator) -> Generator[NodeEvent, None, None]: def _handle_invoke_result(
self, invoke_result: LLMResult | Generator[LLMResultChunk, None, None]
) -> Generator[NodeEvent, None, None]:
# For blocking mode
if isinstance(invoke_result, LLMResult): if isinstance(invoke_result, LLMResult):
content = invoke_result.message.content event = self._handle_blocking_result(invoke_result=invoke_result)
if content is None: yield event
message_text = ""
elif isinstance(content, str):
message_text = content
elif isinstance(content, list):
# Assuming the list contains PromptMessageContent objects with a "data" attribute
message_text = "".join(
item.data if hasattr(item, "data") and isinstance(item.data, str) else str(item) for item in content
)
else:
message_text = str(content)
yield ModelInvokeCompletedEvent(
text=message_text,
usage=invoke_result.usage,
finish_reason=None,
)
return return
model = None # For streaming mode
model = ""
prompt_messages: list[PromptMessage] = [] prompt_messages: list[PromptMessage] = []
full_text = "" full_text = ""
usage = None usage = LLMUsage.empty_usage()
finish_reason = None finish_reason = None
for result in invoke_result: for result in invoke_result:
text = result.delta.message.content contents = result.delta.message.content
full_text += text if contents is None:
continue
yield RunStreamChunkEvent(chunk_content=text, from_variable_selector=[self.node_id, "text"]) if isinstance(contents, str):
yield RunStreamChunkEvent(chunk_content=contents, from_variable_selector=[self.node_id, "text"])
full_text += contents
elif isinstance(contents, list):
for content in contents:
if isinstance(content, TextPromptMessageContent):
text_chunk = content.data
elif isinstance(content, ImagePromptMessageContent):
text_chunk = self._image_to_markdown(content)
else:
raise UnsupportedPromptContentTypeError(type_name=str(type(content)))
yield RunStreamChunkEvent(chunk_content=text_chunk, from_variable_selector=[self.node_id])
full_text += text_chunk
if not model: # Update the whole metadata
if not model and result.model:
model = result.model model = result.model
if len(prompt_messages) == 0:
if not prompt_messages:
prompt_messages = result.prompt_messages prompt_messages = result.prompt_messages
if usage.prompt_tokens == 0 and result.delta.usage:
if not usage and result.delta.usage:
usage = result.delta.usage usage = result.delta.usage
if finish_reason is None and result.delta.finish_reason:
if not finish_reason and result.delta.finish_reason:
finish_reason = result.delta.finish_reason finish_reason = result.delta.finish_reason
if not usage:
usage = LLMUsage.empty_usage()
yield ModelInvokeCompletedEvent(text=full_text, usage=usage, finish_reason=finish_reason) yield ModelInvokeCompletedEvent(text=full_text, usage=usage, finish_reason=finish_reason)
def _image_to_markdown(self, content: ImagePromptMessageContent, /):
if content.url:
text_chunk = f"![]({content.url})"
elif content.base64_data:
# insert b64 image into markdown text
text_chunk = f"![]({content.data})"
else:
raise ValueError("Image content must have either a URL or base64 data")
return text_chunk
def _transform_chat_messages( def _transform_chat_messages(
self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, / self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, /
) -> Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate: ) -> Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
@ -926,6 +933,31 @@ class LLMNode(BaseNode[LLMNodeData]):
return prompt_messages return prompt_messages
def _handle_blocking_result(self, *, invoke_result: LLMResult) -> ModelInvokeCompletedEvent:
contents = invoke_result.message.content
if contents is None:
message_text = ""
elif isinstance(contents, str):
message_text = contents
elif isinstance(contents, list):
# TODO: support multi modal content
message_text = ""
for item in contents:
if isinstance(item, TextPromptMessageContent):
message_text += item.data
elif isinstance(item, ImagePromptMessageContent):
message_text += self._image_to_markdown(item)
else:
message_text += str(item)
else:
message_text = str(contents)
return ModelInvokeCompletedEvent(
text=message_text,
usage=invoke_result.usage,
finish_reason=None,
)
def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole): def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole):
match role: match role: