Skip to content

Commit

Permalink
breaking: Use AI Factory SDK and LLM tracing
Browse files Browse the repository at this point in the history
  • Loading branch information
clemlesne committed Dec 17, 2024
1 parent 7698308 commit 310dfb8
Show file tree
Hide file tree
Showing 13 changed files with 329 additions and 520 deletions.
31 changes: 0 additions & 31 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -484,37 +484,6 @@ Conversation options are represented as features. They can be configured from Ap
| `vad_silence_timeout_ms` | Silence to trigger voice activity detection in milliseconds. | `int` | 500 |
| `vad_threshold` | The threshold for voice activity detection. Between 0.1 and 1. | `float` | 0.5 |

### Use an OpenAI compatible model for the LLM

To use a model compatible with the OpenAI completion API, you need to create an account and get the following information:

- API key
- Context window size
- Endpoint URL
- Model name
- Streaming capability

Then, add the following in the `config.yaml` file:

```yaml
# config.yaml
llm:
fast:
mode: openai
openai:
context: 128000
endpoint: https://api.openai.com
model: gpt-4o-mini
streaming: true
slow:
mode: openai
openai:
context: 128000
endpoint: https://api.openai.com
model: gpt-4o
streaming: true
```

### Use Twilio for SMS

To use Twilio for SMS, you need to create an account and get the following information:
Expand Down
15 changes: 5 additions & 10 deletions app/helpers/call_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
SpeechSynthesizer,
)
from azure.communication.callautomation.aio import CallAutomationClient
from openai import APIError

from app.helpers.call_utils import (
AECStream,
Expand Down Expand Up @@ -427,7 +426,7 @@ def _clear_tasks() -> None:

# TODO: Refacto, this function is too long
@tracer.start_as_current_span("call_generate_chat_completion")
async def _generate_chat_completion( # noqa: PLR0913, PLR0911, PLR0912, PLR0915
async def _generate_chat_completion( # noqa: PLR0913, PLR0912, PLR0915
call: CallStateModel,
client: CallAutomationClient,
post_callback: Callable[[CallStateModel], Awaitable[None]],
Expand Down Expand Up @@ -495,7 +494,7 @@ async def _content_callback(buffer: str) -> None:
# Execute LLM inference
maximum_tokens_reached = False
content_buffer_pointer = 0
tool_calls_buffer: dict[int, MessageToolModel] = {}
tool_calls_buffer: dict[str, MessageToolModel] = {}
try:
async for delta in completion_stream(
max_tokens=160, # Lowest possible value for 90% of the cases, if not sufficient, retry will be triggered, 100 tokens ~= 75 words, 20 words ~= 1 sentence, 6 sentences ~= 160 tokens
Expand All @@ -505,10 +504,10 @@ async def _content_callback(buffer: str) -> None:
):
if not delta.content:
for piece in delta.tool_calls or []:
tool_calls_buffer[piece.index] = tool_calls_buffer.get(
piece.index, MessageToolModel()
tool_calls_buffer[piece.id] = tool_calls_buffer.get(
piece.id, MessageToolModel()
)
tool_calls_buffer[piece.index] += piece
tool_calls_buffer[piece.id] += piece
else:
# Store whole content
content_full += delta.content
Expand All @@ -522,10 +521,6 @@ async def _content_callback(buffer: str) -> None:
except MaximumTokensReachedError:
logger.warning("Maximum tokens reached for this completion, retry asked")
maximum_tokens_reached = True
# Retry on API error
except APIError as e:
logger.warning("OpenAI API call error: %s", e)
return True, True, call # Error, retry
# Last user message is trash, remove it
except SafetyCheckError as e:
logger.warning("Safety Check error: %s", e)
Expand Down
119 changes: 21 additions & 98 deletions app/helpers/config_models/llm.py
Original file line number Diff line number Diff line change
@@ -1,117 +1,40 @@
from abc import abstractmethod
from enum import Enum
from typing import Any

from openai import AsyncAzureOpenAI, AsyncOpenAI
from pydantic import BaseModel, Field, SecretStr, ValidationInfo, field_validator
from azure.ai.inference.aio import ChatCompletionsClient
from pydantic import BaseModel

from app.helpers.cache import async_lru_cache
from app.helpers.identity import token


class ModeEnum(str, Enum):
AZURE_OPENAI = "azure_openai"
"""Use Azure OpenAI."""
OPENAI = "openai"
"""Use OpenAI."""
from app.helpers.http import azure_transport
from app.helpers.identity import credential


class AbstractPlatformModel(BaseModel, frozen=True):
_client_kwargs: dict[str, Any] = {
# Reliability
"max_retries": 0, # Retries are managed manually
"timeout": 60,
}
class DeploymentModel(BaseModel, frozen=True):
api_version: str = "2024-10-21" # See: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#api-specs
context: int
endpoint: str
model: str
seed: int = 42 # Reproducible results
streaming: bool
temperature: float = 0.0 # Most focused and deterministic

@abstractmethod
async def instance(
self,
) -> tuple[AsyncAzureOpenAI | AsyncOpenAI, "AbstractPlatformModel"]:
pass


class AzureOpenaiPlatformModel(AbstractPlatformModel, frozen=True):
api_version: str = "2024-06-01"
deployment: str

@async_lru_cache()
async def instance(self) -> tuple[AsyncAzureOpenAI, AbstractPlatformModel]:
return AsyncAzureOpenAI(
**self._client_kwargs,
async def instance(self) -> tuple[ChatCompletionsClient, "DeploymentModel"]:
return ChatCompletionsClient(
# Reliability
seed=self.seed,
temperature=self.temperature,
# Deployment
api_version=self.api_version,
azure_deployment=self.deployment,
azure_endpoint=self.endpoint,
endpoint=self.endpoint,
model=self.model,
# Performance
transport=await azure_transport(),
# Authentication
azure_ad_token_provider=await token(
"https://cognitiveservices.azure.com/.default"
),
credential_scopes=["https://cognitiveservices.azure.com/.default"],
credential=await credential(),
), self


class OpenaiPlatformModel(AbstractPlatformModel, frozen=True):
api_key: SecretStr

@async_lru_cache()
async def instance(self) -> tuple[AsyncOpenAI, AbstractPlatformModel]:
return AsyncOpenAI(
**self._client_kwargs,
# API root URL
base_url=self.endpoint,
# Authentication
api_key=self.api_key.get_secret_value(),
), self


class SelectedPlatformModel(BaseModel):
azure_openai: AzureOpenaiPlatformModel | None = None
mode: ModeEnum
openai: OpenaiPlatformModel | None = None

@field_validator("azure_openai")
@classmethod
def _validate_azure_openai(
cls,
azure_openai: AzureOpenaiPlatformModel | None,
info: ValidationInfo,
) -> AzureOpenaiPlatformModel | None:
if not azure_openai and info.data.get("mode", None) == ModeEnum.AZURE_OPENAI:
raise ValueError("Azure OpenAI config required")
return azure_openai

@field_validator("openai")
@classmethod
def _validate_openai(
cls,
openai: OpenaiPlatformModel | None,
info: ValidationInfo,
) -> OpenaiPlatformModel | None:
if not openai and info.data.get("mode", None) == ModeEnum.OPENAI:
raise ValueError("OpenAI config required")
return openai

def selected(self) -> AzureOpenaiPlatformModel | OpenaiPlatformModel:
platform = (
self.azure_openai if self.mode == ModeEnum.AZURE_OPENAI else self.openai
)
assert platform
return platform


class LlmModel(BaseModel):
fast: SelectedPlatformModel = Field(
serialization_alias="backup", # Backwards compatibility with v6
)
slow: SelectedPlatformModel = Field(
serialization_alias="primary", # Backwards compatibility with v6
)
fast: DeploymentModel
slow: DeploymentModel

def selected(self, is_fast: bool) -> AzureOpenaiPlatformModel | OpenaiPlatformModel:
platform = self.fast if is_fast else self.slow
return platform.selected()
def selected(self, is_fast: bool) -> DeploymentModel:
return self.fast if is_fast else self.slow
30 changes: 9 additions & 21 deletions app/helpers/config_models/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from logging import Logger
from textwrap import dedent

from azure.ai.inference.models import SystemMessage
from azure.core.exceptions import HttpResponseError
from openai.types.chat import ChatCompletionSystemMessageParam
from pydantic import BaseModel, TypeAdapter

from app.models.call import CallStateModel
Expand Down Expand Up @@ -329,7 +329,7 @@ def default_system(self, call: CallStateModel) -> str:

def chat_system(
self, call: CallStateModel, trainings: list[TrainingModel]
) -> list[ChatCompletionSystemMessageParam]:
) -> list[SystemMessage]:
from app.models.message import (
ActionEnum as MessageActionEnum,
StyleEnum as MessageStyleEnum,
Expand All @@ -352,9 +352,7 @@ def chat_system(
call=call,
)

def sms_summary_system(
self, call: CallStateModel
) -> list[ChatCompletionSystemMessageParam]:
def sms_summary_system(self, call: CallStateModel) -> list[SystemMessage]:
return self._messages(
self._format(
self.sms_summary_system_tpl,
Expand All @@ -373,9 +371,7 @@ def sms_summary_system(
call=call,
)

def synthesis_system(
self, call: CallStateModel
) -> list[ChatCompletionSystemMessageParam]:
def synthesis_system(self, call: CallStateModel) -> list[SystemMessage]:
return self._messages(
self._format(
self.synthesis_system_tpl,
Expand All @@ -392,9 +388,7 @@ def synthesis_system(
call=call,
)

def citations_system(
self, call: CallStateModel, text: str
) -> list[ChatCompletionSystemMessageParam]:
def citations_system(self, call: CallStateModel, text: str) -> list[SystemMessage]:
"""
Return the formatted prompt. Prompt is used to add citations to the text, without cluttering the content itself.
Expand All @@ -412,9 +406,7 @@ def citations_system(
call=call,
)

def next_system(
self, call: CallStateModel
) -> list[ChatCompletionSystemMessageParam]:
def next_system(self, call: CallStateModel) -> list[SystemMessage]:
return self._messages(
self._format(
self.next_system_tpl,
Expand Down Expand Up @@ -461,17 +453,13 @@ def _format(
# self.logger.debug("Formatted prompt: %s", formatted_prompt)
return formatted_prompt

def _messages(
self, system: str, call: CallStateModel
) -> list[ChatCompletionSystemMessageParam]:
def _messages(self, system: str, call: CallStateModel) -> list[SystemMessage]:
messages = [
ChatCompletionSystemMessageParam(
SystemMessage(
content=self.default_system(call),
role="system",
),
ChatCompletionSystemMessageParam(
SystemMessage(
content=system,
role="system",
),
]
# self.logger.debug("Messages: %s", messages)
Expand Down
10 changes: 4 additions & 6 deletions app/helpers/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,13 @@
from typing import Annotated, Any, ForwardRef, TypeVar

from aiojobs import Scheduler
from azure.ai.inference.models import ChatCompletionsToolDefinition, FunctionDefinition
from azure.cognitiveservices.speech import (
SpeechSynthesizer,
)
from azure.communication.callautomation.aio import CallAutomationClient
from jinja2 import Environment
from json_repair import repair_json
from openai.types.chat import ChatCompletionToolParam
from openai.types.shared_params.function_definition import FunctionDefinition
from pydantic import BaseModel, TypeAdapter
from pydantic._internal._typing_extra import eval_type_lenient
from pydantic.json_schema import JsonSchemaValue
Expand Down Expand Up @@ -77,7 +76,7 @@ def __init__( # noqa: PLR0913
async def to_openai(
self,
blacklist: frozenset[str],
) -> list[ChatCompletionToolParam]:
) -> list[ChatCompletionsToolDefinition]:
"""
Get the OpenAI SDK schema for all functions of the plugin, excluding the ones in the blacklist.
"""
Expand Down Expand Up @@ -257,7 +256,7 @@ async def wrapper(
async def _function_schema(
f: Callable[..., Any],
**kwargs: Any,
) -> ChatCompletionToolParam:
) -> ChatCompletionsToolDefinition:
"""
Take a function and return a JSON schema for it as defined by the OpenAI API.
Expand Down Expand Up @@ -303,8 +302,7 @@ async def _function_schema(
)
).model_dump()

return ChatCompletionToolParam(
type="function",
return ChatCompletionsToolDefinition(
function=FunctionDefinition(
description=description,
name=name,
Expand Down
Loading

0 comments on commit 310dfb8

Please sign in to comment.