Skip to content

Commit

Permalink
Merge branch 'develop'
Browse files Browse the repository at this point in the history
  • Loading branch information
clemlesne committed Dec 8, 2024
2 parents 07c44b4 + 277a8fb commit 0170fcf
Show file tree
Hide file tree
Showing 17 changed files with 245 additions and 199 deletions.
114 changes: 59 additions & 55 deletions app/helpers/call_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
)
from azure.core.exceptions import HttpResponseError, ResourceNotFoundError

from app.helpers.cache import async_lru_cache
from app.helpers.config import CONFIG
from app.helpers.identity import token
from app.helpers.logging import logger
Expand Down Expand Up @@ -84,10 +85,14 @@ class ContextEnum(str, Enum):
Used to track the operation context of a call in Azure Communication Services.
"""

GOODBYE = "goodbye" # Hang up
IVR_LANG_SELECT = "ivr_lang_select" # IVR language selection
START_REALTIME = "start_realtime" # Start realtime call
TRANSFER_FAILED = "transfer_failed" # Transfer failed
GOODBYE = "goodbye"
"""Hang up"""
IVR_LANG_SELECT = "ivr_lang_select"
"""IVR language selection"""
START_REALTIME = "start_realtime"
"""Start realtime call"""
TRANSFER_FAILED = "transfer_failed"
"""Transfer failed"""


def tts_sentence_split(
Expand Down Expand Up @@ -141,11 +146,11 @@ async def handle_media(
"""
with _detect_hangup():
assert call.voice_id, "Voice ID is required to control the call"
async with _use_call_client(client, call.voice_id) as call_client:
await call_client.play_media(
operation_context=_context_serializer({context}),
play_source=FileSource(url=sound_url),
)
call_client = await _use_call_client(client, call.voice_id)
await call_client.play_media(
operation_context=_context_serializer({context}),
play_source=FileSource(url=sound_url),
)


async def handle_automation_tts( # noqa: PLR0913
Expand All @@ -170,19 +175,19 @@ async def handle_automation_tts( # noqa: PLR0913
# Play each chunk
jobs: list[Job] = []
chunks = _chunk_for_tts(text)
async with _use_call_client(client, call.voice_id) as call_client:
jobs += [
await scheduler.spawn(
_automation_play_text(
call_client=call_client,
call=call,
context=context,
style=style,
text=chunk,
)
call_client = await _use_call_client(client, call.voice_id)
jobs += [
await scheduler.spawn(
_automation_play_text(
call_client=call_client,
call=call,
context=context,
style=style,
text=chunk,
)
for chunk in chunks
]
)
for chunk in chunks
]

# Wait for all jobs to finish and catch hangup
for job in jobs:
Expand Down Expand Up @@ -379,20 +384,20 @@ async def handle_recognize_ivr(
logger.info("Recognizing IVR: %s", text)
try:
assert call.voice_id, "Voice ID is required to control the call"
async with _use_call_client(client, call.voice_id) as call_client:
await call_client.start_recognizing_media(
choices=choices,
input_type=RecognizeInputType.CHOICES,
interrupt_prompt=True,
operation_context=_context_serializer({context}),
play_prompt=_ssml_from_text(
call=call,
style=MessageStyleEnum.NONE,
text=text,
),
speech_language=call.lang.short_code,
target_participant=PhoneNumberIdentifier(call.initiate.phone_number), # pyright: ignore
)
call_client = await _use_call_client(client, call.voice_id)
await call_client.start_recognizing_media(
choices=choices,
input_type=RecognizeInputType.CHOICES,
interrupt_prompt=True,
operation_context=_context_serializer({context}),
play_prompt=_ssml_from_text(
call=call,
style=MessageStyleEnum.NONE,
text=text,
),
speech_language=call.lang.short_code,
target_participant=PhoneNumberIdentifier(call.initiate.phone_number), # pyright: ignore
)
except ResourceNotFoundError:
logger.debug("Call hung up before recognizing")

Expand All @@ -414,8 +419,8 @@ async def handle_hangup(
_detect_hangup(),
):
assert call.voice_id, "Voice ID is required to control the call"
async with _use_call_client(client, call.voice_id) as call_client:
await call_client.hang_up(is_for_everyone=True)
call_client = await _use_call_client(client, call.voice_id)
await call_client.hang_up(is_for_everyone=True)


async def handle_transfer(
Expand All @@ -432,11 +437,11 @@ async def handle_transfer(
logger.info("Transferring call: %s", target)
with _detect_hangup():
assert call.voice_id, "Voice ID is required to control the call"
async with _use_call_client(client, call.voice_id) as call_client:
await call_client.transfer_call_to_participant(
operation_context=_context_serializer({context}),
target_participant=PhoneNumberIdentifier(target),
)
call_client = await _use_call_client(client, call.voice_id)
await call_client.transfer_call_to_participant(
operation_context=_context_serializer({context}),
target_participant=PhoneNumberIdentifier(target),
)


async def start_audio_streaming(
Expand All @@ -451,13 +456,13 @@ async def start_audio_streaming(
logger.info("Starting audio streaming")
with _detect_hangup():
assert call.voice_id, "Voice ID is required to control the call"
async with _use_call_client(client, call.voice_id) as call_client:
# TODO: Use the public API once the "await" have been fixed
# await call_client.start_media_streaming()
await call_client._call_media_client.start_media_streaming(
call_connection_id=call_client._call_connection_id,
start_media_streaming_request=StartMediaStreamingRequest(),
)
call_client = await _use_call_client(client, call.voice_id)
# TODO: Use the public API once the "await" have been fixed
# await call_client.start_media_streaming()
await call_client._call_media_client.start_media_streaming(
call_connection_id=call_client._call_connection_id,
start_media_streaming_request=StartMediaStreamingRequest(),
)


async def stop_audio_streaming(
Expand All @@ -472,8 +477,8 @@ async def stop_audio_streaming(
logger.info("Stopping audio streaming")
with _detect_hangup():
assert call.voice_id, "Voice ID is required to control the call"
async with _use_call_client(client, call.voice_id) as call_client:
await call_client.stop_media_streaming()
call_client = await _use_call_client(client, call.voice_id)
await call_client.stop_media_streaming()


def _context_serializer(contexts: set[ContextEnum | None] | None) -> str | None:
Expand Down Expand Up @@ -505,15 +510,14 @@ def _detect_hangup() -> Generator[None, None, None]:
raise e


@asynccontextmanager
@async_lru_cache()
async def _use_call_client(
client: CallAutomationClient, voice_id: str
) -> AsyncGenerator[CallConnectionClient, None]:
) -> CallConnectionClient:
"""
Return the call client for a given call.
"""
# Client already been created in the call client, never close it from here
yield client.get_call_connection(call_connection_id=voice_id)
return client.get_call_connection(call_connection_id=voice_id)


@asynccontextmanager
Expand Down
2 changes: 2 additions & 0 deletions app/helpers/config_models/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

class ModeEnum(str, Enum):
MEMORY = "memory"
"""Use memory cache."""
REDIS = "redis"
"""Use Redis cache."""


class MemoryModel(BaseModel, frozen=True):
Expand Down
57 changes: 28 additions & 29 deletions app/helpers/config_models/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,18 @@
from openai import AsyncAzureOpenAI, AsyncOpenAI
from pydantic import BaseModel, Field, SecretStr, ValidationInfo, field_validator

from app.helpers.cache import async_lru_cache
from app.helpers.identity import token


class ModeEnum(str, Enum):
AZURE_OPENAI = "azure_openai"
"""Use Azure OpenAI."""
OPENAI = "openai"
"""Use OpenAI."""


class AbstractPlatformModel(BaseModel):
class AbstractPlatformModel(BaseModel, frozen=True):
_client_kwargs: dict[str, Any] = {
# Reliability
"max_retries": 0, # Retries are managed manually
Expand All @@ -33,41 +36,37 @@ async def instance(
pass


class AzureOpenaiPlatformModel(AbstractPlatformModel):
_client: AsyncAzureOpenAI | None = None
class AzureOpenaiPlatformModel(AbstractPlatformModel, frozen=True):
api_version: str = "2024-06-01"
deployment: str

@async_lru_cache()
async def instance(self) -> tuple[AsyncAzureOpenAI, AbstractPlatformModel]:
if not self._client:
self._client = AsyncAzureOpenAI(
**self._client_kwargs,
# Deployment
api_version=self.api_version,
azure_deployment=self.deployment,
azure_endpoint=self.endpoint,
# Authentication
azure_ad_token_provider=await token(
"https://cognitiveservices.azure.com/.default"
),
)
return self._client, self


class OpenaiPlatformModel(AbstractPlatformModel):
_client: AsyncOpenAI | None = None
return AsyncAzureOpenAI(
**self._client_kwargs,
# Deployment
api_version=self.api_version,
azure_deployment=self.deployment,
azure_endpoint=self.endpoint,
# Authentication
azure_ad_token_provider=await token(
"https://cognitiveservices.azure.com/.default"
),
), self


class OpenaiPlatformModel(AbstractPlatformModel, frozen=True):
api_key: SecretStr

@async_lru_cache()
async def instance(self) -> tuple[AsyncOpenAI, AbstractPlatformModel]:
if not self._client:
self._client = AsyncOpenAI(
**self._client_kwargs,
# API root URL
base_url=self.endpoint,
# Authentication
api_key=self.api_key.get_secret_value(),
)
return self._client, self
return AsyncOpenAI(
**self._client_kwargs,
# API root URL
base_url=self.endpoint,
# Authentication
api_key=self.api_key.get_secret_value(),
), self


class SelectedPlatformModel(BaseModel):
Expand Down
2 changes: 2 additions & 0 deletions app/helpers/config_models/sms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@

class ModeEnum(str, Enum):
COMMUNICATION_SERVICES = "communication_services"
"""Use Communication Services."""
TWILIO = "twilio"
"""Use Twilio."""


class CommunicationServiceModel(BaseModel, frozen=True):
Expand Down
2 changes: 1 addition & 1 deletion app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ async def health_readiness_get() -> JSONResponse:
status_code = HTTPStatus.SERVICE_UNAVAILABLE
break
return JSONResponse(
content=readiness.model_dump_json(),
content=readiness.model_dump(mode="json"),
status_code=status_code,
)

Expand Down
24 changes: 16 additions & 8 deletions app/models/call.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,18 +70,26 @@ def _validate_messages(cls, messages: list[MessageModel]) -> list[MessageModel]:
"""
Merge messages with the same persona.
"""
merged: list[MessageModel] = []
for new_message in messages:
if not (
merged
and (last := merged[-1]).persona == new_message.persona
and last.action == new_message.action
):

# Skip if there are no messages
if not messages:
return messages

# Iterate over the messages
merged: list[MessageModel] = [messages[0]]
for new_message in messages[1:]:
# If the last message is not from the same persona or action, keep it as is
last = merged[-1]
if last.persona != new_message.persona or last.action != new_message.action:
merged.append(new_message)
continue

# Merge the content and tool calls
last.content = (last.content + " " + new_message.content).strip()
last.tool_calls = list({*last.tool_calls, *new_message.tool_calls})
# Override the style
last.style = new_message.style
last.tool_calls += new_message.tool_calls

return merged


Expand Down
4 changes: 4 additions & 0 deletions app/models/claim.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@

class ClaimTypeEnum(str, Enum):
DATETIME = "datetime"
"""Parsed to a Python datetime object."""
EMAIL = "email"
"""Validated as an email address string."""
PHONE_NUMBER = "phone_number"
"""Validated as a phone number string."""
TEXT = "text"
"""Validated as a string."""


class ClaimFieldModel(BaseModel):
Expand Down
Loading

0 comments on commit 0170fcf

Please sign in to comment.