From 42286af8c1d564aa2de8f4bd2302c51c61a289ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mence=20Lesn=C3=A9?= Date: Sat, 14 Dec 2024 16:09:31 +0100 Subject: [PATCH] perf: Remove defer feature update This is totally an anti pattern. It overloads in few seconds the scheduler queue and slow down the whole request by 5-10 seconds. --- app/helpers/call_events.py | 8 +-- app/helpers/call_llm.py | 17 +++--- app/helpers/call_utils.py | 5 +- app/helpers/features.py | 59 ++++++------------- app/helpers/llm_tools.py | 5 +- app/helpers/llm_worker.py | 8 +-- app/main.py | 99 ++++++++++++-------------------- app/persistence/cosmos_db.py | 11 ++-- app/persistence/istore.py | 3 - tests/store.py | 108 ++++++++++++----------------------- 10 files changed, 112 insertions(+), 211 deletions(-) diff --git a/app/helpers/call_events.py b/app/helpers/call_events.py index d082c207..c154039c 100644 --- a/app/helpers/call_events.py +++ b/app/helpers/call_events.py @@ -125,7 +125,6 @@ async def on_call_connected( _handle_recording( call=call, client=client, - scheduler=scheduler, server_call_id=server_call_id, ), # Second, start recording the call ) @@ -235,7 +234,7 @@ async def on_automation_recognize_error( logger.info( "Timeout, retrying language selection (%s/%s)", call.recognition_retry, - await recognition_retry_max(scheduler), + await recognition_retry_max(), ) await _handle_ivr_language( call=call, @@ -321,7 +320,7 @@ async def _pre_recognize_error( Returns True if the call should continue, False if it should end. """ # Voice retries are exhausted, end call - if call.recognition_retry >= await recognition_retry_max(scheduler): + if call.recognition_retry >= await recognition_retry_max(): logger.info("Timeout, ending call") return False @@ -793,7 +792,6 @@ async def _handle_ivr_language( async def _handle_recording( call: CallStateModel, client: CallAutomationClient, - scheduler: Scheduler, server_call_id: str, ) -> None: """ @@ -801,7 +799,7 @@ async def _handle_recording( Feature activation is checked before starting the recording. """ - if not await recording_enabled(scheduler): + if not await recording_enabled(): return assert CONFIG.communication_services.recording_container_url diff --git a/app/helpers/call_llm.py b/app/helpers/call_llm.py index 36e7c585..794fdb8a 100644 --- a/app/helpers/call_llm.py +++ b/app/helpers/call_llm.py @@ -227,7 +227,6 @@ async def _response_callback(_retry: bool = False) -> None: in_callback=aec.pull_audio, out_callback=stt_client.push_audio, response_callback=_response_callback, - scheduler=scheduler, stop_callback=_stop_callback, timeout_callback=_timeout_callback, ) @@ -307,10 +306,10 @@ def _loading_task() -> asyncio.Task: # Timeouts soft_timeout_triggered = False soft_timeout_task = asyncio.create_task( - asyncio.sleep(await answer_soft_timeout_sec(scheduler)) + asyncio.sleep(await answer_soft_timeout_sec()) ) hard_timeout_task = asyncio.create_task( - asyncio.sleep(await answer_hard_timeout_sec(scheduler)) + asyncio.sleep(await answer_hard_timeout_sec()) ) def _clear_tasks() -> None: @@ -340,7 +339,7 @@ def _clear_tasks() -> None: if hard_timeout_task.done(): logger.warning( "Hard timeout of %ss reached", - await answer_hard_timeout_sec(scheduler), + await answer_hard_timeout_sec(), ) # Clean up _clear_tasks() @@ -352,7 +351,7 @@ def _clear_tasks() -> None: if soft_timeout_task.done() and not soft_timeout_triggered: logger.warning( "Soft timeout of %ss reached", - await answer_soft_timeout_sec(scheduler), + await answer_soft_timeout_sec(), ) soft_timeout_triggered = True # Never store the error message in the call history, it has caused hallucinations in the LLM @@ -501,7 +500,6 @@ async def _content_callback(buffer: str) -> None: async for delta in completion_stream( max_tokens=160, # Lowest possible value for 90% of the cases, if not sufficient, retry will be triggered, 100 tokens ~= 75 words, 20 words ~= 1 sentence, 6 sentences ~= 160 tokens messages=call.messages, - scheduler=scheduler, system=system, tools=tools, ): @@ -619,7 +617,6 @@ async def _process_audio_for_vad( # noqa: PLR0913 in_callback: Callable[[], Awaitable[tuple[bytes, bool]]], out_callback: Callable[[bytes], None], response_callback: Callable[[], Awaitable[None]], - scheduler: Scheduler, stop_callback: Callable[[], Awaitable[None]], timeout_callback: Callable[[], Awaitable[None]], ) -> None: @@ -643,7 +640,7 @@ async def _wait_for_silence() -> None: """ # Wait before flushing nonlocal stop_task - timeout_ms = await vad_silence_timeout_ms(scheduler) + timeout_ms = await vad_silence_timeout_ms() await asyncio.sleep(timeout_ms / 1000) # Cancel the clear TTS task @@ -656,7 +653,7 @@ async def _wait_for_silence() -> None: await response_callback() # Wait for silence and trigger timeout - timeout_sec = await phone_silence_timeout_sec(scheduler) + timeout_sec = await phone_silence_timeout_sec() while True: # Stop this time if the call played a message timeout_start = datetime.now(UTC) @@ -685,7 +682,7 @@ async def _wait_for_stop() -> None: """ Stop the TTS if user speaks for too long. """ - timeout_ms = await vad_cutoff_timeout_ms(scheduler) + timeout_ms = await vad_cutoff_timeout_ms() # Wait before clearing the TTS queue await asyncio.sleep(timeout_ms / 1000) diff --git a/app/helpers/call_utils.py b/app/helpers/call_utils.py index f9bcd0fe..1eed49f3 100644 --- a/app/helpers/call_utils.py +++ b/app/helpers/call_utils.py @@ -728,8 +728,7 @@ async def pull_recognition(self) -> str: try: await asyncio.wait_for( self._stt_complete_gate.wait(), - timeout=await recognition_stt_complete_timeout_ms(self._scheduler) - / 1000, + timeout=await recognition_stt_complete_timeout_ms() / 1000, ) except TimeoutError: logger.debug("Complete recognition timeout, using partial recognition") @@ -856,7 +855,7 @@ async def _rms_speech_detection(self, voice: np.ndarray) -> bool: # Calculate Root Mean Square (RMS) rms = np.sqrt(np.mean(voice**2)) # Get VAD threshold, divide by 10 to more usability from user side, as RMS is in range 0-1 and a detection of 0.1 is a good maximum threshold - threshold = await vad_threshold(self._scheduler) / 10 + threshold = await vad_threshold() / 10 return rms >= threshold async def _process_one(self, input_pcm: bytes) -> None: diff --git a/app/helpers/features.py b/app/helpers/features.py index 84658683..12339645 100644 --- a/app/helpers/features.py +++ b/app/helpers/features.py @@ -1,6 +1,5 @@ from typing import TypeVar, cast -from aiojobs import Scheduler from azure.appconfiguration.aio import AzureAppConfigurationClient from azure.core.exceptions import ResourceNotFoundError @@ -16,55 +15,51 @@ T = TypeVar("T", bool, int, float, str) -async def answer_hard_timeout_sec(scheduler: Scheduler) -> int: +async def answer_hard_timeout_sec() -> int: """ The hard timeout for the bot answer in secs. """ return await _default( default=60, key="answer_hard_timeout_sec", - scheduler=scheduler, type_res=int, ) -async def answer_soft_timeout_sec(scheduler: Scheduler) -> int: +async def answer_soft_timeout_sec() -> int: """ The soft timeout for the bot answer in secs. """ return await _default( default=30, key="answer_soft_timeout_sec", - scheduler=scheduler, type_res=int, ) -async def callback_timeout_hour(scheduler: Scheduler) -> int: +async def callback_timeout_hour() -> int: """ The timeout for a callback in hours. Set 0 to disable. """ return await _default( default=24, key="callback_timeout_hour", - scheduler=scheduler, type_res=int, ) -async def phone_silence_timeout_sec(scheduler: Scheduler) -> int: +async def phone_silence_timeout_sec() -> int: """ Amount of silence in secs to trigger a warning message from the assistant. """ return await _default( default=20, key="phone_silence_timeout_sec", - scheduler=scheduler, type_res=int, ) -async def vad_threshold(scheduler: Scheduler) -> float: +async def vad_threshold() -> float: """ The threshold for voice activity detection. Between 0.1 and 1. """ @@ -73,60 +68,55 @@ async def vad_threshold(scheduler: Scheduler) -> float: key="vad_threshold", max_incl=1, min_incl=0.1, - scheduler=scheduler, type_res=float, ) -async def vad_silence_timeout_ms(scheduler: Scheduler) -> int: +async def vad_silence_timeout_ms() -> int: """ Silence to trigger voice activity detection in milliseconds. """ return await _default( default=500, key="vad_silence_timeout_ms", - scheduler=scheduler, type_res=int, ) -async def vad_cutoff_timeout_ms(scheduler: Scheduler) -> int: +async def vad_cutoff_timeout_ms() -> int: """ The cutoff timeout for voice activity detection in milliseconds. """ return await _default( default=250, key="vad_cutoff_timeout_ms", - scheduler=scheduler, type_res=int, ) -async def recording_enabled(scheduler: Scheduler) -> bool: +async def recording_enabled() -> bool: """ Whether call recording is enabled. """ return await _default( default=False, key="recording_enabled", - scheduler=scheduler, type_res=bool, ) -async def slow_llm_for_chat(scheduler: Scheduler) -> bool: +async def slow_llm_for_chat() -> bool: """ Whether to use the slow LLM for chat. """ return await _default( default=True, key="slow_llm_for_chat", - scheduler=scheduler, type_res=bool, ) -async def recognition_retry_max(scheduler: Scheduler) -> int: +async def recognition_retry_max() -> int: """ The maximum number of retries for voice recognition. Minimum of 1. """ @@ -134,27 +124,24 @@ async def recognition_retry_max(scheduler: Scheduler) -> int: default=3, key="recognition_retry_max", min_incl=1, - scheduler=scheduler, type_res=int, ) -async def recognition_stt_complete_timeout_ms(scheduler: Scheduler) -> int: +async def recognition_stt_complete_timeout_ms() -> int: """ The timeout for STT completion in milliseconds. """ return await _default( default=100, key="recognition_stt_complete_timeout_ms", - scheduler=scheduler, type_res=int, ) -async def _default( # noqa: PLR0913 +async def _default( default: T, key: str, - scheduler: Scheduler, type_res: type[T], max_incl: T | None = None, min_incl: T | None = None, @@ -165,7 +152,6 @@ async def _default( # noqa: PLR0913 # Get the setting res = await _get( key=key, - scheduler=scheduler, type_res=type_res, ) if res: @@ -207,11 +193,7 @@ def _validate( return res -async def _get( - key: str, - scheduler: Scheduler, - type_res: type[T], -) -> T | None: +async def _get(key: str, type_res: type[T]) -> T | None: """ Get a setting from the App Configuration service. """ @@ -224,15 +206,6 @@ async def _get( value=cached.decode(), ) - # Defer the update - await scheduler.spawn(_refresh(cache_key, key)) - return - - -async def _refresh( - cache_key: str, - key: str, -) -> T | None: # Try live try: async with await _use_client() as client: @@ -253,6 +226,12 @@ async def _refresh( value=res, ) + # Return value + return _parse( + type_res=type_res, + value=res, + ) + @async_lru_cache() async def _use_client() -> AzureAppConfigurationClient: diff --git a/app/helpers/llm_tools.py b/app/helpers/llm_tools.py index 79d4f758..8a1dd0a8 100644 --- a/app/helpers/llm_tools.py +++ b/app/helpers/llm_tools.py @@ -88,7 +88,7 @@ async def new_claim( # Store the last message and use it at first message of the new claim self.call = await _db.call_create( - call=CallStateModel( + CallStateModel( initiate=self.call.initiate.model_copy(), voice_id=self.call.voice_id, messages=[ @@ -101,8 +101,7 @@ async def new_claim( # Reinsert the last message, using more will add the user message asking to create the new claim and the assistant can loop on it sometimes self.call.messages[-1], ], - ), - scheduler=self.scheduler, + ) ) return "Claim, reminders and messages reset" diff --git a/app/helpers/llm_worker.py b/app/helpers/llm_worker.py index 44339817..3bfe22e9 100644 --- a/app/helpers/llm_worker.py +++ b/app/helpers/llm_worker.py @@ -5,7 +5,6 @@ from typing import TypeVar import tiktoken -from aiojobs import Scheduler from json_repair import repair_json from openai import ( APIConnectionError, @@ -89,7 +88,6 @@ class MaximumTokensReachedError(Exception): async def completion_stream( max_tokens: int, messages: list[MessageModel], - scheduler: Scheduler, system: list[ChatCompletionSystemMessageParam], tools: list[ChatCompletionToolParam] | None = None, ) -> AsyncGenerator[ChoiceDelta, None]: @@ -112,9 +110,7 @@ async def completion_stream( async for attempt in retryed: with attempt: async for chunck in _completion_stream_worker( - is_fast=not await slow_llm_for_chat( - scheduler - ), # Let configuration decide + is_fast=not await slow_llm_for_chat(), # Let configuration decide max_tokens=max_tokens, messages=messages, system=system, @@ -134,7 +130,7 @@ async def completion_stream( async for attempt in retryed: with attempt: async for chunck in _completion_stream_worker( - is_fast=await slow_llm_for_chat(scheduler), # Let configuration decide + is_fast=await slow_llm_for_chat(), # Let configuration decide max_tokens=max_tokens, messages=messages, system=system, diff --git a/app/main.py b/app/main.py index fcaa8e67..513fb50b 100644 --- a/app/main.py +++ b/app/main.py @@ -292,11 +292,7 @@ async def report_single_get(call_id: UUID) -> HTMLResponse: Returns a single call with a web interface. """ - async with get_scheduler() as scheduler: - call = await _db.call_get( - call_id=call_id, - scheduler=scheduler, - ) + call = await _db.call_get(call_id) if not call: return HTMLResponse( content=f"Call {call_id} not found", @@ -356,29 +352,24 @@ async def call_get(call_id_or_phone_number: str) -> CallGetModel: Returns a single call object `CallGetModel`, in JSON format. """ - async with get_scheduler() as scheduler: - # First, try to get by call ID - with suppress(ValueError): - call_id = UUID(call_id_or_phone_number) - call = await _db.call_get( - call_id=call_id, - scheduler=scheduler, - ) - if call: - return TypeAdapter(CallGetModel).dump_python(call) - - # Second, try to get by phone number - phone_number = PhoneNumber(call_id_or_phone_number) - call = await _db.call_search_one( - callback_timeout=False, - phone_number=phone_number, - scheduler=scheduler, + # First, try to get by call ID + with suppress(ValueError): + call_id = UUID(call_id_or_phone_number) + call = await _db.call_get(call_id) + if call: + return TypeAdapter(CallGetModel).dump_python(call) + + # Second, try to get by phone number + phone_number = PhoneNumber(call_id_or_phone_number) + call = await _db.call_search_one( + callback_timeout=False, + phone_number=phone_number, + ) + if not call: + raise HTTPException( + detail=f"Call {call_id_or_phone_number} not found", + status_code=HTTPStatus.NOT_FOUND, ) - if not call: - raise HTTPException( - detail=f"Call {call_id_or_phone_number} not found", - status_code=HTTPStatus.NOT_FOUND, - ) return TypeAdapter(CallGetModel).dump_python(call) @@ -508,7 +499,6 @@ async def sms_event( call = await _db.call_search_one( callback_timeout=False, phone_number=phone_number, - scheduler=scheduler, ) if not call: logger.warning("Call for phone number %s not found", phone_number) @@ -565,17 +555,13 @@ async def _communicationservices_validate_call_id( # Enrich span SpanAttributeEnum.CALL_ID.attribute(str(call_id)) - async with get_scheduler() as scheduler: - # Validate call - call = await _db.call_get( - call_id=call_id, - scheduler=scheduler, + # Validate call + call = await _db.call_get(call_id) + if not call: + raise HTTPException( + detail=f"Call {call_id} not found", + status_code=HTTPStatus.NOT_FOUND, ) - if not call: - raise HTTPException( - detail=f"Call {call_id} not found", - status_code=HTTPStatus.NOT_FOUND, - ) # Validate secret if call.callback_secret != secret: @@ -927,10 +913,7 @@ async def post_event( """ async with get_scheduler() as scheduler: # Validate call - call = await _db.call_get( - call_id=UUID(post.content), - scheduler=scheduler, - ) + call = await _db.call_get(UUID(post.content)) if not call: logger.warning("Call %s not found", post.content) return @@ -971,25 +954,20 @@ async def _communicationservices_urls( Returnes a tuple of the callback URL, the WebSocket URL, and the call object. """ - async with get_scheduler() as scheduler: - # Get call - call = await _db.call_search_one( - phone_number=phone_number, - scheduler=scheduler, - ) - - # Create new call if initiate is different - if not call or (initiate and call.initiate != initiate): - call = await _db.call_create( - call=CallStateModel( - initiate=initiate - or CallInitiateModel( - **CONFIG.conversation.initiate.model_dump(), - phone_number=phone_number, - ) - ), - scheduler=scheduler, + # Get call + call = await _db.call_search_one(phone_number) + + # Create new call if initiate is different + if not call or (initiate and call.initiate != initiate): + call = await _db.call_create( + CallStateModel( + initiate=initiate + or CallInitiateModel( + **CONFIG.conversation.initiate.model_dump(), + phone_number=phone_number, + ) ) + ) # Format URLs wss_url = _COMMUNICATIONSERVICES_WSS_TPL.format( @@ -1029,7 +1007,6 @@ async def twilio_sms_post( call = await _db.call_search_one( callback_timeout=False, phone_number=From, - scheduler=scheduler, ) # Call not found diff --git a/app/persistence/cosmos_db.py b/app/persistence/cosmos_db.py index 55e805d4..7b1de221 100644 --- a/app/persistence/cosmos_db.py +++ b/app/persistence/cosmos_db.py @@ -84,7 +84,6 @@ async def _item_exists(self, test_id: str, partition_key: str) -> bool: async def call_get( self, call_id: UUID, - scheduler: Scheduler, ) -> CallStateModel | None: logger.debug("Loading call %s", call_id) @@ -118,7 +117,7 @@ async def call_get( if call: await self._cache.set( key=cache_key, - ttl_sec=max(await callback_timeout_hour(scheduler), 1) + ttl_sec=max(await callback_timeout_hour(), 1) * 60 * 60, # Ensure at least 1 hour value=call.model_dump_json(), @@ -191,7 +190,7 @@ async def _exec() -> None: cache_key_id = self._cache_key_call_id(call.call_id) await self._cache.set( key=cache_key_id, - ttl_sec=max(await callback_timeout_hour(scheduler), 1) + ttl_sec=max(await callback_timeout_hour(), 1) * 60 * 60, # Ensure at least 1 hour value=call.model_dump_json(), @@ -204,7 +203,6 @@ async def _exec() -> None: async def call_create( self, call: CallStateModel, - scheduler: Scheduler, ) -> CallStateModel: logger.debug("Creating new call %s", call.call_id) @@ -225,7 +223,7 @@ async def call_create( cache_key = self._cache_key_call_id(call.call_id) await self._cache.set( key=cache_key, - ttl_sec=max(await callback_timeout_hour(scheduler), 1) + ttl_sec=max(await callback_timeout_hour(), 1) * 60 * 60, # Ensure at least 1 hour value=call.model_dump_json(), @@ -242,12 +240,11 @@ async def call_create( async def call_search_one( self, phone_number: str, - scheduler: Scheduler, callback_timeout: bool = True, ) -> CallStateModel | None: logger.debug("Loading last call for %s", phone_number) - timeout = await callback_timeout_hour(scheduler) + timeout = await callback_timeout_hour() if timeout < 1 and callback_timeout: logger.debug("Callback timeout if off, skipping search") return None diff --git a/app/persistence/istore.py b/app/persistence/istore.py index 3cfe7927..becd3fa7 100644 --- a/app/persistence/istore.py +++ b/app/persistence/istore.py @@ -26,7 +26,6 @@ async def readiness(self) -> ReadinessEnum: async def call_get( self, call_id: UUID, - scheduler: Scheduler, ) -> CallStateModel | None: pass @@ -44,7 +43,6 @@ def call_transac( async def call_create( self, call: CallStateModel, - scheduler: Scheduler, ) -> CallStateModel: pass @@ -53,7 +51,6 @@ async def call_create( async def call_search_one( self, phone_number: str, - scheduler: Scheduler, callback_timeout: bool = True, ) -> CallStateModel | None: pass diff --git a/tests/store.py b/tests/store.py index 6146cefb..c74d844f 100644 --- a/tests/store.py +++ b/tests/store.py @@ -22,67 +22,40 @@ async def test_acid(call: CallStateModel) -> None: """ db = CONFIG.database.instance() - async with Scheduler() as scheduler: - # Check not exists - assume( - not await db.call_get( - call_id=call.call_id, - scheduler=scheduler, - ) - ) - assume( - await db.call_search_one( - phone_number=call.initiate.phone_number, - scheduler=scheduler, - ) - != call - ) - assume( - call - not in ( - ( - await db.call_search_all( - phone_number=call.initiate.phone_number, count=1 - ) - )[0] - or [] - ) - ) - - # Insert test call - await db.call_create( - call=call, - scheduler=scheduler, - ) - - # Check point read - assume( - await db.call_get( - call_id=call.call_id, - scheduler=scheduler, - ) - == call - ) - # Check search one - assume( - await db.call_search_one( - phone_number=call.initiate.phone_number, - scheduler=scheduler, - ) - == call + # Check not exists + assume(not await db.call_get(call.call_id)) + assume(await db.call_search_one(call.initiate.phone_number) != call) + assume( + call + not in ( + ( + await db.call_search_all( + phone_number=call.initiate.phone_number, count=1 + ) + )[0] + or [] ) - # Check search all - assume( - call - in ( - ( - await db.call_search_all( - phone_number=call.initiate.phone_number, count=1 - ) - )[0] - or [] - ) + ) + + # Insert test call + await db.call_create(call) + + # Check point read + assume(await db.call_get(call.call_id) == call) + # Check search one + assume(await db.call_search_one(call.initiate.phone_number) == call) + # Check search all + assume( + call + in ( + ( + await db.call_search_all( + phone_number=call.initiate.phone_number, count=1 + ) + )[0] + or [] ) + ) @pytest.mark.asyncio(loop_scope="session") @@ -98,18 +71,10 @@ async def test_transaction( async with Scheduler() as scheduler: # Check not exists - assume( - not await db.call_get( - call_id=call.call_id, - scheduler=scheduler, - ) - ) + assume(not await db.call_get(call.call_id)) # Insert call - await db.call_create( - call=call, - scheduler=scheduler, - ) + await db.call_create(call) # Check first change async with db.call_transac( @@ -137,8 +102,5 @@ async def test_transaction( assume(call.voice_id == random_text) # Check point read - new_call = await db.call_get( - call_id=call.call_id, - scheduler=scheduler, - ) + new_call = await db.call_get(call.call_id) assume(new_call and new_call.voice_id == random_text and new_call.in_progress)