From 2c110d02abe8132929fd7002357dbdd22e19f002 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mence=20Lesn=C3=A9?= Date: Wed, 11 Dec 2024 19:22:17 +0100 Subject: [PATCH 01/17] ux: Lower voice cutting delay --- README.md | 2 +- app/helpers/features.py | 2 +- cicd/bicep/app.bicep | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 38e14a8..44941f6 100644 --- a/README.md +++ b/README.md @@ -498,7 +498,7 @@ Conversation options are represented as features. They can be configured from Ap | `recognition_retry_max` | The maximum number of retries for voice recognition. | `int` | 2 | | `recording_enabled` | Whether call recording is enabled. | `bool` | false | | `slow_llm_for_chat` | Whether to use the slow LLM for chat. | `bool` | false | -| `vad_cutoff_timeout_ms` | The cutoff timeout for voice activity detection in secs. | `int` | 300 | +| `vad_cutoff_timeout_ms` | The cutoff timeout for voice activity detection in secs. | `int` | 150 | | `vad_silence_timeout_ms` | The timeout for phone silence in secs. | `int` | 500 | | `vad_threshold` | The threshold for voice activity detection. | `float` | 0.5 | diff --git a/app/helpers/features.py b/app/helpers/features.py index 2df8fa1..a44981f 100644 --- a/app/helpers/features.py +++ b/app/helpers/features.py @@ -67,7 +67,7 @@ async def vad_silence_timeout_ms() -> int: async def vad_cutoff_timeout_ms() -> int: return await _default( - default=300, + default=150, key="vad_cutoff_timeout_ms", type_res=int, ) diff --git a/cicd/bicep/app.bicep b/cicd/bicep/app.bicep index f7ffafe..6389ac3 100644 --- a/cicd/bicep/app.bicep +++ b/cicd/bicep/app.bicep @@ -907,7 +907,7 @@ resource configValues 'Microsoft.AppConfiguration/configurationStores/keyValues@ recognition_retry_max: 2 recording_enabled: false slow_llm_for_chat: false - vad_cutoff_timeout_ms: 300 + vad_cutoff_timeout_ms: 150 vad_silence_timeout_ms: 500 vad_threshold: '0.5' }): { From 7731fff6eead11a8b87d005b52513d6245f92ca6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mence=20Lesn=C3=A9?= Date: Wed, 11 Dec 2024 19:22:48 +0100 Subject: [PATCH 02/17] refacto: Features parsing --- app/helpers/features.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/app/helpers/features.py b/app/helpers/features.py index a44981f..7753223 100644 --- a/app/helpers/features.py +++ b/app/helpers/features.py @@ -1,4 +1,3 @@ -from contextlib import suppress from typing import TypeVar, cast from azure.appconfiguration.aio import AzureAppConfigurationClient @@ -172,13 +171,16 @@ def _parse(value: str, type_res: type[T]) -> T | None: Supported types: bool, int, float, str. """ - with suppress(ValueError): - if type_res is bool: - return cast(T, value.lower() == "true") - if type_res is int: - return cast(T, int(value)) - if type_res is float: - return cast(T, float(value)) - if type_res is str: - return cast(T, str(value)) - raise ValueError(f"Unsupported type: {type_res}") + # Try parse + if type_res is bool: + return cast(T, value.lower() == "true") + if type_res is int: + return cast(T, int(value)) + if type_res is float: + return cast(T, float(value)) + if type_res is str: + return cast(T, str(value)) + + # Unsupported type + logger.error("Unsupported feature type: %s", type_res) + return From 8e9ffe7ea68c14a36d1d3c65093388bbe09fb37c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mence=20Lesn=C3=A9?= Date: Wed, 11 Dec 2024 19:23:34 +0100 Subject: [PATCH 03/17] quality: Code quality --- app/helpers/features.py | 30 ++++++++++++++++++++++-------- app/persistence/icache.py | 7 ++++++- app/persistence/memory.py | 7 ++++++- app/persistence/redis.py | 7 ++++++- 4 files changed, 40 insertions(+), 11 deletions(-) diff --git a/app/helpers/features.py b/app/helpers/features.py index 7753223..86af770 100644 --- a/app/helpers/features.py +++ b/app/helpers/features.py @@ -9,10 +9,9 @@ from app.helpers.http import azure_transport from app.helpers.identity import credential from app.helpers.logging import logger -from app.persistence.icache import ICache from app.persistence.memory import MemoryCache -_cache: ICache = MemoryCache(MemoryModel(max_size=100)) +_cache = MemoryCache(MemoryModel()) T = TypeVar("T", bool, int, float, str) @@ -104,7 +103,11 @@ async def _default( """ Get a setting from the App Configuration service with a default value. """ - return (await _get(key=key, type_res=type_res)) or default + res = await _get(key=key, type_res=type_res) + if res: + return res + logger.warning("Setting %s not found, using default: %s", key, default) + return default async def _get( @@ -118,7 +121,11 @@ async def _get( cache_key = _cache_key(key) cached = await _cache.get(cache_key) if cached: - return _parse(value=cached.decode(), type_res=type_res) + return _parse( + type_res=type_res, + value=cached.decode(), + ) + # Try live try: async with await _use_client() as client: @@ -126,17 +133,24 @@ async def _get( # Return default if not found if not setting: return + res = setting.value except ResourceNotFoundError: - logger.warning("Setting %s not found", key) return + + logger.debug("Setting %s refreshed: %s", key, res) + # Update cache await _cache.set( key=cache_key, ttl_sec=CONFIG.app_configuration.ttl_sec, - value=setting.value, + value=res, + ) + + # Return the type + return _parse( + type_res=type_res, + value=res, ) - # Return - return _parse(value=setting.value, type_res=type_res) @async_lru_cache() diff --git a/app/persistence/icache.py b/app/persistence/icache.py index d347864..6f681cf 100644 --- a/app/persistence/icache.py +++ b/app/persistence/icache.py @@ -17,7 +17,12 @@ async def get(self, key: str) -> bytes | None: @abstractmethod @tracer.start_as_current_span("cache_set") - async def set(self, key: str, value: str | bytes | None, ttl_sec: int) -> bool: + async def set( + self, + key: str, + ttl_sec: int, + value: str | bytes | None, + ) -> bool: pass @abstractmethod diff --git a/app/persistence/memory.py b/app/persistence/memory.py index 95ccc43..aedc8ec 100644 --- a/app/persistence/memory.py +++ b/app/persistence/memory.py @@ -53,7 +53,12 @@ async def get(self, key: str) -> bytes | None: self._cache.move_to_end(sha_key, last=False) return res - async def set(self, key: str, value: str | bytes | None, ttl_sec: int) -> bool: + async def set( + self, + key: str, + ttl_sec: int, + value: str | bytes | None, + ) -> bool: """ Set a value in the cache. """ diff --git a/app/persistence/redis.py b/app/persistence/redis.py index 5e2ad29..b6b7551 100644 --- a/app/persistence/redis.py +++ b/app/persistence/redis.py @@ -92,7 +92,12 @@ async def get(self, key: str) -> bytes | None: logger.exception("Error getting value") return res - async def set(self, key: str, value: str | bytes | None, ttl_sec: int) -> bool: + async def set( + self, + key: str, + ttl_sec: int, + value: str | bytes | None, + ) -> bool: """ Set a value in the cache. From 096e6d3809dd819cbe4b34e18ea99d7763e7891a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mence=20Lesn=C3=A9?= Date: Wed, 11 Dec 2024 20:56:15 +0100 Subject: [PATCH 04/17] perf: Higher async function cache hit --- app/helpers/cache.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/app/helpers/cache.py b/app/helpers/cache.py index cfe0f68..e9443b1 100644 --- a/app/helpers/cache.py +++ b/app/helpers/cache.py @@ -1,4 +1,3 @@ -import asyncio import functools from collections import OrderedDict from collections.abc import AsyncGenerator, Awaitable @@ -33,11 +32,7 @@ def decorator(func): @functools.wraps(func) async def wrapper(*args, **kwargs) -> Awaitable: # Create a cache key from event loop, args and kwargs, using frozenset for kwargs to ensure hashability - key = ( - asyncio.get_event_loop(), - args, - frozenset(kwargs.items()), - ) + key = (args, frozenset(kwargs.items())) if key in cache: # Move the recently accessed key to the end (most recently used) From ae4e841606cc7b868a703783ca38499e9768c507 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mence=20Lesn=C3=A9?= Date: Wed, 11 Dec 2024 20:58:01 +0100 Subject: [PATCH 05/17] perf: Defer feature loading --- app/helpers/call_events.py | 8 +- app/helpers/call_llm.py | 20 +++-- app/helpers/call_utils.py | 5 +- app/helpers/features.py | 59 +++++++++---- app/helpers/llm_tools.py | 5 +- app/helpers/llm_worker.py | 8 +- app/main.py | 161 +++++++++++++++++++++-------------- app/persistence/cosmos_db.py | 28 ++++-- app/persistence/istore.py | 18 +++- tests/conftest.py | 10 ++- tests/store.py | 118 ++++++++++++++++--------- 11 files changed, 289 insertions(+), 151 deletions(-) diff --git a/app/helpers/call_events.py b/app/helpers/call_events.py index 4964d77..6c2ae68 100644 --- a/app/helpers/call_events.py +++ b/app/helpers/call_events.py @@ -125,6 +125,7 @@ async def on_call_connected( _handle_recording( call=call, client=client, + scheduler=scheduler, server_call_id=server_call_id, ), # Second, start recording the call ) @@ -234,7 +235,7 @@ async def on_automation_recognize_error( logger.info( "Timeout, retrying language selection (%s/%s)", call.recognition_retry, - await recognition_retry_max(), + await recognition_retry_max(scheduler), ) await _handle_ivr_language( call=call, @@ -320,7 +321,7 @@ async def _pre_recognize_error( Returns True if the call should continue, False if it should end. """ # Voice retries are exhausted, end call - if call.recognition_retry >= await recognition_retry_max(): + if call.recognition_retry >= await recognition_retry_max(scheduler): logger.info("Timeout, ending call") return False @@ -792,6 +793,7 @@ async def _handle_ivr_language( async def _handle_recording( call: CallStateModel, client: CallAutomationClient, + scheduler: Scheduler, server_call_id: str, ) -> None: """ @@ -799,7 +801,7 @@ async def _handle_recording( Feature activation is checked before starting the recording. """ - if not await recording_enabled(): + if not await recording_enabled(scheduler): return assert CONFIG.communication_services.recording_container_url diff --git a/app/helpers/call_llm.py b/app/helpers/call_llm.py index 00ae850..7308768 100644 --- a/app/helpers/call_llm.py +++ b/app/helpers/call_llm.py @@ -65,6 +65,7 @@ async def load_llm_chat( # noqa: PLR0913, PLR0915 stt_complete_gate = asyncio.Event() # Gate to wait for the recognition aec = EchoCancellationStream( sample_rate=audio_sample_rate, + scheduler=scheduler, ) audio_reference: asyncio.Queue[bytes] = asyncio.Queue() @@ -271,10 +272,11 @@ async def _response_callback(_retry: bool = False) -> None: # Detect VAD _process_audio_for_vad( call=call, - stop_callback=_stop_callback, echo_cancellation=aec, out_stream=stt_stream, response_callback=_response_callback, + scheduler=scheduler, + stop_callback=_stop_callback, timeout_callback=_timeout_callback, ), ) @@ -354,10 +356,10 @@ def _loading_task() -> asyncio.Task: # Timeouts soft_timeout_triggered = False soft_timeout_task = asyncio.create_task( - asyncio.sleep(await answer_soft_timeout_sec()) + asyncio.sleep(await answer_soft_timeout_sec(scheduler)) ) hard_timeout_task = asyncio.create_task( - asyncio.sleep(await answer_hard_timeout_sec()) + asyncio.sleep(await answer_hard_timeout_sec(scheduler)) ) def _clear_tasks() -> None: @@ -387,7 +389,7 @@ def _clear_tasks() -> None: if hard_timeout_task.done(): logger.warning( "Hard timeout of %ss reached", - await answer_hard_timeout_sec(), + await answer_hard_timeout_sec(scheduler), ) # Clean up _clear_tasks() @@ -399,7 +401,7 @@ def _clear_tasks() -> None: if soft_timeout_task.done() and not soft_timeout_triggered: logger.warning( "Soft timeout of %ss reached", - await answer_soft_timeout_sec(), + await answer_soft_timeout_sec(scheduler), ) soft_timeout_triggered = True # Never store the error message in the call history, it has caused hallucinations in the LLM @@ -548,6 +550,7 @@ async def _content_callback(buffer: str) -> None: async for delta in completion_stream( max_tokens=160, # Lowest possible value for 90% of the cases, if not sufficient, retry will be triggered, 100 tokens ~= 75 words, 20 words ~= 1 sentence, 6 sentences ~= 160 tokens messages=call.messages, + scheduler=scheduler, system=system, tools=tools, ): @@ -659,6 +662,7 @@ async def _process_audio_for_vad( # noqa: PLR0913 echo_cancellation: EchoCancellationStream, out_stream: PushAudioInputStream, response_callback: Callable[[], Awaitable[None]], + scheduler: Scheduler, stop_callback: Callable[[], Awaitable[None]], timeout_callback: Callable[[], Awaitable[None]], ) -> None: @@ -682,7 +686,7 @@ async def _wait_for_silence() -> None: """ # Wait before flushing nonlocal stop_task - timeout_ms = await vad_silence_timeout_ms() + timeout_ms = await vad_silence_timeout_ms(scheduler) await asyncio.sleep(timeout_ms / 1000) # Cancel the clear TTS task @@ -695,7 +699,7 @@ async def _wait_for_silence() -> None: await response_callback() # Wait for silence and trigger timeout - timeout_sec = await phone_silence_timeout_sec() + timeout_sec = await phone_silence_timeout_sec(scheduler) while True: # Stop this time if the call played a message timeout_start = datetime.now(UTC) @@ -724,7 +728,7 @@ async def _wait_for_stop() -> None: """ Stop the TTS if user speaks for too long. """ - timeout_ms = await vad_cutoff_timeout_ms() + timeout_ms = await vad_cutoff_timeout_ms(scheduler) # Wait before clearing the TTS queue await asyncio.sleep(timeout_ms / 1000) diff --git a/app/helpers/call_utils.py b/app/helpers/call_utils.py index 1076706..3fc162b 100644 --- a/app/helpers/call_utils.py +++ b/app/helpers/call_utils.py @@ -644,15 +644,18 @@ class EchoCancellationStream: _packet_size: int _reference_queue: asyncio.Queue[bytes] = asyncio.Queue() _sample_rate: int + _scheduler: Scheduler def __init__( self, sample_rate: int, + scheduler: Scheduler, max_delay_ms: int = 200, packet_duration_ms: int = 20, ): self._packet_duration_ms = packet_duration_ms self._sample_rate = sample_rate + self._scheduler = scheduler max_delay_samples = int(max_delay_ms / 1000 * self._sample_rate) self._bot_voice_buffer = np.zeros(max_delay_samples, dtype=np.float32) @@ -703,7 +706,7 @@ async def _rms_speech_detection(self, voice: np.ndarray) -> bool: # Calculate Root Mean Square (RMS) rms = np.sqrt(np.mean(voice**2)) # Get VAD threshold, divide by 10 to more usability from user side, as RMS is in range 0-1 and a detection of 0.1 is a good maximum threshold - threshold = await vad_threshold() / 10 + threshold = await vad_threshold(self._scheduler) / 10 return rms >= threshold async def _process_one(self, input_pcm: bytes) -> None: diff --git a/app/helpers/features.py b/app/helpers/features.py index 86af770..d9b4f8e 100644 --- a/app/helpers/features.py +++ b/app/helpers/features.py @@ -1,5 +1,6 @@ from typing import TypeVar, cast +from aiojobs import Scheduler from azure.appconfiguration.aio import AzureAppConfigurationClient from azure.core.exceptions import ResourceNotFoundError @@ -15,82 +16,92 @@ T = TypeVar("T", bool, int, float, str) -async def answer_hard_timeout_sec() -> int: +async def answer_hard_timeout_sec(scheduler: Scheduler) -> int: return await _default( default=180, key="answer_hard_timeout_sec", + scheduler=scheduler, type_res=int, ) -async def answer_soft_timeout_sec() -> int: +async def answer_soft_timeout_sec(scheduler: Scheduler) -> int: return await _default( default=120, key="answer_soft_timeout_sec", + scheduler=scheduler, type_res=int, ) -async def callback_timeout_hour() -> int: +async def callback_timeout_hour(scheduler: Scheduler) -> int: return await _default( default=24, key="callback_timeout_hour", + scheduler=scheduler, type_res=int, ) -async def phone_silence_timeout_sec() -> int: +async def phone_silence_timeout_sec(scheduler: Scheduler) -> int: return await _default( default=20, key="phone_silence_timeout_sec", + scheduler=scheduler, type_res=int, ) -async def vad_threshold() -> float: +async def vad_threshold(scheduler: Scheduler) -> float: return await _default( default=0.5, key="vad_threshold", + scheduler=scheduler, type_res=float, ) -async def vad_silence_timeout_ms() -> int: +async def vad_silence_timeout_ms(scheduler: Scheduler) -> int: return await _default( default=500, key="vad_silence_timeout_ms", + scheduler=scheduler, type_res=int, ) -async def vad_cutoff_timeout_ms() -> int: +async def vad_cutoff_timeout_ms(scheduler: Scheduler) -> int: return await _default( default=150, key="vad_cutoff_timeout_ms", + scheduler=scheduler, type_res=int, ) -async def recording_enabled() -> bool: +async def recording_enabled(scheduler: Scheduler) -> bool: return await _default( default=False, key="recording_enabled", + scheduler=scheduler, type_res=bool, ) -async def slow_llm_for_chat() -> bool: +async def slow_llm_for_chat(scheduler: Scheduler) -> bool: return await _default( default=True, key="slow_llm_for_chat", + scheduler=scheduler, type_res=bool, ) -async def recognition_retry_max() -> int: +async def recognition_retry_max(scheduler: Scheduler) -> int: return await _default( default=3, key="recognition_retry_max", + scheduler=scheduler, type_res=int, ) @@ -98,20 +109,29 @@ async def recognition_retry_max() -> int: async def _default( default: T, key: str, + scheduler: Scheduler, type_res: type[T], ) -> T: """ Get a setting from the App Configuration service with a default value. """ - res = await _get(key=key, type_res=type_res) + # Get the setting + res = await _get( + key=key, + scheduler=scheduler, + type_res=type_res, + ) if res: return res - logger.warning("Setting %s not found, using default: %s", key, default) + + # Return default + logger.info("Feature %s not found, using default: %s", key, default) return default async def _get( key: str, + scheduler: Scheduler, type_res: type[T], ) -> T | None: """ @@ -126,6 +146,15 @@ async def _get( value=cached.decode(), ) + # Defer the update + await scheduler.spawn(_refresh(cache_key, key)) + return + + +async def _refresh( + cache_key: str, + key: str, +) -> T | None: # Try live try: async with await _use_client() as client: @@ -146,12 +175,6 @@ async def _get( value=res, ) - # Return the type - return _parse( - type_res=type_res, - value=res, - ) - @async_lru_cache() async def _use_client() -> AzureAppConfigurationClient: diff --git a/app/helpers/llm_tools.py b/app/helpers/llm_tools.py index 037dc80..8a45275 100644 --- a/app/helpers/llm_tools.py +++ b/app/helpers/llm_tools.py @@ -98,7 +98,7 @@ async def new_claim( # Store the last message and use it at first message of the new claim self.call = await _db.call_create( - CallStateModel( + call=CallStateModel( initiate=self.call.initiate.model_copy(), voice_id=self.call.voice_id, messages=[ @@ -111,7 +111,8 @@ async def new_claim( # Reinsert the last message, using more will add the user message asking to create the new claim and the assistant can loop on it sometimes self.call.messages[-1], ], - ) + ), + scheduler=self.scheduler, ) return "Claim, reminders and messages reset" diff --git a/app/helpers/llm_worker.py b/app/helpers/llm_worker.py index 3bfe22e..4433981 100644 --- a/app/helpers/llm_worker.py +++ b/app/helpers/llm_worker.py @@ -5,6 +5,7 @@ from typing import TypeVar import tiktoken +from aiojobs import Scheduler from json_repair import repair_json from openai import ( APIConnectionError, @@ -88,6 +89,7 @@ class MaximumTokensReachedError(Exception): async def completion_stream( max_tokens: int, messages: list[MessageModel], + scheduler: Scheduler, system: list[ChatCompletionSystemMessageParam], tools: list[ChatCompletionToolParam] | None = None, ) -> AsyncGenerator[ChoiceDelta, None]: @@ -110,7 +112,9 @@ async def completion_stream( async for attempt in retryed: with attempt: async for chunck in _completion_stream_worker( - is_fast=not await slow_llm_for_chat(), # Let configuration decide + is_fast=not await slow_llm_for_chat( + scheduler + ), # Let configuration decide max_tokens=max_tokens, messages=messages, system=system, @@ -130,7 +134,7 @@ async def completion_stream( async for attempt in retryed: with attempt: async for chunck in _completion_stream_worker( - is_fast=await slow_llm_for_chat(), # Let configuration decide + is_fast=await slow_llm_for_chat(scheduler), # Let configuration decide max_tokens=max_tokens, messages=messages, system=system, diff --git a/app/main.py b/app/main.py index 18d7e94..df4fb3d 100644 --- a/app/main.py +++ b/app/main.py @@ -285,7 +285,11 @@ async def report_single_get(call_id: UUID) -> HTMLResponse: Returns a single call with a web interface. """ - call = await _db.call_get(call_id) + async with get_scheduler() as scheduler: + call = await _db.call_get( + call_id=call_id, + scheduler=scheduler, + ) if not call: return HTMLResponse( content=f"Call {call_id} not found", @@ -345,21 +349,29 @@ async def call_get(call_id_or_phone_number: str) -> CallGetModel: Returns a single call object `CallGetModel`, in JSON format. """ - # First, try to get by call ID - with suppress(ValueError): - call_id = UUID(call_id_or_phone_number) - call = await _db.call_get(call_id) - if call: - return TypeAdapter(CallGetModel).dump_python(call) - - # Second, try to get by phone number - phone_number = PhoneNumber(call_id_or_phone_number) - call = await _db.call_search_one(phone_number=phone_number) - if not call: - raise HTTPException( - detail=f"Call {call_id_or_phone_number} not found", - status_code=HTTPStatus.NOT_FOUND, + async with get_scheduler() as scheduler: + # First, try to get by call ID + with suppress(ValueError): + call_id = UUID(call_id_or_phone_number) + call = await _db.call_get( + call_id=call_id, + scheduler=scheduler, + ) + if call: + return TypeAdapter(CallGetModel).dump_python(call) + + # Second, try to get by phone number + phone_number = PhoneNumber(call_id_or_phone_number) + call = await _db.call_search_one( + phone_number=phone_number, + scheduler=scheduler, ) + if not call: + raise HTTPException( + detail=f"Call {call_id_or_phone_number} not found", + status_code=HTTPStatus.NOT_FOUND, + ) + return TypeAdapter(CallGetModel).dump_python(call) @@ -483,11 +495,15 @@ async def sms_event( # Enrich span span_attribute(SpanAttributes.CALL_PHONE_NUMBER, phone_number) - # Get call - call = await _db.call_search_one(phone_number) - if not call: - logger.warning("Call for phone number %s not found", phone_number) - return + async with get_scheduler() as scheduler: + # Get call + call = await _db.call_search_one( + phone_number=phone_number, + scheduler=scheduler, + ) + if not call: + logger.warning("Call for phone number %s not found", phone_number) + return # Enrich span span_attribute(SpanAttributes.CALL_ID, str(call.call_id)) @@ -540,13 +556,17 @@ async def _communicationservices_validate_call_id( # Enrich span span_attribute(SpanAttributes.CALL_ID, str(call_id)) - # Validate call - call = await _db.call_get(call_id) - if not call: - raise HTTPException( - detail=f"Call {call_id} not found", - status_code=HTTPStatus.NOT_FOUND, + async with get_scheduler() as scheduler: + # Validate call + call = await _db.call_get( + call_id=call_id, + scheduler=scheduler, ) + if not call: + raise HTTPException( + detail=f"Call {call_id} not found", + status_code=HTTPStatus.NOT_FOUND, + ) # Validate secret if call.callback_secret != secret: @@ -865,20 +885,22 @@ async def post_event( Queue message is the UUID of a call. The event will load asynchroniously the `on_end_call` workflow. """ - # Validate call - call = await _db.call_get(UUID(post.content)) - if not call: - logger.warning("Call %s not found", post.content) - return - - # Enrich span - span_attribute(SpanAttributes.CALL_ID, str(call.call_id)) - span_attribute(SpanAttributes.CALL_PHONE_NUMBER, call.initiate.phone_number) + async with get_scheduler() as scheduler: + # Validate call + call = await _db.call_get( + call_id=UUID(post.content), + scheduler=scheduler, + ) + if not call: + logger.warning("Call %s not found", post.content) + return - logger.debug("Post event received") + # Enrich span + span_attribute(SpanAttributes.CALL_ID, str(call.call_id)) + span_attribute(SpanAttributes.CALL_PHONE_NUMBER, call.initiate.phone_number) - async with get_scheduler() as scheduler: # Execute business logic + logger.debug("Post event received") await on_end_call( call=call, scheduler=scheduler, @@ -909,21 +931,26 @@ async def _communicationservices_urls( Returnes a tuple of the callback URL, the WebSocket URL, and the call object. """ - # Get call - call = await _db.call_search_one(phone_number) - - # Create new call if initiate is different - if not call or (initiate and call.initiate != initiate): - call = await _db.call_create( - CallStateModel( - initiate=initiate - or CallInitiateModel( - **CONFIG.conversation.initiate.model_dump(), - phone_number=phone_number, - ) - ) + async with get_scheduler() as scheduler: + # Get call + call = await _db.call_search_one( + phone_number=phone_number, + scheduler=scheduler, ) + # Create new call if initiate is different + if not call or (initiate and call.initiate != initiate): + call = await _db.call_create( + call=CallStateModel( + initiate=initiate + or CallInitiateModel( + **CONFIG.conversation.initiate.model_dump(), + phone_number=phone_number, + ) + ), + scheduler=scheduler, + ) + # Format URLs wss_url = _COMMUNICATIONSERVICES_WSS_TPL.format( callback_secret=call.callback_secret, @@ -957,16 +984,22 @@ async def twilio_sms_post( # Enrich span span_attribute(SpanAttributes.CALL_PHONE_NUMBER, From) - # Get call - call = await _db.call_search_one(From) + async with get_scheduler() as scheduler: + # Get call + call = await _db.call_search_one( + phone_number=From, + scheduler=scheduler, + ) + + # Call not found + if not call: + logger.warning("Call for phone number %s not found", From) - if not call: - logger.warning("Call for phone number %s not found", From) - else: - # Enrich span - span_attribute(SpanAttributes.CALL_ID, str(call.call_id)) + # Call found + else: + # Enrich span + span_attribute(SpanAttributes.CALL_ID, str(call.call_id)) - async with get_scheduler() as scheduler: # Execute business logic event_status = await on_sms_received( call=call, @@ -974,12 +1007,12 @@ async def twilio_sms_post( scheduler=scheduler, ) - # Return error for unsuccessful event - if not event_status: - raise HTTPException( - detail="SMS event failed", - status_code=HTTPStatus.INTERNAL_SERVER_ERROR, - ) + # Return error for unsuccessful event + if not event_status: + raise HTTPException( + detail="SMS event failed", + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + ) # Default response return Response( diff --git a/app/persistence/cosmos_db.py b/app/persistence/cosmos_db.py index cd4f548..e16c05c 100644 --- a/app/persistence/cosmos_db.py +++ b/app/persistence/cosmos_db.py @@ -81,7 +81,11 @@ async def _item_exists(self, test_id: str, partition_key: str) -> bool: exist = True return exist - async def call_get(self, call_id: UUID) -> CallStateModel | None: + async def call_get( + self, + call_id: UUID, + scheduler: Scheduler, + ) -> CallStateModel | None: logger.debug("Loading call %s", call_id) # Try cache @@ -114,7 +118,7 @@ async def call_get(self, call_id: UUID) -> CallStateModel | None: if call: await self._cache.set( key=cache_key, - ttl_sec=await callback_timeout_hour(), + ttl_sec=await callback_timeout_hour(scheduler), value=call.model_dump_json(), ) @@ -185,7 +189,7 @@ async def _exec() -> None: cache_key_id = self._cache_key_call_id(call.call_id) await self._cache.set( key=cache_key_id, - ttl_sec=await callback_timeout_hour(), + ttl_sec=await callback_timeout_hour(scheduler), value=call.model_dump_json(), ) @@ -193,7 +197,11 @@ async def _exec() -> None: await scheduler.spawn(_exec()) # TODO: Catch errors - async def call_create(self, call: CallStateModel) -> CallStateModel: + async def call_create( + self, + call: CallStateModel, + scheduler: Scheduler, + ) -> CallStateModel: logger.debug("Creating new call %s", call.call_id) # Serialize @@ -213,7 +221,7 @@ async def call_create(self, call: CallStateModel) -> CallStateModel: cache_key = self._cache_key_call_id(call.call_id) await self._cache.set( key=cache_key, - ttl_sec=await callback_timeout_hour(), + ttl_sec=await callback_timeout_hour(scheduler), value=call.model_dump_json(), ) @@ -225,7 +233,11 @@ async def call_create(self, call: CallStateModel) -> CallStateModel: return call - async def call_search_one(self, phone_number: str) -> CallStateModel | None: + async def call_search_one( + self, + phone_number: str, + scheduler: Scheduler, + ) -> CallStateModel | None: logger.debug("Loading last call for %s", phone_number) # Try cache @@ -244,7 +256,7 @@ async def call_search_one(self, phone_number: str) -> CallStateModel | None: async with self._use_client() as db: items = db.query_items( max_item_count=1, - query=f"SELECT * FROM c WHERE (STRINGEQUALS(c.initiate.phone_number, @phone_number, true) OR STRINGEQUALS(c.claim.policyholder_phone, @phone_number, true)) AND c.created_at >= DATETIMEADD('hh', -{await callback_timeout_hour()}, GETCURRENTDATETIME()) ORDER BY c.created_at DESC", + query=f"SELECT * FROM c WHERE (STRINGEQUALS(c.initiate.phone_number, @phone_number, true) OR STRINGEQUALS(c.claim.policyholder_phone, @phone_number, true)) AND c.created_at >= DATETIMEADD('hh', -{await callback_timeout_hour(scheduler)}, GETCURRENTDATETIME()) ORDER BY c.created_at DESC", parameters=[ { "name": "@phone_number", @@ -264,7 +276,7 @@ async def call_search_one(self, phone_number: str) -> CallStateModel | None: if call: await self._cache.set( key=cache_key, - ttl_sec=await callback_timeout_hour(), + ttl_sec=await callback_timeout_hour(scheduler), value=call.model_dump_json(), ) diff --git a/app/persistence/istore.py b/app/persistence/istore.py index e144309..960f4ae 100644 --- a/app/persistence/istore.py +++ b/app/persistence/istore.py @@ -23,7 +23,11 @@ async def readiness(self) -> ReadinessEnum: @abstractmethod @tracer.start_as_current_span("store_call_get") - async def call_get(self, call_id: UUID) -> CallStateModel | None: + async def call_get( + self, + call_id: UUID, + scheduler: Scheduler, + ) -> CallStateModel | None: pass @abstractmethod @@ -37,12 +41,20 @@ def call_transac( @abstractmethod @tracer.start_as_current_span("store_call_create") - async def call_create(self, call: CallStateModel) -> CallStateModel: + async def call_create( + self, + call: CallStateModel, + scheduler: Scheduler, + ) -> CallStateModel: pass @abstractmethod @tracer.start_as_current_span("store_call_search_one") - async def call_search_one(self, phone_number: str) -> CallStateModel | None: + async def call_search_one( + self, + phone_number: str, + scheduler: Scheduler, + ) -> CallStateModel | None: pass @abstractmethod diff --git a/tests/conftest.py b/tests/conftest.py index c3820af..839f4b5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -211,7 +211,10 @@ def generate(self, prompt: str) -> tuple[str, float]: # Try live res = super().generate(prompt) # Update cache - self._cache.set(cache_key, res) + self._cache.set( + key=cache_key, + value=res, + ) # Return return res @@ -225,7 +228,10 @@ async def a_generate(self, prompt: str) -> tuple[str, float]: # Try live res = await super().a_generate(prompt) # Update cache - self._cache.set(cache_key, res) + self._cache.set( + key=cache_key, + value=res, + ) # Return return res diff --git a/tests/store.py b/tests/store.py index db99191..6146cef 100644 --- a/tests/store.py +++ b/tests/store.py @@ -22,40 +22,67 @@ async def test_acid(call: CallStateModel) -> None: """ db = CONFIG.database.instance() - # Check not exists - assume(not await db.call_get(call.call_id)) - assume(await db.call_search_one(call.initiate.phone_number) != call) - assume( - call - not in ( - ( - await db.call_search_all( - phone_number=call.initiate.phone_number, count=1 - ) - )[0] - or [] + async with Scheduler() as scheduler: + # Check not exists + assume( + not await db.call_get( + call_id=call.call_id, + scheduler=scheduler, + ) + ) + assume( + await db.call_search_one( + phone_number=call.initiate.phone_number, + scheduler=scheduler, + ) + != call + ) + assume( + call + not in ( + ( + await db.call_search_all( + phone_number=call.initiate.phone_number, count=1 + ) + )[0] + or [] + ) + ) + + # Insert test call + await db.call_create( + call=call, + scheduler=scheduler, + ) + + # Check point read + assume( + await db.call_get( + call_id=call.call_id, + scheduler=scheduler, + ) + == call ) - ) - - # Insert test call - await db.call_create(call) - - # Check point read - assume(await db.call_get(call.call_id) == call) - # Check search one - assume(await db.call_search_one(call.initiate.phone_number) == call) - # Check search all - assume( - call - in ( - ( - await db.call_search_all( - phone_number=call.initiate.phone_number, count=1 - ) - )[0] - or [] + # Check search one + assume( + await db.call_search_one( + phone_number=call.initiate.phone_number, + scheduler=scheduler, + ) + == call + ) + # Check search all + assume( + call + in ( + ( + await db.call_search_all( + phone_number=call.initiate.phone_number, count=1 + ) + )[0] + or [] + ) ) - ) @pytest.mark.asyncio(loop_scope="session") @@ -69,13 +96,21 @@ async def test_transaction( """ db = CONFIG.database.instance() - # Check not exists - assume(not await db.call_get(call.call_id)) + async with Scheduler() as scheduler: + # Check not exists + assume( + not await db.call_get( + call_id=call.call_id, + scheduler=scheduler, + ) + ) - # Insert call - await db.call_create(call) + # Insert call + await db.call_create( + call=call, + scheduler=scheduler, + ) - async with Scheduler() as scheduler: # Check first change async with db.call_transac( call=call, @@ -101,6 +136,9 @@ async def test_transaction( # Check first string change assume(call.voice_id == random_text) - # Check point read - new_call = await db.call_get(call.call_id) - assume(new_call and new_call.voice_id == random_text and new_call.in_progress) + # Check point read + new_call = await db.call_get( + call_id=call.call_id, + scheduler=scheduler, + ) + assume(new_call and new_call.voice_id == random_text and new_call.in_progress) From 84f6f2ca94eaf0f0cfbd8e97661858a6fb6ce3b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mence=20Lesn=C3=A9?= Date: Wed, 11 Dec 2024 20:58:21 +0100 Subject: [PATCH 06/17] chore: Delete dead logs --- app/helpers/call_utils.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/app/helpers/call_utils.py b/app/helpers/call_utils.py index 3fc162b..c6cd014 100644 --- a/app/helpers/call_utils.py +++ b/app/helpers/call_utils.py @@ -564,10 +564,6 @@ async def use_tts_client( audio_config=AudioOutputConfig(stream=PushAudioOutputStream(TtsCallback(out))), ) - # Connect events - client.synthesis_started.connect(lambda _: logger.debug("TTS started")) - client.synthesis_completed.connect(lambda _: logger.debug("TTS completed")) - # Return yield client From 64b4d2ad659ac383b0480a87c5e0844efc8cbaf5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mence=20Lesn=C3=A9?= Date: Wed, 11 Dec 2024 20:58:42 +0100 Subject: [PATCH 07/17] chore: Harmonize memory cache sizes --- app/helpers/config_models/cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/helpers/config_models/cache.py b/app/helpers/config_models/cache.py index 44d70cf..b8b212d 100644 --- a/app/helpers/config_models/cache.py +++ b/app/helpers/config_models/cache.py @@ -14,7 +14,7 @@ class ModeEnum(str, Enum): class MemoryModel(BaseModel, frozen=True): - max_size: int = Field(default=100, ge=10) + max_size: int = Field(default=128, ge=10) @cache def instance(self) -> ICache: From a484853c456571da1e1120ed32f8fedfb0304319 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mence=20Lesn=C3=A9?= Date: Wed, 11 Dec 2024 20:58:59 +0100 Subject: [PATCH 08/17] fix: Format exceptions in the ocnsole --- app/helpers/logging.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/app/helpers/logging.py b/app/helpers/logging.py index d29b19e..c7d1a82 100644 --- a/app/helpers/logging.py +++ b/app/helpers/logging.py @@ -12,7 +12,6 @@ TimeStamper, UnicodeDecoder, add_log_level, - format_exc_info, ) from structlog.stdlib import PositionalArgumentsFormatter @@ -39,7 +38,6 @@ TimeStamper(fmt="iso", utc=True), # Add exceptions info StackInfoRenderer(), - format_exc_info, # Decode Unicode to str UnicodeDecoder(), # Pretty printing in a terminal session From e92b056ed0973329a08e5dae4003d3c5b9e9bad3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mence=20Lesn=C3=A9?= Date: Wed, 11 Dec 2024 21:00:06 +0100 Subject: [PATCH 09/17] fix: Limit local cache TTL size --- app/persistence/memory.py | 45 +++++++++++++++++++++++---------------- 1 file changed, 27 insertions(+), 18 deletions(-) diff --git a/app/persistence/memory.py b/app/persistence/memory.py index aedc8ec..738860f 100644 --- a/app/persistence/memory.py +++ b/app/persistence/memory.py @@ -1,9 +1,9 @@ import hashlib from collections import OrderedDict +from contextlib import suppress from datetime import UTC, datetime, timedelta from app.helpers.config_models.cache import MemoryModel -from app.helpers.logging import logger from app.models.readiness import ReadinessEnum from app.persistence.icache import ICache @@ -19,13 +19,9 @@ class MemoryCache(ICache): _cache: OrderedDict[str, bytes | None] = OrderedDict() _config: MemoryModel - _ttl: dict[str, datetime] = {} + _ttl: OrderedDict[str, datetime] = OrderedDict() def __init__(self, config: MemoryModel): - logger.warning( - "Using memory cache with %s size limit, memory usage can be high, prefer an external cache like Redis", - config.max_size, - ) self._config = config async def readiness(self) -> ReadinessEnum: @@ -41,16 +37,22 @@ async def get(self, key: str) -> bytes | None: If the key does not exist, return `None`. """ sha_key = self._key_to_hash(key) - # Check TTL - if sha_key in self._ttl: - if self._ttl[sha_key] < datetime.now(UTC): - return None + + # Check TTL, delete if expired + ttl = self._ttl.get(sha_key, None) + if ttl and ttl < datetime.now(UTC): + await self.delete(key) + return None + # Get from cache res = self._cache.get(sha_key, None) if not res: return None + # Move to first self._cache.move_to_end(sha_key, last=False) + self._ttl.move_to_end(sha_key, last=False) + return res async def set( @@ -63,14 +65,20 @@ async def set( Set a value in the cache. """ sha_key = self._key_to_hash(key) + # Delete the last if full if len(self._cache) >= self._config.max_size: + self._ttl.popitem() self._cache.popitem() - # Add to first + + # Set TTL as first element + self._ttl[sha_key] = datetime.now(UTC) + timedelta(seconds=ttl_sec) + self._ttl.move_to_end(sha_key, last=False) + + # Add cache as first element self._cache[sha_key] = value.encode() if isinstance(value, str) else value self._cache.move_to_end(sha_key, last=False) - # Set the TTL - self._ttl[sha_key] = datetime.now(UTC) + timedelta(seconds=ttl_sec) + return True async def delete(self, key: str) -> bool: @@ -78,12 +86,13 @@ async def delete(self, key: str) -> bool: Delete a value from the cache. """ sha_key = self._key_to_hash(key) - # Delete from cache - if sha_key in self._cache: - self._cache.pop(sha_key) - # Delete from TTL - if sha_key in self._ttl: + + # Delete keys + with suppress(KeyError): self._ttl.pop(sha_key) + with suppress(KeyError): + self._cache.pop(sha_key) + return True @staticmethod From 5e969c7cc5d5c8e2013025f8409c710518a4ea08 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mence=20Lesn=C3=A9?= Date: Wed, 11 Dec 2024 21:00:21 +0100 Subject: [PATCH 10/17] perf: Reuse properly Redis connexions --- app/persistence/redis.py | 101 +++++++++++++++++++++++---------------- 1 file changed, 61 insertions(+), 40 deletions(-) diff --git a/app/persistence/redis.py b/app/persistence/redis.py index b6b7551..feb35b7 100644 --- a/app/persistence/redis.py +++ b/app/persistence/redis.py @@ -1,9 +1,10 @@ import hashlib -from datetime import timedelta +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager from uuid import uuid4 from opentelemetry.instrumentation.redis import RedisInstrumentor -from redis.asyncio import Redis +from redis.asyncio import Connection, ConnectionPool, Redis, SSLConnection from redis.asyncio.retry import Retry from redis.backoff import ExponentialBackoff from redis.exceptions import ( @@ -12,6 +13,7 @@ RedisError, ) +from app.helpers.cache import async_lru_cache from app.helpers.config_models.cache import RedisModel from app.helpers.logging import logger from app.models.readiness import ReadinessEnum @@ -20,33 +22,12 @@ # Instrument redis RedisInstrumentor().instrument() -_retry = Retry(backoff=ExponentialBackoff(), retries=3) - class RedisCache(ICache): - _client: Redis _config: RedisModel def __init__(self, config: RedisModel): - logger.info("Using Redis cache %s:%s", config.host, config.port) self._config = config - self._client = Redis( - # Database location - db=config.database, - # Reliability - health_check_interval=10, # Check the health of the connection every 10 secs - retry_on_error=[BusyLoadingError, RedisConnectionError], - retry_on_timeout=True, - retry=_retry, - socket_connect_timeout=5, # Give the system sufficient time to connect even under higher CPU conditions - socket_timeout=1, # Respond quickly or abort, this is a cache - # Deployment - host=config.host, - port=config.port, - ssl=config.ssl, - # Authentication - password=config.password.get_secret_value(), - ) # Redis manage by itself a low level connection pool with asyncio, but be warning to not use a generator while consuming the connection, it will close it async def readiness(self) -> ReadinessEnum: """ @@ -57,16 +38,17 @@ async def readiness(self) -> ReadinessEnum: test_name = str(uuid4()) test_value = "test" try: - # Test the item does not exist - assert await self._client.get(test_name) is None - # Create a new item - await self._client.set(test_name, test_value) - # Test the item is the same - assert (await self._client.get(test_name)).decode() == test_value - # Delete the item - await self._client.delete(test_name) - # Test the item does not exist - assert await self._client.get(test_name) is None + async with self._use_client() as client: + # Test the item does not exist + assert await client.get(test_name) is None + # Create a new item + await client.set(test_name, test_value) + # Test the item is the same + assert (await client.get(test_name)).decode() == test_value + # Delete the item + await client.delete(test_name) + # Test the item does not exist + assert await client.get(test_name) is None return ReadinessEnum.OK except AssertionError: logger.exception("Readiness test failed") @@ -87,7 +69,8 @@ async def get(self, key: str) -> bytes | None: sha_key = self._key_to_hash(key) res = None try: - res = await self._client.get(sha_key) + async with self._use_client() as client: + res = await client.get(sha_key) except RedisError: logger.exception("Error getting value") return res @@ -107,11 +90,12 @@ async def set( """ sha_key = self._key_to_hash(key) try: - await self._client.set( - ex=timedelta(seconds=ttl_sec), - name=sha_key, - value=value if value else "", - ) + async with self._use_client() as client: + await client.set( + ex=ttl_sec, + name=sha_key, + value=value if value else "", + ) except RedisError: logger.exception("Error setting value") return False @@ -125,12 +109,49 @@ async def delete(self, key: str) -> bool: """ sha_key = self._key_to_hash(key) try: - await self._client.delete(sha_key) + async with self._use_client() as client: + await client.delete(sha_key) except RedisError: logger.exception("Error deleting value") return False return True + @async_lru_cache() + async def _use_connection_pool(self) -> ConnectionPool: + """ + Generate the Redis connection pool. + """ + logger.info("Using Redis cache %s:%s", self._config.host, self._config.port) + + return ConnectionPool( + # Database location + db=self._config.database, + # Reliability + health_check_interval=10, # Check the health of the connection every 10 secs + retry_on_error=[BusyLoadingError, RedisConnectionError], + retry_on_timeout=True, + retry=Retry(backoff=ExponentialBackoff(), retries=3), + socket_connect_timeout=5, # Give the system sufficient time to connect even under higher CPU conditions + socket_timeout=1, # Respond quickly or abort, this is a cache + # Deployment + connection_class=SSLConnection if self._config.ssl else Connection, + host=self._config.host, + port=self._config.port, + # Authentication + password=self._config.password.get_secret_value(), + ) + + @asynccontextmanager + async def _use_client(self) -> AsyncGenerator[Redis, None]: + """ + Return a Redis connection. + """ + async with Redis( + auto_close_connection_pool=False, + connection_pool=await self._use_connection_pool(), + ) as client: + yield client + @staticmethod def _key_to_hash(key: str) -> bytes: """ From 543a54ad03088f7117d9d6e4cd2072c0825c1260 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mence=20Lesn=C3=A9?= Date: Wed, 11 Dec 2024 19:25:01 +0100 Subject: [PATCH 11/17] refacto: Simpler call chat task management --- app/helpers/call_llm.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/app/helpers/call_llm.py b/app/helpers/call_llm.py index 7308768..035b922 100644 --- a/app/helpers/call_llm.py +++ b/app/helpers/call_llm.py @@ -3,7 +3,7 @@ from datetime import UTC, datetime, timedelta from functools import wraps -from aiojobs import Job, Scheduler +from aiojobs import Scheduler from azure.cognitiveservices.speech import ( SpeechSynthesizer, ) @@ -132,7 +132,7 @@ def _complete_stt_callback(text: str) -> None: ) as tts_client, ): # Build scheduler - last_response: Job | None = None + last_chat: asyncio.Task | None = None async def _timeout_callback() -> None: """ @@ -157,9 +157,9 @@ async def _stop_callback() -> None: """ Triggered when the audio buffer needs to be cleared. """ - # Close previous response now - if last_response: - await last_response.close(timeout=0) + # Cancel previous chat + if last_chat: + last_chat.cancel() # Stop TTS, clear the buffer and send a stop signal tts_client.stop_speaking_async() @@ -186,8 +186,8 @@ async def _commit_answer( Start the chat task and wait for its response if needed. Job is stored in `last_response` shared variable. """ # Start chat task - nonlocal last_response - last_response = await scheduler.spawn( + nonlocal last_chat + last_chat = asyncio.create_task( _continue_chat( call=call, client=automation_client, @@ -201,7 +201,7 @@ async def _commit_answer( # Wait for its response if wait: - await last_response.wait() + await last_chat async def _response_callback(_retry: bool = False) -> None: """ From e4a96c3e26e6ed4260b6a00956c34d6b51ec4ba2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mence=20Lesn=C3=A9?= Date: Wed, 11 Dec 2024 19:25:23 +0100 Subject: [PATCH 12/17] ux: Reorder voice cut actions to lower latency --- app/helpers/call_llm.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/app/helpers/call_llm.py b/app/helpers/call_llm.py index 035b922..677cdfc 100644 --- a/app/helpers/call_llm.py +++ b/app/helpers/call_llm.py @@ -161,14 +161,10 @@ async def _stop_callback() -> None: if last_chat: last_chat.cancel() - # Stop TTS, clear the buffer and send a stop signal + # Stop TTS task tts_client.stop_speaking_async() - # Reset the recognition - stt_buffer.clear() - stt_complete_gate.clear() - - # Clear the audio buffer + # Clear the out buffer while not audio_out.empty(): audio_out.get_nowait() audio_out.task_done() @@ -176,6 +172,10 @@ async def _stop_callback() -> None: # Send a stop signal await audio_out.put(False) + # Reset TTS buffer + stt_buffer.clear() + stt_complete_gate.clear() + async def _commit_answer( wait: bool, tool_blacklist: set[str] | None = None, From c3ab4b56b4b163ade9ab1d0624ce070b9972ede8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mence=20Lesn=C3=A9?= Date: Wed, 11 Dec 2024 19:29:24 +0100 Subject: [PATCH 13/17] dev: Fix STT recognition log --- app/helpers/call_llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/helpers/call_llm.py b/app/helpers/call_llm.py index 677cdfc..07c004e 100644 --- a/app/helpers/call_llm.py +++ b/app/helpers/call_llm.py @@ -112,9 +112,9 @@ def _complete_stt_callback(text: str) -> None: stt_buffer.append("") # Store the recognition stt_buffer[-1] = text + logger.debug("Complete recognition: %s", stt_buffer) # Add a new buffer for the next partial recognition stt_buffer.append("") - logger.debug("Complete recognition: %s", stt_buffer) # Open the recognition gate stt_complete_gate.set() From d702a57bcb9d0adcd0da89bc4b07bc19cd00ce8b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mence=20Lesn=C3=A9?= Date: Wed, 11 Dec 2024 21:04:00 +0100 Subject: [PATCH 14/17] perf: Keep 25% of the removed noise Finally, this is not 100% perfect and this "muffle" the noise. It should require more fine-tuning. This reverts commit 37de648f58c5e28149f2588a1a83d7afc0835b19. --- app/helpers/call_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/app/helpers/call_utils.py b/app/helpers/call_utils.py index c6cd014..033d682 100644 --- a/app/helpers/call_utils.py +++ b/app/helpers/call_utils.py @@ -734,6 +734,8 @@ async def _process_one(self, input_pcm: bytes) -> None: clip_noise_stationary=False, # Noise is longer than the signal stationary=True, y_noise=self._bot_voice_buffer, + # Output quality + prop_decrease=0.75, # Reduce noise by 75% ) # Perform VAD test From 5fe3c42374444e4b2b6d3dc8a4bccd28cc08b13e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mence=20Lesn=C3=A9?= Date: Wed, 11 Dec 2024 21:38:57 +0100 Subject: [PATCH 15/17] chore: Remove dead code --- app/helpers/call_utils.py | 24 ++++++++++-------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/app/helpers/call_utils.py b/app/helpers/call_utils.py index 033d682..7c0fece 100644 --- a/app/helpers/call_utils.py +++ b/app/helpers/call_utils.py @@ -211,13 +211,11 @@ async def handle_automation_tts( # noqa: PLR0913 return if store: - await scheduler.spawn( - _store_assistant_message( - call=call, - style=style, - text=text, - scheduler=scheduler, - ) + await _store_assistant_message( + call=call, + style=style, + text=text, + scheduler=scheduler, ) @@ -274,13 +272,11 @@ async def handle_realtime_tts( # noqa: PLR0913 ) if store: - await scheduler.spawn( - _store_assistant_message( - call=call, - style=style, - text=text, - scheduler=scheduler, - ) + await _store_assistant_message( + call=call, + style=style, + text=text, + scheduler=scheduler, ) From 1bd86bd3507d63ef611ea0a6032fbc7c9ce44275 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mence=20Lesn=C3=A9?= Date: Wed, 11 Dec 2024 21:39:31 +0100 Subject: [PATCH 16/17] perf: Optimize echo reduction latency --- app/helpers/call_utils.py | 53 ++++++++++++++++++++++++++++----------- 1 file changed, 39 insertions(+), 14 deletions(-) diff --git a/app/helpers/call_utils.py b/app/helpers/call_utils.py index 7c0fece..3587522 100644 --- a/app/helpers/call_utils.py +++ b/app/helpers/call_utils.py @@ -637,6 +637,7 @@ class EchoCancellationStream: _reference_queue: asyncio.Queue[bytes] = asyncio.Queue() _sample_rate: int _scheduler: Scheduler + _empty_packet: bytes def __init__( self, @@ -654,6 +655,7 @@ def __init__( self._chunk_size = int(self._sample_rate * self._packet_duration_ms / 1000) self._packet_size = self._chunk_size * 2 # Each sample is 2 bytes (PCM 16-bit) + self._empty_packet: bytes = b"\x00" * self._packet_size def _pcm_to_float(self, pcm: bytes) -> np.ndarray: """ @@ -705,11 +707,13 @@ async def _process_one(self, input_pcm: bytes) -> None: """ Process one audio chunk with echo cancellation. """ - # Use silence as the reference if none is available + # Push raw input if reference is empty if self._reference_queue.empty(): - reference_pcm = b"\x00" * self._packet_size + reference_pcm = self._empty_packet + + # Reference signal is available else: - reference_pcm = await self._reference_queue.get() + reference_pcm = self._reference_queue.get_nowait() self._reference_queue.task_done() # Convert PCM to float for processing @@ -719,6 +723,15 @@ async def _process_one(self, input_pcm: bytes) -> None: # Update the input buffer with the reference signal self._update_input_buffer(reference_signal) + # Reference signal is empty, skip noise reduction + if np.all(reference_signal == 0): + # Perform VAD test + input_speaking = await self._rms_speech_detection(input_signal) + + # Add processed PCM and metadata to the output queue + self._output_queue.put_nowait((input_pcm, input_speaking)) + return + # Apply noise reduction reduced_signal = reduce_noise( # Input signal @@ -741,7 +754,27 @@ async def _process_one(self, input_pcm: bytes) -> None: processed_pcm = self._float_to_pcm(reduced_signal) # Add processed PCM and metadata to the output queue - await self._output_queue.put((processed_pcm, input_speaking)) + self._output_queue.put_nowait((processed_pcm, input_speaking)) + + async def _ensure_stream(self, input_pcm: bytes) -> None: + """ + Ensure the audio stream is processed in real-time. + + If the processing is delayed, the original input will be returned. + """ + # Process the audio + try: + await asyncio.wait_for( + self._process_one(input_pcm), + timeout=self._packet_duration_ms + / 1000 + * 4, # Allow temporary medium latency + ) + + # If the processing is delayed, return the original input + except TimeoutError: + logger.warning("Echo processing timeout, returning input") + await self._output_queue.put((input_pcm, False)) async def process_stream(self) -> None: """ @@ -756,14 +789,7 @@ async def process_stream(self) -> None: self._input_queue.task_done() # Queue the processing - await scheduler.spawn( - asyncio.wait_for( - self._process_one(input_pcm), - timeout=self._packet_duration_ms - / 1000 - * 5, # Allow temporary high latency - ) - ) + await scheduler.spawn(self._ensure_stream(input_pcm)) async def push_input(self, audio_data: bytes) -> None: """ @@ -794,7 +820,6 @@ async def pull_audio(self) -> tuple[bytes, bool]: Returns a tuple with the echo-cancelled PCM audio and a boolean flag indicating if the user was speaking. """ - # return await self._output_queue.get() try: return await asyncio.wait_for( fut=self._output_queue.get(), @@ -803,4 +828,4 @@ async def pull_audio(self) -> tuple[bytes, bool]: * 1.5, # Allow temporary small latency ) except TimeoutError: - return b"\x00" * self._packet_size, False # Silence PCM chunk and no speech + return self._empty_packet, False From 7be0d540311a0f6bd62213c81415db1c324f826a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mence=20Lesn=C3=A9?= Date: Wed, 11 Dec 2024 21:39:55 +0100 Subject: [PATCH 17/17] ux: Optimize echo reduction reliability --- app/helpers/call_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/helpers/call_utils.py b/app/helpers/call_utils.py index 3587522..3e3ebb2 100644 --- a/app/helpers/call_utils.py +++ b/app/helpers/call_utils.py @@ -643,7 +643,7 @@ def __init__( self, sample_rate: int, scheduler: Scheduler, - max_delay_ms: int = 200, + max_delay_ms: int = 300, packet_duration_ms: int = 20, ): self._packet_duration_ms = packet_duration_ms