feat: enhance handling of prompt message content and add error for unsupported types
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
parent
f31fda62c9
commit
3a02438131
@ -12,7 +12,9 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk,
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageContent,
|
||||
PromptMessageTool,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
ModelType,
|
||||
@ -211,7 +213,7 @@ class LargeLanguageModel(AIModel):
|
||||
def _invoke_result_generator(
|
||||
self,
|
||||
model: str,
|
||||
result: Generator,
|
||||
result: Generator[LLMResultChunk, None, None],
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
@ -228,7 +230,7 @@ class LargeLanguageModel(AIModel):
|
||||
:return: result generator
|
||||
"""
|
||||
callbacks = callbacks or []
|
||||
assistant_message = AssistantPromptMessage(content="")
|
||||
message_content: list[PromptMessageContent] = []
|
||||
usage = None
|
||||
system_fingerprint = None
|
||||
real_model = model
|
||||
@ -250,7 +252,10 @@ class LargeLanguageModel(AIModel):
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
assistant_message.content += chunk.delta.message.content
|
||||
if isinstance(chunk.delta.message.content, list):
|
||||
message_content.extend(chunk.delta.message.content)
|
||||
elif isinstance(chunk.delta.message.content, str):
|
||||
message_content.append(TextPromptMessageContent(data=chunk.delta.message.content))
|
||||
real_model = chunk.model
|
||||
if chunk.delta.usage:
|
||||
usage = chunk.delta.usage
|
||||
@ -260,6 +265,7 @@ class LargeLanguageModel(AIModel):
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
assistant_message = AssistantPromptMessage(content=message_content)
|
||||
self._trigger_after_invoke_callbacks(
|
||||
model=model,
|
||||
result=LLMResult(
|
||||
|
@ -38,3 +38,8 @@ class MemoryRolePrefixRequiredError(LLMNodeError):
|
||||
class FileTypeNotSupportError(LLMNodeError):
|
||||
def __init__(self, *, type_name: str):
|
||||
super().__init__(f"{type_name} type is not supported by this model")
|
||||
|
||||
|
||||
class UnsupportedPromptContentTypeError(LLMNodeError):
|
||||
def __init__(self, *, type_name: str) -> None:
|
||||
super().__init__(f"Prompt content type {type_name} is not supported.")
|
||||
|
@ -19,7 +19,7 @@ from core.model_runtime.entities import (
|
||||
PromptMessageContentType,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessageContent,
|
||||
@ -78,6 +78,7 @@ from .exc import (
|
||||
ModelNotExistError,
|
||||
NoPromptFoundError,
|
||||
TemplateTypeNotSupportError,
|
||||
UnsupportedPromptContentTypeError,
|
||||
VariableNotFoundError,
|
||||
)
|
||||
|
||||
@ -246,56 +247,62 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
|
||||
return self._handle_invoke_result(invoke_result=invoke_result)
|
||||
|
||||
def _handle_invoke_result(self, invoke_result: LLMResult | Generator) -> Generator[NodeEvent, None, None]:
|
||||
def _handle_invoke_result(
|
||||
self, invoke_result: LLMResult | Generator[LLMResultChunk, None, None]
|
||||
) -> Generator[NodeEvent, None, None]:
|
||||
# For blocking mode
|
||||
if isinstance(invoke_result, LLMResult):
|
||||
content = invoke_result.message.content
|
||||
if content is None:
|
||||
message_text = ""
|
||||
elif isinstance(content, str):
|
||||
message_text = content
|
||||
elif isinstance(content, list):
|
||||
# Assuming the list contains PromptMessageContent objects with a "data" attribute
|
||||
message_text = "".join(
|
||||
item.data if hasattr(item, "data") and isinstance(item.data, str) else str(item) for item in content
|
||||
)
|
||||
else:
|
||||
message_text = str(content)
|
||||
|
||||
yield ModelInvokeCompletedEvent(
|
||||
text=message_text,
|
||||
usage=invoke_result.usage,
|
||||
finish_reason=None,
|
||||
)
|
||||
event = self._handle_blocking_result(invoke_result=invoke_result)
|
||||
yield event
|
||||
return
|
||||
|
||||
model = None
|
||||
# For streaming mode
|
||||
model = ""
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
full_text = ""
|
||||
usage = None
|
||||
usage = LLMUsage.empty_usage()
|
||||
finish_reason = None
|
||||
for result in invoke_result:
|
||||
text = result.delta.message.content
|
||||
full_text += text
|
||||
contents = result.delta.message.content
|
||||
if contents is None:
|
||||
continue
|
||||
|
||||
yield RunStreamChunkEvent(chunk_content=text, from_variable_selector=[self.node_id, "text"])
|
||||
if isinstance(contents, str):
|
||||
yield RunStreamChunkEvent(chunk_content=contents, from_variable_selector=[self.node_id, "text"])
|
||||
full_text += contents
|
||||
elif isinstance(contents, list):
|
||||
for content in contents:
|
||||
if isinstance(content, TextPromptMessageContent):
|
||||
text_chunk = content.data
|
||||
elif isinstance(content, ImagePromptMessageContent):
|
||||
text_chunk = self._image_to_markdown(content)
|
||||
else:
|
||||
raise UnsupportedPromptContentTypeError(type_name=str(type(content)))
|
||||
yield RunStreamChunkEvent(chunk_content=text_chunk, from_variable_selector=[self.node_id])
|
||||
full_text += text_chunk
|
||||
|
||||
if not model:
|
||||
# Update the whole metadata
|
||||
if not model and result.model:
|
||||
model = result.model
|
||||
|
||||
if not prompt_messages:
|
||||
if len(prompt_messages) == 0:
|
||||
prompt_messages = result.prompt_messages
|
||||
|
||||
if not usage and result.delta.usage:
|
||||
if usage.prompt_tokens == 0 and result.delta.usage:
|
||||
usage = result.delta.usage
|
||||
|
||||
if not finish_reason and result.delta.finish_reason:
|
||||
if finish_reason is None and result.delta.finish_reason:
|
||||
finish_reason = result.delta.finish_reason
|
||||
|
||||
if not usage:
|
||||
usage = LLMUsage.empty_usage()
|
||||
|
||||
yield ModelInvokeCompletedEvent(text=full_text, usage=usage, finish_reason=finish_reason)
|
||||
|
||||
def _image_to_markdown(self, content: ImagePromptMessageContent, /):
|
||||
if content.url:
|
||||
text_chunk = f""
|
||||
elif content.base64_data:
|
||||
# insert b64 image into markdown text
|
||||
text_chunk = f""
|
||||
else:
|
||||
raise ValueError("Image content must have either a URL or base64 data")
|
||||
return text_chunk
|
||||
|
||||
def _transform_chat_messages(
|
||||
self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, /
|
||||
) -> Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
|
||||
@ -926,6 +933,31 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _handle_blocking_result(self, *, invoke_result: LLMResult) -> ModelInvokeCompletedEvent:
|
||||
contents = invoke_result.message.content
|
||||
if contents is None:
|
||||
message_text = ""
|
||||
elif isinstance(contents, str):
|
||||
message_text = contents
|
||||
elif isinstance(contents, list):
|
||||
# TODO: support multi modal content
|
||||
message_text = ""
|
||||
for item in contents:
|
||||
if isinstance(item, TextPromptMessageContent):
|
||||
message_text += item.data
|
||||
elif isinstance(item, ImagePromptMessageContent):
|
||||
message_text += self._image_to_markdown(item)
|
||||
else:
|
||||
message_text += str(item)
|
||||
else:
|
||||
message_text = str(contents)
|
||||
|
||||
return ModelInvokeCompletedEvent(
|
||||
text=message_text,
|
||||
usage=invoke_result.usage,
|
||||
finish_reason=None,
|
||||
)
|
||||
|
||||
|
||||
def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole):
|
||||
match role:
|
||||
|
Loading…
Reference in New Issue
Block a user