From d9d5d35a7726beaf09f2feb41b49a5815878b0bd Mon Sep 17 00:00:00 2001 From: yihong Date: Sat, 7 Dec 2024 16:28:15 +0800 Subject: [PATCH] fix: issue #10596 by making the iteration node outputs right (#11394) Signed-off-by: yihong0618 Signed-off-by: -LAN- Co-authored-by: -LAN- --- api/core/app/entities/queue_entities.py | 14 +--- .../nodes/iteration/iteration_node.py | 76 +++++++++++-------- 2 files changed, 47 insertions(+), 43 deletions(-) diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 15543638fc..5e9b6517ba 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -2,7 +2,7 @@ from datetime import datetime from enum import Enum, StrEnum from typing import Any, Optional -from pydantic import BaseModel, field_validator +from pydantic import BaseModel from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from core.workflow.entities.node_entities import NodeRunMetadataKey @@ -113,18 +113,6 @@ class QueueIterationNextEvent(AppQueueEvent): output: Optional[Any] = None # output for the current iteration duration: Optional[float] = None - @field_validator("output", mode="before") - @classmethod - def set_output(cls, v): - """ - Set output - """ - if v is None: - return None - if isinstance(v, int | float | str | bool | dict | list): - return v - raise ValueError("output must be a valid type") - class QueueIterationCompletedEvent(AppQueueEvent): """ diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index bba6ac20d3..74ec95deaa 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Any, Optional, cast from flask import Flask, current_app from configs import dify_config -from core.model_runtime.utils.encoders import jsonable_encoder +from core.variables import IntegerVariable from core.workflow.entities.node_entities import ( NodeRunMetadataKey, NodeRunResult, @@ -155,18 +155,19 @@ class IterationNode(BaseNode[IterationNodeData]): iteration_node_data=self.node_data, index=0, pre_iteration_output=None, + duration=None, ) iter_run_map: dict[str, float] = {} outputs: list[Any] = [None] * len(iterator_list_value) try: if self.node_data.is_parallel: futures: list[Future] = [] - q = Queue() + q: Queue = Queue() thread_pool = GraphEngineThreadPool(max_workers=self.node_data.parallel_nums, max_submit_count=100) for index, item in enumerate(iterator_list_value): future: Future = thread_pool.submit( self._run_single_iter_parallel, - current_app._get_current_object(), + current_app._get_current_object(), # type: ignore q, iterator_list_value, inputs, @@ -181,6 +182,7 @@ class IterationNode(BaseNode[IterationNodeData]): future.add_done_callback(thread_pool.task_done_callback) futures.append(future) succeeded_count = 0 + empty_count = 0 while True: try: event = q.get(timeout=1) @@ -208,17 +210,22 @@ class IterationNode(BaseNode[IterationNodeData]): else: for _ in range(len(iterator_list_value)): yield from self._run_single_iter( - iterator_list_value, - variable_pool, - inputs, - outputs, - start_at, - graph_engine, - iteration_graph, - iter_run_map, + iterator_list_value=iterator_list_value, + variable_pool=variable_pool, + inputs=inputs, + outputs=outputs, + start_at=start_at, + graph_engine=graph_engine, + iteration_graph=iteration_graph, + iter_run_map=iter_run_map, ) if self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: outputs = [output for output in outputs if output is not None] + + # Flatten the list of lists + if isinstance(outputs, list) and all(isinstance(output, list) for output in outputs): + outputs = [item for sublist in outputs for item in sublist] + yield IterationRunSucceededEvent( iteration_id=self.id, iteration_node_id=self.node_id, @@ -226,7 +233,7 @@ class IterationNode(BaseNode[IterationNodeData]): iteration_node_data=self.node_data, start_at=start_at, inputs=inputs, - outputs={"output": jsonable_encoder(outputs)}, + outputs={"output": outputs}, steps=len(iterator_list_value), metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, ) @@ -234,7 +241,7 @@ class IterationNode(BaseNode[IterationNodeData]): yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={"output": jsonable_encoder(outputs)}, + outputs={"output": outputs}, metadata={NodeRunMetadataKey.ITERATION_DURATION_MAP: iter_run_map}, ) ) @@ -248,7 +255,7 @@ class IterationNode(BaseNode[IterationNodeData]): iteration_node_data=self.node_data, start_at=start_at, inputs=inputs, - outputs={"output": jsonable_encoder(outputs)}, + outputs={"output": outputs}, steps=len(iterator_list_value), metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, error=str(e), @@ -280,7 +287,7 @@ class IterationNode(BaseNode[IterationNodeData]): :param node_data: node data :return: """ - variable_mapping = { + variable_mapping: dict[str, Sequence[str]] = { f"{node_id}.input_selector": node_data.iterator_selector, } @@ -308,7 +315,7 @@ class IterationNode(BaseNode[IterationNodeData]): sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( graph_config=graph_config, config=sub_node_config ) - sub_node_variable_mapping = cast(dict[str, list[str]], sub_node_variable_mapping) + sub_node_variable_mapping = cast(dict[str, Sequence[str]], sub_node_variable_mapping) except NotImplementedError: sub_node_variable_mapping = {} @@ -329,8 +336,12 @@ class IterationNode(BaseNode[IterationNodeData]): return variable_mapping def _handle_event_metadata( - self, event: BaseNodeEvent, iter_run_index: str, parallel_mode_run_id: str - ) -> NodeRunStartedEvent | BaseNodeEvent: + self, + *, + event: BaseNodeEvent | InNodeEvent, + iter_run_index: int, + parallel_mode_run_id: str | None, + ) -> NodeRunStartedEvent | BaseNodeEvent | InNodeEvent: """ add iteration metadata to event. """ @@ -355,6 +366,7 @@ class IterationNode(BaseNode[IterationNodeData]): def _run_single_iter( self, + *, iterator_list_value: list[str], variable_pool: VariablePool, inputs: dict[str, list], @@ -373,12 +385,12 @@ class IterationNode(BaseNode[IterationNodeData]): try: rst = graph_engine.run() # get current iteration index - current_index = variable_pool.get([self.node_id, "index"]).value + index_variable = variable_pool.get([self.node_id, "index"]) + if not isinstance(index_variable, IntegerVariable): + raise IterationIndexNotFoundError(f"iteration {self.node_id} current index not found") + current_index = index_variable.value iteration_run_id = parallel_mode_run_id if parallel_mode_run_id is not None else f"{current_index}" next_index = int(current_index) + 1 - - if current_index is None: - raise IterationIndexNotFoundError(f"iteration {self.node_id} current index not found") for event in rst: if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id: event.in_iteration_id = self.node_id @@ -391,7 +403,9 @@ class IterationNode(BaseNode[IterationNodeData]): continue if isinstance(event, NodeRunSucceededEvent): - yield self._handle_event_metadata(event, current_index, parallel_mode_run_id) + yield self._handle_event_metadata( + event=event, iter_run_index=current_index, parallel_mode_run_id=parallel_mode_run_id + ) elif isinstance(event, BaseGraphEvent): if isinstance(event, GraphRunFailedEvent): # iteration run failed @@ -404,7 +418,7 @@ class IterationNode(BaseNode[IterationNodeData]): parallel_mode_run_id=parallel_mode_run_id, start_at=start_at, inputs=inputs, - outputs={"output": jsonable_encoder(outputs)}, + outputs={"output": outputs}, steps=len(iterator_list_value), metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, error=event.error, @@ -417,7 +431,7 @@ class IterationNode(BaseNode[IterationNodeData]): iteration_node_data=self.node_data, start_at=start_at, inputs=inputs, - outputs={"output": jsonable_encoder(outputs)}, + outputs={"output": outputs}, steps=len(iterator_list_value), metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, error=event.error, @@ -429,9 +443,11 @@ class IterationNode(BaseNode[IterationNodeData]): ) ) return - else: - event = cast(InNodeEvent, event) - metadata_event = self._handle_event_metadata(event, current_index, parallel_mode_run_id) + elif isinstance(event, InNodeEvent): + # event = cast(InNodeEvent, event) + metadata_event = self._handle_event_metadata( + event=event, iter_run_index=current_index, parallel_mode_run_id=parallel_mode_run_id + ) if isinstance(event, NodeRunFailedEvent): if self.node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR: yield NodeInIterationFailedEvent( @@ -513,7 +529,7 @@ class IterationNode(BaseNode[IterationNodeData]): iteration_node_data=self.node_data, index=next_index, parallel_mode_run_id=parallel_mode_run_id, - pre_iteration_output=jsonable_encoder(current_iteration_output) if current_iteration_output else None, + pre_iteration_output=current_iteration_output or None, duration=duration, ) @@ -551,7 +567,7 @@ class IterationNode(BaseNode[IterationNodeData]): index: int, item: Any, iter_run_map: dict[str, float], - ) -> Generator[NodeEvent | InNodeEvent, None, None]: + ): """ run single iteration in parallel mode """