Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
clemlesne committed Dec 9, 2024
1 parent d1edc75 commit 5de9a25
Show file tree
Hide file tree
Showing 11 changed files with 177 additions and 140 deletions.
18 changes: 9 additions & 9 deletions app/helpers/call_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from app.helpers.features import recognition_retry_max, recording_enabled
from app.helpers.llm_worker import completion_sync
from app.helpers.logging import logger
from app.helpers.monitoring import CallAttributes, span_attribute, tracer
from app.helpers.monitoring import SpanAttributes, span_attribute, tracer
from app.models.call import CallStateModel
from app.models.message import (
ActionEnum as MessageActionEnum,
Expand Down Expand Up @@ -236,7 +236,7 @@ async def on_automation_recognize_error(
logger.warning("Unknown context %s, no action taken", contexts)

# Enrich span
span_attribute(CallAttributes.CALL_CHANNEL, "ivr")
span_attribute(SpanAttributes.CALL_CHANNEL, "ivr")

# Retry IVR recognition
logger.info(
Expand Down Expand Up @@ -356,7 +356,7 @@ async def on_play_started(
logger.debug("Play started")

# Enrich span
span_attribute(CallAttributes.CALL_CHANNEL, "voice")
span_attribute(SpanAttributes.CALL_CHANNEL, "voice")

# Update last interaction
async with _db.call_transac(
Expand All @@ -382,7 +382,7 @@ async def on_automation_play_completed(
logger.debug("Play completed")

# Enrich span
span_attribute(CallAttributes.CALL_CHANNEL, "voice")
span_attribute(SpanAttributes.CALL_CHANNEL, "voice")

# Update last interaction
async with _db.call_transac(
Expand Down Expand Up @@ -431,7 +431,7 @@ async def on_play_error(error_code: int) -> None:
logger.debug("Play failed")

# Enrich span
span_attribute(CallAttributes.CALL_CHANNEL, "voice")
span_attribute(SpanAttributes.CALL_CHANNEL, "voice")

# Suppress known errors
# See: https://github.com/MicrosoftDocs/azure-docs/blob/main/articles/communication-services/how-tos/call-automation/play-action.md
Expand Down Expand Up @@ -470,8 +470,8 @@ async def on_ivr_recognized(
logger.info("IVR recognized: %s", label)

# Enrich span
span_attribute(CallAttributes.CALL_CHANNEL, "ivr")
span_attribute(CallAttributes.CALL_MESSAGE, label)
span_attribute(SpanAttributes.CALL_CHANNEL, "ivr")
span_attribute(SpanAttributes.CALL_MESSAGE, label)

# Parse language from label
try:
Expand Down Expand Up @@ -554,8 +554,8 @@ async def on_sms_received(
logger.info("SMS received from %s: %s", call.initiate.phone_number, message)

# Enrich span
span_attribute(CallAttributes.CALL_CHANNEL, "sms")
span_attribute(CallAttributes.CALL_MESSAGE, message)
span_attribute(SpanAttributes.CALL_CHANNEL, "sms")
span_attribute(SpanAttributes.CALL_MESSAGE, message)

# Add the SMS to the call history
async with _db.call_transac(
Expand Down
19 changes: 13 additions & 6 deletions app/helpers/call_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
completion_stream,
)
from app.helpers.logging import logger
from app.helpers.monitoring import CallAttributes, span_attribute, tracer
from app.helpers.monitoring import SpanAttributes, span_attribute, tracer
from app.models.call import CallStateModel
from app.models.message import (
ActionEnum as MessageAction,
Expand Down Expand Up @@ -208,8 +208,8 @@ async def _out_answer( # noqa: PLR0915, PLR0913
Returns the updated call model.
"""
# Add span attributes
span_attribute(CallAttributes.CALL_CHANNEL, "voice")
span_attribute(CallAttributes.CALL_MESSAGE, call.messages[-1].content)
span_attribute(SpanAttributes.CALL_CHANNEL, "voice")
span_attribute(SpanAttributes.CALL_MESSAGE, call.messages[-1].content)

# Reset recognition retry counter
async with _db.call_transac(
Expand Down Expand Up @@ -528,9 +528,16 @@ async def _content_callback(buffer: str) -> None:
return True, True, call # Error, retry

# Execute tools
tool_tasks = [tool_call.execute_function(plugins) for tool_call in tool_calls]
await asyncio.gather(*tool_tasks)
call = plugins.call # Update call model if object reference changed
async with _db.call_transac(
call=call,
scheduler=scheduler,
):
await asyncio.gather(
*[plugins.execute_tool(tool_call) for tool_call in tool_calls]
)

# Update call model if object reference changed
call = plugins.call

# Store message
async with _db.call_transac(
Expand Down
87 changes: 87 additions & 0 deletions app/helpers/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@

import asyncio
import inspect
import json
from collections.abc import Awaitable, Callable
from functools import cache
from inspect import getmembers, isfunction
from textwrap import dedent
from typing import Annotated, Any, ForwardRef, TypeVar
Expand All @@ -16,14 +18,17 @@
)
from azure.communication.callautomation.aio import CallAutomationClient
from jinja2 import Environment
from json_repair import repair_json
from openai.types.chat import ChatCompletionToolParam
from openai.types.shared_params.function_definition import FunctionDefinition
from pydantic import BaseModel, TypeAdapter
from pydantic._internal._typing_extra import eval_type_lenient
from pydantic.json_schema import JsonSchemaValue

from app.helpers.logging import logger
from app.helpers.monitoring import SpanAttributes, span_attribute, tracer
from app.models.call import CallStateModel
from app.models.message import ToolModel

T = TypeVar("T")
_jinja = Environment(
Expand Down Expand Up @@ -75,6 +80,88 @@ async def to_openai(self) -> list[ChatCompletionToolParam]:
]
)

@tracer.start_as_current_span("plugin_execute_tool")
async def execute_tool(self, tool: ToolModel) -> None:
functions = self._available_functions()
json_str = tool.function_arguments
name = tool.function_name

# Confirm the function name exists, this is a security measure to prevent arbitrary code execution, plus, Pydantic validator is not used on purpose to comply with older tools plugins
if name not in functions:
res = f"Invalid function names {name}, available are {functions}."
logger.warning(res)
# Update tool
tool.content = res
# Enrich span
span_attribute(SpanAttributes.TOOL_RESULT, tool.content)
return

# Try to fix JSON args to catch LLM hallucinations
# See: https://community.openai.com/t/gpt-4-1106-preview-messes-up-function-call-parameters-encoding/478500
args: dict[str, Any] | Any = repair_json(
json_str=json_str,
return_objects=True,
) # pyright: ignore

# Confirm the args are a dictionary
if not isinstance(args, dict):
logger.warning(
"Error decoding JSON args for function %s: %s...%s",
name,
json_str[:20],
json_str[-20:],
)
# Update tool
tool.content = (
f"Bad arguments, available are {functions}. Please try again."
)
# Enrich span
span_attribute(SpanAttributes.TOOL_RESULT, tool.content)
return

# Enrich span
span_attribute(SpanAttributes.TOOL_ARGS, json.dumps(args))
span_attribute(SpanAttributes.TOOL_NAME, name)

# Execute the function
try:
res = await getattr(self, name)(**args)
res_log = f"{res[:20]}...{res[-20:]}"
logger.info("Executed function %s (%s): %s", name, args, res_log)

# Catch wrong arguments
except TypeError as e:
logger.warning(
"Wrong arguments for function %s: %s. Error: %s",
name,
args,
e,
)
res = "Wrong arguments, please fix them and try again."
res_log = res

# Catch execution errors
except Exception as e:
logger.exception(
"Error executing function %s with args %s",
tool.function_name,
args,
)
res = f"Error: {e}."
res_log = res

# Update tool
tool.content = res
# Enrich span
span_attribute(SpanAttributes.TOOL_RESULT, tool.content)

@cache
def _available_functions(self) -> list[str]:
"""
List all available functions of the plugin, including the inherited ones.
"""
return [name for name, _ in getmembers(self.__class__, isfunction)]


async def _function_schema(
f: Callable[..., Any], **kwargs: Any
Expand Down
14 changes: 12 additions & 2 deletions app/helpers/monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,24 @@ def span_attribute(key: str, value: AttributeValue) -> None:
span.set_attribute(key, value)


class CallAttributes:
class SpanAttributes:
"""
OpenTelemetry attributes for a call.
OpenTelemetry attributes.
These attributes are used to track a call in the logs and metrics.
"""

CALL_CHANNEL = "call.channel"
"""Message channel (e.g. sms, ivr, ...)."""
CALL_ID = "call.id"
"""Technical call identifier."""
CALL_MESSAGE = "call.message"
"""Message content as a string."""
CALL_PHONE_NUMBER = "call.phone_number"
"""Phone number of the caller."""
TOOL_ARGS = "tool.args"
"""Tool arguments being used."""
TOOL_NAME = "tool.name"
"""Tool name being used."""
TOOL_RESULT = "tool.result"
"""Tool result."""
30 changes: 15 additions & 15 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
from app.helpers.config import CONFIG
from app.helpers.http import azure_transport
from app.helpers.logging import logger
from app.helpers.monitoring import CallAttributes, span_attribute, tracer
from app.helpers.monitoring import SpanAttributes, span_attribute, tracer
from app.helpers.pydantic_types.phone_numbers import PhoneNumber
from app.helpers.resources import resources_dir
from app.models.call import CallGetModel, CallInitiateModel, CallStateModel
Expand Down Expand Up @@ -386,8 +386,8 @@ async def call_post(request: Request) -> CallGetModel:
)

# Enrich span
span_attribute(CallAttributes.CALL_ID, str(call.call_id))
span_attribute(CallAttributes.CALL_PHONE_NUMBER, call.initiate.phone_number)
span_attribute(SpanAttributes.CALL_ID, str(call.call_id))
span_attribute(SpanAttributes.CALL_PHONE_NUMBER, call.initiate.phone_number)

# Init SDK
automation_client = await _use_automation_client()
Expand Down Expand Up @@ -441,8 +441,8 @@ async def call_event(
callback_url, wss_url, _call = await _communicationservices_urls(phone_number)

# Enrich span
span_attribute(CallAttributes.CALL_ID, str(_call.call_id))
span_attribute(CallAttributes.CALL_PHONE_NUMBER, _call.initiate.phone_number)
span_attribute(SpanAttributes.CALL_ID, str(_call.call_id))
span_attribute(SpanAttributes.CALL_PHONE_NUMBER, _call.initiate.phone_number)

# Execute business logic
await on_new_call(
Expand Down Expand Up @@ -479,7 +479,7 @@ async def sms_event(
phone_number: str = event.data["from"]

# Enrich span
span_attribute(CallAttributes.CALL_PHONE_NUMBER, phone_number)
span_attribute(SpanAttributes.CALL_PHONE_NUMBER, phone_number)

# Get call
call = await _db.call_search_one(phone_number)
Expand All @@ -488,7 +488,7 @@ async def sms_event(
return

# Enrich span
span_attribute(CallAttributes.CALL_ID, str(call.call_id))
span_attribute(SpanAttributes.CALL_ID, str(call.call_id))

# Execute business logic
async with Scheduler() as scheduler:
Expand Down Expand Up @@ -536,7 +536,7 @@ async def _communicationservices_validate_call_id(
secret: str,
) -> CallStateModel:
# Enrich span
span_attribute(CallAttributes.CALL_ID, str(call_id))
span_attribute(SpanAttributes.CALL_ID, str(call_id))

# Validate call
call = await _db.call_get(call_id)
Expand All @@ -554,7 +554,7 @@ async def _communicationservices_validate_call_id(
)

# Enrich span
span_attribute(CallAttributes.CALL_PHONE_NUMBER, call.initiate.phone_number)
span_attribute(SpanAttributes.CALL_PHONE_NUMBER, call.initiate.phone_number)

return call

Expand Down Expand Up @@ -850,8 +850,8 @@ async def training_event(
call = CallStateModel.model_validate_json(training.content)

# Enrich span
span_attribute(CallAttributes.CALL_ID, str(call.call_id))
span_attribute(CallAttributes.CALL_PHONE_NUMBER, call.initiate.phone_number)
span_attribute(SpanAttributes.CALL_ID, str(call.call_id))
span_attribute(SpanAttributes.CALL_PHONE_NUMBER, call.initiate.phone_number)

logger.debug("Training event received")

Expand All @@ -875,8 +875,8 @@ async def post_event(
return

# Enrich span
span_attribute(CallAttributes.CALL_ID, str(call.call_id))
span_attribute(CallAttributes.CALL_PHONE_NUMBER, call.initiate.phone_number)
span_attribute(SpanAttributes.CALL_ID, str(call.call_id))
span_attribute(SpanAttributes.CALL_PHONE_NUMBER, call.initiate.phone_number)

logger.debug("Post event received")

Expand Down Expand Up @@ -957,7 +957,7 @@ async def twilio_sms_post(
Returns a 200 OK if the SMS is properly formatted. Otherwise, returns a 400 Bad Request.
"""
# Enrich span
span_attribute(CallAttributes.CALL_PHONE_NUMBER, From)
span_attribute(SpanAttributes.CALL_PHONE_NUMBER, From)

# Get call
call = await _db.call_search_one(From)
Expand All @@ -966,7 +966,7 @@ async def twilio_sms_post(
logger.warning("Call for phone number %s not found", From)
else:
# Enrich span
span_attribute(CallAttributes.CALL_ID, str(call.call_id))
span_attribute(SpanAttributes.CALL_ID, str(call.call_id))

async with Scheduler() as scheduler:
# Execute business logic
Expand Down
Loading

0 comments on commit 5de9a25

Please sign in to comment.