Compare commits

...

2 Commits

Author SHA1 Message Date
Novice Lee
5f7771bc47 fix: iteration node use the main thread pool 2024-12-02 21:13:47 +08:00
Novice Lee
286741e139 fix: iteration node use the main thread pool 2024-12-02 21:13:39 +08:00
3 changed files with 10 additions and 5 deletions

View File

@ -15,7 +15,7 @@ class ComfyUIProvider(BuiltinToolProviderController):
try:
ws.connect(ws_address)
except Exception as e:
except Exception:
raise ToolProviderCredentialValidationError(f"can not connect to {ws_address}")
finally:
ws.close()

View File

@ -116,7 +116,7 @@ class IterationNode(BaseNode[IterationNodeData]):
variable_pool.add([self.node_id, "item"], iterator_list_value[0])
# init graph engine
from core.workflow.graph_engine.graph_engine import GraphEngine, GraphEngineThreadPool
from core.workflow.graph_engine.graph_engine import GraphEngine
graph_engine = GraphEngine(
tenant_id=self.tenant_id,
@ -162,7 +162,8 @@ class IterationNode(BaseNode[IterationNodeData]):
if self.node_data.is_parallel:
futures: list[Future] = []
q = Queue()
thread_pool = GraphEngineThreadPool(max_workers=self.node_data.parallel_nums, max_submit_count=100)
thread_pool = graph_engine.workflow_thread_pool_mapping[self.thread_pool_id]
thread_pool._max_workers = self.node_data.parallel_nums
for index, item in enumerate(iterator_list_value):
future: Future = thread_pool.submit(
self._run_single_iter_parallel,
@ -235,7 +236,10 @@ class IterationNode(BaseNode[IterationNodeData]):
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"output": jsonable_encoder(outputs)},
metadata={NodeRunMetadataKey.ITERATION_DURATION_MAP: iter_run_map},
metadata={
NodeRunMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
"total_tokens": graph_engine.graph_runtime_state.total_tokens,
},
)
)
except IterationNodeError as e:
@ -258,6 +262,7 @@ class IterationNode(BaseNode[IterationNodeData]):
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
)
)
finally:

View File

@ -10,10 +10,10 @@ from collections.abc import Generator, Mapping
from datetime import datetime
from hashlib import sha256
from typing import Any, Optional, Union
from zoneinfo import available_timezones
from flask import Response, stream_with_context
from flask_restful import fields
from zoneinfo import available_timezones
from configs import dify_config
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator