Signed-off-by: yihong0618 <zouzou0208@gmail.com> Signed-off-by: -LAN- <laipz8200@outlook.com> Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
parent
9277156b6c
commit
d9d5d35a77
@ -2,7 +2,7 @@ from datetime import datetime
|
|||||||
from enum import Enum, StrEnum
|
from enum import Enum, StrEnum
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, field_validator
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||||
@ -113,18 +113,6 @@ class QueueIterationNextEvent(AppQueueEvent):
|
|||||||
output: Optional[Any] = None # output for the current iteration
|
output: Optional[Any] = None # output for the current iteration
|
||||||
duration: Optional[float] = None
|
duration: Optional[float] = None
|
||||||
|
|
||||||
@field_validator("output", mode="before")
|
|
||||||
@classmethod
|
|
||||||
def set_output(cls, v):
|
|
||||||
"""
|
|
||||||
Set output
|
|
||||||
"""
|
|
||||||
if v is None:
|
|
||||||
return None
|
|
||||||
if isinstance(v, int | float | str | bool | dict | list):
|
|
||||||
return v
|
|
||||||
raise ValueError("output must be a valid type")
|
|
||||||
|
|
||||||
|
|
||||||
class QueueIterationCompletedEvent(AppQueueEvent):
|
class QueueIterationCompletedEvent(AppQueueEvent):
|
||||||
"""
|
"""
|
||||||
|
@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Any, Optional, cast
|
|||||||
from flask import Flask, current_app
|
from flask import Flask, current_app
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.variables import IntegerVariable
|
||||||
from core.workflow.entities.node_entities import (
|
from core.workflow.entities.node_entities import (
|
||||||
NodeRunMetadataKey,
|
NodeRunMetadataKey,
|
||||||
NodeRunResult,
|
NodeRunResult,
|
||||||
@ -155,18 +155,19 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|||||||
iteration_node_data=self.node_data,
|
iteration_node_data=self.node_data,
|
||||||
index=0,
|
index=0,
|
||||||
pre_iteration_output=None,
|
pre_iteration_output=None,
|
||||||
|
duration=None,
|
||||||
)
|
)
|
||||||
iter_run_map: dict[str, float] = {}
|
iter_run_map: dict[str, float] = {}
|
||||||
outputs: list[Any] = [None] * len(iterator_list_value)
|
outputs: list[Any] = [None] * len(iterator_list_value)
|
||||||
try:
|
try:
|
||||||
if self.node_data.is_parallel:
|
if self.node_data.is_parallel:
|
||||||
futures: list[Future] = []
|
futures: list[Future] = []
|
||||||
q = Queue()
|
q: Queue = Queue()
|
||||||
thread_pool = GraphEngineThreadPool(max_workers=self.node_data.parallel_nums, max_submit_count=100)
|
thread_pool = GraphEngineThreadPool(max_workers=self.node_data.parallel_nums, max_submit_count=100)
|
||||||
for index, item in enumerate(iterator_list_value):
|
for index, item in enumerate(iterator_list_value):
|
||||||
future: Future = thread_pool.submit(
|
future: Future = thread_pool.submit(
|
||||||
self._run_single_iter_parallel,
|
self._run_single_iter_parallel,
|
||||||
current_app._get_current_object(),
|
current_app._get_current_object(), # type: ignore
|
||||||
q,
|
q,
|
||||||
iterator_list_value,
|
iterator_list_value,
|
||||||
inputs,
|
inputs,
|
||||||
@ -181,6 +182,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|||||||
future.add_done_callback(thread_pool.task_done_callback)
|
future.add_done_callback(thread_pool.task_done_callback)
|
||||||
futures.append(future)
|
futures.append(future)
|
||||||
succeeded_count = 0
|
succeeded_count = 0
|
||||||
|
empty_count = 0
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
event = q.get(timeout=1)
|
event = q.get(timeout=1)
|
||||||
@ -208,17 +210,22 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|||||||
else:
|
else:
|
||||||
for _ in range(len(iterator_list_value)):
|
for _ in range(len(iterator_list_value)):
|
||||||
yield from self._run_single_iter(
|
yield from self._run_single_iter(
|
||||||
iterator_list_value,
|
iterator_list_value=iterator_list_value,
|
||||||
variable_pool,
|
variable_pool=variable_pool,
|
||||||
inputs,
|
inputs=inputs,
|
||||||
outputs,
|
outputs=outputs,
|
||||||
start_at,
|
start_at=start_at,
|
||||||
graph_engine,
|
graph_engine=graph_engine,
|
||||||
iteration_graph,
|
iteration_graph=iteration_graph,
|
||||||
iter_run_map,
|
iter_run_map=iter_run_map,
|
||||||
)
|
)
|
||||||
if self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
|
if self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
|
||||||
outputs = [output for output in outputs if output is not None]
|
outputs = [output for output in outputs if output is not None]
|
||||||
|
|
||||||
|
# Flatten the list of lists
|
||||||
|
if isinstance(outputs, list) and all(isinstance(output, list) for output in outputs):
|
||||||
|
outputs = [item for sublist in outputs for item in sublist]
|
||||||
|
|
||||||
yield IterationRunSucceededEvent(
|
yield IterationRunSucceededEvent(
|
||||||
iteration_id=self.id,
|
iteration_id=self.id,
|
||||||
iteration_node_id=self.node_id,
|
iteration_node_id=self.node_id,
|
||||||
@ -226,7 +233,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|||||||
iteration_node_data=self.node_data,
|
iteration_node_data=self.node_data,
|
||||||
start_at=start_at,
|
start_at=start_at,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
outputs={"output": jsonable_encoder(outputs)},
|
outputs={"output": outputs},
|
||||||
steps=len(iterator_list_value),
|
steps=len(iterator_list_value),
|
||||||
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
|
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
|
||||||
)
|
)
|
||||||
@ -234,7 +241,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|||||||
yield RunCompletedEvent(
|
yield RunCompletedEvent(
|
||||||
run_result=NodeRunResult(
|
run_result=NodeRunResult(
|
||||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||||
outputs={"output": jsonable_encoder(outputs)},
|
outputs={"output": outputs},
|
||||||
metadata={NodeRunMetadataKey.ITERATION_DURATION_MAP: iter_run_map},
|
metadata={NodeRunMetadataKey.ITERATION_DURATION_MAP: iter_run_map},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -248,7 +255,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|||||||
iteration_node_data=self.node_data,
|
iteration_node_data=self.node_data,
|
||||||
start_at=start_at,
|
start_at=start_at,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
outputs={"output": jsonable_encoder(outputs)},
|
outputs={"output": outputs},
|
||||||
steps=len(iterator_list_value),
|
steps=len(iterator_list_value),
|
||||||
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
|
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
|
||||||
error=str(e),
|
error=str(e),
|
||||||
@ -280,7 +287,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|||||||
:param node_data: node data
|
:param node_data: node data
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
variable_mapping = {
|
variable_mapping: dict[str, Sequence[str]] = {
|
||||||
f"{node_id}.input_selector": node_data.iterator_selector,
|
f"{node_id}.input_selector": node_data.iterator_selector,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -308,7 +315,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|||||||
sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||||
graph_config=graph_config, config=sub_node_config
|
graph_config=graph_config, config=sub_node_config
|
||||||
)
|
)
|
||||||
sub_node_variable_mapping = cast(dict[str, list[str]], sub_node_variable_mapping)
|
sub_node_variable_mapping = cast(dict[str, Sequence[str]], sub_node_variable_mapping)
|
||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
sub_node_variable_mapping = {}
|
sub_node_variable_mapping = {}
|
||||||
|
|
||||||
@ -329,8 +336,12 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|||||||
return variable_mapping
|
return variable_mapping
|
||||||
|
|
||||||
def _handle_event_metadata(
|
def _handle_event_metadata(
|
||||||
self, event: BaseNodeEvent, iter_run_index: str, parallel_mode_run_id: str
|
self,
|
||||||
) -> NodeRunStartedEvent | BaseNodeEvent:
|
*,
|
||||||
|
event: BaseNodeEvent | InNodeEvent,
|
||||||
|
iter_run_index: int,
|
||||||
|
parallel_mode_run_id: str | None,
|
||||||
|
) -> NodeRunStartedEvent | BaseNodeEvent | InNodeEvent:
|
||||||
"""
|
"""
|
||||||
add iteration metadata to event.
|
add iteration metadata to event.
|
||||||
"""
|
"""
|
||||||
@ -355,6 +366,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|||||||
|
|
||||||
def _run_single_iter(
|
def _run_single_iter(
|
||||||
self,
|
self,
|
||||||
|
*,
|
||||||
iterator_list_value: list[str],
|
iterator_list_value: list[str],
|
||||||
variable_pool: VariablePool,
|
variable_pool: VariablePool,
|
||||||
inputs: dict[str, list],
|
inputs: dict[str, list],
|
||||||
@ -373,12 +385,12 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|||||||
try:
|
try:
|
||||||
rst = graph_engine.run()
|
rst = graph_engine.run()
|
||||||
# get current iteration index
|
# get current iteration index
|
||||||
current_index = variable_pool.get([self.node_id, "index"]).value
|
index_variable = variable_pool.get([self.node_id, "index"])
|
||||||
|
if not isinstance(index_variable, IntegerVariable):
|
||||||
|
raise IterationIndexNotFoundError(f"iteration {self.node_id} current index not found")
|
||||||
|
current_index = index_variable.value
|
||||||
iteration_run_id = parallel_mode_run_id if parallel_mode_run_id is not None else f"{current_index}"
|
iteration_run_id = parallel_mode_run_id if parallel_mode_run_id is not None else f"{current_index}"
|
||||||
next_index = int(current_index) + 1
|
next_index = int(current_index) + 1
|
||||||
|
|
||||||
if current_index is None:
|
|
||||||
raise IterationIndexNotFoundError(f"iteration {self.node_id} current index not found")
|
|
||||||
for event in rst:
|
for event in rst:
|
||||||
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id:
|
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id:
|
||||||
event.in_iteration_id = self.node_id
|
event.in_iteration_id = self.node_id
|
||||||
@ -391,7 +403,9 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
if isinstance(event, NodeRunSucceededEvent):
|
if isinstance(event, NodeRunSucceededEvent):
|
||||||
yield self._handle_event_metadata(event, current_index, parallel_mode_run_id)
|
yield self._handle_event_metadata(
|
||||||
|
event=event, iter_run_index=current_index, parallel_mode_run_id=parallel_mode_run_id
|
||||||
|
)
|
||||||
elif isinstance(event, BaseGraphEvent):
|
elif isinstance(event, BaseGraphEvent):
|
||||||
if isinstance(event, GraphRunFailedEvent):
|
if isinstance(event, GraphRunFailedEvent):
|
||||||
# iteration run failed
|
# iteration run failed
|
||||||
@ -404,7 +418,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|||||||
parallel_mode_run_id=parallel_mode_run_id,
|
parallel_mode_run_id=parallel_mode_run_id,
|
||||||
start_at=start_at,
|
start_at=start_at,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
outputs={"output": jsonable_encoder(outputs)},
|
outputs={"output": outputs},
|
||||||
steps=len(iterator_list_value),
|
steps=len(iterator_list_value),
|
||||||
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
|
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
|
||||||
error=event.error,
|
error=event.error,
|
||||||
@ -417,7 +431,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|||||||
iteration_node_data=self.node_data,
|
iteration_node_data=self.node_data,
|
||||||
start_at=start_at,
|
start_at=start_at,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
outputs={"output": jsonable_encoder(outputs)},
|
outputs={"output": outputs},
|
||||||
steps=len(iterator_list_value),
|
steps=len(iterator_list_value),
|
||||||
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
|
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
|
||||||
error=event.error,
|
error=event.error,
|
||||||
@ -429,9 +443,11 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
else:
|
elif isinstance(event, InNodeEvent):
|
||||||
event = cast(InNodeEvent, event)
|
# event = cast(InNodeEvent, event)
|
||||||
metadata_event = self._handle_event_metadata(event, current_index, parallel_mode_run_id)
|
metadata_event = self._handle_event_metadata(
|
||||||
|
event=event, iter_run_index=current_index, parallel_mode_run_id=parallel_mode_run_id
|
||||||
|
)
|
||||||
if isinstance(event, NodeRunFailedEvent):
|
if isinstance(event, NodeRunFailedEvent):
|
||||||
if self.node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR:
|
if self.node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR:
|
||||||
yield NodeInIterationFailedEvent(
|
yield NodeInIterationFailedEvent(
|
||||||
@ -513,7 +529,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|||||||
iteration_node_data=self.node_data,
|
iteration_node_data=self.node_data,
|
||||||
index=next_index,
|
index=next_index,
|
||||||
parallel_mode_run_id=parallel_mode_run_id,
|
parallel_mode_run_id=parallel_mode_run_id,
|
||||||
pre_iteration_output=jsonable_encoder(current_iteration_output) if current_iteration_output else None,
|
pre_iteration_output=current_iteration_output or None,
|
||||||
duration=duration,
|
duration=duration,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -551,7 +567,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|||||||
index: int,
|
index: int,
|
||||||
item: Any,
|
item: Any,
|
||||||
iter_run_map: dict[str, float],
|
iter_run_map: dict[str, float],
|
||||||
) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
):
|
||||||
"""
|
"""
|
||||||
run single iteration in parallel mode
|
run single iteration in parallel mode
|
||||||
"""
|
"""
|
||||||
|
Loading…
Reference in New Issue
Block a user