From 3a024381312b027b7bd38e63f3c1bcc58a5bb44a Mon Sep 17 00:00:00 2001 From: -LAN- Date: Tue, 18 Mar 2025 12:51:43 +0800 Subject: [PATCH] feat: enhance handling of prompt message content and add error for unsupported types Signed-off-by: -LAN- --- .../__base/large_language_model.py | 12 ++- api/core/workflow/nodes/llm/exc.py | 5 + api/core/workflow/nodes/llm/node.py | 102 ++++++++++++------ 3 files changed, 81 insertions(+), 38 deletions(-) diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index ed67fef768..017aa6f244 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -12,7 +12,9 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, + PromptMessageContent, PromptMessageTool, + TextPromptMessageContent, ) from core.model_runtime.entities.model_entities import ( ModelType, @@ -211,7 +213,7 @@ class LargeLanguageModel(AIModel): def _invoke_result_generator( self, model: str, - result: Generator, + result: Generator[LLMResultChunk, None, None], credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, @@ -228,7 +230,7 @@ class LargeLanguageModel(AIModel): :return: result generator """ callbacks = callbacks or [] - assistant_message = AssistantPromptMessage(content="") + message_content: list[PromptMessageContent] = [] usage = None system_fingerprint = None real_model = model @@ -250,7 +252,10 @@ class LargeLanguageModel(AIModel): callbacks=callbacks, ) - assistant_message.content += chunk.delta.message.content + if isinstance(chunk.delta.message.content, list): + message_content.extend(chunk.delta.message.content) + elif isinstance(chunk.delta.message.content, str): + message_content.append(TextPromptMessageContent(data=chunk.delta.message.content)) real_model = chunk.model if chunk.delta.usage: usage = chunk.delta.usage @@ -260,6 +265,7 @@ class LargeLanguageModel(AIModel): except Exception as e: raise self._transform_invoke_error(e) + assistant_message = AssistantPromptMessage(content=message_content) self._trigger_after_invoke_callbacks( model=model, result=LLMResult( diff --git a/api/core/workflow/nodes/llm/exc.py b/api/core/workflow/nodes/llm/exc.py index 6599221691..42b8f4e6ce 100644 --- a/api/core/workflow/nodes/llm/exc.py +++ b/api/core/workflow/nodes/llm/exc.py @@ -38,3 +38,8 @@ class MemoryRolePrefixRequiredError(LLMNodeError): class FileTypeNotSupportError(LLMNodeError): def __init__(self, *, type_name: str): super().__init__(f"{type_name} type is not supported by this model") + + +class UnsupportedPromptContentTypeError(LLMNodeError): + def __init__(self, *, type_name: str) -> None: + super().__init__(f"Prompt content type {type_name} is not supported.") diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index fe0ed3e564..44532b41ce 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -19,7 +19,7 @@ from core.model_runtime.entities import ( PromptMessageContentType, TextPromptMessageContent, ) -from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMUsage from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessageContent, @@ -78,6 +78,7 @@ from .exc import ( ModelNotExistError, NoPromptFoundError, TemplateTypeNotSupportError, + UnsupportedPromptContentTypeError, VariableNotFoundError, ) @@ -246,56 +247,62 @@ class LLMNode(BaseNode[LLMNodeData]): return self._handle_invoke_result(invoke_result=invoke_result) - def _handle_invoke_result(self, invoke_result: LLMResult | Generator) -> Generator[NodeEvent, None, None]: + def _handle_invoke_result( + self, invoke_result: LLMResult | Generator[LLMResultChunk, None, None] + ) -> Generator[NodeEvent, None, None]: + # For blocking mode if isinstance(invoke_result, LLMResult): - content = invoke_result.message.content - if content is None: - message_text = "" - elif isinstance(content, str): - message_text = content - elif isinstance(content, list): - # Assuming the list contains PromptMessageContent objects with a "data" attribute - message_text = "".join( - item.data if hasattr(item, "data") and isinstance(item.data, str) else str(item) for item in content - ) - else: - message_text = str(content) - - yield ModelInvokeCompletedEvent( - text=message_text, - usage=invoke_result.usage, - finish_reason=None, - ) + event = self._handle_blocking_result(invoke_result=invoke_result) + yield event return - model = None + # For streaming mode + model = "" prompt_messages: list[PromptMessage] = [] full_text = "" - usage = None + usage = LLMUsage.empty_usage() finish_reason = None for result in invoke_result: - text = result.delta.message.content - full_text += text + contents = result.delta.message.content + if contents is None: + continue - yield RunStreamChunkEvent(chunk_content=text, from_variable_selector=[self.node_id, "text"]) + if isinstance(contents, str): + yield RunStreamChunkEvent(chunk_content=contents, from_variable_selector=[self.node_id, "text"]) + full_text += contents + elif isinstance(contents, list): + for content in contents: + if isinstance(content, TextPromptMessageContent): + text_chunk = content.data + elif isinstance(content, ImagePromptMessageContent): + text_chunk = self._image_to_markdown(content) + else: + raise UnsupportedPromptContentTypeError(type_name=str(type(content))) + yield RunStreamChunkEvent(chunk_content=text_chunk, from_variable_selector=[self.node_id]) + full_text += text_chunk - if not model: + # Update the whole metadata + if not model and result.model: model = result.model - - if not prompt_messages: + if len(prompt_messages) == 0: prompt_messages = result.prompt_messages - - if not usage and result.delta.usage: + if usage.prompt_tokens == 0 and result.delta.usage: usage = result.delta.usage - - if not finish_reason and result.delta.finish_reason: + if finish_reason is None and result.delta.finish_reason: finish_reason = result.delta.finish_reason - if not usage: - usage = LLMUsage.empty_usage() - yield ModelInvokeCompletedEvent(text=full_text, usage=usage, finish_reason=finish_reason) + def _image_to_markdown(self, content: ImagePromptMessageContent, /): + if content.url: + text_chunk = f"![]({content.url})" + elif content.base64_data: + # insert b64 image into markdown text + text_chunk = f"![]({content.data})" + else: + raise ValueError("Image content must have either a URL or base64 data") + return text_chunk + def _transform_chat_messages( self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, / ) -> Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate: @@ -926,6 +933,31 @@ class LLMNode(BaseNode[LLMNodeData]): return prompt_messages + def _handle_blocking_result(self, *, invoke_result: LLMResult) -> ModelInvokeCompletedEvent: + contents = invoke_result.message.content + if contents is None: + message_text = "" + elif isinstance(contents, str): + message_text = contents + elif isinstance(contents, list): + # TODO: support multi modal content + message_text = "" + for item in contents: + if isinstance(item, TextPromptMessageContent): + message_text += item.data + elif isinstance(item, ImagePromptMessageContent): + message_text += self._image_to_markdown(item) + else: + message_text += str(item) + else: + message_text = str(contents) + + return ModelInvokeCompletedEvent( + text=message_text, + usage=invoke_result.usage, + finish_reason=None, + ) + def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole): match role: