add end stream output test

This commit is contained in:
takatost 2024-07-25 04:03:53 +08:00
parent 833584ba76
commit f4eb7cd037
6 changed files with 300 additions and 154 deletions

View File

@ -29,6 +29,7 @@ from core.workflow.graph_engine.entities.graph_init_params import GraphInitParam
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
# from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
@ -82,14 +83,21 @@ class GraphEngine:
yield GraphRunStartedEvent()
try:
# run graph
generator = self._run(start_node_id=self.graph.root_node_id)
if self.init_params.workflow_type == WorkflowType.CHAT:
answer_stream_processor = AnswerStreamProcessor(
stream_processor = AnswerStreamProcessor(
graph=self.graph,
variable_pool=self.graph_runtime_state.variable_pool
)
generator = answer_stream_processor.process(generator)
else:
stream_processor = EndStreamProcessor(
graph=self.graph,
variable_pool=self.graph_runtime_state.variable_pool
)
# run graph
generator = stream_processor.process(
self._run(start_node_id=self.graph.root_node_id)
)
for item in generator:
yield item
@ -151,6 +159,11 @@ class GraphEngine:
)
raise e
# It may not be necessary, but it is necessary. :)
if (self.graph.node_id_config_mapping[next_node_id]
.get("data", {}).get("type", "").lower() == NodeType.END.value):
break
previous_route_node_state = route_node_state
# get next node ids
@ -160,11 +173,6 @@ class GraphEngine:
if len(edge_mappings) == 1:
next_node_id = edge_mappings[0].target_node_id
# It may not be necessary, but it is necessary. :)
if (self.graph.node_id_config_mapping[next_node_id]
.get("data", {}).get("type", "").lower() == NodeType.END.value):
break
else:
if any(edge.run_condition for edge in edge_mappings):
# if nodes has run conditions, get node id which branch to take based on the run condition results

View File

@ -66,6 +66,7 @@ class AnswerStreamProcessor:
for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items():
self.route_position[answer_node_id] = 0
self.rest_node_ids = self.graph.node_ids.copy()
self.current_stream_chunk_generating_node_ids = {}
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None:
finished_node_id = event.route_node_state.node_id
@ -179,14 +180,13 @@ class AnswerStreamProcessor:
return []
stream_out_answer_node_ids = []
for answer_node_id, position in self.route_position.items():
for answer_node_id, route_position in self.route_position.items():
if answer_node_id not in self.rest_node_ids:
continue
# all depends on answer node id not in rest node ids
if all(dep_id not in self.rest_node_ids
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]):
route_position = self.route_position[answer_node_id]
if route_position >= len(self.generate_routes.answer_generate_route[answer_node_id]):
continue

View File

@ -31,49 +31,6 @@ class EndNode(BaseNode):
outputs=outputs
)
@classmethod
def extract_generate_nodes(cls, graph: dict, config: dict) -> list[str]:
"""
Extract generate nodes
:param graph: graph
:param config: node config
:return:
"""
node_data = cls._node_data_cls(**config.get("data", {}))
node_data = cast(EndNodeData, node_data)
return cls.extract_generate_nodes_from_node_data(graph, node_data)
@classmethod
def extract_generate_nodes_from_node_data(cls, graph: dict, node_data: EndNodeData) -> list[str]:
"""
Extract generate nodes from node data
:param graph: graph
:param node_data: node data object
:return:
"""
nodes = graph.get('nodes', [])
node_mapping = {node.get('id'): node for node in nodes}
variable_selectors = node_data.outputs
generate_nodes = []
for variable_selector in variable_selectors:
if not variable_selector.value_selector:
continue
node_id = variable_selector.value_selector[0]
if node_id != 'sys' and node_id in node_mapping:
node = node_mapping[node_id]
node_type = node.get('data', {}).get('type')
if node_type == NodeType.LLM.value and variable_selector.value_selector[1] == 'text':
generate_nodes.append(node_id)
# remove duplicates
generate_nodes = list(set(generate_nodes))
return generate_nodes
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
"""

View File

@ -61,7 +61,9 @@ class EndStreamGeneratorRouter:
value_selectors.append(variable_selector.value_selector)
# remove duplicates
value_selectors = list(set(value_selectors))
value_selector_tuples = [tuple(item) for item in value_selectors]
unique_value_selector_tuples = list(set(value_selector_tuples))
value_selectors = [list(item) for item in unique_value_selector_tuples]
return value_selectors

View File

@ -1,6 +1,5 @@
import logging
from collections.abc import Generator
from typing import cast
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import (
@ -9,7 +8,6 @@ from core.workflow.graph_engine.entities.event import (
NodeRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes.answer.entities import GenerateRouteChunk, TextGenerateRouteChunk, VarGenerateRouteChunk
logger = logging.getLogger(__name__)
@ -20,10 +18,7 @@ class EndStreamProcessor:
self.graph = graph
self.variable_pool = variable_pool
self.stream_param = graph.end_stream_param
self.end_streamed_variable_selectors: dict[str, list[str]] = {
end_node_id: [] for end_node_id in graph.end_stream_param.end_stream_variable_selector_mapping
}
self.end_streamed_variable_selectors = graph.end_stream_param.end_stream_variable_selector_mapping.copy()
self.rest_node_ids = graph.node_ids.copy()
self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {}
@ -33,43 +28,37 @@ class EndStreamProcessor:
for event in generator:
if isinstance(event, NodeRunStreamChunkEvent):
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
stream_out_answer_node_ids = self.current_stream_chunk_generating_node_ids[
stream_out_end_node_ids = self.current_stream_chunk_generating_node_ids[
event.route_node_state.node_id
]
else:
stream_out_answer_node_ids = self._get_stream_out_answer_node_ids(event)
stream_out_end_node_ids = self._get_stream_out_end_node_ids(event)
self.current_stream_chunk_generating_node_ids[
event.route_node_state.node_id
] = stream_out_answer_node_ids
] = stream_out_end_node_ids
for _ in stream_out_answer_node_ids:
for _ in stream_out_end_node_ids:
yield event
elif isinstance(event, NodeRunSucceededEvent):
yield event
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
# update self.route_position after all stream event finished
for answer_node_id in self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]:
self.route_position[answer_node_id] += 1
del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]
# remove unreachable nodes
self._remove_unreachable_nodes(event)
# generate stream outputs
yield from self._generate_stream_outputs_when_node_finished(event)
else:
yield event
def reset(self) -> None:
self.route_position = {}
for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items():
self.route_position[answer_node_id] = 0
self.end_streamed_variable_selectors = {}
self.end_streamed_variable_selectors: dict[str, list[str]] = {
end_node_id: [] for end_node_id in self.graph.end_stream_param.end_stream_variable_selector_mapping
}
self.rest_node_ids = self.graph.node_ids.copy()
self.current_stream_chunk_generating_node_ids = {}
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None:
finished_node_id = event.route_node_state.node_id
if finished_node_id not in self.rest_node_ids:
return
@ -113,59 +102,7 @@ class EndStreamProcessor:
self._remove_node_ids_in_unreachable_branch(edge.target_node_id, reachable_node_ids)
def _generate_stream_outputs_when_node_finished(self,
event: NodeRunSucceededEvent
) -> Generator[GraphEngineEvent, None, None]:
"""
Generate stream outputs.
:param event: node run succeeded event
:return:
"""
for answer_node_id, position in self.route_position.items():
# all depends on answer node id not in rest node ids
if (event.route_node_state.node_id != answer_node_id
and (answer_node_id not in self.rest_node_ids
or not all(dep_id not in self.rest_node_ids
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]))):
continue
route_position = self.route_position[answer_node_id]
route_chunks = self.generate_routes.answer_generate_route[answer_node_id][route_position:]
for route_chunk in route_chunks:
if route_chunk.type == GenerateRouteChunk.ChunkType.TEXT:
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
yield NodeRunStreamChunkEvent(
chunk_content=route_chunk.text,
route_node_state=event.route_node_state,
parallel_id=event.parallel_id,
)
else:
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
value_selector = route_chunk.value_selector
if not value_selector:
break
value = self.variable_pool.get(
value_selector
)
if value is None:
break
text = value.markdown
if text:
yield NodeRunStreamChunkEvent(
chunk_content=text,
from_variable_selector=value_selector,
route_node_state=event.route_node_state,
parallel_id=event.parallel_id,
)
self.route_position[answer_node_id] += 1
def _get_stream_out_answer_node_ids(self, event: NodeRunStreamChunkEvent) -> list[str]:
def _get_stream_out_end_node_ids(self, event: NodeRunStreamChunkEvent) -> list[str]:
"""
Is stream out support
:param event: queue text chunk event
@ -178,30 +115,17 @@ class EndStreamProcessor:
if not stream_output_value_selector:
return []
stream_out_answer_node_ids = []
for answer_node_id, position in self.route_position.items():
if answer_node_id not in self.rest_node_ids:
stream_out_end_node_ids = []
for end_node_id, variable_selectors in self.end_streamed_variable_selectors.items():
if end_node_id not in self.rest_node_ids:
continue
# all depends on answer node id not in rest node ids
# all depends on end node id not in rest node ids
if all(dep_id not in self.rest_node_ids
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]):
route_position = self.route_position[answer_node_id]
if route_position >= len(self.generate_routes.answer_generate_route[answer_node_id]):
for dep_id in self.stream_param.end_dependencies[end_node_id]):
if stream_output_value_selector not in variable_selectors:
continue
route_chunk = self.generate_routes.answer_generate_route[answer_node_id][route_position]
stream_out_end_node_ids.append(end_node_id)
if route_chunk.type != GenerateRouteChunk.ChunkType.VAR:
continue
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
value_selector = route_chunk.value_selector
# check chunk node id is before current node id or equal to current node id
if value_selector != stream_output_value_selector:
continue
stream_out_answer_node_ids.append(answer_node_id)
return stream_out_answer_node_ids
return stream_out_end_node_ids

View File

@ -1,7 +1,7 @@
from unittest.mock import patch
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import SystemVariable, UserFrom
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, SystemVariable, UserFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import (
BaseNodeEvent,
@ -16,12 +16,267 @@ from core.workflow.graph_engine.entities.event import (
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.graph_engine.graph_engine import GraphEngine
from models.workflow import WorkflowType
from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
from core.workflow.nodes.llm.llm_node import LLMNode
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
@patch('extensions.ext_database.db.session.remove')
@patch('extensions.ext_database.db.session.close')
def test_run_parallel(mock_close, mock_remove):
def test_run_parallel_in_workflow(mock_close, mock_remove, mocker):
graph_config = {
"edges": [
{
"id": "1",
"source": "start",
"target": "llm1",
},
{
"id": "2",
"source": "llm1",
"target": "llm2",
},
{
"id": "3",
"source": "llm1",
"target": "llm3",
},
{
"id": "4",
"source": "llm2",
"target": "end1",
},
{
"id": "5",
"source": "llm3",
"target": "end2",
}
],
"nodes": [
{
"data": {
"type": "start",
"title": "start",
"variables": [{
"label": "query",
"max_length": 48,
"options": [],
"required": True,
"type": "text-input",
"variable": "query"
}]
},
"id": "start"
},
{
"data": {
"type": "llm",
"title": "llm1",
"context": {
"enabled": False,
"variable_selector": []
},
"model": {
"completion_params": {
"temperature": 0.7
},
"mode": "chat",
"name": "gpt-4o",
"provider": "openai"
},
"prompt_template": [{
"role": "system",
"text": "say hi"
}, {
"role": "user",
"text": "{{#start.query#}}"
}],
"vision": {
"configs": {
"detail": "high"
},
"enabled": False
}
},
"id": "llm1"
},
{
"data": {
"type": "llm",
"title": "llm2",
"context": {
"enabled": False,
"variable_selector": []
},
"model": {
"completion_params": {
"temperature": 0.7
},
"mode": "chat",
"name": "gpt-4o",
"provider": "openai"
},
"prompt_template": [{
"role": "system",
"text": "say bye"
}, {
"role": "user",
"text": "{{#start.query#}}"
}],
"vision": {
"configs": {
"detail": "high"
},
"enabled": False
}
},
"id": "llm2",
},
{
"data": {
"type": "llm",
"title": "llm3",
"context": {
"enabled": False,
"variable_selector": []
},
"model": {
"completion_params": {
"temperature": 0.7
},
"mode": "chat",
"name": "gpt-4o",
"provider": "openai"
},
"prompt_template": [{
"role": "system",
"text": "say good morning"
}, {
"role": "user",
"text": "{{#start.query#}}"
}],
"vision": {
"configs": {
"detail": "high"
},
"enabled": False
}
},
"id": "llm3",
},
{
"data": {
"type": "end",
"title": "end1",
"outputs": [{
"value_selector": ["llm2", "text"],
"variable": "result2"
}, {
"value_selector": ["start", "query"],
"variable": "query"
}],
},
"id": "end1",
},
{
"data": {
"type": "end",
"title": "end2",
"outputs": [{
"value_selector": ["llm1", "text"],
"variable": "result1"
}, {
"value_selector": ["llm3", "text"],
"variable": "result3"
}],
},
"id": "end2",
}
],
}
graph = Graph.init(
graph_config=graph_config
)
variable_pool = VariablePool(system_variables={
SystemVariable.FILES: [],
SystemVariable.USER_ID: 'aaa'
}, user_inputs={
"query": "hi"
})
graph_engine = GraphEngine(
tenant_id="111",
app_id="222",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="333",
user_id="444",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.WEB_APP,
call_depth=0,
graph=graph,
variable_pool=variable_pool,
max_execution_steps=500,
max_execution_time=1200
)
def llm_generator(self):
contents = [
'hi',
'bye',
'good morning'
]
yield RunStreamChunkEvent(
chunk_content=contents[int(self.node_id[-1]) - 1],
from_variable_selector=[self.node_id, 'text']
)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs={},
process_data={},
outputs={},
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: 1,
NodeRunMetadataKey.TOTAL_PRICE: 1,
NodeRunMetadataKey.CURRENCY: 'USD'
}
)
)
print("")
with patch.object(LLMNode, '_run', new=llm_generator):
items = []
generator = graph_engine.run()
for item in generator:
print(type(item), item)
items.append(item)
if isinstance(item, NodeRunSucceededEvent):
assert item.route_node_state.status == RouteNodeState.Status.SUCCESS
assert not isinstance(item, NodeRunFailedEvent)
assert not isinstance(item, GraphRunFailedEvent)
if isinstance(item, BaseNodeEvent) and item.route_node_state.node_id in [
'llm2', 'llm3', 'end1', 'end2'
]:
assert item.parallel_id is not None
assert len(items) == 17
assert isinstance(items[0], GraphRunStartedEvent)
assert isinstance(items[1], NodeRunStartedEvent)
assert items[1].route_node_state.node_id == 'start'
assert isinstance(items[2], NodeRunSucceededEvent)
assert items[2].route_node_state.node_id == 'start'
@patch('extensions.ext_database.db.session.remove')
@patch('extensions.ext_database.db.session.close')
def test_run_parallel_in_chatflow(mock_close, mock_remove):
graph_config = {
"edges": [
{
@ -291,7 +546,7 @@ def test_run_branch(mock_close, mock_remove):
items = []
generator = graph_engine.run()
for item in generator:
print(type(item), item)
# print(type(item), item)
items.append(item)
assert len(items) == 10