Compare commits

...

2 Commits

Author SHA1 Message Date
Novice Lee
5f7771bc47 fix: iteration node use the main thread pool 2024-12-02 21:13:47 +08:00
Novice Lee
286741e139 fix: iteration node use the main thread pool 2024-12-02 21:13:39 +08:00
3 changed files with 10 additions and 5 deletions

View File

@ -15,7 +15,7 @@ class ComfyUIProvider(BuiltinToolProviderController):
try: try:
ws.connect(ws_address) ws.connect(ws_address)
except Exception as e: except Exception:
raise ToolProviderCredentialValidationError(f"can not connect to {ws_address}") raise ToolProviderCredentialValidationError(f"can not connect to {ws_address}")
finally: finally:
ws.close() ws.close()

View File

@ -116,7 +116,7 @@ class IterationNode(BaseNode[IterationNodeData]):
variable_pool.add([self.node_id, "item"], iterator_list_value[0]) variable_pool.add([self.node_id, "item"], iterator_list_value[0])
# init graph engine # init graph engine
from core.workflow.graph_engine.graph_engine import GraphEngine, GraphEngineThreadPool from core.workflow.graph_engine.graph_engine import GraphEngine
graph_engine = GraphEngine( graph_engine = GraphEngine(
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
@ -162,7 +162,8 @@ class IterationNode(BaseNode[IterationNodeData]):
if self.node_data.is_parallel: if self.node_data.is_parallel:
futures: list[Future] = [] futures: list[Future] = []
q = Queue() q = Queue()
thread_pool = GraphEngineThreadPool(max_workers=self.node_data.parallel_nums, max_submit_count=100) thread_pool = graph_engine.workflow_thread_pool_mapping[self.thread_pool_id]
thread_pool._max_workers = self.node_data.parallel_nums
for index, item in enumerate(iterator_list_value): for index, item in enumerate(iterator_list_value):
future: Future = thread_pool.submit( future: Future = thread_pool.submit(
self._run_single_iter_parallel, self._run_single_iter_parallel,
@ -235,7 +236,10 @@ class IterationNode(BaseNode[IterationNodeData]):
run_result=NodeRunResult( run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"output": jsonable_encoder(outputs)}, outputs={"output": jsonable_encoder(outputs)},
metadata={NodeRunMetadataKey.ITERATION_DURATION_MAP: iter_run_map}, metadata={
NodeRunMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
"total_tokens": graph_engine.graph_runtime_state.total_tokens,
},
) )
) )
except IterationNodeError as e: except IterationNodeError as e:
@ -258,6 +262,7 @@ class IterationNode(BaseNode[IterationNodeData]):
run_result=NodeRunResult( run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, status=WorkflowNodeExecutionStatus.FAILED,
error=str(e), error=str(e),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
) )
) )
finally: finally:

View File

@ -10,10 +10,10 @@ from collections.abc import Generator, Mapping
from datetime import datetime from datetime import datetime
from hashlib import sha256 from hashlib import sha256
from typing import Any, Optional, Union from typing import Any, Optional, Union
from zoneinfo import available_timezones
from flask import Response, stream_with_context from flask import Response, stream_with_context
from flask_restful import fields from flask_restful import fields
from zoneinfo import available_timezones
from configs import dify_config from configs import dify_config
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator from core.app.features.rate_limiting.rate_limit import RateLimitGenerator