Skip to content

Commit

Permalink
feat: Ability for customer to send SMS during a call
Browse files Browse the repository at this point in the history
  • Loading branch information
clemlesne committed Jan 12, 2024
1 parent 0211632 commit a4a3481
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 19 deletions.
112 changes: 94 additions & 18 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
)
from models.claim import ClaimModel
from openai import AsyncAzureOpenAI
from os import environ
from os import environ, urandom
from uuid import UUID, uuid4
import asyncio
import json
Expand All @@ -58,7 +58,9 @@
azure_endpoint=CONFIG.openai.endpoint,
azure_deployment=CONFIG.openai.gpt_deployment,
)
eventgrid_subscription_name = f"tmp-{uuid4()}"
event_subscription_prefix = urandom(4).hex()
event_call_subscription_name = f"{event_subscription_prefix}-call"
event_sms_subscription_name = f"{event_subscription_prefix}-sms"
eventgrid_mgmt_client = EventGridManagementClient(
credential=DefaultAzureCredential(),
subscription_id=CONFIG.eventgrid.subscription_id,
Expand Down Expand Up @@ -91,10 +93,14 @@ class Context(str, Enum):
@asynccontextmanager
async def lifespan(_: FastAPI):
init_db()
task = asyncio.create_task(eventgrid_register()) # Background task
call_task = asyncio.create_task(eventgrid_call_register()) # Background task
sms_task = asyncio.create_task(eventgrid_sms_register()) # Background task
yield
task.cancel()
eventgrid_unregister() # Foreground task
call_task.cancel()
sms_task.cancel()
eventgrid_unregister(event_call_subscription_name) # Foreground task
eventgrid_unregister(event_sms_subscription_name) # Foreground task
db.close()


api = FastAPI(
Expand All @@ -120,15 +126,15 @@ async def lifespan(_: FastAPI):
)


async def eventgrid_register() -> None:
async def eventgrid_call_register() -> None:
def callback(future: ARMPolling):
_logger.info(f"Event Grid subscription created (status {future.status()})")
_logger.info(f"Event call subscription updated (status {future.status()})")

_logger.info(f"Creating Event Grid subscription {eventgrid_subscription_name}")
_logger.info(f"Creating event call subscription {event_call_subscription_name}")
eventgrid_mgmt_client.system_topic_event_subscriptions.begin_create_or_update(
resource_group_name=CONFIG.eventgrid.resource_group,
system_topic_name=CONFIG.eventgrid.system_topic,
event_subscription_name=eventgrid_subscription_name,
event_subscription_name=event_call_subscription_name,
event_subscription_info={
"properties": {
"eventDeliverySchema": "EventGridSchema",
Expand Down Expand Up @@ -159,12 +165,55 @@ def callback(future: ARMPolling):
).add_done_callback(callback)


def eventgrid_unregister() -> None:
async def eventgrid_sms_register() -> None:
def callback(future: ARMPolling):
_logger.info(f"Event SMS subscription updated (status {future.status()})")

_logger.info(f"Creating event SMS subscription {event_sms_subscription_name}")
eventgrid_mgmt_client.system_topic_event_subscriptions.begin_create_or_update(
resource_group_name=CONFIG.eventgrid.resource_group,
system_topic_name=CONFIG.eventgrid.system_topic,
event_subscription_name=event_sms_subscription_name,
event_subscription_info={
"properties": {
"eventDeliverySchema": "EventGridSchema",
"retryPolicy": {
"maxDeliveryAttempts": 8,
"eventTimeToLiveInMinutes": 30, # SMS are not real time
},
"destination": {
"endpointType": "WebHook",
"properties": {
"endpointUrl": CALL_INBOUND_URL,
"maxEventsPerBatch": 1,
},
},
"filter": {
"enableAdvancedFilteringOnArrays": True,
"includedEventTypes": ["Microsoft.Communication.SmsReceived"],
"advancedFilters": [
{
"key": "data.To",
"operatorType": "StringBeginsWith",
"values": [
CONFIG.communication_service.phone_number[
1:
], # Remove + sign from phone number
],
}
],
},
},
},
).add_done_callback(callback)


def eventgrid_unregister(subscription_name: str) -> None:
_logger.info(
f"Deleting Event Grid subscription {eventgrid_subscription_name} (do not wait for completion)"
f"Deleting Event Grid subscription {subscription_name} (do not wait for completion)"
)
eventgrid_mgmt_client.system_topic_event_subscriptions.begin_delete(
event_subscription_name=eventgrid_subscription_name,
event_subscription_name=subscription_name,
resource_group_name=CONFIG.eventgrid.resource_group,
system_topic_name=CONFIG.eventgrid.system_topic,
)
Expand Down Expand Up @@ -210,7 +259,7 @@ async def call_inbound_post(request: Request):

if event_type == SystemEventNames.EventGridSubscriptionValidationEventName:
validation_code = event.data["validationCode"]
_logger.info(f"Validating Event Grid subscription ({validation_code})")
_logger.info(f"Validating event subscription ({validation_code})")
return JSONResponse(
content={"validationResponse": event.data["validationCode"]},
status_code=status.HTTP_200_OK,
Expand All @@ -221,8 +270,8 @@ async def call_inbound_post(request: Request):
phone_number = event.data["from"]["phoneNumber"]["value"]
else:
phone_number = event.data["from"]["rawId"]
_logger.debug(f"Incoming call: {phone_number}")

_logger.debug(f"Incoming call handler caller ID: {phone_number}")
call_context = event.data["incomingCallContext"]
answer_call_result = call_automation_client.answer_call(
callback_url=callback_url(phone_number),
Expand All @@ -233,6 +282,30 @@ async def call_inbound_post(request: Request):
f"Answered call with {phone_number} ({answer_call_result.call_connection_id})"
)

elif event_type == SystemEventNames.AcsSmsReceivedEventName:
message = event.data["Message"]
phone_number = event.data["To"]
_logger.debug(f"Incoming SMS: {phone_number}")

call = get_last_call_by_phone_number(phone_number)
if not call or not call.last_connection_id:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Call {phone_number} not found, cannot handle SMS",
)

# TODO: Test the client is still connected, do not know the error code for this
client = call_automation_client.get_call_connection(
call_connection_id=call.last_connection_id
)

call.messages.append(
CallMessageModel(
content=f"Cutomer send a SMS: {message}", persona=CallPersona.HUMAN
)
)
await call_intelligence(call, client)


@api.post(
"/call/event/{call_id}",
Expand All @@ -258,6 +331,9 @@ async def call_event_post(request: Request, call_id: UUID) -> None:
)
event_type = event.type

# Update call last connection id, used to get the call from a SMS
call.last_connection_id = connection_id

_logger.debug(f"Call event received {event_type} for call {call}")
_logger.debug(event.data)

Expand Down Expand Up @@ -288,7 +364,7 @@ async def call_event_post(request: Request, call_id: UUID) -> None:
client=client,
file="acknowledge.mp3",
)
await intelligence(call, client)
await call_intelligence(call, client)

elif event_type == "Microsoft.Communication.CallDisconnected": # Call hung up
_logger.info(f"Call disconnected ({call.id})")
Expand All @@ -311,7 +387,7 @@ async def call_event_post(request: Request, call_id: UUID) -> None:
call.messages.append(
CallMessageModel(content=speech_text, persona=CallPersona.HUMAN)
)
await intelligence(call, client)
await call_intelligence(call, client)

elif (
event_type == "Microsoft.Communication.RecognizeFailed"
Expand Down Expand Up @@ -404,7 +480,7 @@ async def call_event_post(request: Request, call_id: UUID) -> None:
save_call(call)


async def intelligence(call: CallModel, client: CallConnectionClient) -> None:
async def call_intelligence(call: CallModel, client: CallConnectionClient) -> None:
chat_res = await gpt_chat(call)
_logger.info(f"Chat ({call.id}): {chat_res}")

Expand Down Expand Up @@ -435,7 +511,7 @@ async def intelligence(call: CallModel, client: CallConnectionClient) -> None:
store=False,
text=chat_res.content,
)
await intelligence(call, client)
await call_intelligence(call, client)

else:
await handle_recognize_text(
Expand Down
3 changes: 2 additions & 1 deletion models/call.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from models.claim import ClaimModel
from models.reminder import ReminderModel
from pydantic import BaseModel, Field
from typing import List
from typing import List, Optional
from uuid import UUID, uuid4


Expand Down Expand Up @@ -31,6 +31,7 @@ class CallModel(BaseModel):
claim: ClaimModel = Field(default_factory=ClaimModel)
created_at: datetime = Field(default_factory=datetime.utcnow)
id: UUID = Field(default_factory=uuid4)
last_connection_id: Optional[str] = None
messages: List[MessageModel] = []
phone_number: str
recognition_retry: int = Field(default=0)
Expand Down

0 comments on commit a4a3481

Please sign in to comment.