Skip to content

Commit

Permalink
fixed startable flows
Browse files Browse the repository at this point in the history
  • Loading branch information
tmbo committed Aug 13, 2023
1 parent 50ca17c commit eb7a2cf
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 5 deletions.
12 changes: 7 additions & 5 deletions rasa/cdu/command_generator/llm_command_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,7 @@ def predict_commands(
if flows is None or tracker is None:
# cannot do anything if there are no flows or no tracker
return []
flows_without_patterns = FlowsList(
[f for f in flows.underlying_flows if not f.is_handling_pattern()]
)
flow_prompt = self.render_template(message, tracker, flows_without_patterns)
flow_prompt = self.render_template(message, tracker, flows)
structlogger.info(
"llm_command_generator.predict_commands.prompt_rendered", prompt=flow_prompt
)
Expand Down Expand Up @@ -265,6 +262,9 @@ def allowed_values_for_slot(self, slot: Slot) -> Optional[str]:
def render_template(
self, message: Message, tracker: DialogueStateTracker, flows: FlowsList
) -> str:
flows_without_patterns = FlowsList(
[f for f in flows.underlying_flows if not f.is_handling_pattern()]
)
flow_stack = FlowStack.from_tracker(tracker)
top_flow = flow_stack.top_flow(flows) if flow_stack is not None else None
current_step = (
Expand Down Expand Up @@ -297,7 +297,9 @@ def render_template(
current_conversation += f"\nUSER: {latest_user_message}"

inputs = {
"available_flows": self.create_template_inputs(flows, tracker),
"available_flows": self.create_template_inputs(
flows_without_patterns, tracker
),
"current_conversation": current_conversation,
"flow_slots": flow_slots,
"current_flow": top_flow.id if top_flow is not None else None,
Expand Down
13 changes: 13 additions & 0 deletions rasa/cdu/command_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,13 +274,26 @@ def clean_up_commands(

clean_commands: List[Command] = []

startable_flow_ids = [
f.id for f in all_flows.underlying_flows if not f.is_handling_pattern()
]

for command in commands:
if isinstance(command, StartFlowCommand) and command.flow in flows_on_the_stack:
structlogger.debug(
"command_executor.skip_command.already_started_flow", command=command
)
continue

if (
isinstance(command, StartFlowCommand)
and command.flow not in startable_flow_ids
):
structlogger.debug(
"command_executor.skip_command.start_invalid_flow_id", command=command
)
continue

if (
isinstance(command, SetSlotCommand)
and tracker.get_slot(command.name) == command.value
Expand Down

0 comments on commit eb7a2cf

Please sign in to comment.