This commit is contained in:
takatost 2024-07-24 00:24:24 +08:00
parent e9bfedab9b
commit ec7760795f
5 changed files with 379 additions and 35 deletions

View File

@ -7,6 +7,8 @@ from core.workflow.entities.node_entities import NodeType
from core.workflow.graph_engine.entities.run_condition import RunCondition
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
from core.workflow.nodes.answer.entities import AnswerStreamGenerateRoute
from core.workflow.nodes.end.end_stream_generate_router import EndStreamGeneratorRouter
from core.workflow.nodes.end.entities import EndStreamParam
class GraphEdge(BaseModel):
@ -52,6 +54,10 @@ class Graph(BaseModel):
...,
description="answer stream generate routes"
)
end_stream_param: EndStreamParam = Field(
...,
description="end stream param"
)
@classmethod
def init(cls,
@ -166,6 +172,12 @@ class Graph(BaseModel):
reverse_edge_mapping=reverse_edge_mapping
)
# init end stream param
end_stream_param = EndStreamGeneratorRouter.init(
node_id_config_mapping=node_id_config_mapping,
reverse_edge_mapping=reverse_edge_mapping
)
# init graph
graph = cls(
root_node_id=root_node_id,
@ -175,7 +187,8 @@ class Graph(BaseModel):
reverse_edge_mapping=reverse_edge_mapping,
parallel_mapping=parallel_mapping,
node_parallel_mapping=node_parallel_mapping,
answer_stream_generate_routes=answer_stream_generate_routes
answer_stream_generate_routes=answer_stream_generate_routes,
end_stream_param=end_stream_param
)
return graph

View File

@ -167,37 +167,3 @@ class AnswerStreamGeneratorRouter:
reverse_edge_mapping=reverse_edge_mapping,
answer_dependencies=answer_dependencies
)
@classmethod
def _fetch_answer_dependencies(cls,
current_node_id: str,
answer_node_id: str,
answer_node_ids: list[str],
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
answer_dependencies: dict[str, list[str]]
) -> None:
"""
Fetch answer dependencies
:param current_node_id: current node id
:param answer_node_id: answer node id
:param answer_node_ids: answer node ids
:param reverse_edge_mapping: reverse edge mapping
:param answer_dependencies: answer dependencies
:return:
"""
for edge in reverse_edge_mapping.get(current_node_id, []):
source_node_id = edge.source_node_id
if source_node_id == answer_node_id:
continue
if source_node_id in answer_node_ids:
# is answer node
answer_dependencies[answer_node_id].append(source_node_id)
else:
cls._fetch_answer_dependencies(
current_node_id=source_node_id,
answer_node_id=answer_node_id,
answer_node_ids=answer_node_ids,
reverse_edge_mapping=reverse_edge_mapping,
answer_dependencies=answer_dependencies
)

View File

@ -0,0 +1,142 @@
from core.workflow.entities.node_entities import NodeType
from core.workflow.nodes.end.entities import EndNodeData, EndStreamParam
class EndStreamGeneratorRouter:
@classmethod
def init(cls,
node_id_config_mapping: dict[str, dict],
reverse_edge_mapping: dict[str, list["GraphEdge"]] # type: ignore[name-defined]
) -> EndStreamParam:
"""
Get stream generate routes.
:return:
"""
# parse stream output node value selector of end nodes
end_stream_variable_selectors_mapping: dict[str, list[list[str]]] = {}
for end_node_id, node_config in node_id_config_mapping.items():
if not node_config.get('data', {}).get('type') == NodeType.END.value:
continue
# get generate route for stream output
stream_variable_selectors = cls._extract_stream_variable_selector(node_id_config_mapping, node_config)
end_stream_variable_selectors_mapping[end_node_id] = stream_variable_selectors
# fetch end dependencies
end_node_ids = list(end_stream_variable_selectors_mapping.keys())
end_dependencies = cls._fetch_ends_dependencies(
end_node_ids=end_node_ids,
reverse_edge_mapping=reverse_edge_mapping,
node_id_config_mapping=node_id_config_mapping
)
return EndStreamParam(
end_stream_variable_selector_mapping=end_stream_variable_selectors_mapping,
end_dependencies=end_dependencies
)
@classmethod
def extract_stream_variable_selector_from_node_data(cls,
node_id_config_mapping: dict[str, dict],
node_data: EndNodeData) -> list[list[str]]:
"""
Extract stream variable selector from node data
:param node_id_config_mapping: node id config mapping
:param node_data: node data object
:return:
"""
variable_selectors = node_data.outputs
value_selectors = []
for variable_selector in variable_selectors:
if not variable_selector.value_selector:
continue
node_id = variable_selector.value_selector[0]
if node_id != 'sys' and node_id in node_id_config_mapping:
node = node_id_config_mapping[node_id]
node_type = node.get('data', {}).get('type')
if node_type == NodeType.LLM.value and variable_selector.value_selector[1] == 'text':
value_selectors.append(variable_selector.value_selector)
# remove duplicates
value_selectors = list(set(value_selectors))
return value_selectors
@classmethod
def _extract_stream_variable_selector(cls, node_id_config_mapping: dict[str, dict], config: dict) \
-> list[list[str]]:
"""
Extract stream variable selector from node config
:param node_id_config_mapping: node id config mapping
:param config: node config
:return:
"""
node_data = EndNodeData(**config.get("data", {}))
return cls.extract_stream_variable_selector_from_node_data(node_id_config_mapping, node_data)
@classmethod
def _fetch_ends_dependencies(cls,
end_node_ids: list[str],
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
node_id_config_mapping: dict[str, dict]
) -> dict[str, list[str]]:
"""
Fetch end dependencies
:param end_node_ids: end node ids
:param reverse_edge_mapping: reverse edge mapping
:param node_id_config_mapping: node id config mapping
:return:
"""
end_dependencies: dict[str, list[str]] = {}
for end_node_id in end_node_ids:
if end_dependencies.get(end_node_id) is None:
end_dependencies[end_node_id] = []
cls._recursive_fetch_end_dependencies(
current_node_id=end_node_id,
end_node_id=end_node_id,
node_id_config_mapping=node_id_config_mapping,
reverse_edge_mapping=reverse_edge_mapping,
end_dependencies=end_dependencies
)
return end_dependencies
@classmethod
def _recursive_fetch_end_dependencies(cls,
current_node_id: str,
end_node_id: str,
node_id_config_mapping: dict[str, dict],
reverse_edge_mapping: dict[str, list["GraphEdge"]],
# type: ignore[name-defined]
end_dependencies: dict[str, list[str]]
) -> None:
"""
Recursive fetch end dependencies
:param current_node_id: current node id
:param end_node_id: end node id
:param node_id_config_mapping: node id config mapping
:param reverse_edge_mapping: reverse edge mapping
:param end_dependencies: end dependencies
:return:
"""
reverse_edges = reverse_edge_mapping.get(current_node_id, [])
for edge in reverse_edges:
source_node_id = edge.source_node_id
source_node_type = node_id_config_mapping[source_node_id].get('data', {}).get('type')
if source_node_type in (
NodeType.IF_ELSE.value,
NodeType.QUESTION_CLASSIFIER,
):
end_dependencies[end_node_id].append(source_node_id)
else:
cls._recursive_fetch_end_dependencies(
current_node_id=source_node_id,
end_node_id=end_node_id,
node_id_config_mapping=node_id_config_mapping,
reverse_edge_mapping=reverse_edge_mapping,
end_dependencies=end_dependencies
)

View File

@ -0,0 +1,207 @@
import logging
from collections.abc import Generator
from typing import cast
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes.answer.entities import GenerateRouteChunk, TextGenerateRouteChunk, VarGenerateRouteChunk
logger = logging.getLogger(__name__)
class EndStreamProcessor:
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
self.graph = graph
self.variable_pool = variable_pool
self.stream_param = graph.end_stream_param
self.end_streamed_variable_selectors: dict[str, list[str]] = {
end_node_id: [] for end_node_id in graph.end_stream_param.end_stream_variable_selector_mapping
}
self.rest_node_ids = graph.node_ids.copy()
self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {}
def process(self,
generator: Generator[GraphEngineEvent, None, None]
) -> Generator[GraphEngineEvent, None, None]:
for event in generator:
if isinstance(event, NodeRunStreamChunkEvent):
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
stream_out_answer_node_ids = self.current_stream_chunk_generating_node_ids[
event.route_node_state.node_id
]
else:
stream_out_answer_node_ids = self._get_stream_out_answer_node_ids(event)
self.current_stream_chunk_generating_node_ids[
event.route_node_state.node_id
] = stream_out_answer_node_ids
for _ in stream_out_answer_node_ids:
yield event
elif isinstance(event, NodeRunSucceededEvent):
yield event
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
# update self.route_position after all stream event finished
for answer_node_id in self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]:
self.route_position[answer_node_id] += 1
del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]
# remove unreachable nodes
self._remove_unreachable_nodes(event)
# generate stream outputs
yield from self._generate_stream_outputs_when_node_finished(event)
else:
yield event
def reset(self) -> None:
self.route_position = {}
for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items():
self.route_position[answer_node_id] = 0
self.rest_node_ids = self.graph.node_ids.copy()
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None:
finished_node_id = event.route_node_state.node_id
if finished_node_id not in self.rest_node_ids:
return
# remove finished node id
self.rest_node_ids.remove(finished_node_id)
run_result = event.route_node_state.node_run_result
if not run_result:
return
if run_result.edge_source_handle:
reachable_node_ids = []
unreachable_first_node_ids = []
for edge in self.graph.edge_mapping[finished_node_id]:
if (edge.run_condition
and edge.run_condition.branch_identify
and run_result.edge_source_handle == edge.run_condition.branch_identify):
reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
continue
else:
unreachable_first_node_ids.append(edge.target_node_id)
for node_id in unreachable_first_node_ids:
self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids)
def _fetch_node_ids_in_reachable_branch(self, node_id: str) -> list[str]:
node_ids = []
for edge in self.graph.edge_mapping.get(node_id, []):
node_ids.append(edge.target_node_id)
node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
return node_ids
def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]) -> None:
"""
remove target node ids until merge
"""
self.rest_node_ids.remove(node_id)
for edge in self.graph.edge_mapping.get(node_id, []):
if edge.target_node_id in reachable_node_ids:
continue
self._remove_node_ids_in_unreachable_branch(edge.target_node_id, reachable_node_ids)
def _generate_stream_outputs_when_node_finished(self,
event: NodeRunSucceededEvent
) -> Generator[GraphEngineEvent, None, None]:
"""
Generate stream outputs.
:param event: node run succeeded event
:return:
"""
for answer_node_id, position in self.route_position.items():
# all depends on answer node id not in rest node ids
if (event.route_node_state.node_id != answer_node_id
and (answer_node_id not in self.rest_node_ids
or not all(dep_id not in self.rest_node_ids
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]))):
continue
route_position = self.route_position[answer_node_id]
route_chunks = self.generate_routes.answer_generate_route[answer_node_id][route_position:]
for route_chunk in route_chunks:
if route_chunk.type == GenerateRouteChunk.ChunkType.TEXT:
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
yield NodeRunStreamChunkEvent(
chunk_content=route_chunk.text,
route_node_state=event.route_node_state,
parallel_id=event.parallel_id,
)
else:
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
value_selector = route_chunk.value_selector
if not value_selector:
break
value = self.variable_pool.get(
value_selector
)
if value is None:
break
text = value.markdown
if text:
yield NodeRunStreamChunkEvent(
chunk_content=text,
from_variable_selector=value_selector,
route_node_state=event.route_node_state,
parallel_id=event.parallel_id,
)
self.route_position[answer_node_id] += 1
def _get_stream_out_answer_node_ids(self, event: NodeRunStreamChunkEvent) -> list[str]:
"""
Is stream out support
:param event: queue text chunk event
:return:
"""
if not event.from_variable_selector:
return []
stream_output_value_selector = event.from_variable_selector
if not stream_output_value_selector:
return []
stream_out_answer_node_ids = []
for answer_node_id, position in self.route_position.items():
if answer_node_id not in self.rest_node_ids:
continue
# all depends on answer node id not in rest node ids
if all(dep_id not in self.rest_node_ids
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]):
route_position = self.route_position[answer_node_id]
if route_position >= len(self.generate_routes.answer_generate_route[answer_node_id]):
continue
route_chunk = self.generate_routes.answer_generate_route[answer_node_id][route_position]
if route_chunk.type != GenerateRouteChunk.ChunkType.VAR:
continue
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
value_selector = route_chunk.value_selector
# check chunk node id is before current node id or equal to current node id
if value_selector != stream_output_value_selector:
continue
stream_out_answer_node_ids.append(answer_node_id)
return stream_out_answer_node_ids

View File

@ -1,3 +1,5 @@
from pydantic import BaseModel, Field
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.variable_entities import VariableSelector
@ -7,3 +9,17 @@ class EndNodeData(BaseNodeData):
END Node Data.
"""
outputs: list[VariableSelector]
class EndStreamParam(BaseModel):
"""
EndStreamParam entity
"""
end_dependencies: dict[str, list[str]] = Field(
...,
description="end dependencies (end node id -> dependent node ids)"
)
end_stream_variable_selector_mapping: dict[str, list[list[str]]] = Field(
...,
description="end stream variable selector mapping (end node id -> stream variable selectors)"
)