Skip to content

Commit

Permalink
Merge branch 'feat/continue-on-error' into deploy/dev
Browse files Browse the repository at this point in the history
  • Loading branch information
Novice Lee authored and Novice Lee committed Dec 8, 2024
2 parents b8f3097 + 97fdea0 commit 4dfe73a
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 16 deletions.
2 changes: 1 addition & 1 deletion api/core/app/apps/workflow/generate_task_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ def _process_stream_response(
"error": event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
"conversation_id": None,
"trace_manager": trace_manager,
"exceptions_count": event.exceptions_count,
"exceptions_count": event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0,
}
workflow_run = self._handle_workflow_run_failed(**handle_args)

Expand Down
14 changes: 6 additions & 8 deletions api/core/workflow/graph_engine/entities/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,20 @@ def init(cls, graph_config: Mapping[str, Any], root_node_id: Optional[str] = Non
edge_configs = graph_config.get("edges")
if edge_configs is None:
edge_configs = []
# node configs
node_configs = graph_config.get("nodes")
if not node_configs:
raise ValueError("Graph must have at least one node")

edge_configs = cast(list, edge_configs)
node_configs = cast(list, node_configs)

# reorganize edges mapping
edge_mapping: dict[str, list[GraphEdge]] = {}
reverse_edge_mapping: dict[str, list[GraphEdge]] = {}
target_edge_ids = set()
fail_branch_source_node_id = [
edge["source"] for edge in edge_configs if edge.get("sourceHandle") == "fail-branch"
node["id"] for node in node_configs if node["data"].get("error_strategy") == "fail-branch"
]
for edge_config in edge_configs:
source_node_id = edge_config.get("source")
Expand Down Expand Up @@ -111,13 +116,6 @@ def init(cls, graph_config: Mapping[str, Any], root_node_id: Optional[str] = Non
edge_mapping[source_node_id].append(graph_edge)
reverse_edge_mapping[target_node_id].append(graph_edge)

# node configs
node_configs = graph_config.get("nodes")
if not node_configs:
raise ValueError("Graph must have at least one node")

node_configs = cast(list, node_configs)

# fetch nodes that have no predecessor node
root_node_configs = []
all_node_id_config_mapping: dict[str, dict] = {}
Expand Down
17 changes: 11 additions & 6 deletions api/core/workflow/graph_engine/graph_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def __init__(
def run(self) -> Generator[GraphEngineEvent, None, None]:
# trigger graph run start event
yield GraphRunStartedEvent()
handle_exceptions = []

try:
if self.init_params.workflow_type == WorkflowType.CHAT:
Expand All @@ -144,7 +145,6 @@ def run(self) -> Generator[GraphEngineEvent, None, None]:
)

# run graph
handle_exceptions = []
generator = stream_processor.process(
self._run(start_node_id=self.graph.root_node_id, handle_exceptions=handle_exceptions)
)
Expand Down Expand Up @@ -184,22 +184,21 @@ def run(self) -> Generator[GraphEngineEvent, None, None]:
yield GraphRunFailedEvent(error=str(e), exceptions_count=len(handle_exceptions))
return
# count exceptions to determine partial success
exceptions_count = len(handle_exceptions)
if exceptions_count > 0:
if len(handle_exceptions) > 0:
yield GraphRunPartialSucceededEvent(
exceptions_count=exceptions_count, outputs=self.graph_runtime_state.outputs
exceptions_count=len(handle_exceptions), outputs=self.graph_runtime_state.outputs
)
else:
# trigger graph run success event
yield GraphRunSucceededEvent(outputs=self.graph_runtime_state.outputs)
self._release_thread()
except GraphRunFailedError as e:
yield GraphRunFailedEvent(error=e.error, exceptions_count=exceptions_count)
yield GraphRunFailedEvent(error=e.error, exceptions_count=len(handle_exceptions))
self._release_thread()
return
except Exception as e:
logger.exception("Unknown Error when graph running")
yield GraphRunFailedEvent(error=str(e), exceptions_count=exceptions_count)
yield GraphRunFailedEvent(error=str(e), exceptions_count=len(handle_exceptions))
self._release_thread()
raise e

Expand Down Expand Up @@ -376,6 +375,7 @@ def _run(
edge_mappings=sub_edge_mappings,
in_parallel_id=in_parallel_id,
parallel_start_node_id=parallel_start_node_id,
handle_exceptions=handle_exceptions,
)

for item in parallel_generator:
Expand All @@ -401,6 +401,7 @@ def _run(
edge_mappings=edge_mappings,
in_parallel_id=in_parallel_id,
parallel_start_node_id=parallel_start_node_id,
handle_exceptions=handle_exceptions,
)

for item in parallel_generator:
Expand All @@ -422,6 +423,7 @@ def _run_parallel_branches(
edge_mappings: list[GraphEdge],
in_parallel_id: Optional[str] = None,
parallel_start_node_id: Optional[str] = None,
handle_exceptions: list[str] = [],
) -> Generator[GraphEngineEvent | str, None, None]:
# if nodes has no run conditions, parallel run all nodes
parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id)
Expand Down Expand Up @@ -465,6 +467,7 @@ def _run_parallel_branches(
"parallel_start_node_id": edge.target_node_id,
"parent_parallel_id": in_parallel_id,
"parent_parallel_start_node_id": parallel_start_node_id,
"handle_exceptions": handle_exceptions,
},
)

Expand Down Expand Up @@ -508,6 +511,7 @@ def _run_parallel_node(
parallel_start_node_id: str,
parent_parallel_id: Optional[str] = None,
parent_parallel_start_node_id: Optional[str] = None,
handle_exceptions: list[str] = [],
) -> None:
"""
Run parallel nodes
Expand All @@ -529,6 +533,7 @@ def _run_parallel_node(
in_parallel_id=parallel_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
handle_exceptions=handle_exceptions,
)

for item in generator:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def _run(self):
"usage": jsonable_encoder(usage),
"finish_reason": finish_reason,
}
outputs = {"class_name": category_name}
outputs = {"class_name": category_name, "class_id": category_id}

return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
Expand Down

0 comments on commit 4dfe73a

Please sign in to comment.