Skip to content

Commit

Permalink
Merge branch 'feat/continue-on-error' into deploy/dev
Browse files Browse the repository at this point in the history
  • Loading branch information
Novice Lee authored and Novice Lee committed Dec 4, 2024
2 parents 0dcb03f + 560b7fe commit 1ed9dcd
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 11 deletions.
1 change: 1 addition & 0 deletions api/core/workflow/entities/node_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class NodeRunMetadataKey(StrEnum):
PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id"
PARALLEL_MODE_RUN_ID = "parallel_mode_run_id"
ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field


class NodeRunResult(BaseModel):
Expand Down
2 changes: 1 addition & 1 deletion api/core/workflow/graph_engine/entities/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class GraphRunSucceededEvent(BaseGraphEvent):

class GraphRunFailedEvent(BaseGraphEvent):
error: str = Field(..., description="failed reason")
exceptions_count: int = Field(..., description="exception count")
exceptions_count: Optional[int] = Field(description="exception count", default=0)


class GraphRunPartialSucceededEvent(BaseGraphEvent):
Expand Down
27 changes: 23 additions & 4 deletions api/core/workflow/graph_engine/graph_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def run(self) -> Generator[GraphEngineEvent, None, None]:
return
except Exception as e:
logger.exception("Unknown Error when graph running")
yield GraphRunFailedEvent(error=str(e), handle_exceptions=handle_exceptions)
yield GraphRunFailedEvent(error=str(e), exceptions_count=len(handle_exceptions))
self._release_thread()
raise e

Expand Down Expand Up @@ -314,6 +314,11 @@ def _run(
break

if len(edge_mappings) == 1:
if (
previous_route_node_state.status == RouteNodeState.Status.EXCEPTION
and node_instance.node_data.error_strategy == ErrorStrategy.FAIL_BRANCH
):
break
edge = edge_mappings[0]

if edge.run_condition:
Expand All @@ -332,7 +337,6 @@ def _run(
next_node_id = edge.target_node_id
else:
final_node_id = None

if any(edge.run_condition for edge in edge_mappings):
# if nodes has run conditions, get node id which branch to take based on the run condition results
condition_edge_mappings = {}
Expand Down Expand Up @@ -611,6 +615,7 @@ def _run_node(
handle_exceptions=handle_exceptions,
)
route_node_state.node_run_result = run_result
route_node_state.status = RouteNodeState.Status.EXCEPTION
if run_result.outputs:
for variable_key, variable_value in run_result.outputs.items():
# append variables to variable pool recursively
Expand Down Expand Up @@ -808,15 +813,29 @@ def _handle_continue_on_error(
"status": WorkflowNodeExecutionStatus.EXCEPTION,
"error": error_result.error,
"inputs": error_result.inputs,
"metadata": {
NodeRunMetadataKey.ERROR_STRATEGY: node_instance.node_data.error_strategy,
},
}

if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE:
return NodeRunResult(
**node_error_args,
outputs=node_instance.node_data.default_value_dict,
outputs={
**node_instance.node_data.default_value_dict,
"error_message": error_result.error,
"error_type": error_result.error_type,
},
)

return NodeRunResult(**node_error_args, outputs=None, edge_source_handle=FailBranchSourceHandle.FAILED)
return NodeRunResult(
**node_error_args,
outputs={
"error_message": error_result.error,
"error_type": error_result.error_type,
},
edge_source_handle=FailBranchSourceHandle.FAILED,
)


class GraphRunFailedError(Exception):
Expand Down
4 changes: 3 additions & 1 deletion api/core/workflow/nodes/base/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class DefaultValueType(StrEnum):
ARRAY_NUMBER = "Array[Number]"
ARRAY_STRING = "Array[String]"
ARRAY_OBJECT = "Array[Object]"
ARRAY_FILES = "Array[Files]"
ARRAY_FILES = "Array[File]"


NumberType = Union[int, float]
Expand Down Expand Up @@ -101,6 +101,8 @@ def validate_value_type(self) -> Any:
# handle files type
pass

return self


class BaseNodeData(ABC, BaseModel):
title: str
Expand Down
38 changes: 34 additions & 4 deletions api/services/workflow_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.nodes import NodeType
from core.workflow.nodes.enums import ErrorStrategy
from core.workflow.nodes.event import RunCompletedEvent
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
from core.workflow.workflow_entry import WorkflowEntry
Expand Down Expand Up @@ -237,8 +238,34 @@ def run_draft_workflow_node(

if not node_run_result:
raise ValueError("Node run failed with no run result")

run_succeeded = True if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED else False
if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.should_continue_on_error:
node_error_args = {
"status": WorkflowNodeExecutionStatus.EXCEPTION,
"error": node_run_result.error,
"inputs": node_run_result.inputs,
"metadata": {"error_strategy": node_instance.node_data.error_strategy},
}
if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE:
node_run_result = NodeRunResult(
**node_error_args,
outputs={
**node_instance.node_data.default_value_dict,
"error_message": node_run_result.error,
"error_type": node_run_result.error_type,
},
)

node_run_result = NodeRunResult(
**node_error_args,
outputs={
"error_message": node_run_result.error,
"error_type": node_run_result.error_type,
},
)
run_succeeded = node_run_result.status in (
WorkflowNodeExecutionStatus.SUCCEEDED,
WorkflowNodeExecutionStatus.EXCEPTION,
)
error = node_run_result.error if not run_succeeded else None
except WorkflowNodeRunFailedError as e:
node_instance = e.node_instance
Expand All @@ -260,7 +287,6 @@ def run_draft_workflow_node(
workflow_node_execution.created_by = account.id
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None)

if run_succeeded and node_run_result:
# create workflow node execution
inputs = WorkflowEntry.handle_special_values(node_run_result.inputs) if node_run_result.inputs else None
Expand All @@ -277,7 +303,11 @@ def run_draft_workflow_node(
workflow_node_execution.execution_metadata = (
json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None
)
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
elif node_run_result.status == WorkflowNodeExecutionStatus.EXCEPTION:
workflow_node_execution.status = WorkflowNodeExecutionStatus.EXCEPTION.value
workflow_node_execution.error = node_run_result.error
else:
# create workflow node execution
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ def test_variable_pool_error_type_variable():
list(graph_engine.run())
error_message = graph_engine.graph_runtime_state.variable_pool.get(["node", "error_message"])
error_type = graph_engine.graph_runtime_state.variable_pool.get(["node", "error_type"])
assert error_message.value == "Request failed with status code 404"
assert error_message != None
assert error_type.value == "HTTPResponseCodeError"


Expand Down Expand Up @@ -554,3 +554,25 @@ def main() -> dict:
for e in events
)
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 2


def test_no_node_in_fail_branch_continue_on_error():
"""Test HTTP node with fail-branch error strategy"""
graph_config = {
"edges": FAIL_BRANCH_EDGES[:-1],
"nodes": [
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
{
"data": {"title": "success", "type": "answer", "answer": "HTTP request successful"},
"id": "success",
},
ContinueOnErrorTestHelper.get_http_node(),
],
}

graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())

assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
assert any(isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {} for e in events)
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 0

0 comments on commit 1ed9dcd

Please sign in to comment.