Compare commits

...

5 Commits

Author SHA1 Message Date
-LAN-
3a02438131
feat: enhance handling of prompt message content and add error for unsupported types
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-03-18 12:51:43 +08:00
-LAN-
f31fda62c9
Merge branch 'main' into feat/support-image-generate-for-gemini 2025-03-17 16:40:52 +08:00
-LAN-
4ce2819263
Merge branch 'main' into feat/support-image-generate-for-gemini 2025-03-17 16:36:25 +08:00
-LAN-
e945afb3cd
feat: enhance prompt message validation and add content type mapping
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-03-17 16:34:33 +08:00
-LAN-
c43f388586
fix: update datetime usage to use UTC consistently across workflow and task modules
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-03-17 15:16:04 +08:00
6 changed files with 110 additions and 41 deletions

View File

@ -101,7 +101,7 @@ class ModelInstance:
@overload @overload
def invoke_llm( def invoke_llm(
self, self,
prompt_messages: list[PromptMessage], prompt_messages: Sequence[PromptMessage],
model_parameters: Optional[dict] = None, model_parameters: Optional[dict] = None,
tools: Sequence[PromptMessageTool] | None = None, tools: Sequence[PromptMessageTool] | None = None,
stop: Optional[list[str]] = None, stop: Optional[list[str]] = None,

View File

@ -1,5 +1,5 @@
from abc import ABC from abc import ABC
from collections.abc import Sequence from collections.abc import Mapping, Sequence
from enum import Enum, StrEnum from enum import Enum, StrEnum
from typing import Optional from typing import Optional
@ -119,6 +119,15 @@ class DocumentPromptMessageContent(MultiModalPromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.DOCUMENT type: PromptMessageContentType = PromptMessageContentType.DOCUMENT
CONTENT_TYPE_MAPPING: Mapping[PromptMessageContentType, type[PromptMessageContent]] = {
PromptMessageContentType.TEXT: TextPromptMessageContent,
PromptMessageContentType.IMAGE: ImagePromptMessageContent,
PromptMessageContentType.AUDIO: AudioPromptMessageContent,
PromptMessageContentType.VIDEO: VideoPromptMessageContent,
PromptMessageContentType.DOCUMENT: DocumentPromptMessageContent,
}
class PromptMessage(ABC, BaseModel): class PromptMessage(ABC, BaseModel):
""" """
Model class for prompt message. Model class for prompt message.
@ -136,6 +145,23 @@ class PromptMessage(ABC, BaseModel):
""" """
return not self.content return not self.content
@field_validator("content", mode="before")
@classmethod
def validate_content(cls, v):
if isinstance(v, list):
prompts = []
for prompt in v:
if isinstance(prompt, PromptMessageContent):
if not isinstance(prompt, TextPromptMessageContent | MultiModalPromptMessageContent):
prompt = CONTENT_TYPE_MAPPING[prompt.type].model_validate(prompt.model_dump())
elif isinstance(prompt, dict):
prompt = CONTENT_TYPE_MAPPING[prompt["type"]].model_validate(prompt)
else:
raise ValueError(f"invalid prompt message {prompt}")
prompts.append(prompt)
return prompts
return v
class UserPromptMessage(PromptMessage): class UserPromptMessage(PromptMessage):
""" """

View File

@ -12,7 +12,9 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk,
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
AssistantPromptMessage, AssistantPromptMessage,
PromptMessage, PromptMessage,
PromptMessageContent,
PromptMessageTool, PromptMessageTool,
TextPromptMessageContent,
) )
from core.model_runtime.entities.model_entities import ( from core.model_runtime.entities.model_entities import (
ModelType, ModelType,
@ -211,7 +213,7 @@ class LargeLanguageModel(AIModel):
def _invoke_result_generator( def _invoke_result_generator(
self, self,
model: str, model: str,
result: Generator, result: Generator[LLMResultChunk, None, None],
credentials: dict, credentials: dict,
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
@ -228,7 +230,7 @@ class LargeLanguageModel(AIModel):
:return: result generator :return: result generator
""" """
callbacks = callbacks or [] callbacks = callbacks or []
assistant_message = AssistantPromptMessage(content="") message_content: list[PromptMessageContent] = []
usage = None usage = None
system_fingerprint = None system_fingerprint = None
real_model = model real_model = model
@ -250,7 +252,10 @@ class LargeLanguageModel(AIModel):
callbacks=callbacks, callbacks=callbacks,
) )
assistant_message.content += chunk.delta.message.content if isinstance(chunk.delta.message.content, list):
message_content.extend(chunk.delta.message.content)
elif isinstance(chunk.delta.message.content, str):
message_content.append(TextPromptMessageContent(data=chunk.delta.message.content))
real_model = chunk.model real_model = chunk.model
if chunk.delta.usage: if chunk.delta.usage:
usage = chunk.delta.usage usage = chunk.delta.usage
@ -260,6 +265,7 @@ class LargeLanguageModel(AIModel):
except Exception as e: except Exception as e:
raise self._transform_invoke_error(e) raise self._transform_invoke_error(e)
assistant_message = AssistantPromptMessage(content=message_content)
self._trigger_after_invoke_callbacks( self._trigger_after_invoke_callbacks(
model=model, model=model,
result=LLMResult( result=LLMResult(

View File

@ -171,7 +171,7 @@ class BasePluginManager:
line_data = None line_data = None
try: try:
line_data = json.loads(line) line_data = json.loads(line)
rep = PluginDaemonBasicResponse[type](**line_data) # type: ignore rep = PluginDaemonBasicResponse[type](**line_data)
except Exception: except Exception:
# TODO modify this when line_data has code and message # TODO modify this when line_data has code and message
if line_data and "error" in line_data: if line_data and "error" in line_data:

View File

@ -38,3 +38,8 @@ class MemoryRolePrefixRequiredError(LLMNodeError):
class FileTypeNotSupportError(LLMNodeError): class FileTypeNotSupportError(LLMNodeError):
def __init__(self, *, type_name: str): def __init__(self, *, type_name: str):
super().__init__(f"{type_name} type is not supported by this model") super().__init__(f"{type_name} type is not supported by this model")
class UnsupportedPromptContentTypeError(LLMNodeError):
def __init__(self, *, type_name: str) -> None:
super().__init__(f"Prompt content type {type_name} is not supported.")

View File

@ -19,7 +19,7 @@ from core.model_runtime.entities import (
PromptMessageContentType, PromptMessageContentType,
TextPromptMessageContent, TextPromptMessageContent,
) )
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMUsage
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
AssistantPromptMessage, AssistantPromptMessage,
PromptMessageContent, PromptMessageContent,
@ -78,6 +78,7 @@ from .exc import (
ModelNotExistError, ModelNotExistError,
NoPromptFoundError, NoPromptFoundError,
TemplateTypeNotSupportError, TemplateTypeNotSupportError,
UnsupportedPromptContentTypeError,
VariableNotFoundError, VariableNotFoundError,
) )
@ -246,56 +247,62 @@ class LLMNode(BaseNode[LLMNodeData]):
return self._handle_invoke_result(invoke_result=invoke_result) return self._handle_invoke_result(invoke_result=invoke_result)
def _handle_invoke_result(self, invoke_result: LLMResult | Generator) -> Generator[NodeEvent, None, None]: def _handle_invoke_result(
self, invoke_result: LLMResult | Generator[LLMResultChunk, None, None]
) -> Generator[NodeEvent, None, None]:
# For blocking mode
if isinstance(invoke_result, LLMResult): if isinstance(invoke_result, LLMResult):
content = invoke_result.message.content event = self._handle_blocking_result(invoke_result=invoke_result)
if content is None: yield event
message_text = ""
elif isinstance(content, str):
message_text = content
elif isinstance(content, list):
# Assuming the list contains PromptMessageContent objects with a "data" attribute
message_text = "".join(
item.data if hasattr(item, "data") and isinstance(item.data, str) else str(item) for item in content
)
else:
message_text = str(content)
yield ModelInvokeCompletedEvent(
text=message_text,
usage=invoke_result.usage,
finish_reason=None,
)
return return
model = None # For streaming mode
model = ""
prompt_messages: list[PromptMessage] = [] prompt_messages: list[PromptMessage] = []
full_text = "" full_text = ""
usage = None usage = LLMUsage.empty_usage()
finish_reason = None finish_reason = None
for result in invoke_result: for result in invoke_result:
text = result.delta.message.content contents = result.delta.message.content
full_text += text if contents is None:
continue
yield RunStreamChunkEvent(chunk_content=text, from_variable_selector=[self.node_id, "text"]) if isinstance(contents, str):
yield RunStreamChunkEvent(chunk_content=contents, from_variable_selector=[self.node_id, "text"])
full_text += contents
elif isinstance(contents, list):
for content in contents:
if isinstance(content, TextPromptMessageContent):
text_chunk = content.data
elif isinstance(content, ImagePromptMessageContent):
text_chunk = self._image_to_markdown(content)
else:
raise UnsupportedPromptContentTypeError(type_name=str(type(content)))
yield RunStreamChunkEvent(chunk_content=text_chunk, from_variable_selector=[self.node_id])
full_text += text_chunk
if not model: # Update the whole metadata
if not model and result.model:
model = result.model model = result.model
if len(prompt_messages) == 0:
if not prompt_messages:
prompt_messages = result.prompt_messages prompt_messages = result.prompt_messages
if usage.prompt_tokens == 0 and result.delta.usage:
if not usage and result.delta.usage:
usage = result.delta.usage usage = result.delta.usage
if finish_reason is None and result.delta.finish_reason:
if not finish_reason and result.delta.finish_reason:
finish_reason = result.delta.finish_reason finish_reason = result.delta.finish_reason
if not usage:
usage = LLMUsage.empty_usage()
yield ModelInvokeCompletedEvent(text=full_text, usage=usage, finish_reason=finish_reason) yield ModelInvokeCompletedEvent(text=full_text, usage=usage, finish_reason=finish_reason)
def _image_to_markdown(self, content: ImagePromptMessageContent, /):
if content.url:
text_chunk = f"![]({content.url})"
elif content.base64_data:
# insert b64 image into markdown text
text_chunk = f"![]({content.data})"
else:
raise ValueError("Image content must have either a URL or base64 data")
return text_chunk
def _transform_chat_messages( def _transform_chat_messages(
self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, / self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, /
) -> Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate: ) -> Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
@ -926,6 +933,31 @@ class LLMNode(BaseNode[LLMNodeData]):
return prompt_messages return prompt_messages
def _handle_blocking_result(self, *, invoke_result: LLMResult) -> ModelInvokeCompletedEvent:
contents = invoke_result.message.content
if contents is None:
message_text = ""
elif isinstance(contents, str):
message_text = contents
elif isinstance(contents, list):
# TODO: support multi modal content
message_text = ""
for item in contents:
if isinstance(item, TextPromptMessageContent):
message_text += item.data
elif isinstance(item, ImagePromptMessageContent):
message_text += self._image_to_markdown(item)
else:
message_text += str(item)
else:
message_text = str(contents)
return ModelInvokeCompletedEvent(
text=message_text,
usage=invoke_result.usage,
finish_reason=None,
)
def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole): def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole):
match role: match role: