diff --git a/README.md b/README.md index 9472ba8..352273b 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,7 @@ class EchoApplication(ChatCompletion): # Generate response with a single choice with response.create_single_choice() as choice: # Fill the content of the response with the last user's content - choice.append_content(last_user_message.content or "") + choice.append_content(last_user_message.text()) # DIALApp extends FastAPI to provide a user-friendly interface for routing requests to your applications diff --git a/aidial_sdk/chat_completion/__init__.py b/aidial_sdk/chat_completion/__init__.py index 17c756d..db42add 100644 --- a/aidial_sdk/chat_completion/__init__.py +++ b/aidial_sdk/chat_completion/__init__.py @@ -9,6 +9,9 @@ FunctionCall, FunctionChoice, Message, + MessageContentImagePart, + MessageContentPart, + MessageContentTextPart, Request, ResponseFormat, Role, diff --git a/aidial_sdk/chat_completion/request.py b/aidial_sdk/chat_completion/request.py index 847cfe4..a7c161f 100644 --- a/aidial_sdk/chat_completion/request.py +++ b/aidial_sdk/chat_completion/request.py @@ -1,8 +1,11 @@ from enum import Enum from typing import Any, Dict, List, Literal, Mapping, Optional, Union +from typing_extensions import assert_never + from aidial_sdk.chat_completion.enums import Status from aidial_sdk.deployment.from_request_mixin import FromRequestDeploymentMixin +from aidial_sdk.exceptions import InvalidRequestError from aidial_sdk.pydantic_v1 import ( ConstrainedFloat, ConstrainedInt, @@ -58,15 +61,51 @@ class Role(str, Enum): TOOL = "tool" +class ImageURL(ExtraForbidModel): + url: StrictStr + detail: Optional[Literal["auto", "low", "high"]] = None + + +class MessageContentImagePart(ExtraForbidModel): + type: Literal["image_url"] + image_url: ImageURL + + +class MessageContentTextPart(ExtraForbidModel): + type: Literal["text"] + text: StrictStr + + +MessageContentPart = Union[MessageContentTextPart, MessageContentImagePart] + + class Message(ExtraForbidModel): role: Role - content: Optional[StrictStr] = None + content: Optional[Union[StrictStr, List[MessageContentPart]]] = None custom_content: Optional[CustomContent] = None name: Optional[StrictStr] = None tool_calls: Optional[List[ToolCall]] = None tool_call_id: Optional[StrictStr] = None function_call: Optional[FunctionCall] = None + def text(self) -> str: + """ + Returns content of the message only if it's present as a string. + Otherwise, throws an invalid request exception. + """ + + def _error_message(actual: str) -> str: + return f"Unable to retrieve text content of the message: the actual content is {actual}." + + if self.content is None: + raise InvalidRequestError(_error_message("null or missing")) + elif isinstance(self.content, str): + return self.content + elif isinstance(self.content, list): + raise InvalidRequestError(_error_message("a list of content parts")) + else: + assert_never(self.content) + class Addon(ExtraForbidModel): name: Optional[StrictStr] = None diff --git a/examples/echo/app.py b/examples/echo/app.py index c0f8478..d1ba905 100644 --- a/examples/echo/app.py +++ b/examples/echo/app.py @@ -15,17 +15,15 @@ async def chat_completion( self, request: Request, response: Response ) -> None: # Get last message (the newest) from the history - last_user_message = request.messages[-1] + last_message = request.messages[-1] # Generate response with a single choice with response.create_single_choice() as choice: # Fill the content of the response with the last user's content - choice.append_content(last_user_message.content or "") + choice.append_content(last_message.text()) - if last_user_message.custom_content is not None: - for attachment in ( - last_user_message.custom_content.attachments or [] - ): + if last_message.custom_content is not None: + for attachment in last_message.custom_content.attachments or []: # Add the same attachment to the response choice.add_attachment(**attachment.dict()) diff --git a/examples/langchain_rag/app.py b/examples/langchain_rag/app.py index ee2a979..3666eb1 100644 --- a/examples/langchain_rag/app.py +++ b/examples/langchain_rag/app.py @@ -61,7 +61,7 @@ async def chat_completion( with response.create_single_choice() as choice: message = request.messages[-1] - user_query = message.content or "" + user_query = message.text() file_url = get_last_attachment_url(request.messages) file_abs_url = urljoin(f"{DIAL_URL}/v1/", file_url) diff --git a/examples/render_text/app/main.py b/examples/render_text/app/main.py index 780a7cf..16386e6 100644 --- a/examples/render_text/app/main.py +++ b/examples/render_text/app/main.py @@ -23,7 +23,7 @@ async def chat_completion(self, request: Request, response: Response): # Create a single choice with response.create_single_choice() as choice: # Get the last message content - content = request.messages[-1].content or "" + content = request.messages[-1].text() # The image may be returned either as base64 string or as URL # The content specifies the mode of return: 'base64' or 'url' diff --git a/tests/applications/broken_immediately.py b/tests/applications/broken_immediately.py index 12f384f..2ea42fa 100644 --- a/tests/applications/broken_immediately.py +++ b/tests/applications/broken_immediately.py @@ -27,4 +27,4 @@ class BrokenApplication(ChatCompletion): async def chat_completion( self, request: Request, response: Response ) -> None: - raise_exception(request.messages[0].content or "") + raise_exception(request.messages[0].text()) diff --git a/tests/applications/broken_in_runtime.py b/tests/applications/broken_in_runtime.py index 999cb3b..291217e 100644 --- a/tests/applications/broken_in_runtime.py +++ b/tests/applications/broken_in_runtime.py @@ -17,4 +17,4 @@ async def chat_completion( choice.append_content("Test content") await response.aflush() - raise_exception(request.messages[0].content or "") + raise_exception(request.messages[0].text()) diff --git a/tests/applications/echo.py b/tests/applications/echo.py index 44bed79..51ad96e 100644 --- a/tests/applications/echo.py +++ b/tests/applications/echo.py @@ -27,7 +27,7 @@ async def chat_completion( response.set_response_id("test_id") response.set_created(0) - content = request.messages[-1].content or "" + content = request.messages[-1].text() with response.create_single_choice() as choice: choice.append_content(content) diff --git a/tests/test_errors.py b/tests/test_errors.py index a0caa94..37468cb 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -51,6 +51,28 @@ } }, ), + ( + None, + 400, + { + "error": { + "message": "Unable to retrieve text content of the message: the actual content is null or missing.", + "type": "invalid_request_error", + "code": "400", + } + }, + ), + ( + [{"type": "text", "text": "hello"}], + 400, + { + "error": { + "message": "Unable to retrieve text content of the message: the actual content is a list of content parts.", + "type": "invalid_request_error", + "code": "400", + } + }, + ), ] @@ -72,10 +94,8 @@ def test_error(type, response_status_code, response_content): headers={"Api-Key": "TEST_API_KEY"}, ) - assert ( - response.status_code == response_status_code - and response.json() == response_content - ) + assert response.status_code == response_status_code + assert response.json() == response_content @pytest.mark.parametrize( @@ -96,10 +116,8 @@ def test_streaming_error(type, response_status_code, response_content): headers={"Api-Key": "TEST_API_KEY"}, ) - assert ( - response.status_code == response_status_code - and response.json() == response_content - ) + assert response.status_code == response_status_code + assert response.json() == response_content @pytest.mark.parametrize( @@ -184,4 +202,5 @@ def test_no_api_key(): }, ) - assert response.status_code == 400 and response.json() == API_KEY_IS_MISSING + assert response.status_code == 400 + assert response.json() == API_KEY_IS_MISSING diff --git a/tests/utils/tokenization.py b/tests/utils/tokenization.py index 1428356..a37dfeb 100644 --- a/tests/utils/tokenization.py +++ b/tests/utils/tokenization.py @@ -26,7 +26,7 @@ def word_count_string(string: str) -> int: def word_count_message(message: Message) -> int: - return word_count_string(message.content or "") + return word_count_string(message.text()) def word_count_request(request: ChatCompletionRequest) -> int: