Skip to content

Commit

Permalink
refacto: AECStream class
Browse files Browse the repository at this point in the history
  • Loading branch information
clemlesne committed Dec 14, 2024
1 parent 9908e63 commit d9604bc
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 106 deletions.
84 changes: 21 additions & 63 deletions app/helpers/call_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
from app.helpers.logging import logger
from app.helpers.monitoring import (
SpanAttributeEnum,
call_answer_latency,
call_cutoff_latency,
gauge_set,
tracer,
Expand All @@ -56,7 +55,7 @@

# TODO: Refacto, this function is too long
@tracer.start_as_current_span("call_load_llm_chat")
async def load_llm_chat( # noqa: PLR0913, PLR0915
async def load_llm_chat( # noqa: PLR0913
audio_in: asyncio.Queue[bytes],
audio_out: asyncio.Queue[bytes | bool],
audio_sample_rate: int,
Expand All @@ -67,48 +66,7 @@ async def load_llm_chat( # noqa: PLR0913, PLR0915
training_callback: Callable[[CallStateModel], Awaitable[None]],
) -> None:
# Init language recognition
aec = AECStream(
sample_rate=audio_sample_rate,
scheduler=scheduler,
)
audio_reference: asyncio.Queue[bytes] = asyncio.Queue()
answer_start: float | None = None

async def _send_in_to_aec() -> None:
"""
Send input audio to the echo cancellation.
"""
while True:
in_chunck = await audio_in.get()
audio_in.task_done()
await aec.push_input(in_chunck)

async def _send_out_to_aec() -> None:
"""
Forward the TTS to the echo cancellation and output.
"""
while True:
# Consume the audio
out_chunck = await audio_reference.get()
audio_reference.task_done()

# Report the answer latency and reset the timer
nonlocal answer_start
if answer_start:
# Enrich span
gauge_set(
metric=call_answer_latency,
value=time.monotonic() - answer_start,
)
answer_start = None

# Forward the audio
await asyncio.gather(
# First, send the audio to the output
audio_out.put(out_chunck),
# Then, send the audio to the echo cancellation
aec.push_reference(out_chunck),
)
audio_tts: asyncio.Queue[bytes] = asyncio.Queue()

async with (
SttClient(
Expand All @@ -118,8 +76,15 @@ async def _send_out_to_aec() -> None:
) as stt_client,
use_tts_client(
call=call,
out=audio_reference,
out=audio_tts,
) as tts_client,
AECStream(
in_raw_queue=audio_in,
in_reference_queue=audio_tts,
out_queue=audio_out,
sample_rate=audio_sample_rate,
scheduler=scheduler,
) as aec,
):
# Build scheduler
last_chat: asyncio.Task | None = None
Expand Down Expand Up @@ -205,9 +170,9 @@ async def _response_callback(_retry: bool = False) -> None:
If the recognition is empty, retry the recognition once. Otherwise, process the response.
"""
# Report the answer latency
nonlocal answer_start
answer_start = time.monotonic()
aec.answer_start()

# Pull the recognition
stt_text = await stt_client.pull_recognition()

# Ignore empty recognition
Expand Down Expand Up @@ -256,22 +221,15 @@ async def _response_callback(_retry: bool = False) -> None:
wait=False,
)

await asyncio.gather(
# Start the echo cancellation
aec.process_stream(),
# Apply the echo cancellation
_send_in_to_aec(),
_send_out_to_aec(),
# Detect VAD
_process_audio_for_vad(
call=call,
in_callback=aec.pull_audio,
out_callback=stt_client.push_audio,
response_callback=_response_callback,
scheduler=scheduler,
stop_callback=_stop_callback,
timeout_callback=_timeout_callback,
),
# Detect VAD
await _process_audio_for_vad(
call=call,
in_callback=aec.pull_audio,
out_callback=stt_client.push_audio,
response_callback=_response_callback,
scheduler=scheduler,
stop_callback=_stop_callback,
timeout_callback=_timeout_callback,
)


Expand Down
153 changes: 110 additions & 43 deletions app/helpers/call_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from collections.abc import AsyncGenerator, Awaitable, Callable, Generator
from contextlib import asynccontextmanager, contextmanager, suppress
from enum import Enum
from typing import Any

import numpy as np
from aiojobs import Job, Scheduler
Expand Down Expand Up @@ -50,6 +51,7 @@
from app.helpers.monitoring import (
call_aec_droped,
call_aec_missed,
call_answer_latency,
call_stt_complete_latency,
counter_add,
gauge_set,
Expand Down Expand Up @@ -744,28 +746,51 @@ async def pull_recognition(self) -> str:

class AECStream:
"""
Real-time audio stream with echo cancellation.
Real-time audio stream with echo cancellation (AEC).
Input and output formats are in PCM 16-bit, 16 kHz, 1 channel.
"""

_aec_in_queue: asyncio.Queue[bytes] = asyncio.Queue()
_aec_out_queue: asyncio.Queue[tuple[bytes, bool]] = asyncio.Queue()
_aec_reference_queue: asyncio.Queue[bytes] = asyncio.Queue()
_answer_start: float | None = None
_chunk_size: int
_input_queue: asyncio.Queue[bytes] = asyncio.Queue()
_output_queue: asyncio.Queue[tuple[bytes, bool]] = asyncio.Queue()
_empty_packet: bytes
_in_raw_queue: asyncio.Queue[bytes]
_in_reference_queue: asyncio.Queue[bytes] = asyncio.Queue()
_out_queue: asyncio.Queue[bytes]
_packet_duration_ms: int
_packet_size: int
_reference_queue: asyncio.Queue[bytes] = asyncio.Queue()
_run_task: asyncio.Future
_sample_rate: int
_scheduler: Scheduler
_empty_packet: bytes

def __init__(
def __init__( # noqa: PLR0913
self,
in_raw_queue: asyncio.Queue[bytes],
in_reference_queue: asyncio.Queue[bytes],
out_queue: asyncio.Queue[bytes | Any],
sample_rate: int,
scheduler: Scheduler,
max_delay_ms: int = 200,
packet_duration_ms: int = 20,
):
"""
Initialize the audio stream.
Parameters:
- `in_raw_queue`: Queue for the raw audio input (user speaking).
- `in_reference_queue`: Queue for the reference audio input (bot speaking).
- `max_delay_ms`: Maximum delay to consider between the raw and reference audio.
- `out_queue`: Queue for the processed audio output (echo-cancelled user speaking).
- `packet_duration_ms`: Duration of each audio packet in milliseconds.
- `sample_rate`: Audio sample rate in Hz.
- `scheduler`: Scheduler for the async tasks.
"""
self._in_raw_queue = in_raw_queue
self._in_reference_queue = in_reference_queue
self._out_queue = out_queue
self._packet_duration_ms = packet_duration_ms
self._sample_rate = sample_rate
self._scheduler = scheduler
Expand All @@ -777,6 +802,17 @@ def __init__(
self._packet_size = self._chunk_size * 2 # Each sample is 2 bytes (PCM 16-bit)
self._empty_packet: bytes = b"\x00" * self._packet_size

async def __aenter__(self):
self._run_task = asyncio.gather(
self._forward_in(),
self._forward_out(),
self._run(),
)
return self

async def __aexit__(self, *args, **kwargs):
self._run_task.cancel()

def _pcm_to_float(self, pcm: bytes) -> np.ndarray:
"""
Convert PCM 16-bit to float (-1.0 to 1.0).
Expand Down Expand Up @@ -825,16 +861,16 @@ async def _rms_speech_detection(self, voice: np.ndarray) -> bool:

async def _process_one(self, input_pcm: bytes) -> None:
"""
Process one audio chunk with echo cancellation.
Process one audio chunk.
"""
# Push raw input if reference is empty
if self._reference_queue.empty():
if self._aec_reference_queue.empty():
reference_pcm = self._empty_packet

# Reference signal is available
else:
reference_pcm = self._reference_queue.get_nowait()
self._reference_queue.task_done()
reference_pcm = await self._aec_reference_queue.get()
self._aec_reference_queue.task_done()

# Convert PCM to float for processing
input_signal = self._pcm_to_float(input_pcm)
Expand All @@ -849,7 +885,7 @@ async def _process_one(self, input_pcm: bytes) -> None:
input_speaking = await self._rms_speech_detection(input_signal)

# Add processed PCM and metadata to the output queue
self._output_queue.put_nowait((input_pcm, input_speaking))
await self._aec_out_queue.put((input_pcm, input_speaking))
return

# Apply noise reduction
Expand All @@ -874,11 +910,11 @@ async def _process_one(self, input_pcm: bytes) -> None:
processed_pcm = self._float_to_pcm(reduced_signal)

# Add processed PCM and metadata to the output queue
self._output_queue.put_nowait((processed_pcm, input_speaking))
await self._aec_out_queue.put((processed_pcm, input_speaking))

async def _ensure_stream(self, input_pcm: bytes) -> None:
async def _ensure_run_slo(self, input_pcm: bytes) -> None:
"""
Ensure the audio stream is processed in real-time.
Ensure the audio stream is processed within the SLO.
If the processing is delayed, the original input will be returned.
"""
Expand All @@ -898,9 +934,9 @@ async def _ensure_stream(self, input_pcm: bytes) -> None:
metric=call_aec_missed,
value=1,
)
await self._output_queue.put((input_pcm, False))
await self._aec_out_queue.put((input_pcm, False))

async def process_stream(self) -> None:
async def _run(self) -> None:
"""
Process the audio stream in real-time.
"""
Expand All @@ -909,34 +945,11 @@ async def process_stream(self) -> None:
) as scheduler:
while True:
# Fetch input audio
input_pcm = await self._input_queue.get()
self._input_queue.task_done()
input_pcm = await self._aec_in_queue.get()
self._aec_in_queue.task_done()

# Queue the processing
await scheduler.spawn(self._ensure_stream(input_pcm))

async def push_input(self, audio_data: bytes) -> None:
"""
Push PCM input audio into the input queue.
"""
if len(audio_data) != self._packet_size:
raise ValueError(
f"Expected packet size {self._packet_size} bytes, got {len(audio_data)} bytes."
)
await self._input_queue.put(audio_data)

async def push_reference(self, audio_data: bytes) -> None:
"""
Push PCM reference audio into the reference queue.
The reference audio is used for echo cancellation.
"""
# Extract packets and pad them if necessary
buffer_pointer = 0
while buffer_pointer < len(audio_data):
chunk = audio_data[: self._packet_size].ljust(self._packet_size, b"\x00")
await self._reference_queue.put(chunk)
buffer_pointer += self._packet_size
await scheduler.spawn(self._ensure_run_slo(input_pcm))

async def pull_audio(self) -> tuple[bytes, bool]:
"""
Expand All @@ -947,7 +960,7 @@ async def pull_audio(self) -> tuple[bytes, bool]:
# Fetch output audio
try:
return await asyncio.wait_for(
fut=self._output_queue.get(),
fut=self._aec_out_queue.get(),
timeout=self._packet_duration_ms
/ 1000
* 1.5, # Allow temporary small latency
Expand All @@ -962,3 +975,57 @@ async def pull_audio(self) -> tuple[bytes, bool]:
)
# Return empty packet
return self._empty_packet, False

async def _forward_in(self) -> None:
"""
Send input audio to the runner.
"""
while True:
# Consume input
audio_data = await self._in_raw_queue.get()
self._in_raw_queue.task_done()

# Validate packet size
if len(audio_data) != self._packet_size:
raise ValueError(
f"Expected packet size {self._packet_size} bytes, got {len(audio_data)} bytes."
)

# Push audio to the AEC queue
await self._aec_in_queue.put(audio_data)

async def _forward_out(self) -> None:
"""
Forward processed audio to the clean output queue.
"""
while True:
# Consume input
audio_data = await self._in_reference_queue.get()
self._in_reference_queue.task_done()

# Report the answer latency and reset the timer
if self._answer_start:
# Enrich span
gauge_set(
metric=call_answer_latency,
value=time.monotonic() - self._answer_start,
)
self._answer_start = None

# Send to clean output
await self._out_queue.put(audio_data)

# Send a copy as reference, extract packets and pad them if necessary
buffer_pointer = 0
while buffer_pointer < len(audio_data):
chunk = audio_data[: self._packet_size].ljust(
self._packet_size, b"\x00"
)
await self._aec_reference_queue.put(chunk)
buffer_pointer += self._packet_size

def answer_start(self):
"""
Notify the the user ended speaking.
"""
self._answer_start = time.monotonic()

0 comments on commit d9604bc

Please sign in to comment.