Skip to content

Commit

Permalink
Merge branch 'feat/continue-on-error' into deploy/dev
Browse files Browse the repository at this point in the history
  • Loading branch information
Novice Lee authored and Novice Lee committed Dec 9, 2024
2 parents 26b924e + fe4e3ae commit 4166f49
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 130 deletions.
144 changes: 73 additions & 71 deletions api/core/workflow/nodes/base/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,86 +20,88 @@ class DefaultValueType(StrEnum):


NumberType = Union[int, float]
ObjectType = dict[str, Any]


class DefaultValue(BaseModel):
value: Union[
str,
NumberType,
ObjectType,
list[NumberType],
list[str],
list[ObjectType],
]
value: Any
type: DefaultValueType
key: str

@staticmethod
def _parse_json(value: str) -> Any:
"""Unified JSON parsing handler"""
try:
return json.loads(value)
except json.JSONDecodeError:
raise DefaultValueTypeError(f"Invalid JSON format for value: {value}")

@staticmethod
def _validate_array(value: Any, element_type: DefaultValueType) -> bool:
"""Unified array type validation"""
return isinstance(value, list) and all(isinstance(x, element_type) for x in value)

@staticmethod
def _convert_number(value: str) -> float:
"""Unified number conversion handler"""
try:
return float(value)
except ValueError:
raise DefaultValueTypeError(f"Cannot convert to number: {value}")

@model_validator(mode="after")
def validate_value_type(self) -> Any:
value_type = self.type
value = self.value
if value_type is None:
def validate_value_type(self) -> "DefaultValue":
if self.type is None:
raise DefaultValueTypeError("type field is required")

# validate string type
if value_type == DefaultValueType.STRING:
if not isinstance(value, str):
raise DefaultValueTypeError(f"Value must be string type for {value}")

# validate number type
elif value_type == DefaultValueType.NUMBER:
if not isinstance(value, NumberType):
raise DefaultValueTypeError(f"Value must be number type for {value}")

# validate object type
elif value_type == DefaultValueType.OBJECT:
if isinstance(value, str):
try:
value = json.loads(value)
except json.JSONDecodeError:
raise DefaultValueTypeError(f"Value must be object type for {value}")
if not isinstance(value, ObjectType):
raise DefaultValueTypeError(f"Value must be object type for {value}")

# validate array[number] type
elif value_type == DefaultValueType.ARRAY_NUMBER:
if isinstance(value, str):
try:
value = json.loads(value)
except json.JSONDecodeError:
raise DefaultValueTypeError(f"Value must be object type for {value}")
if not isinstance(value, list):
raise DefaultValueTypeError(f"Value must be array type for {value}")
if not all(isinstance(x, NumberType) for x in value):
raise DefaultValueTypeError(f"All elements must be numbers for {value}")

# validate array[string] type
elif value_type == DefaultValueType.ARRAY_STRING:
if isinstance(value, str):
try:
value = json.loads(value)
except json.JSONDecodeError:
raise DefaultValueTypeError(f"Value must be object type for {value}")
if not isinstance(value, list):
raise DefaultValueTypeError(f"Value must be array type for {value}")
if not all(isinstance(x, str) for x in value):
raise DefaultValueTypeError(f"All elements must be strings for {value}")

# validate array[object] type
elif value_type == DefaultValueType.ARRAY_OBJECT:
if isinstance(value, str):
try:
value = json.loads(value)
except json.JSONDecodeError:
raise DefaultValueTypeError(f"Value must be object type for {value}")
if not isinstance(value, list):
raise DefaultValueTypeError(f"Value must be array type for {value}")
if not all(isinstance(x, ObjectType) for x in value):
raise DefaultValueTypeError(f"All elements must be objects for {value}")
elif value_type == DefaultValueType.ARRAY_FILES:
# handle files type
pass
# Type validation configuration
type_validators = {
DefaultValueType.STRING: {
"type": str,
"converter": lambda x: x,
},
DefaultValueType.NUMBER: {
"type": NumberType,
"converter": self._convert_number,
},
DefaultValueType.OBJECT: {
"type": dict,
"converter": self._parse_json,
},
DefaultValueType.ARRAY_NUMBER: {
"type": list,
"element_type": NumberType,
"converter": self._parse_json,
},
DefaultValueType.ARRAY_STRING: {
"type": list,
"element_type": str,
"converter": self._parse_json,
},
DefaultValueType.ARRAY_OBJECT: {
"type": list,
"element_type": dict,
"converter": self._parse_json,
},
}

validator = type_validators.get(self.type)
if not validator:
if self.type == DefaultValueType.ARRAY_FILES:
# Handle files type
return self
raise DefaultValueTypeError(f"Unsupported type: {self.type}")

# Handle string input cases
if isinstance(self.value, str) and self.type != DefaultValueType.STRING:
self.value = validator["converter"](self.value)

# Validate base type
if not isinstance(self.value, validator["type"]):
raise DefaultValueTypeError(f"Value must be {validator['type'].__name__} type for {self.value}")

# Validate array element types
if validator["type"] == list and not self._validate_array(self.value, validator["element_type"]):
raise DefaultValueTypeError(f"All elements must be {validator['element_type'].__name__} for {self.value}")

return self

Expand Down
2 changes: 1 addition & 1 deletion api/core/workflow/nodes/base/exc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
class BaseNodeError(ValueError):
class BaseNodeError(Exception):
"""Base class for node errors."""

pass
Expand Down
1 change: 1 addition & 0 deletions api/tests/integration_tests/workflow/nodes/test_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from core.workflow.nodes.code.entities import CodeNodeData
from models.enums import UserFrom
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock

CODE_MAX_STRING_LENGTH = int(getenv("CODE_MAX_STRING_LENGTH", "10000"))

Expand Down
58 changes: 0 additions & 58 deletions api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,64 +480,6 @@ def test_variable_pool_error_type_variable():
assert error_type.value == "HTTPResponseCodeError"


def test_continue_on_error_link_fail_branch():
success_code = """
def main() -> dict:
return {
"result": 1 / 1,
}
"""
graph_config = {
"edges": [
*FAIL_BRANCH_EDGES,
{
"id": "start-source-code-target",
"source": "start",
"target": "code",
"sourceHandle": "source",
},
{
"id": "code-source-error-target",
"source": "code",
"target": "error",
"sourceHandle": "source",
},
],
"nodes": [
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
{
"data": {"title": "success", "type": "answer", "answer": "http execute successful"},
"id": "success",
},
{
"data": {"title": "error", "type": "answer", "answer": "http execute failed"},
"id": "error",
},
ContinueOnErrorTestHelper.get_http_node(authorization_success=True),
{
"id": "code",
"data": {
"outputs": {"result": {"type": "number"}},
"title": "code",
"variables": [],
"code_language": "python3",
"code": "\n".join([line[4:] for line in success_code.split("\n")]),
"type": "code",
},
},
],
}

graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())
assert any(
isinstance(e, GraphRunSucceededEvent)
and e.outputs == {"answer": "http execute successful\nhttp execute failed"}
for e in events
)
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 2


def test_no_node_in_fail_branch_continue_on_error():
"""Test HTTP node with fail-branch error strategy"""
graph_config = {
Expand Down

0 comments on commit 4166f49

Please sign in to comment.