add end stream output test
This commit is contained in:
parent
833584ba76
commit
f4eb7cd037
@ -29,6 +29,7 @@ from core.workflow.graph_engine.entities.graph_init_params import GraphInitParam
|
|||||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||||
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
|
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
|
||||||
|
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
|
||||||
|
|
||||||
# from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
|
# from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
|
||||||
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
||||||
@ -82,14 +83,21 @@ class GraphEngine:
|
|||||||
yield GraphRunStartedEvent()
|
yield GraphRunStartedEvent()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# run graph
|
|
||||||
generator = self._run(start_node_id=self.graph.root_node_id)
|
|
||||||
if self.init_params.workflow_type == WorkflowType.CHAT:
|
if self.init_params.workflow_type == WorkflowType.CHAT:
|
||||||
answer_stream_processor = AnswerStreamProcessor(
|
stream_processor = AnswerStreamProcessor(
|
||||||
graph=self.graph,
|
graph=self.graph,
|
||||||
variable_pool=self.graph_runtime_state.variable_pool
|
variable_pool=self.graph_runtime_state.variable_pool
|
||||||
)
|
)
|
||||||
generator = answer_stream_processor.process(generator)
|
else:
|
||||||
|
stream_processor = EndStreamProcessor(
|
||||||
|
graph=self.graph,
|
||||||
|
variable_pool=self.graph_runtime_state.variable_pool
|
||||||
|
)
|
||||||
|
|
||||||
|
# run graph
|
||||||
|
generator = stream_processor.process(
|
||||||
|
self._run(start_node_id=self.graph.root_node_id)
|
||||||
|
)
|
||||||
|
|
||||||
for item in generator:
|
for item in generator:
|
||||||
yield item
|
yield item
|
||||||
@ -151,6 +159,11 @@ class GraphEngine:
|
|||||||
)
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
# It may not be necessary, but it is necessary. :)
|
||||||
|
if (self.graph.node_id_config_mapping[next_node_id]
|
||||||
|
.get("data", {}).get("type", "").lower() == NodeType.END.value):
|
||||||
|
break
|
||||||
|
|
||||||
previous_route_node_state = route_node_state
|
previous_route_node_state = route_node_state
|
||||||
|
|
||||||
# get next node ids
|
# get next node ids
|
||||||
@ -160,11 +173,6 @@ class GraphEngine:
|
|||||||
|
|
||||||
if len(edge_mappings) == 1:
|
if len(edge_mappings) == 1:
|
||||||
next_node_id = edge_mappings[0].target_node_id
|
next_node_id = edge_mappings[0].target_node_id
|
||||||
|
|
||||||
# It may not be necessary, but it is necessary. :)
|
|
||||||
if (self.graph.node_id_config_mapping[next_node_id]
|
|
||||||
.get("data", {}).get("type", "").lower() == NodeType.END.value):
|
|
||||||
break
|
|
||||||
else:
|
else:
|
||||||
if any(edge.run_condition for edge in edge_mappings):
|
if any(edge.run_condition for edge in edge_mappings):
|
||||||
# if nodes has run conditions, get node id which branch to take based on the run condition results
|
# if nodes has run conditions, get node id which branch to take based on the run condition results
|
||||||
|
@ -66,6 +66,7 @@ class AnswerStreamProcessor:
|
|||||||
for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items():
|
for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items():
|
||||||
self.route_position[answer_node_id] = 0
|
self.route_position[answer_node_id] = 0
|
||||||
self.rest_node_ids = self.graph.node_ids.copy()
|
self.rest_node_ids = self.graph.node_ids.copy()
|
||||||
|
self.current_stream_chunk_generating_node_ids = {}
|
||||||
|
|
||||||
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None:
|
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None:
|
||||||
finished_node_id = event.route_node_state.node_id
|
finished_node_id = event.route_node_state.node_id
|
||||||
@ -179,14 +180,13 @@ class AnswerStreamProcessor:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
stream_out_answer_node_ids = []
|
stream_out_answer_node_ids = []
|
||||||
for answer_node_id, position in self.route_position.items():
|
for answer_node_id, route_position in self.route_position.items():
|
||||||
if answer_node_id not in self.rest_node_ids:
|
if answer_node_id not in self.rest_node_ids:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# all depends on answer node id not in rest node ids
|
# all depends on answer node id not in rest node ids
|
||||||
if all(dep_id not in self.rest_node_ids
|
if all(dep_id not in self.rest_node_ids
|
||||||
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]):
|
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]):
|
||||||
route_position = self.route_position[answer_node_id]
|
|
||||||
if route_position >= len(self.generate_routes.answer_generate_route[answer_node_id]):
|
if route_position >= len(self.generate_routes.answer_generate_route[answer_node_id]):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -31,49 +31,6 @@ class EndNode(BaseNode):
|
|||||||
outputs=outputs
|
outputs=outputs
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def extract_generate_nodes(cls, graph: dict, config: dict) -> list[str]:
|
|
||||||
"""
|
|
||||||
Extract generate nodes
|
|
||||||
:param graph: graph
|
|
||||||
:param config: node config
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
node_data = cls._node_data_cls(**config.get("data", {}))
|
|
||||||
node_data = cast(EndNodeData, node_data)
|
|
||||||
|
|
||||||
return cls.extract_generate_nodes_from_node_data(graph, node_data)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def extract_generate_nodes_from_node_data(cls, graph: dict, node_data: EndNodeData) -> list[str]:
|
|
||||||
"""
|
|
||||||
Extract generate nodes from node data
|
|
||||||
:param graph: graph
|
|
||||||
:param node_data: node data object
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
nodes = graph.get('nodes', [])
|
|
||||||
node_mapping = {node.get('id'): node for node in nodes}
|
|
||||||
|
|
||||||
variable_selectors = node_data.outputs
|
|
||||||
|
|
||||||
generate_nodes = []
|
|
||||||
for variable_selector in variable_selectors:
|
|
||||||
if not variable_selector.value_selector:
|
|
||||||
continue
|
|
||||||
|
|
||||||
node_id = variable_selector.value_selector[0]
|
|
||||||
if node_id != 'sys' and node_id in node_mapping:
|
|
||||||
node = node_mapping[node_id]
|
|
||||||
node_type = node.get('data', {}).get('type')
|
|
||||||
if node_type == NodeType.LLM.value and variable_selector.value_selector[1] == 'text':
|
|
||||||
generate_nodes.append(node_id)
|
|
||||||
|
|
||||||
# remove duplicates
|
|
||||||
generate_nodes = list(set(generate_nodes))
|
|
||||||
|
|
||||||
return generate_nodes
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
||||||
"""
|
"""
|
||||||
|
@ -61,7 +61,9 @@ class EndStreamGeneratorRouter:
|
|||||||
value_selectors.append(variable_selector.value_selector)
|
value_selectors.append(variable_selector.value_selector)
|
||||||
|
|
||||||
# remove duplicates
|
# remove duplicates
|
||||||
value_selectors = list(set(value_selectors))
|
value_selector_tuples = [tuple(item) for item in value_selectors]
|
||||||
|
unique_value_selector_tuples = list(set(value_selector_tuples))
|
||||||
|
value_selectors = [list(item) for item in unique_value_selector_tuples]
|
||||||
|
|
||||||
return value_selectors
|
return value_selectors
|
||||||
|
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import cast
|
|
||||||
|
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.graph_engine.entities.event import (
|
from core.workflow.graph_engine.entities.event import (
|
||||||
@ -9,7 +8,6 @@ from core.workflow.graph_engine.entities.event import (
|
|||||||
NodeRunSucceededEvent,
|
NodeRunSucceededEvent,
|
||||||
)
|
)
|
||||||
from core.workflow.graph_engine.entities.graph import Graph
|
from core.workflow.graph_engine.entities.graph import Graph
|
||||||
from core.workflow.nodes.answer.entities import GenerateRouteChunk, TextGenerateRouteChunk, VarGenerateRouteChunk
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -20,10 +18,7 @@ class EndStreamProcessor:
|
|||||||
self.graph = graph
|
self.graph = graph
|
||||||
self.variable_pool = variable_pool
|
self.variable_pool = variable_pool
|
||||||
self.stream_param = graph.end_stream_param
|
self.stream_param = graph.end_stream_param
|
||||||
self.end_streamed_variable_selectors: dict[str, list[str]] = {
|
self.end_streamed_variable_selectors = graph.end_stream_param.end_stream_variable_selector_mapping.copy()
|
||||||
end_node_id: [] for end_node_id in graph.end_stream_param.end_stream_variable_selector_mapping
|
|
||||||
}
|
|
||||||
|
|
||||||
self.rest_node_ids = graph.node_ids.copy()
|
self.rest_node_ids = graph.node_ids.copy()
|
||||||
self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {}
|
self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {}
|
||||||
|
|
||||||
@ -33,43 +28,37 @@ class EndStreamProcessor:
|
|||||||
for event in generator:
|
for event in generator:
|
||||||
if isinstance(event, NodeRunStreamChunkEvent):
|
if isinstance(event, NodeRunStreamChunkEvent):
|
||||||
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
|
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
|
||||||
stream_out_answer_node_ids = self.current_stream_chunk_generating_node_ids[
|
stream_out_end_node_ids = self.current_stream_chunk_generating_node_ids[
|
||||||
event.route_node_state.node_id
|
event.route_node_state.node_id
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
stream_out_answer_node_ids = self._get_stream_out_answer_node_ids(event)
|
stream_out_end_node_ids = self._get_stream_out_end_node_ids(event)
|
||||||
self.current_stream_chunk_generating_node_ids[
|
self.current_stream_chunk_generating_node_ids[
|
||||||
event.route_node_state.node_id
|
event.route_node_state.node_id
|
||||||
] = stream_out_answer_node_ids
|
] = stream_out_end_node_ids
|
||||||
|
|
||||||
for _ in stream_out_answer_node_ids:
|
for _ in stream_out_end_node_ids:
|
||||||
yield event
|
yield event
|
||||||
elif isinstance(event, NodeRunSucceededEvent):
|
elif isinstance(event, NodeRunSucceededEvent):
|
||||||
yield event
|
yield event
|
||||||
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
|
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
|
||||||
# update self.route_position after all stream event finished
|
|
||||||
for answer_node_id in self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]:
|
|
||||||
self.route_position[answer_node_id] += 1
|
|
||||||
|
|
||||||
del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]
|
del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]
|
||||||
|
|
||||||
# remove unreachable nodes
|
# remove unreachable nodes
|
||||||
self._remove_unreachable_nodes(event)
|
self._remove_unreachable_nodes(event)
|
||||||
|
|
||||||
# generate stream outputs
|
|
||||||
yield from self._generate_stream_outputs_when_node_finished(event)
|
|
||||||
else:
|
else:
|
||||||
yield event
|
yield event
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
self.route_position = {}
|
self.end_streamed_variable_selectors = {}
|
||||||
for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items():
|
self.end_streamed_variable_selectors: dict[str, list[str]] = {
|
||||||
self.route_position[answer_node_id] = 0
|
end_node_id: [] for end_node_id in self.graph.end_stream_param.end_stream_variable_selector_mapping
|
||||||
|
}
|
||||||
self.rest_node_ids = self.graph.node_ids.copy()
|
self.rest_node_ids = self.graph.node_ids.copy()
|
||||||
|
self.current_stream_chunk_generating_node_ids = {}
|
||||||
|
|
||||||
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None:
|
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None:
|
||||||
finished_node_id = event.route_node_state.node_id
|
finished_node_id = event.route_node_state.node_id
|
||||||
|
|
||||||
if finished_node_id not in self.rest_node_ids:
|
if finished_node_id not in self.rest_node_ids:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -113,59 +102,7 @@ class EndStreamProcessor:
|
|||||||
|
|
||||||
self._remove_node_ids_in_unreachable_branch(edge.target_node_id, reachable_node_ids)
|
self._remove_node_ids_in_unreachable_branch(edge.target_node_id, reachable_node_ids)
|
||||||
|
|
||||||
def _generate_stream_outputs_when_node_finished(self,
|
def _get_stream_out_end_node_ids(self, event: NodeRunStreamChunkEvent) -> list[str]:
|
||||||
event: NodeRunSucceededEvent
|
|
||||||
) -> Generator[GraphEngineEvent, None, None]:
|
|
||||||
"""
|
|
||||||
Generate stream outputs.
|
|
||||||
:param event: node run succeeded event
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
for answer_node_id, position in self.route_position.items():
|
|
||||||
# all depends on answer node id not in rest node ids
|
|
||||||
if (event.route_node_state.node_id != answer_node_id
|
|
||||||
and (answer_node_id not in self.rest_node_ids
|
|
||||||
or not all(dep_id not in self.rest_node_ids
|
|
||||||
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]))):
|
|
||||||
continue
|
|
||||||
|
|
||||||
route_position = self.route_position[answer_node_id]
|
|
||||||
route_chunks = self.generate_routes.answer_generate_route[answer_node_id][route_position:]
|
|
||||||
|
|
||||||
for route_chunk in route_chunks:
|
|
||||||
if route_chunk.type == GenerateRouteChunk.ChunkType.TEXT:
|
|
||||||
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
|
|
||||||
yield NodeRunStreamChunkEvent(
|
|
||||||
chunk_content=route_chunk.text,
|
|
||||||
route_node_state=event.route_node_state,
|
|
||||||
parallel_id=event.parallel_id,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
|
|
||||||
value_selector = route_chunk.value_selector
|
|
||||||
if not value_selector:
|
|
||||||
break
|
|
||||||
|
|
||||||
value = self.variable_pool.get(
|
|
||||||
value_selector
|
|
||||||
)
|
|
||||||
|
|
||||||
if value is None:
|
|
||||||
break
|
|
||||||
|
|
||||||
text = value.markdown
|
|
||||||
|
|
||||||
if text:
|
|
||||||
yield NodeRunStreamChunkEvent(
|
|
||||||
chunk_content=text,
|
|
||||||
from_variable_selector=value_selector,
|
|
||||||
route_node_state=event.route_node_state,
|
|
||||||
parallel_id=event.parallel_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.route_position[answer_node_id] += 1
|
|
||||||
|
|
||||||
def _get_stream_out_answer_node_ids(self, event: NodeRunStreamChunkEvent) -> list[str]:
|
|
||||||
"""
|
"""
|
||||||
Is stream out support
|
Is stream out support
|
||||||
:param event: queue text chunk event
|
:param event: queue text chunk event
|
||||||
@ -178,30 +115,17 @@ class EndStreamProcessor:
|
|||||||
if not stream_output_value_selector:
|
if not stream_output_value_selector:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
stream_out_answer_node_ids = []
|
stream_out_end_node_ids = []
|
||||||
for answer_node_id, position in self.route_position.items():
|
for end_node_id, variable_selectors in self.end_streamed_variable_selectors.items():
|
||||||
if answer_node_id not in self.rest_node_ids:
|
if end_node_id not in self.rest_node_ids:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# all depends on answer node id not in rest node ids
|
# all depends on end node id not in rest node ids
|
||||||
if all(dep_id not in self.rest_node_ids
|
if all(dep_id not in self.rest_node_ids
|
||||||
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]):
|
for dep_id in self.stream_param.end_dependencies[end_node_id]):
|
||||||
route_position = self.route_position[answer_node_id]
|
if stream_output_value_selector not in variable_selectors:
|
||||||
if route_position >= len(self.generate_routes.answer_generate_route[answer_node_id]):
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
route_chunk = self.generate_routes.answer_generate_route[answer_node_id][route_position]
|
stream_out_end_node_ids.append(end_node_id)
|
||||||
|
|
||||||
if route_chunk.type != GenerateRouteChunk.ChunkType.VAR:
|
return stream_out_end_node_ids
|
||||||
continue
|
|
||||||
|
|
||||||
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
|
|
||||||
value_selector = route_chunk.value_selector
|
|
||||||
|
|
||||||
# check chunk node id is before current node id or equal to current node id
|
|
||||||
if value_selector != stream_output_value_selector:
|
|
||||||
continue
|
|
||||||
|
|
||||||
stream_out_answer_node_ids.append(answer_node_id)
|
|
||||||
|
|
||||||
return stream_out_answer_node_ids
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.workflow.entities.node_entities import SystemVariable, UserFrom
|
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, SystemVariable, UserFrom
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.graph_engine.entities.event import (
|
from core.workflow.graph_engine.entities.event import (
|
||||||
BaseNodeEvent,
|
BaseNodeEvent,
|
||||||
@ -16,12 +16,267 @@ from core.workflow.graph_engine.entities.event import (
|
|||||||
from core.workflow.graph_engine.entities.graph import Graph
|
from core.workflow.graph_engine.entities.graph import Graph
|
||||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||||
from models.workflow import WorkflowType
|
from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
|
||||||
|
from core.workflow.nodes.llm.llm_node import LLMNode
|
||||||
|
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||||
|
|
||||||
|
|
||||||
@patch('extensions.ext_database.db.session.remove')
|
@patch('extensions.ext_database.db.session.remove')
|
||||||
@patch('extensions.ext_database.db.session.close')
|
@patch('extensions.ext_database.db.session.close')
|
||||||
def test_run_parallel(mock_close, mock_remove):
|
def test_run_parallel_in_workflow(mock_close, mock_remove, mocker):
|
||||||
|
graph_config = {
|
||||||
|
"edges": [
|
||||||
|
{
|
||||||
|
"id": "1",
|
||||||
|
"source": "start",
|
||||||
|
"target": "llm1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "2",
|
||||||
|
"source": "llm1",
|
||||||
|
"target": "llm2",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "3",
|
||||||
|
"source": "llm1",
|
||||||
|
"target": "llm3",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "4",
|
||||||
|
"source": "llm2",
|
||||||
|
"target": "end1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "5",
|
||||||
|
"source": "llm3",
|
||||||
|
"target": "end2",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"nodes": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "start",
|
||||||
|
"title": "start",
|
||||||
|
"variables": [{
|
||||||
|
"label": "query",
|
||||||
|
"max_length": 48,
|
||||||
|
"options": [],
|
||||||
|
"required": True,
|
||||||
|
"type": "text-input",
|
||||||
|
"variable": "query"
|
||||||
|
}]
|
||||||
|
},
|
||||||
|
"id": "start"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "llm",
|
||||||
|
"title": "llm1",
|
||||||
|
"context": {
|
||||||
|
"enabled": False,
|
||||||
|
"variable_selector": []
|
||||||
|
},
|
||||||
|
"model": {
|
||||||
|
"completion_params": {
|
||||||
|
"temperature": 0.7
|
||||||
|
},
|
||||||
|
"mode": "chat",
|
||||||
|
"name": "gpt-4o",
|
||||||
|
"provider": "openai"
|
||||||
|
},
|
||||||
|
"prompt_template": [{
|
||||||
|
"role": "system",
|
||||||
|
"text": "say hi"
|
||||||
|
}, {
|
||||||
|
"role": "user",
|
||||||
|
"text": "{{#start.query#}}"
|
||||||
|
}],
|
||||||
|
"vision": {
|
||||||
|
"configs": {
|
||||||
|
"detail": "high"
|
||||||
|
},
|
||||||
|
"enabled": False
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "llm1"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "llm",
|
||||||
|
"title": "llm2",
|
||||||
|
"context": {
|
||||||
|
"enabled": False,
|
||||||
|
"variable_selector": []
|
||||||
|
},
|
||||||
|
"model": {
|
||||||
|
"completion_params": {
|
||||||
|
"temperature": 0.7
|
||||||
|
},
|
||||||
|
"mode": "chat",
|
||||||
|
"name": "gpt-4o",
|
||||||
|
"provider": "openai"
|
||||||
|
},
|
||||||
|
"prompt_template": [{
|
||||||
|
"role": "system",
|
||||||
|
"text": "say bye"
|
||||||
|
}, {
|
||||||
|
"role": "user",
|
||||||
|
"text": "{{#start.query#}}"
|
||||||
|
}],
|
||||||
|
"vision": {
|
||||||
|
"configs": {
|
||||||
|
"detail": "high"
|
||||||
|
},
|
||||||
|
"enabled": False
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "llm2",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "llm",
|
||||||
|
"title": "llm3",
|
||||||
|
"context": {
|
||||||
|
"enabled": False,
|
||||||
|
"variable_selector": []
|
||||||
|
},
|
||||||
|
"model": {
|
||||||
|
"completion_params": {
|
||||||
|
"temperature": 0.7
|
||||||
|
},
|
||||||
|
"mode": "chat",
|
||||||
|
"name": "gpt-4o",
|
||||||
|
"provider": "openai"
|
||||||
|
},
|
||||||
|
"prompt_template": [{
|
||||||
|
"role": "system",
|
||||||
|
"text": "say good morning"
|
||||||
|
}, {
|
||||||
|
"role": "user",
|
||||||
|
"text": "{{#start.query#}}"
|
||||||
|
}],
|
||||||
|
"vision": {
|
||||||
|
"configs": {
|
||||||
|
"detail": "high"
|
||||||
|
},
|
||||||
|
"enabled": False
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "llm3",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "end",
|
||||||
|
"title": "end1",
|
||||||
|
"outputs": [{
|
||||||
|
"value_selector": ["llm2", "text"],
|
||||||
|
"variable": "result2"
|
||||||
|
}, {
|
||||||
|
"value_selector": ["start", "query"],
|
||||||
|
"variable": "query"
|
||||||
|
}],
|
||||||
|
},
|
||||||
|
"id": "end1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "end",
|
||||||
|
"title": "end2",
|
||||||
|
"outputs": [{
|
||||||
|
"value_selector": ["llm1", "text"],
|
||||||
|
"variable": "result1"
|
||||||
|
}, {
|
||||||
|
"value_selector": ["llm3", "text"],
|
||||||
|
"variable": "result3"
|
||||||
|
}],
|
||||||
|
},
|
||||||
|
"id": "end2",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
graph = Graph.init(
|
||||||
|
graph_config=graph_config
|
||||||
|
)
|
||||||
|
|
||||||
|
variable_pool = VariablePool(system_variables={
|
||||||
|
SystemVariable.FILES: [],
|
||||||
|
SystemVariable.USER_ID: 'aaa'
|
||||||
|
}, user_inputs={
|
||||||
|
"query": "hi"
|
||||||
|
})
|
||||||
|
|
||||||
|
graph_engine = GraphEngine(
|
||||||
|
tenant_id="111",
|
||||||
|
app_id="222",
|
||||||
|
workflow_type=WorkflowType.WORKFLOW,
|
||||||
|
workflow_id="333",
|
||||||
|
user_id="444",
|
||||||
|
user_from=UserFrom.ACCOUNT,
|
||||||
|
invoke_from=InvokeFrom.WEB_APP,
|
||||||
|
call_depth=0,
|
||||||
|
graph=graph,
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
max_execution_steps=500,
|
||||||
|
max_execution_time=1200
|
||||||
|
)
|
||||||
|
|
||||||
|
def llm_generator(self):
|
||||||
|
contents = [
|
||||||
|
'hi',
|
||||||
|
'bye',
|
||||||
|
'good morning'
|
||||||
|
]
|
||||||
|
|
||||||
|
yield RunStreamChunkEvent(
|
||||||
|
chunk_content=contents[int(self.node_id[-1]) - 1],
|
||||||
|
from_variable_selector=[self.node_id, 'text']
|
||||||
|
)
|
||||||
|
|
||||||
|
yield RunCompletedEvent(
|
||||||
|
run_result=NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||||
|
inputs={},
|
||||||
|
process_data={},
|
||||||
|
outputs={},
|
||||||
|
metadata={
|
||||||
|
NodeRunMetadataKey.TOTAL_TOKENS: 1,
|
||||||
|
NodeRunMetadataKey.TOTAL_PRICE: 1,
|
||||||
|
NodeRunMetadataKey.CURRENCY: 'USD'
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
print("")
|
||||||
|
|
||||||
|
with patch.object(LLMNode, '_run', new=llm_generator):
|
||||||
|
items = []
|
||||||
|
generator = graph_engine.run()
|
||||||
|
for item in generator:
|
||||||
|
print(type(item), item)
|
||||||
|
items.append(item)
|
||||||
|
if isinstance(item, NodeRunSucceededEvent):
|
||||||
|
assert item.route_node_state.status == RouteNodeState.Status.SUCCESS
|
||||||
|
|
||||||
|
assert not isinstance(item, NodeRunFailedEvent)
|
||||||
|
assert not isinstance(item, GraphRunFailedEvent)
|
||||||
|
|
||||||
|
if isinstance(item, BaseNodeEvent) and item.route_node_state.node_id in [
|
||||||
|
'llm2', 'llm3', 'end1', 'end2'
|
||||||
|
]:
|
||||||
|
assert item.parallel_id is not None
|
||||||
|
|
||||||
|
assert len(items) == 17
|
||||||
|
assert isinstance(items[0], GraphRunStartedEvent)
|
||||||
|
assert isinstance(items[1], NodeRunStartedEvent)
|
||||||
|
assert items[1].route_node_state.node_id == 'start'
|
||||||
|
assert isinstance(items[2], NodeRunSucceededEvent)
|
||||||
|
assert items[2].route_node_state.node_id == 'start'
|
||||||
|
|
||||||
|
|
||||||
|
@patch('extensions.ext_database.db.session.remove')
|
||||||
|
@patch('extensions.ext_database.db.session.close')
|
||||||
|
def test_run_parallel_in_chatflow(mock_close, mock_remove):
|
||||||
graph_config = {
|
graph_config = {
|
||||||
"edges": [
|
"edges": [
|
||||||
{
|
{
|
||||||
@ -291,7 +546,7 @@ def test_run_branch(mock_close, mock_remove):
|
|||||||
items = []
|
items = []
|
||||||
generator = graph_engine.run()
|
generator = graph_engine.run()
|
||||||
for item in generator:
|
for item in generator:
|
||||||
print(type(item), item)
|
# print(type(item), item)
|
||||||
items.append(item)
|
items.append(item)
|
||||||
|
|
||||||
assert len(items) == 10
|
assert len(items) == 10
|
||||||
|
Loading…
Reference in New Issue
Block a user