fix(workflow): fix node link to previous node issue

This commit is contained in:
takatost 2024-08-20 23:28:11 +08:00
parent 617ea4b3b8
commit 1d88b62e25

View File

@ -1,6 +1,7 @@
import uuid
from collections.abc import Mapping
from typing import Any, Optional, cast
from cycler import V
from pydantic import BaseModel, Field
@ -152,6 +153,12 @@ class Graph(BaseModel):
if not root_node_id or root_node_id not in root_node_ids:
raise ValueError(f"Root node id {root_node_id} not found in the graph")
# Check whether it is connected to the previous node
cls._check_connected_to_previous_node(
route=[root_node_id],
edge_mapping=edge_mapping
)
# fetch all node ids from root node
node_ids = [root_node_id]
@ -267,6 +274,30 @@ class Graph(BaseModel):
node_id=graph_edge.target_node_id
)
@classmethod
def _check_connected_to_previous_node(
cls,
route: list[str],
edge_mapping: dict[str, list[GraphEdge]]
) -> None:
"""
Check whether it is connected to the previous node
"""
new_route = list(route)
for graph_edge in edge_mapping.get(new_route[-1], []):
if not graph_edge.target_node_id:
continue
if graph_edge.target_node_id in new_route:
raise ValueError(f"Node {graph_edge.source_node_id} is connected to the previous node, please check the graph.")
new_route.append(graph_edge.target_node_id)
cls._check_connected_to_previous_node(
route=new_route,
edge_mapping=edge_mapping,
)
@classmethod
def _recursively_add_parallels(cls,
edge_mapping: dict[str, list[GraphEdge]],