feat: workflow node support retry

This commit is contained in:
Novice Lee 2024-12-17 16:50:07 +08:00
parent a725b8bb6e
commit b99f1a09f4
21 changed files with 646 additions and 142 deletions

View File

@ -22,6 +22,7 @@ from core.app.entities.queue_entities import (
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeRetryEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
@ -328,6 +329,22 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
workflow_node_execution=workflow_node_execution,
)
if response:
yield response
elif isinstance(
event,
QueueNodeRetryEvent,
):
workflow_node_execution = self._handle_workflow_node_execution_retried(
workflow_run=workflow_run, event=event
)
response = self._workflow_node_retry_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if response:
yield response
elif isinstance(event, QueueParallelBranchRunStartedEvent):

View File

@ -18,6 +18,7 @@ from core.app.entities.queue_entities import (
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeRetryEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
@ -287,6 +288,22 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
workflow_node_execution=workflow_node_execution,
)
if response:
yield response
elif isinstance(
event,
QueueNodeRetryEvent,
):
workflow_node_execution = self._handle_workflow_node_execution_retried(
workflow_run=workflow_run, event=event
)
response = self._workflow_node_retry_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if response:
yield response
elif isinstance(event, QueueParallelBranchRunStartedEvent):

View File

@ -11,6 +11,7 @@ from core.app.entities.queue_entities import (
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeRetryEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
@ -38,6 +39,7 @@ from core.workflow.graph_engine.entities.event import (
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunRetrieverResourceEvent,
NodeRunRetryEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
@ -420,6 +422,35 @@ class WorkflowBasedAppRunner(AppRunner):
error=event.error if isinstance(event, IterationRunFailedEvent) else None,
)
)
elif isinstance(event, NodeRunRetryEvent):
self._publish_event(
QueueNodeRetryEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.start_at,
inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result
else {},
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result
else {},
outputs=event.route_node_state.node_run_result.outputs
if event.route_node_state.node_run_result
else {},
error=event.error,
execution_metadata=event.route_node_state.node_run_result.metadata
if event.route_node_state.node_run_result
else {},
in_iteration_id=event.in_iteration_id,
retry_index=event.retry_index,
)
)
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
"""

View File

@ -43,6 +43,7 @@ class QueueEvent(StrEnum):
ERROR = "error"
PING = "ping"
STOP = "stop"
RETRY = "retry"
class AppQueueEvent(BaseModel):
@ -313,6 +314,36 @@ class QueueNodeSucceededEvent(AppQueueEvent):
iteration_duration_map: Optional[dict[str, float]] = None
class QueueNodeRetryEvent(AppQueueEvent):
"""QueueNodeRetryEvent entity"""
event: QueueEvent = QueueEvent.RETRY
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
start_at: datetime
inputs: Optional[dict[str, Any]] = None
process_data: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
error: str
retry_index: int # retry index
class QueueNodeInIterationFailedEvent(AppQueueEvent):
"""
QueueNodeInIterationFailedEvent entity

View File

@ -52,6 +52,7 @@ class StreamEvent(Enum):
WORKFLOW_FINISHED = "workflow_finished"
NODE_STARTED = "node_started"
NODE_FINISHED = "node_finished"
NODE_RETRY = "node_retry"
PARALLEL_BRANCH_STARTED = "parallel_branch_started"
PARALLEL_BRANCH_FINISHED = "parallel_branch_finished"
ITERATION_STARTED = "iteration_started"
@ -342,6 +343,75 @@ class NodeFinishStreamResponse(StreamResponse):
}
class NodeRetryStreamResponse(StreamResponse):
"""
NodeFinishStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
"""
id: str
node_id: str
node_type: str
title: str
index: int
predecessor_node_id: Optional[str] = None
inputs: Optional[dict] = None
process_data: Optional[dict] = None
outputs: Optional[dict] = None
status: str
error: Optional[str] = None
elapsed_time: float
execution_metadata: Optional[dict] = None
created_at: int
finished_at: int
files: Optional[Sequence[Mapping[str, Any]]] = []
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
parent_parallel_id: Optional[str] = None
parent_parallel_start_node_id: Optional[str] = None
iteration_id: Optional[str] = None
retry_index: int = 0
event: StreamEvent = StreamEvent.NODE_RETRY
workflow_run_id: str
data: Data
def to_ignore_detail_dict(self):
return {
"event": self.event.value,
"task_id": self.task_id,
"workflow_run_id": self.workflow_run_id,
"data": {
"id": self.data.id,
"node_id": self.data.node_id,
"node_type": self.data.node_type,
"title": self.data.title,
"index": self.data.index,
"predecessor_node_id": self.data.predecessor_node_id,
"inputs": None,
"process_data": None,
"outputs": None,
"status": self.data.status,
"error": None,
"elapsed_time": self.data.elapsed_time,
"execution_metadata": None,
"created_at": self.data.created_at,
"finished_at": self.data.finished_at,
"files": [],
"parallel_id": self.data.parallel_id,
"parallel_start_node_id": self.data.parallel_start_node_id,
"parent_parallel_id": self.data.parent_parallel_id,
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
"iteration_id": self.data.iteration_id,
"retry_index": self.data.retry_index,
},
}
class ParallelBranchStartStreamResponse(StreamResponse):
"""
ParallelBranchStartStreamResponse entity

View File

@ -15,6 +15,7 @@ from core.app.entities.queue_entities import (
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeRetryEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
@ -26,6 +27,7 @@ from core.app.entities.task_entities import (
IterationNodeNextStreamResponse,
IterationNodeStartStreamResponse,
NodeFinishStreamResponse,
NodeRetryStreamResponse,
NodeStartStreamResponse,
ParallelBranchFinishedStreamResponse,
ParallelBranchStartStreamResponse,
@ -423,6 +425,63 @@ class WorkflowCycleManage:
return workflow_node_execution
def _handle_workflow_node_execution_retried(
self, workflow_run: WorkflowRun, event: QueueNodeRetryEvent
) -> WorkflowNodeExecution:
"""
Workflow node execution failed
:param event: queue node failed event
:return:
"""
created_at = event.start_at
finished_at = datetime.now(UTC).replace(tzinfo=None)
elapsed_time = (finished_at - created_at).total_seconds()
workflow_node_execution = WorkflowNodeExecution()
workflow_node_execution.tenant_id = workflow_run.tenant_id
workflow_node_execution.app_id = workflow_run.app_id
workflow_node_execution.workflow_id = workflow_run.workflow_id
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
workflow_node_execution.workflow_run_id = workflow_run.id
workflow_node_execution.node_execution_id = event.node_execution_id
workflow_node_execution.node_id = event.node_id
workflow_node_execution.node_type = event.node_type.value
workflow_node_execution.title = event.node_data.title
workflow_node_execution.status = WorkflowNodeExecutionStatus.RETRY.value
workflow_node_execution.created_by_role = workflow_run.created_by_role
workflow_node_execution.created_by = workflow_run.created_by
workflow_node_execution.created_at = created_at
workflow_node_execution.finished_at = finished_at
workflow_node_execution.elapsed_time = elapsed_time
workflow_node_execution.error = event.error
workflow_node_execution.execution_metadata = json.dumps(
{
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
}
)
with Session(db.engine, expire_on_commit=False) as session:
failed_execution = (
session.query(WorkflowNodeExecution)
.filter(
WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
WorkflowNodeExecution.app_id == workflow_run.app_id,
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
WorkflowNodeExecution.workflow_run_id == workflow_run.id,
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value,
)
.first()
)
node_run_index = failed_execution.index
workflow_node_execution.index = node_run_index
session.add(workflow_node_execution)
session.commit()
session.refresh(workflow_node_execution)
return workflow_node_execution
#################################################
# to stream responses #
#################################################
@ -587,6 +646,51 @@ class WorkflowCycleManage:
),
)
def _workflow_node_retry_to_stream_response(
self,
event: QueueNodeRetryEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
) -> Optional[NodeFinishStreamResponse]:
"""
Workflow node finish to stream response.
:param event: queue node succeeded or failed event
:param task_id: task id
:param workflow_node_execution: workflow node execution
:return:
"""
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
return None
return NodeRetryStreamResponse(
task_id=task_id,
workflow_run_id=workflow_node_execution.workflow_run_id,
data=NodeRetryStreamResponse.Data(
id=workflow_node_execution.id,
node_id=workflow_node_execution.node_id,
node_type=workflow_node_execution.node_type,
index=workflow_node_execution.index,
title=workflow_node_execution.title,
predecessor_node_id=workflow_node_execution.predecessor_node_id,
inputs=workflow_node_execution.inputs_dict,
process_data=workflow_node_execution.process_data_dict,
outputs=workflow_node_execution.outputs_dict,
status=workflow_node_execution.status,
error=workflow_node_execution.error,
elapsed_time=workflow_node_execution.elapsed_time,
execution_metadata=workflow_node_execution.execution_metadata_dict,
created_at=int(workflow_node_execution.created_at.timestamp()),
finished_at=int(workflow_node_execution.finished_at.timestamp()),
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}),
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
retry_index=event.retry_index,
),
)
def _workflow_parallel_branch_start_to_stream_response(
self, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent
) -> ParallelBranchStartStreamResponse:

View File

@ -45,7 +45,6 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
)
retries = 0
stream = kwargs.pop("stream", False)
while retries <= max_retries:
try:
if dify_config.SSRF_PROXY_ALL_URL:

View File

@ -97,6 +97,12 @@ class NodeInIterationFailedEvent(BaseNodeEvent):
error: str = Field(..., description="error")
class NodeRunRetryEvent(BaseNodeEvent):
error: str = Field(..., description="error")
retry_index: int = Field(..., description="which retry attempt is about to be performed")
start_at: datetime = Field(..., description="retry start time")
###########################################
# Parallel Branch Events
###########################################

View File

@ -5,6 +5,7 @@ import uuid
from collections.abc import Generator, Mapping
from concurrent.futures import ThreadPoolExecutor, wait
from copy import copy, deepcopy
from datetime import UTC, datetime
from typing import Any, Optional, cast
from flask import Flask, current_app
@ -24,6 +25,7 @@ from core.workflow.graph_engine.entities.event import (
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunRetrieverResourceEvent,
NodeRunRetryEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
@ -575,7 +577,7 @@ class GraphEngine:
def _run_node(
self,
node_instance: BaseNode,
node_instance: BaseNode[BaseNodeData],
route_node_state: RouteNodeState,
parallel_id: Optional[str] = None,
parallel_start_node_id: Optional[str] = None,
@ -601,36 +603,111 @@ class GraphEngine:
)
db.session.close()
max_retries = node_instance.node_data.retry_config.max_retries
retry_interval = node_instance.node_data.retry_config.retry_interval
retries = 0
shoudl_continue_retry = True
while shoudl_continue_retry and retries <= max_retries:
try:
# run node
retry_start_at = datetime.now(UTC).replace(tzinfo=None)
generator = node_instance.run()
for item in generator:
if isinstance(item, GraphEngineEvent):
if isinstance(item, BaseIterationEvent):
# add parallel info to iteration event
item.parallel_id = parallel_id
item.parallel_start_node_id = parallel_start_node_id
item.parent_parallel_id = parent_parallel_id
item.parent_parallel_start_node_id = parent_parallel_start_node_id
try:
# run node
generator = node_instance.run()
for item in generator:
if isinstance(item, GraphEngineEvent):
if isinstance(item, BaseIterationEvent):
# add parallel info to iteration event
item.parallel_id = parallel_id
item.parallel_start_node_id = parallel_start_node_id
item.parent_parallel_id = parent_parallel_id
item.parent_parallel_start_node_id = parent_parallel_start_node_id
yield item
else:
if isinstance(item, RunCompletedEvent):
run_result = item.run_result
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
if node_instance.should_retry and retries < max_retries:
retries += 1
yield NodeRunRetryEvent(
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
error=run_result.error,
retry_index=retries,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
start_at=retry_start_at,
)
time.sleep(retry_interval / 1000)
continue
route_node_state.set_finished(run_result=run_result)
yield item
else:
if isinstance(item, RunCompletedEvent):
run_result = item.run_result
route_node_state.set_finished(run_result=run_result)
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
if node_instance.should_continue_on_error:
# if run failed, handle error
run_result = self._handle_continue_on_error(
node_instance,
item.run_result,
self.graph_runtime_state.variable_pool,
handle_exceptions=handle_exceptions,
)
route_node_state.node_run_result = run_result
route_node_state.status = RouteNodeState.Status.EXCEPTION
if run_result.outputs:
for variable_key, variable_value in run_result.outputs.items():
# append variables to variable pool recursively
self._append_variables_recursively(
node_id=node_instance.node_id,
variable_key_list=[variable_key],
variable_value=variable_value,
)
yield NodeRunExceptionEvent(
error=run_result.error or "System Error",
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
shoudl_continue_retry = False
else:
yield NodeRunFailedEvent(
error=route_node_state.failed_reason or "Unknown error.",
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
if node_instance.should_continue_on_error:
# if run failed, handle error
run_result = self._handle_continue_on_error(
node_instance,
item.run_result,
self.graph_runtime_state.variable_pool,
handle_exceptions=handle_exceptions,
)
route_node_state.node_run_result = run_result
route_node_state.status = RouteNodeState.Status.EXCEPTION
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
if node_instance.should_continue_on_error and self.graph.edge_mapping.get(
node_instance.node_id
):
run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
# plus state total_tokens
self.graph_runtime_state.total_tokens += int(
run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type]
)
if run_result.llm_usage:
# use the latest usage
self.graph_runtime_state.llm_usage += run_result.llm_usage
# append node output variables to variable pool
if run_result.outputs:
for variable_key, variable_value in run_result.outputs.items():
# append variables to variable pool recursively
@ -639,21 +716,23 @@ class GraphEngine:
variable_key_list=[variable_key],
variable_value=variable_value,
)
yield NodeRunExceptionEvent(
error=run_result.error or "System Error",
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
else:
yield NodeRunFailedEvent(
error=route_node_state.failed_reason or "Unknown error.",
# add parallel info to run result metadata
if parallel_id and parallel_start_node_id:
if not run_result.metadata:
run_result.metadata = {}
run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id
run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = (
parallel_start_node_id
)
if parent_parallel_id and parent_parallel_start_node_id:
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = (
parent_parallel_start_node_id
)
yield NodeRunSucceededEvent(
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
@ -664,108 +743,59 @@ class GraphEngine:
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
shoudl_continue_retry = False
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
if node_instance.should_continue_on_error and self.graph.edge_mapping.get(
node_instance.node_id
):
run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
# plus state total_tokens
self.graph_runtime_state.total_tokens += int(
run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type]
)
if run_result.llm_usage:
# use the latest usage
self.graph_runtime_state.llm_usage += run_result.llm_usage
# append node output variables to variable pool
if run_result.outputs:
for variable_key, variable_value in run_result.outputs.items():
# append variables to variable pool recursively
self._append_variables_recursively(
node_id=node_instance.node_id,
variable_key_list=[variable_key],
variable_value=variable_value,
)
# add parallel info to run result metadata
if parallel_id and parallel_start_node_id:
if not run_result.metadata:
run_result.metadata = {}
run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id
run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id
if parent_parallel_id and parent_parallel_start_node_id:
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = (
parent_parallel_start_node_id
)
yield NodeRunSucceededEvent(
break
elif isinstance(item, RunStreamChunkEvent):
yield NodeRunStreamChunkEvent(
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
chunk_content=item.chunk_content,
from_variable_selector=item.from_variable_selector,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
break
elif isinstance(item, RunStreamChunkEvent):
yield NodeRunStreamChunkEvent(
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
chunk_content=item.chunk_content,
from_variable_selector=item.from_variable_selector,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
elif isinstance(item, RunRetrieverResourceEvent):
yield NodeRunRetrieverResourceEvent(
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
retriever_resources=item.retriever_resources,
context=item.context,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
except GenerateTaskStoppedError:
# trigger node run failed event
route_node_state.status = RouteNodeState.Status.FAILED
route_node_state.failed_reason = "Workflow stopped."
yield NodeRunFailedEvent(
error="Workflow stopped.",
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
return
except Exception as e:
logger.exception(f"Node {node_instance.node_data.title} run failed")
raise e
finally:
db.session.close()
elif isinstance(item, RunRetrieverResourceEvent):
yield NodeRunRetrieverResourceEvent(
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
retriever_resources=item.retriever_resources,
context=item.context,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
except GenerateTaskStoppedError:
# trigger node run failed event
route_node_state.status = RouteNodeState.Status.FAILED
route_node_state.failed_reason = "Workflow stopped."
yield NodeRunFailedEvent(
error="Workflow stopped.",
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
return
except Exception as e:
logger.exception(f"Node {node_instance.node_data.title} run failed")
raise e
finally:
db.session.close()
def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue):
"""

View File

@ -106,12 +106,21 @@ class DefaultValue(BaseModel):
return self
class RetryConfig(BaseModel):
"""node retry config"""
max_retries: int = 0 # max retry times
retry_interval: int = 0 # retry interval in milliseconds
retry_enabled: bool = False # whether retry is enabled
class BaseNodeData(ABC, BaseModel):
title: str
desc: Optional[str] = None
error_strategy: Optional[ErrorStrategy] = None
default_value: Optional[list[DefaultValue]] = None
version: str = "1"
retry_config: RetryConfig = RetryConfig()
@property
def default_value_dict(self):

View File

@ -4,7 +4,7 @@ from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, NodeType
from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, RETRY_ON_ERROR_NODE_TYPE, NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from models.workflow import WorkflowNodeExecutionStatus
@ -143,3 +143,12 @@ class BaseNode(Generic[GenericNodeData]):
bool: if should continue on error
"""
return self.node_data.error_strategy is not None and self.node_type in CONTINUE_ON_ERROR_NODE_TYPE
@property
def should_retry(self) -> bool:
"""judge if should retry
Returns:
bool: if should retry
"""
return self.node_data.retry_config.retry_enabled and self.node_type in RETRY_ON_ERROR_NODE_TYPE

View File

@ -35,3 +35,5 @@ class FailBranchSourceHandle(StrEnum):
CONTINUE_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.CODE, NodeType.TOOL, NodeType.HTTP_REQUEST]
# TODO Remove code node
RETRY_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.CODE, NodeType.TOOL, NodeType.HTTP_REQUEST]

View File

@ -1,4 +1,10 @@
from .event import ModelInvokeCompletedEvent, RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from .event import (
ModelInvokeCompletedEvent,
RunCompletedEvent,
RunRetrieverResourceEvent,
RunRetryEvent,
RunStreamChunkEvent,
)
from .types import NodeEvent
__all__ = [
@ -6,5 +12,6 @@ __all__ = [
"NodeEvent",
"RunCompletedEvent",
"RunRetrieverResourceEvent",
"RunRetryEvent",
"RunStreamChunkEvent",
]

View File

@ -1,3 +1,5 @@
from datetime import datetime
from pydantic import BaseModel, Field
from core.model_runtime.entities.llm_entities import LLMUsage
@ -26,3 +28,11 @@ class ModelInvokeCompletedEvent(BaseModel):
text: str
usage: LLMUsage
finish_reason: str | None = None
class RunRetryEvent(BaseModel):
"""Node Run Retry event"""
error: str = Field(..., description="error")
retry_index: int = Field(..., description="Retry attempt number")
start_at: datetime = Field(..., description="Retry start time")

View File

@ -45,6 +45,7 @@ class Executor:
headers: dict[str, str]
auth: HttpRequestNodeAuthorization
timeout: HttpRequestNodeTimeout
max_retries: int
boundary: str
@ -54,6 +55,7 @@ class Executor:
node_data: HttpRequestNodeData,
timeout: HttpRequestNodeTimeout,
variable_pool: VariablePool,
max_retries: int = dify_config.SSRF_DEFAULT_MAX_RETRIES,
):
# If authorization API key is present, convert the API key using the variable pool
if node_data.authorization.type == "api-key":
@ -73,6 +75,7 @@ class Executor:
self.files = None
self.data = None
self.json = None
self.max_retries = max_retries
# init template
self.variable_pool = variable_pool
@ -207,6 +210,7 @@ class Executor:
"params": self.params,
"timeout": (self.timeout.connect, self.timeout.read, self.timeout.write),
"follow_redirects": True,
"max_retries": self.max_retries,
}
# request_args = {k: v for k, v in request_args.items() if v is not None}
try:

View File

@ -56,10 +56,15 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
def _run(self) -> NodeRunResult:
process_data = {}
try:
executor_config = {
"node_data": self.node_data,
"timeout": self._get_request_timeout(self.node_data),
"variable_pool": self.graph_runtime_state.variable_pool,
}
if self.should_retry:
executor_config["max_retries"] = 0
http_executor = Executor(
node_data=self.node_data,
timeout=self._get_request_timeout(self.node_data),
variable_pool=self.graph_runtime_state.variable_pool,
**executor_config,
)
process_data["request"] = http_executor.to_log()

View File

@ -29,6 +29,7 @@ workflow_run_for_list_fields = {
"created_at": TimestampField,
"finished_at": TimestampField,
"exceptions_count": fields.Integer,
"retry_index": fields.Integer,
}
advanced_chat_workflow_run_for_list_fields = {
@ -45,6 +46,7 @@ advanced_chat_workflow_run_for_list_fields = {
"created_at": TimestampField,
"finished_at": TimestampField,
"exceptions_count": fields.Integer,
"retry_index": fields.Integer,
}
advanced_chat_workflow_run_pagination_fields = {

View File

@ -0,0 +1,33 @@
"""add retry_index field to node-execution model
Revision ID: 348cb0a93d53
Revises: cf8f4fc45278
Create Date: 2024-12-16 01:23:13.093432
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '348cb0a93d53'
down_revision = 'cf8f4fc45278'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op:
batch_op.add_column(sa.Column('retry_index', sa.Integer(), server_default=sa.text('0'), nullable=True))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op:
batch_op.drop_column('retry_index')
# ### end Alembic commands ###

View File

@ -527,6 +527,7 @@ class WorkflowNodeExecutionStatus(Enum):
SUCCEEDED = "succeeded"
FAILED = "failed"
EXCEPTION = "exception"
RETRY = "retry"
@classmethod
def value_of(cls, value: str) -> "WorkflowNodeExecutionStatus":
@ -637,6 +638,7 @@ class WorkflowNodeExecution(db.Model):
created_by_role = db.Column(db.String(255), nullable=False)
created_by = db.Column(StringUUID, nullable=False)
finished_at = db.Column(db.DateTime)
retry_index = db.Column(db.Integer, server_default=db.text("0"))
@property
def created_by_account(self):

View File

@ -2,7 +2,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import (
GraphRunPartialSucceededEvent,
GraphRunSucceededEvent,
NodeRunExceptionEvent,
NodeRunStreamChunkEvent,
)
@ -14,7 +13,9 @@ from models.workflow import WorkflowType
class ContinueOnErrorTestHelper:
@staticmethod
def get_code_node(code: str, error_strategy: str = "fail-branch", default_value: dict | None = None):
def get_code_node(
code: str, error_strategy: str = "fail-branch", default_value: dict | None = None, retry_config: dict = {}
):
"""Helper method to create a code node configuration"""
node = {
"id": "node",
@ -26,6 +27,7 @@ class ContinueOnErrorTestHelper:
"code_language": "python3",
"code": "\n".join([line[4:] for line in code.split("\n")]),
"type": "code",
**retry_config,
},
}
if default_value:

View File

@ -0,0 +1,114 @@
from core.workflow.graph_engine.entities.event import (
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunSucceededEvent,
NodeRunRetryEvent,
)
from tests.unit_tests.core.workflow.nodes.test_continue_on_error import ContinueOnErrorTestHelper
DEFAULT_VALUE_EDGE = [
{
"id": "start-source-node-target",
"source": "start",
"target": "node",
"sourceHandle": "source",
},
{
"id": "node-source-answer-target",
"source": "node",
"target": "answer",
"sourceHandle": "source",
},
]
def test_retry_default_value_partial_success():
"""retry default value node with partial success status"""
error_code = """
def main() -> dict:
return {
"result": 1 / 0,
}
"""
graph_config = {
"edges": DEFAULT_VALUE_EDGE,
"nodes": [
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
ContinueOnErrorTestHelper.get_code_node(
error_code,
"default-value",
[{"key": "result", "type": "number", "value": 132123}],
{"retry_config": {"max_retries": 2, "retry_interval": 1, "retry_enabled": True}},
),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())
assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2
assert events[-1].outputs == {"answer": "132123"}
assert any(isinstance(e, GraphRunPartialSucceededEvent) for e in events)
assert len(events) == 11
def test_retry_success():
"""retry node with success status"""
success_code = """
count = 0
def main():
global count
count += 1
if count == 1:
raise Exception("First attempt fails")
if count == 2:
return {"result": "success"}
"""
graph_config = {
"edges": DEFAULT_VALUE_EDGE,
"nodes": [
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
ContinueOnErrorTestHelper.get_code_node(
success_code,
None,
None,
{"retry_config": {"max_retries": 2, "retry_interval": 1000, "retry_enabled": True}},
),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())
assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2
assert any(isinstance(e, GraphRunSucceededEvent) for e in events)
assert len(events) == 9
def test_retry_failed():
"""retry failed with success status"""
error_code = """
def main() -> dict:
return {
"result": 1 / 0,
}
"""
graph_config = {
"edges": DEFAULT_VALUE_EDGE,
"nodes": [
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
ContinueOnErrorTestHelper.get_code_node(
error_code,
None,
None,
{"retry_config": {"max_retries": 2, "retry_interval": 1, "retry_enabled": True}},
),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())
assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2
assert any(isinstance(e, GraphRunFailedEvent) for e in events)
assert len(events) == 8