fix: issue #10596 by making the iteration node outputs right (#11394)

Signed-off-by: yihong0618 <zouzou0208@gmail.com>
Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
yihong 2024-12-07 16:28:15 +08:00 committed by GitHub
parent 9277156b6c
commit d9d5d35a77
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 47 additions and 43 deletions

View File

@ -2,7 +2,7 @@ from datetime import datetime
from enum import Enum, StrEnum
from typing import Any, Optional
from pydantic import BaseModel, field_validator
from pydantic import BaseModel
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.workflow.entities.node_entities import NodeRunMetadataKey
@ -113,18 +113,6 @@ class QueueIterationNextEvent(AppQueueEvent):
output: Optional[Any] = None # output for the current iteration
duration: Optional[float] = None
@field_validator("output", mode="before")
@classmethod
def set_output(cls, v):
"""
Set output
"""
if v is None:
return None
if isinstance(v, int | float | str | bool | dict | list):
return v
raise ValueError("output must be a valid type")
class QueueIterationCompletedEvent(AppQueueEvent):
"""

View File

@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Any, Optional, cast
from flask import Flask, current_app
from configs import dify_config
from core.model_runtime.utils.encoders import jsonable_encoder
from core.variables import IntegerVariable
from core.workflow.entities.node_entities import (
NodeRunMetadataKey,
NodeRunResult,
@ -155,18 +155,19 @@ class IterationNode(BaseNode[IterationNodeData]):
iteration_node_data=self.node_data,
index=0,
pre_iteration_output=None,
duration=None,
)
iter_run_map: dict[str, float] = {}
outputs: list[Any] = [None] * len(iterator_list_value)
try:
if self.node_data.is_parallel:
futures: list[Future] = []
q = Queue()
q: Queue = Queue()
thread_pool = GraphEngineThreadPool(max_workers=self.node_data.parallel_nums, max_submit_count=100)
for index, item in enumerate(iterator_list_value):
future: Future = thread_pool.submit(
self._run_single_iter_parallel,
current_app._get_current_object(),
current_app._get_current_object(), # type: ignore
q,
iterator_list_value,
inputs,
@ -181,6 +182,7 @@ class IterationNode(BaseNode[IterationNodeData]):
future.add_done_callback(thread_pool.task_done_callback)
futures.append(future)
succeeded_count = 0
empty_count = 0
while True:
try:
event = q.get(timeout=1)
@ -208,17 +210,22 @@ class IterationNode(BaseNode[IterationNodeData]):
else:
for _ in range(len(iterator_list_value)):
yield from self._run_single_iter(
iterator_list_value,
variable_pool,
inputs,
outputs,
start_at,
graph_engine,
iteration_graph,
iter_run_map,
iterator_list_value=iterator_list_value,
variable_pool=variable_pool,
inputs=inputs,
outputs=outputs,
start_at=start_at,
graph_engine=graph_engine,
iteration_graph=iteration_graph,
iter_run_map=iter_run_map,
)
if self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
outputs = [output for output in outputs if output is not None]
# Flatten the list of lists
if isinstance(outputs, list) and all(isinstance(output, list) for output in outputs):
outputs = [item for sublist in outputs for item in sublist]
yield IterationRunSucceededEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
@ -226,7 +233,7 @@ class IterationNode(BaseNode[IterationNodeData]):
iteration_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
outputs={"output": jsonable_encoder(outputs)},
outputs={"output": outputs},
steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
)
@ -234,7 +241,7 @@ class IterationNode(BaseNode[IterationNodeData]):
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"output": jsonable_encoder(outputs)},
outputs={"output": outputs},
metadata={NodeRunMetadataKey.ITERATION_DURATION_MAP: iter_run_map},
)
)
@ -248,7 +255,7 @@ class IterationNode(BaseNode[IterationNodeData]):
iteration_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
outputs={"output": jsonable_encoder(outputs)},
outputs={"output": outputs},
steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
error=str(e),
@ -280,7 +287,7 @@ class IterationNode(BaseNode[IterationNodeData]):
:param node_data: node data
:return:
"""
variable_mapping = {
variable_mapping: dict[str, Sequence[str]] = {
f"{node_id}.input_selector": node_data.iterator_selector,
}
@ -308,7 +315,7 @@ class IterationNode(BaseNode[IterationNodeData]):
sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=graph_config, config=sub_node_config
)
sub_node_variable_mapping = cast(dict[str, list[str]], sub_node_variable_mapping)
sub_node_variable_mapping = cast(dict[str, Sequence[str]], sub_node_variable_mapping)
except NotImplementedError:
sub_node_variable_mapping = {}
@ -329,8 +336,12 @@ class IterationNode(BaseNode[IterationNodeData]):
return variable_mapping
def _handle_event_metadata(
self, event: BaseNodeEvent, iter_run_index: str, parallel_mode_run_id: str
) -> NodeRunStartedEvent | BaseNodeEvent:
self,
*,
event: BaseNodeEvent | InNodeEvent,
iter_run_index: int,
parallel_mode_run_id: str | None,
) -> NodeRunStartedEvent | BaseNodeEvent | InNodeEvent:
"""
add iteration metadata to event.
"""
@ -355,6 +366,7 @@ class IterationNode(BaseNode[IterationNodeData]):
def _run_single_iter(
self,
*,
iterator_list_value: list[str],
variable_pool: VariablePool,
inputs: dict[str, list],
@ -373,12 +385,12 @@ class IterationNode(BaseNode[IterationNodeData]):
try:
rst = graph_engine.run()
# get current iteration index
current_index = variable_pool.get([self.node_id, "index"]).value
index_variable = variable_pool.get([self.node_id, "index"])
if not isinstance(index_variable, IntegerVariable):
raise IterationIndexNotFoundError(f"iteration {self.node_id} current index not found")
current_index = index_variable.value
iteration_run_id = parallel_mode_run_id if parallel_mode_run_id is not None else f"{current_index}"
next_index = int(current_index) + 1
if current_index is None:
raise IterationIndexNotFoundError(f"iteration {self.node_id} current index not found")
for event in rst:
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id:
event.in_iteration_id = self.node_id
@ -391,7 +403,9 @@ class IterationNode(BaseNode[IterationNodeData]):
continue
if isinstance(event, NodeRunSucceededEvent):
yield self._handle_event_metadata(event, current_index, parallel_mode_run_id)
yield self._handle_event_metadata(
event=event, iter_run_index=current_index, parallel_mode_run_id=parallel_mode_run_id
)
elif isinstance(event, BaseGraphEvent):
if isinstance(event, GraphRunFailedEvent):
# iteration run failed
@ -404,7 +418,7 @@ class IterationNode(BaseNode[IterationNodeData]):
parallel_mode_run_id=parallel_mode_run_id,
start_at=start_at,
inputs=inputs,
outputs={"output": jsonable_encoder(outputs)},
outputs={"output": outputs},
steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
error=event.error,
@ -417,7 +431,7 @@ class IterationNode(BaseNode[IterationNodeData]):
iteration_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
outputs={"output": jsonable_encoder(outputs)},
outputs={"output": outputs},
steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
error=event.error,
@ -429,9 +443,11 @@ class IterationNode(BaseNode[IterationNodeData]):
)
)
return
else:
event = cast(InNodeEvent, event)
metadata_event = self._handle_event_metadata(event, current_index, parallel_mode_run_id)
elif isinstance(event, InNodeEvent):
# event = cast(InNodeEvent, event)
metadata_event = self._handle_event_metadata(
event=event, iter_run_index=current_index, parallel_mode_run_id=parallel_mode_run_id
)
if isinstance(event, NodeRunFailedEvent):
if self.node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR:
yield NodeInIterationFailedEvent(
@ -513,7 +529,7 @@ class IterationNode(BaseNode[IterationNodeData]):
iteration_node_data=self.node_data,
index=next_index,
parallel_mode_run_id=parallel_mode_run_id,
pre_iteration_output=jsonable_encoder(current_iteration_output) if current_iteration_output else None,
pre_iteration_output=current_iteration_output or None,
duration=duration,
)
@ -551,7 +567,7 @@ class IterationNode(BaseNode[IterationNodeData]):
index: int,
item: Any,
iter_run_map: dict[str, float],
) -> Generator[NodeEvent | InNodeEvent, None, None]:
):
"""
run single iteration in parallel mode
"""