Compare commits
5 Commits
main
...
feat/suppo
Author | SHA1 | Date | |
---|---|---|---|
![]() |
3a02438131 | ||
![]() |
f31fda62c9 | ||
![]() |
4ce2819263 | ||
![]() |
e945afb3cd | ||
![]() |
c43f388586 |
@ -101,7 +101,7 @@ class ModelInstance:
|
|||||||
@overload
|
@overload
|
||||||
def invoke_llm(
|
def invoke_llm(
|
||||||
self,
|
self,
|
||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: Sequence[PromptMessage],
|
||||||
model_parameters: Optional[dict] = None,
|
model_parameters: Optional[dict] = None,
|
||||||
tools: Sequence[PromptMessageTool] | None = None,
|
tools: Sequence[PromptMessageTool] | None = None,
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[list[str]] = None,
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from abc import ABC
|
from abc import ABC
|
||||||
from collections.abc import Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from enum import Enum, StrEnum
|
from enum import Enum, StrEnum
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@ -119,6 +119,15 @@ class DocumentPromptMessageContent(MultiModalPromptMessageContent):
|
|||||||
type: PromptMessageContentType = PromptMessageContentType.DOCUMENT
|
type: PromptMessageContentType = PromptMessageContentType.DOCUMENT
|
||||||
|
|
||||||
|
|
||||||
|
CONTENT_TYPE_MAPPING: Mapping[PromptMessageContentType, type[PromptMessageContent]] = {
|
||||||
|
PromptMessageContentType.TEXT: TextPromptMessageContent,
|
||||||
|
PromptMessageContentType.IMAGE: ImagePromptMessageContent,
|
||||||
|
PromptMessageContentType.AUDIO: AudioPromptMessageContent,
|
||||||
|
PromptMessageContentType.VIDEO: VideoPromptMessageContent,
|
||||||
|
PromptMessageContentType.DOCUMENT: DocumentPromptMessageContent,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class PromptMessage(ABC, BaseModel):
|
class PromptMessage(ABC, BaseModel):
|
||||||
"""
|
"""
|
||||||
Model class for prompt message.
|
Model class for prompt message.
|
||||||
@ -136,6 +145,23 @@ class PromptMessage(ABC, BaseModel):
|
|||||||
"""
|
"""
|
||||||
return not self.content
|
return not self.content
|
||||||
|
|
||||||
|
@field_validator("content", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_content(cls, v):
|
||||||
|
if isinstance(v, list):
|
||||||
|
prompts = []
|
||||||
|
for prompt in v:
|
||||||
|
if isinstance(prompt, PromptMessageContent):
|
||||||
|
if not isinstance(prompt, TextPromptMessageContent | MultiModalPromptMessageContent):
|
||||||
|
prompt = CONTENT_TYPE_MAPPING[prompt.type].model_validate(prompt.model_dump())
|
||||||
|
elif isinstance(prompt, dict):
|
||||||
|
prompt = CONTENT_TYPE_MAPPING[prompt["type"]].model_validate(prompt)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"invalid prompt message {prompt}")
|
||||||
|
prompts.append(prompt)
|
||||||
|
return prompts
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
class UserPromptMessage(PromptMessage):
|
class UserPromptMessage(PromptMessage):
|
||||||
"""
|
"""
|
||||||
|
@ -12,7 +12,9 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk,
|
|||||||
from core.model_runtime.entities.message_entities import (
|
from core.model_runtime.entities.message_entities import (
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
PromptMessage,
|
PromptMessage,
|
||||||
|
PromptMessageContent,
|
||||||
PromptMessageTool,
|
PromptMessageTool,
|
||||||
|
TextPromptMessageContent,
|
||||||
)
|
)
|
||||||
from core.model_runtime.entities.model_entities import (
|
from core.model_runtime.entities.model_entities import (
|
||||||
ModelType,
|
ModelType,
|
||||||
@ -211,7 +213,7 @@ class LargeLanguageModel(AIModel):
|
|||||||
def _invoke_result_generator(
|
def _invoke_result_generator(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
result: Generator,
|
result: Generator[LLMResultChunk, None, None],
|
||||||
credentials: dict,
|
credentials: dict,
|
||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: list[PromptMessage],
|
||||||
model_parameters: dict,
|
model_parameters: dict,
|
||||||
@ -228,7 +230,7 @@ class LargeLanguageModel(AIModel):
|
|||||||
:return: result generator
|
:return: result generator
|
||||||
"""
|
"""
|
||||||
callbacks = callbacks or []
|
callbacks = callbacks or []
|
||||||
assistant_message = AssistantPromptMessage(content="")
|
message_content: list[PromptMessageContent] = []
|
||||||
usage = None
|
usage = None
|
||||||
system_fingerprint = None
|
system_fingerprint = None
|
||||||
real_model = model
|
real_model = model
|
||||||
@ -250,7 +252,10 @@ class LargeLanguageModel(AIModel):
|
|||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
)
|
)
|
||||||
|
|
||||||
assistant_message.content += chunk.delta.message.content
|
if isinstance(chunk.delta.message.content, list):
|
||||||
|
message_content.extend(chunk.delta.message.content)
|
||||||
|
elif isinstance(chunk.delta.message.content, str):
|
||||||
|
message_content.append(TextPromptMessageContent(data=chunk.delta.message.content))
|
||||||
real_model = chunk.model
|
real_model = chunk.model
|
||||||
if chunk.delta.usage:
|
if chunk.delta.usage:
|
||||||
usage = chunk.delta.usage
|
usage = chunk.delta.usage
|
||||||
@ -260,6 +265,7 @@ class LargeLanguageModel(AIModel):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise self._transform_invoke_error(e)
|
raise self._transform_invoke_error(e)
|
||||||
|
|
||||||
|
assistant_message = AssistantPromptMessage(content=message_content)
|
||||||
self._trigger_after_invoke_callbacks(
|
self._trigger_after_invoke_callbacks(
|
||||||
model=model,
|
model=model,
|
||||||
result=LLMResult(
|
result=LLMResult(
|
||||||
|
@ -171,7 +171,7 @@ class BasePluginManager:
|
|||||||
line_data = None
|
line_data = None
|
||||||
try:
|
try:
|
||||||
line_data = json.loads(line)
|
line_data = json.loads(line)
|
||||||
rep = PluginDaemonBasicResponse[type](**line_data) # type: ignore
|
rep = PluginDaemonBasicResponse[type](**line_data)
|
||||||
except Exception:
|
except Exception:
|
||||||
# TODO modify this when line_data has code and message
|
# TODO modify this when line_data has code and message
|
||||||
if line_data and "error" in line_data:
|
if line_data and "error" in line_data:
|
||||||
|
@ -38,3 +38,8 @@ class MemoryRolePrefixRequiredError(LLMNodeError):
|
|||||||
class FileTypeNotSupportError(LLMNodeError):
|
class FileTypeNotSupportError(LLMNodeError):
|
||||||
def __init__(self, *, type_name: str):
|
def __init__(self, *, type_name: str):
|
||||||
super().__init__(f"{type_name} type is not supported by this model")
|
super().__init__(f"{type_name} type is not supported by this model")
|
||||||
|
|
||||||
|
|
||||||
|
class UnsupportedPromptContentTypeError(LLMNodeError):
|
||||||
|
def __init__(self, *, type_name: str) -> None:
|
||||||
|
super().__init__(f"Prompt content type {type_name} is not supported.")
|
||||||
|
@ -19,7 +19,7 @@ from core.model_runtime.entities import (
|
|||||||
PromptMessageContentType,
|
PromptMessageContentType,
|
||||||
TextPromptMessageContent,
|
TextPromptMessageContent,
|
||||||
)
|
)
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMUsage
|
||||||
from core.model_runtime.entities.message_entities import (
|
from core.model_runtime.entities.message_entities import (
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
PromptMessageContent,
|
PromptMessageContent,
|
||||||
@ -78,6 +78,7 @@ from .exc import (
|
|||||||
ModelNotExistError,
|
ModelNotExistError,
|
||||||
NoPromptFoundError,
|
NoPromptFoundError,
|
||||||
TemplateTypeNotSupportError,
|
TemplateTypeNotSupportError,
|
||||||
|
UnsupportedPromptContentTypeError,
|
||||||
VariableNotFoundError,
|
VariableNotFoundError,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -246,56 +247,62 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
|
|
||||||
return self._handle_invoke_result(invoke_result=invoke_result)
|
return self._handle_invoke_result(invoke_result=invoke_result)
|
||||||
|
|
||||||
def _handle_invoke_result(self, invoke_result: LLMResult | Generator) -> Generator[NodeEvent, None, None]:
|
def _handle_invoke_result(
|
||||||
|
self, invoke_result: LLMResult | Generator[LLMResultChunk, None, None]
|
||||||
|
) -> Generator[NodeEvent, None, None]:
|
||||||
|
# For blocking mode
|
||||||
if isinstance(invoke_result, LLMResult):
|
if isinstance(invoke_result, LLMResult):
|
||||||
content = invoke_result.message.content
|
event = self._handle_blocking_result(invoke_result=invoke_result)
|
||||||
if content is None:
|
yield event
|
||||||
message_text = ""
|
|
||||||
elif isinstance(content, str):
|
|
||||||
message_text = content
|
|
||||||
elif isinstance(content, list):
|
|
||||||
# Assuming the list contains PromptMessageContent objects with a "data" attribute
|
|
||||||
message_text = "".join(
|
|
||||||
item.data if hasattr(item, "data") and isinstance(item.data, str) else str(item) for item in content
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
message_text = str(content)
|
|
||||||
|
|
||||||
yield ModelInvokeCompletedEvent(
|
|
||||||
text=message_text,
|
|
||||||
usage=invoke_result.usage,
|
|
||||||
finish_reason=None,
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
model = None
|
# For streaming mode
|
||||||
|
model = ""
|
||||||
prompt_messages: list[PromptMessage] = []
|
prompt_messages: list[PromptMessage] = []
|
||||||
full_text = ""
|
full_text = ""
|
||||||
usage = None
|
usage = LLMUsage.empty_usage()
|
||||||
finish_reason = None
|
finish_reason = None
|
||||||
for result in invoke_result:
|
for result in invoke_result:
|
||||||
text = result.delta.message.content
|
contents = result.delta.message.content
|
||||||
full_text += text
|
if contents is None:
|
||||||
|
continue
|
||||||
|
|
||||||
yield RunStreamChunkEvent(chunk_content=text, from_variable_selector=[self.node_id, "text"])
|
if isinstance(contents, str):
|
||||||
|
yield RunStreamChunkEvent(chunk_content=contents, from_variable_selector=[self.node_id, "text"])
|
||||||
|
full_text += contents
|
||||||
|
elif isinstance(contents, list):
|
||||||
|
for content in contents:
|
||||||
|
if isinstance(content, TextPromptMessageContent):
|
||||||
|
text_chunk = content.data
|
||||||
|
elif isinstance(content, ImagePromptMessageContent):
|
||||||
|
text_chunk = self._image_to_markdown(content)
|
||||||
|
else:
|
||||||
|
raise UnsupportedPromptContentTypeError(type_name=str(type(content)))
|
||||||
|
yield RunStreamChunkEvent(chunk_content=text_chunk, from_variable_selector=[self.node_id])
|
||||||
|
full_text += text_chunk
|
||||||
|
|
||||||
if not model:
|
# Update the whole metadata
|
||||||
|
if not model and result.model:
|
||||||
model = result.model
|
model = result.model
|
||||||
|
if len(prompt_messages) == 0:
|
||||||
if not prompt_messages:
|
|
||||||
prompt_messages = result.prompt_messages
|
prompt_messages = result.prompt_messages
|
||||||
|
if usage.prompt_tokens == 0 and result.delta.usage:
|
||||||
if not usage and result.delta.usage:
|
|
||||||
usage = result.delta.usage
|
usage = result.delta.usage
|
||||||
|
if finish_reason is None and result.delta.finish_reason:
|
||||||
if not finish_reason and result.delta.finish_reason:
|
|
||||||
finish_reason = result.delta.finish_reason
|
finish_reason = result.delta.finish_reason
|
||||||
|
|
||||||
if not usage:
|
|
||||||
usage = LLMUsage.empty_usage()
|
|
||||||
|
|
||||||
yield ModelInvokeCompletedEvent(text=full_text, usage=usage, finish_reason=finish_reason)
|
yield ModelInvokeCompletedEvent(text=full_text, usage=usage, finish_reason=finish_reason)
|
||||||
|
|
||||||
|
def _image_to_markdown(self, content: ImagePromptMessageContent, /):
|
||||||
|
if content.url:
|
||||||
|
text_chunk = f""
|
||||||
|
elif content.base64_data:
|
||||||
|
# insert b64 image into markdown text
|
||||||
|
text_chunk = f""
|
||||||
|
else:
|
||||||
|
raise ValueError("Image content must have either a URL or base64 data")
|
||||||
|
return text_chunk
|
||||||
|
|
||||||
def _transform_chat_messages(
|
def _transform_chat_messages(
|
||||||
self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, /
|
self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, /
|
||||||
) -> Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
|
) -> Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
|
||||||
@ -926,6 +933,31 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
|
|
||||||
return prompt_messages
|
return prompt_messages
|
||||||
|
|
||||||
|
def _handle_blocking_result(self, *, invoke_result: LLMResult) -> ModelInvokeCompletedEvent:
|
||||||
|
contents = invoke_result.message.content
|
||||||
|
if contents is None:
|
||||||
|
message_text = ""
|
||||||
|
elif isinstance(contents, str):
|
||||||
|
message_text = contents
|
||||||
|
elif isinstance(contents, list):
|
||||||
|
# TODO: support multi modal content
|
||||||
|
message_text = ""
|
||||||
|
for item in contents:
|
||||||
|
if isinstance(item, TextPromptMessageContent):
|
||||||
|
message_text += item.data
|
||||||
|
elif isinstance(item, ImagePromptMessageContent):
|
||||||
|
message_text += self._image_to_markdown(item)
|
||||||
|
else:
|
||||||
|
message_text += str(item)
|
||||||
|
else:
|
||||||
|
message_text = str(contents)
|
||||||
|
|
||||||
|
return ModelInvokeCompletedEvent(
|
||||||
|
text=message_text,
|
||||||
|
usage=invoke_result.usage,
|
||||||
|
finish_reason=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole):
|
def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole):
|
||||||
match role:
|
match role:
|
||||||
|
Loading…
Reference in New Issue
Block a user