diff --git a/main.py b/main.py index d1c92e4c..9274bab0 100644 --- a/main.py +++ b/main.py @@ -36,7 +36,7 @@ ) from models.claim import ClaimModel from openai import AsyncAzureOpenAI -from os import environ +from os import environ, urandom from uuid import UUID, uuid4 import asyncio import json @@ -58,7 +58,9 @@ azure_endpoint=CONFIG.openai.endpoint, azure_deployment=CONFIG.openai.gpt_deployment, ) -eventgrid_subscription_name = f"tmp-{uuid4()}" +event_subscription_prefix = urandom(4).hex() +event_call_subscription_name = f"{event_subscription_prefix}-call" +event_sms_subscription_name = f"{event_subscription_prefix}-sms" eventgrid_mgmt_client = EventGridManagementClient( credential=DefaultAzureCredential(), subscription_id=CONFIG.eventgrid.subscription_id, @@ -91,10 +93,14 @@ class Context(str, Enum): @asynccontextmanager async def lifespan(_: FastAPI): init_db() - task = asyncio.create_task(eventgrid_register()) # Background task + call_task = asyncio.create_task(eventgrid_call_register()) # Background task + sms_task = asyncio.create_task(eventgrid_sms_register()) # Background task yield - task.cancel() - eventgrid_unregister() # Foreground task + call_task.cancel() + sms_task.cancel() + eventgrid_unregister(event_call_subscription_name) # Foreground task + eventgrid_unregister(event_sms_subscription_name) # Foreground task + db.close() api = FastAPI( @@ -120,15 +126,15 @@ async def lifespan(_: FastAPI): ) -async def eventgrid_register() -> None: +async def eventgrid_call_register() -> None: def callback(future: ARMPolling): - _logger.info(f"Event Grid subscription created (status {future.status()})") + _logger.info(f"Event call subscription updated (status {future.status()})") - _logger.info(f"Creating Event Grid subscription {eventgrid_subscription_name}") + _logger.info(f"Creating event call subscription {event_call_subscription_name}") eventgrid_mgmt_client.system_topic_event_subscriptions.begin_create_or_update( resource_group_name=CONFIG.eventgrid.resource_group, system_topic_name=CONFIG.eventgrid.system_topic, - event_subscription_name=eventgrid_subscription_name, + event_subscription_name=event_call_subscription_name, event_subscription_info={ "properties": { "eventDeliverySchema": "EventGridSchema", @@ -159,12 +165,55 @@ def callback(future: ARMPolling): ).add_done_callback(callback) -def eventgrid_unregister() -> None: +async def eventgrid_sms_register() -> None: + def callback(future: ARMPolling): + _logger.info(f"Event SMS subscription updated (status {future.status()})") + + _logger.info(f"Creating event SMS subscription {event_sms_subscription_name}") + eventgrid_mgmt_client.system_topic_event_subscriptions.begin_create_or_update( + resource_group_name=CONFIG.eventgrid.resource_group, + system_topic_name=CONFIG.eventgrid.system_topic, + event_subscription_name=event_sms_subscription_name, + event_subscription_info={ + "properties": { + "eventDeliverySchema": "EventGridSchema", + "retryPolicy": { + "maxDeliveryAttempts": 8, + "eventTimeToLiveInMinutes": 30, # SMS are not real time + }, + "destination": { + "endpointType": "WebHook", + "properties": { + "endpointUrl": CALL_INBOUND_URL, + "maxEventsPerBatch": 1, + }, + }, + "filter": { + "enableAdvancedFilteringOnArrays": True, + "includedEventTypes": ["Microsoft.Communication.SmsReceived"], + "advancedFilters": [ + { + "key": "data.To", + "operatorType": "StringBeginsWith", + "values": [ + CONFIG.communication_service.phone_number[ + 1: + ], # Remove + sign from phone number + ], + } + ], + }, + }, + }, + ).add_done_callback(callback) + + +def eventgrid_unregister(subscription_name: str) -> None: _logger.info( - f"Deleting Event Grid subscription {eventgrid_subscription_name} (do not wait for completion)" + f"Deleting Event Grid subscription {subscription_name} (do not wait for completion)" ) eventgrid_mgmt_client.system_topic_event_subscriptions.begin_delete( - event_subscription_name=eventgrid_subscription_name, + event_subscription_name=subscription_name, resource_group_name=CONFIG.eventgrid.resource_group, system_topic_name=CONFIG.eventgrid.system_topic, ) @@ -210,7 +259,7 @@ async def call_inbound_post(request: Request): if event_type == SystemEventNames.EventGridSubscriptionValidationEventName: validation_code = event.data["validationCode"] - _logger.info(f"Validating Event Grid subscription ({validation_code})") + _logger.info(f"Validating event subscription ({validation_code})") return JSONResponse( content={"validationResponse": event.data["validationCode"]}, status_code=status.HTTP_200_OK, @@ -221,8 +270,8 @@ async def call_inbound_post(request: Request): phone_number = event.data["from"]["phoneNumber"]["value"] else: phone_number = event.data["from"]["rawId"] + _logger.debug(f"Incoming call: {phone_number}") - _logger.debug(f"Incoming call handler caller ID: {phone_number}") call_context = event.data["incomingCallContext"] answer_call_result = call_automation_client.answer_call( callback_url=callback_url(phone_number), @@ -233,6 +282,30 @@ async def call_inbound_post(request: Request): f"Answered call with {phone_number} ({answer_call_result.call_connection_id})" ) + elif event_type == SystemEventNames.AcsSmsReceivedEventName: + message = event.data["Message"] + phone_number = event.data["To"] + _logger.debug(f"Incoming SMS: {phone_number}") + + call = get_last_call_by_phone_number(phone_number) + if not call or not call.last_connection_id: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Call {phone_number} not found, cannot handle SMS", + ) + + # TODO: Test the client is still connected, do not know the error code for this + client = call_automation_client.get_call_connection( + call_connection_id=call.last_connection_id + ) + + call.messages.append( + CallMessageModel( + content=f"Cutomer send a SMS: {message}", persona=CallPersona.HUMAN + ) + ) + await call_intelligence(call, client) + @api.post( "/call/event/{call_id}", @@ -258,6 +331,9 @@ async def call_event_post(request: Request, call_id: UUID) -> None: ) event_type = event.type + # Update call last connection id, used to get the call from a SMS + call.last_connection_id = connection_id + _logger.debug(f"Call event received {event_type} for call {call}") _logger.debug(event.data) @@ -288,7 +364,7 @@ async def call_event_post(request: Request, call_id: UUID) -> None: client=client, file="acknowledge.mp3", ) - await intelligence(call, client) + await call_intelligence(call, client) elif event_type == "Microsoft.Communication.CallDisconnected": # Call hung up _logger.info(f"Call disconnected ({call.id})") @@ -311,7 +387,7 @@ async def call_event_post(request: Request, call_id: UUID) -> None: call.messages.append( CallMessageModel(content=speech_text, persona=CallPersona.HUMAN) ) - await intelligence(call, client) + await call_intelligence(call, client) elif ( event_type == "Microsoft.Communication.RecognizeFailed" @@ -404,7 +480,7 @@ async def call_event_post(request: Request, call_id: UUID) -> None: save_call(call) -async def intelligence(call: CallModel, client: CallConnectionClient) -> None: +async def call_intelligence(call: CallModel, client: CallConnectionClient) -> None: chat_res = await gpt_chat(call) _logger.info(f"Chat ({call.id}): {chat_res}") @@ -435,7 +511,7 @@ async def intelligence(call: CallModel, client: CallConnectionClient) -> None: store=False, text=chat_res.content, ) - await intelligence(call, client) + await call_intelligence(call, client) else: await handle_recognize_text( diff --git a/models/call.py b/models/call.py index 05ce95c2..52dcc662 100644 --- a/models/call.py +++ b/models/call.py @@ -3,7 +3,7 @@ from models.claim import ClaimModel from models.reminder import ReminderModel from pydantic import BaseModel, Field -from typing import List +from typing import List, Optional from uuid import UUID, uuid4 @@ -31,6 +31,7 @@ class CallModel(BaseModel): claim: ClaimModel = Field(default_factory=ClaimModel) created_at: datetime = Field(default_factory=datetime.utcnow) id: UUID = Field(default_factory=uuid4) + last_connection_id: Optional[str] = None messages: List[MessageModel] = [] phone_number: str recognition_retry: int = Field(default=0)