Skip to content

Commit

Permalink
fix: tool
Browse files Browse the repository at this point in the history
  • Loading branch information
Yeuoly committed Mar 14, 2024
1 parent 13a7248 commit d85b5b9
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
11 changes: 9 additions & 2 deletions api/core/workflow/nodes/tool/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from pydantic import BaseModel, validator

from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.variable_entities import VariableSelector

ToolParameterValue = Union[str, int, float, bool]

Expand All @@ -16,15 +15,23 @@ class ToolEntity(BaseModel):
tool_configurations: dict[str, ToolParameterValue]

class ToolNodeData(BaseNodeData, ToolEntity):
class ToolInput(VariableSelector):
class ToolInput(BaseModel):
variable: str
variable_type: Literal['selector', 'static']
value_selector: Optional[list[str]]
value: Optional[str]

@validator('value')
def check_value(cls, value, values, **kwargs):
if values['variable_type'] == 'static' and value is None:
raise ValueError('value is required for static variable')
return value

@validator('value_selector')
def check_value_selector(cls, value_selector, values, **kwargs):
if values['variable_type'] == 'selector' and value_selector is None:
raise ValueError('value_selector is required for selector variable')
return value_selector

"""
Tool Node Schema
Expand Down
3 changes: 2 additions & 1 deletion api/core/workflow/nodes/tool/tool_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters,
error=f'Failed to invoke tool: {str(e)}'
error=f'Failed to invoke tool: {str(e)}',
)

# convert tool messages
Expand All @@ -56,6 +56,7 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult:
'text': plain_text,
'files': files
},
inputs=parameters
)

def _generate_parameters(self, variable_pool: VariablePool, node_data: ToolNodeData) -> dict:
Expand Down

0 comments on commit d85b5b9

Please sign in to comment.