Skip to content

Commit

Permalink
Fix/refactor invoke result handling in question classifier node (#12015)
Browse files Browse the repository at this point in the history
Signed-off-by: -LAN- <[email protected]>
  • Loading branch information
laipz8200 authored Dec 23, 2024
1 parent af2888d commit c3c8527
Showing 1 changed file with 18 additions and 23 deletions.
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import json
import logging
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, cast

from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.llm_generator.output_parser.errors import OutputParserError
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole
Expand Down Expand Up @@ -96,27 +94,28 @@ def _run(self):
jinja2_variables=[],
)

# handle invoke result
generator = self._invoke_llm(
node_data_model=node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
)

result_text = ""
usage = LLMUsage.empty_usage()
finish_reason = None
for event in generator:
if isinstance(event, ModelInvokeCompletedEvent):
result_text = event.text
usage = event.usage
finish_reason = event.finish_reason
break

category_name = node_data.classes[0].name
category_id = node_data.classes[0].id
try:
# handle invoke result
generator = self._invoke_llm(
node_data_model=node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
)

for event in generator:
if isinstance(event, ModelInvokeCompletedEvent):
result_text = event.text
usage = event.usage
finish_reason = event.finish_reason
break

category_name = node_data.classes[0].name
category_id = node_data.classes[0].id
result_text_json = parse_and_check_json_markdown(result_text, [])
# result_text_json = json.loads(result_text.strip('```JSON\n'))
if "category_name" in result_text_json and "category_id" in result_text_json:
Expand All @@ -127,10 +126,6 @@ def _run(self):
if category_id_result in category_ids:
category_name = classes_map[category_id_result]
category_id = category_id_result

except OutputParserError:
logging.exception(f"Failed to parse result text: {result_text}")
try:
process_data = {
"model_mode": model_config.mode,
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
Expand All @@ -154,7 +149,7 @@ def _run(self):
},
llm_usage=usage,
)
except Exception as e:
except ValueError as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
Expand Down

0 comments on commit c3c8527

Please sign in to comment.