From d38dc68df4aefc77357deaa0635465a771973854 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mence=20Lesn=C3=A9?= Date: Thu, 5 Dec 2024 17:33:34 +0100 Subject: [PATCH 1/3] fix: Close call properly after hanging up --- app/helpers/call_events.py | 16 +++++++++- app/helpers/call_llm.py | 64 +++++++++++++------------------------- app/helpers/call_utils.py | 15 +++++++-- app/main.py | 1 + app/models/call.py | 25 ++------------- 5 files changed, 52 insertions(+), 69 deletions(-) diff --git a/app/helpers/call_events.py b/app/helpers/call_events.py index c7738ca5..211a484a 100644 --- a/app/helpers/call_events.py +++ b/app/helpers/call_events.py @@ -106,6 +106,7 @@ async def on_call_connected( ) -> None: logger.info("Call connected, asking for language") call.recognition_retry = 0 # Reset recognition retry counter + call.in_progress = True call.messages.append( MessageModel( action=MessageActionEnum.CALL, @@ -171,6 +172,7 @@ async def on_recognize_error( call: CallStateModel, client: CallAutomationClient, contexts: set[CallContextEnum] | None, + post_callback: Callable[[CallStateModel], Awaitable[None]], ) -> None: # Retry IVR recognition if contexts and CallContextEnum.IVR_LANG_SELECT in contexts: @@ -190,6 +192,7 @@ async def on_recognize_error( await _handle_goodbye( call=call, client=client, + post_callback=post_callback, ) return @@ -199,6 +202,7 @@ async def on_recognize_error( await _handle_goodbye( call=call, client=client, + post_callback=post_callback, ) return @@ -217,8 +221,9 @@ async def on_recognize_error( async def _handle_goodbye( call: CallStateModel, client: CallAutomationClient, + post_callback: Callable[[CallStateModel], Awaitable[None]], ) -> None: - await handle_play_text( + res = await handle_play_text( call=call, client=client, context=CallContextEnum.GOODBYE, @@ -226,6 +231,14 @@ async def _handle_goodbye( text=await CONFIG.prompts.tts.goodbye(call), ) + if not res: + logger.info("Failed to play goodbye prompt, ending call now") + await _handle_hangup( + call=call, + client=client, + post_callback=post_callback, + ) + @tracer.start_as_current_span("on_play_completed") async def on_play_completed( @@ -393,6 +406,7 @@ async def _handle_hangup( post_callback: Callable[[CallStateModel], Awaitable[None]], ) -> None: await handle_hangup(client=client, call=call) + call.in_progress = False call.messages.append( MessageModel( action=MessageActionEnum.HANGUP, diff --git a/app/helpers/call_llm.py b/app/helpers/call_llm.py index 5ef34181..005cba66 100644 --- a/app/helpers/call_llm.py +++ b/app/helpers/call_llm.py @@ -130,6 +130,7 @@ async def _timeout_callback() -> None: call=call, client=automation_client, contexts=None, + post_callback=post_callback, ) ) @@ -171,7 +172,10 @@ async def _response_callback() -> None: ) ) - # Add recognition to the call history + # Clear the recognition buffer + recognizer_buffer.clear() + + # Store recognitio task nonlocal last_response last_response = await scheduler.spawn( _out_answer( @@ -183,11 +187,12 @@ async def _response_callback() -> None: ) ) - # Clear the recognition buffer - recognizer_buffer.clear() + # Wait for the response to be processed + await last_response.wait() await _in_audio( bits_per_sample=audio_bits_per_sample, + call=call, channels=audio_channels, clear_audio_callback=_clear_audio_callback, in_stream=audio_stream, @@ -552,6 +557,7 @@ async def _content_callback( # TODO: Refacto and simplify async def _in_audio( # noqa: PLR0913 bits_per_sample: int, + call: CallStateModel, channels: int, clear_audio_callback: Callable[[], Awaitable[None]], in_stream: asyncio.Queue[bytes], @@ -561,59 +567,33 @@ async def _in_audio( # noqa: PLR0913 timeout_callback: Callable[[], Awaitable[None]], ) -> None: clear_tts_task: asyncio.Task | None = None - no_voice_task: asyncio.Task | None = None + flush_task: asyncio.Task | None = None vad = Vad( # Aggressiveness mode (0, 1, 2, or 3) # Sets the VAD operating mode. A more aggressive (higher mode) VAD is more restrictive in reporting speech. Put in other words the probability of being speech when the VAD returns 1 is increased with increasing mode. As a consequence also the missed detection rate goes up. mode=3, ) - async def _timeout_callback() -> None: - """ - Alert the user that the call is about to be cut off. - """ - timeout_sec = await phone_silence_timeout_sec() - - while True: - logger.debug( - "Wait foor %i sec before cutting off the call", - timeout_sec, - ) - - # Wait for the timeout - await asyncio.sleep(timeout_sec) - - logger.info("Phone silence timeout triggered") - - # Execute the callback - await timeout_callback() - async def _flush_callback() -> None: """ Flush the audio buffer if no audio is detected for a while. """ + # Wait for the timeout nonlocal clear_tts_task - timeout_ms = await vad_silence_timeout_ms() - - # Wait for the timeout await asyncio.sleep(timeout_ms / 1000) - - # Cancel the TTS clear task if any if clear_tts_task: clear_tts_task.cancel() clear_tts_task = None - logger.debug("Flushing audio buffer after %i ms", timeout_ms) - - # Commit the buffer await response_callback() - async def _no_voice_callback() -> None: - await asyncio.gather( - _flush_callback(), - _timeout_callback(), - ) + # Wait for the timeout, if any + timeout_sec = await phone_silence_timeout_sec() + while call.in_progress: + await asyncio.sleep(timeout_sec) + logger.info("Silence triggered after %i sec", timeout_sec) + await timeout_callback() async def _clear_tts_callback() -> None: """ @@ -663,17 +643,17 @@ async def _clear_tts_callback() -> None: ): in_empty = True # Start timeout if not already started - if not no_voice_task: - no_voice_task = asyncio.create_task(_no_voice_callback()) + if not flush_task: + flush_task = asyncio.create_task(_flush_callback()) if in_empty: # Continue to the next audio packet continue # Voice detected, cancel the timeout if any - if no_voice_task: - no_voice_task.cancel() - no_voice_task = None + if flush_task: + flush_task.cancel() + flush_task = None # Start the TTS clear task if not clear_tts_task: diff --git a/app/helpers/call_utils.py b/app/helpers/call_utils.py index 3b017279..5f605dc4 100644 --- a/app/helpers/call_utils.py +++ b/app/helpers/call_utils.py @@ -91,11 +91,13 @@ async def _handle_play_text( context: ContextEnum | None, style: MessageStyleEnum, text: str, -) -> None: +) -> bool: """ Play a text to a call participant. If `context` is provided, it will be used to track the operation. + + Returns `True` if the text was played, `False` otherwise. """ logger.info("Playing TTS: %s", text) try: @@ -109,6 +111,7 @@ async def _handle_play_text( text=text, ), ) + return True except ResourceNotFoundError: logger.debug("Call hung up before playing") except HttpResponseError as e: @@ -116,6 +119,7 @@ async def _handle_play_text( logger.debug("Call hung up before playing") else: raise e + return False async def handle_media( @@ -152,11 +156,13 @@ async def handle_play_text( # noqa: PLR0913 context: ContextEnum | None = None, store: bool = True, style: MessageStyleEnum = MessageStyleEnum.NONE, -) -> None: +) -> bool: """ Play a text to a call participant. If `store` is `True`, the text will be stored in the call messages. + + Returns `True` if the text was played, `False` otherwise. """ # Split text in chunks chunks = await _chunk_before_tts( @@ -168,13 +174,16 @@ async def handle_play_text( # noqa: PLR0913 # Play each chunk for chunk in chunks: - await _handle_play_text( + res = await _handle_play_text( call=call, client=client, context=context, style=style, text=chunk, ) + if not res: + return False + return True async def handle_clear_queue( diff --git a/app/main.py b/app/main.py index 8b364732..ba55c272 100644 --- a/app/main.py +++ b/app/main.py @@ -735,6 +735,7 @@ async def _communicationservices_event_worker( call=call, client=automation_client, contexts=operation_contexts, + post_callback=_trigger_post_event, ) case "Microsoft.Communication.PlayCompleted": # Media played diff --git a/app/models/call.py b/app/models/call.py index cd93ce28..b8df7dfa 100644 --- a/app/models/call.py +++ b/app/models/call.py @@ -14,7 +14,6 @@ from app.helpers.monitoring import tracer from app.helpers.pydantic_types.phone_numbers import PhoneNumber from app.models.message import ( - ActionEnum as MessageActionEnum, MessageModel, PersonaEnum as MessagePersonaEnum, StyleEnum as MessageStyleEnum, @@ -34,36 +33,16 @@ class CallGetModel(BaseModel): call_id: UUID = Field(default_factory=uuid4, frozen=True) created_at: datetime = Field(default_factory=lambda: datetime.now(UTC), frozen=True) # Editable fields - initiate: CallInitiateModel = Field(frozen=True) claim: dict[ str, Any ] = {} # Place after "initiate" as it depends on it for validation + in_progress: bool = False + initiate: CallInitiateModel = Field(frozen=True) messages: list[MessageModel] = [] next: NextModel | None = None reminders: list[ReminderModel] = [] synthesis: SynthesisModel | None = None - @computed_field - @property - def in_progress(self) -> bool: - """ - Check if the call is in progress. - - The call is in progress if the most recent message action status (CALL or HANGUP) is CALL. Otherwise, it is not in progress. - """ - # Reverse - inverted_messages = self.messages.copy() - inverted_messages.reverse() - # Search for the first action we want - for message in inverted_messages: - match message.action: - case MessageActionEnum.CALL: - return True - case MessageActionEnum.HANGUP: - return False - # Otherwise, we assume the call is completed - return False - @field_validator("claim") @classmethod def _validate_claim( From 8bb5a0a30a1b2e593d25e2cc7a0a07b6bd271df8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mence=20Lesn=C3=A9?= Date: Thu, 5 Dec 2024 17:33:49 +0100 Subject: [PATCH 2/3] dev: Disable prompt debug log --- app/helpers/config_models/prompts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/helpers/config_models/prompts.py b/app/helpers/config_models/prompts.py index 9fac74e4..bf9d4d4b 100644 --- a/app/helpers/config_models/prompts.py +++ b/app/helpers/config_models/prompts.py @@ -457,7 +457,7 @@ def _format( [line.strip() for line in formatted_prompt.splitlines()] ) - self.logger.debug("Formatted prompt: %s", formatted_prompt) + # self.logger.debug("Formatted prompt: %s", formatted_prompt) return formatted_prompt def _messages( From 9819c96a8c6a68a120f954201793dcf6f0524cb3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mence=20Lesn=C3=A9?= Date: Thu, 5 Dec 2024 17:34:15 +0100 Subject: [PATCH 3/3] perf: All AWS handle actions are wrapper with scheduler --- app/helpers/call_llm.py | 40 ++++++++++++++++++++++++---------------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/app/helpers/call_llm.py b/app/helpers/call_llm.py index 005cba66..69fbe261 100644 --- a/app/helpers/call_llm.py +++ b/app/helpers/call_llm.py @@ -254,6 +254,7 @@ async def _tts_callback(text: str, style: MessageStyleEnum) -> None: call=call, client=client, post_callback=post_callback, + scheduler=scheduler, tts_callback=_tts_callback, use_tools=_iterations_remaining > 0, ) @@ -319,20 +320,24 @@ def _clear_tasks() -> None: ) soft_timeout_triggered = True # Never store the error message in the call history, it has caused hallucinations in the LLM - await handle_play_text( - call=call, - client=client, - store=False, - text=await CONFIG.prompts.tts.timeout_loading(call), + await scheduler.spawn( + handle_play_text( + call=call, + client=client, + store=False, + text=await CONFIG.prompts.tts.timeout_loading(call), + ) ) elif loading_task.done(): # Do not play timeout prompt plus loading, it can be frustrating for the user loading_task = _loading_task() - await handle_media( - call=call, - client=client, - sound_url=CONFIG.prompts.sounds.loading(), - ) # Play loading sound + await scheduler.spawn( + handle_media( + call=call, + client=client, + sound_url=CONFIG.prompts.sounds.loading(), + ) + ) # Wait to not block the event loop for other requests await asyncio.sleep(1) @@ -377,10 +382,11 @@ def _clear_tasks() -> None: # TODO: Refacto, this function is too long @tracer.start_as_current_span("call_execute_llm_chat") -async def _execute_llm_chat( # noqa: PLR0911, PLR0912, PLR0915 +async def _execute_llm_chat( # noqa: PLR0913, PLR0911, PLR0912, PLR0915 call: CallStateModel, client: CallAutomationClient, post_callback: Callable[[CallStateModel], Awaitable[None]], + scheduler: aiojobs.Scheduler, tts_callback: Callable[[str, MessageStyleEnum], Awaitable[None]], use_tools: bool, ) -> tuple[bool, bool, CallStateModel]: @@ -433,6 +439,7 @@ async def _content_callback( tts_callback = _tts_callback( automation_client=client, call=call, + scheduler=scheduler, ) # Build plugins @@ -663,6 +670,7 @@ async def _clear_tts_callback() -> None: def _tts_callback( automation_client: CallAutomationClient, call: CallStateModel, + scheduler: aiojobs.Scheduler, ) -> Callable[[str, MessageStyleEnum], Awaitable[None]]: """ Send back the TTS to the user. @@ -673,16 +681,16 @@ async def wrapper( text: str, style: MessageStyleEnum, ) -> None: - await asyncio.gather( + # First, play the TTS to the user + await scheduler.spawn( handle_play_text( call=call, client=automation_client, style=style, text=text, - ), # First, play the TTS to the user - _db.call_aset( - call - ), # Second, save in DB allowing (1) user to cut off the Assistant and (2) SMS answers to be in order + ) ) + # Second, save in DB allowing (1) user to cut off the Assistant and (2) SMS answers to be in order + await _db.call_aset(call) return wrapper