fix: loop node metadata

This commit is contained in:
arkunzz 2025-03-04 11:52:17 +08:00 committed by Wood
parent 712b95380e
commit 1aa0d1c532

View File

@ -108,6 +108,10 @@ class LoopNode(BaseNode[LoopNodeData]):
for i in range(loop_count):
# Run workflow
rst = graph_engine.run()
current_index_variable = variable_pool.get([self.node_id, "index"])
if not isinstance(current_index_variable, IntegerSegment):
raise ValueError(f"loop {self.node_id} current index not found")
current_index = current_index_variable.value
check_break_result = False
@ -123,30 +127,7 @@ class LoopNode(BaseNode[LoopNodeData]):
continue
if isinstance(event, NodeRunSucceededEvent):
if event.route_node_state.node_run_result:
metadata = event.route_node_state.node_run_result.metadata
if not metadata:
metadata = {}
if NodeRunMetadataKey.LOOP_ID not in metadata:
index_variable = variable_pool.get([self.node_id, "index"])
if not isinstance(index_variable, IntegerSegment):
total_tokens = graph_engine.graph_runtime_state.total_tokens
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=f"Invalid index variable type: {type(index_variable)}",
metadata={NodeRunMetadataKey.TOTAL_TOKENS: total_tokens},
)
)
return
metadata = {
**metadata,
NodeRunMetadataKey.LOOP_ID: self.node_id,
NodeRunMetadataKey.LOOP_INDEX: index_variable.value,
}
event.route_node_state.node_run_result.metadata = metadata
yield event
yield self._handle_event_metadata(event=event, iter_run_index=current_index)
# Check if all variables in break conditions exist
exists_variable = False
@ -220,7 +201,7 @@ class LoopNode(BaseNode[LoopNodeData]):
)
return
else:
yield cast(InNodeEvent, event)
yield self._handle_event_metadata(event=cast(InNodeEvent, event), iter_run_index=current_index)
# Remove all nodes outputs from variable pool
for node_id in loop_graph.node_ids:
@ -230,11 +211,7 @@ class LoopNode(BaseNode[LoopNodeData]):
break
# Move to next loop
current_index_variable = variable_pool.get([self.node_id, "index"])
if not isinstance(current_index_variable, IntegerSegment):
raise ValueError(f"loop {self.node_id} current index not found")
next_index = current_index_variable.value + 1
next_index = current_index + 1
variable_pool.add([self.node_id, "index"], next_index)
yield LoopRunNextEvent(
@ -298,6 +275,30 @@ class LoopNode(BaseNode[LoopNodeData]):
# Clean up
variable_pool.remove([self.node_id, "index"])
def _handle_event_metadata(
self,
*,
event: BaseNodeEvent | InNodeEvent,
iter_run_index: int,
) -> NodeRunStartedEvent | BaseNodeEvent | InNodeEvent:
"""
add iteration metadata to event.
"""
if not isinstance(event, BaseNodeEvent):
return event
if event.route_node_state.node_run_result:
metadata = event.route_node_state.node_run_result.metadata
if not metadata:
metadata = {}
if NodeRunMetadataKey.LOOP_ID not in metadata:
metadata = {
**metadata,
NodeRunMetadataKey.LOOP_ID: self.node_id,
NodeRunMetadataKey.LOOP_INDEX: iter_run_index
}
event.route_node_state.node_run_result.metadata = metadata
return event
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,