diff --git a/autochat/__init__.py b/autochat/__init__.py index 22ad69f..4d53166 100644 --- a/autochat/__init__.py +++ b/autochat/__init__.py @@ -11,8 +11,10 @@ wait_random_exponential, ) -from autochat.model import Message +from autochat.model import Message, Image from autochat.utils import csv_dumps, inspect_schema, parse_chat_template +from PIL import Image as PILImage +import io AUTOCHAT_HOST = os.getenv("AUTOCHAT_HOST") AUTOCHAT_MODEL = os.getenv("AUTOCHAT_MODEL") @@ -187,6 +189,8 @@ def run_conversation( function_name = response.function_call["name"] function_arguments = response.function_call["arguments"] + image = None + content = None try: try: content = self.functions[function_name]( @@ -206,6 +210,12 @@ def run_conversation( yield response + if isinstance(content, Message): + # We support if the function returns a Message class + message = content + yield message + continue + if content is None: # If function call returns None, we continue the conversation without adding a message # message = None @@ -234,6 +244,16 @@ def run_conversation( content[:OUTPUT_SIZE_LIMIT] + f"\n... ({len(content)} characters)" ) + # Support bytes + # If it's an image; resize it + elif isinstance(content, bytes): + # Detect if it's an image + try: + image = PILImage.open(io.BytesIO(content)) + content = None + except IOError: + # If it's not an image, return the original content + raise ValueError("Not an image") else: raise ValueError(f"Invalid content type: {type(content)}") @@ -242,6 +262,7 @@ def run_conversation( role="function", content=content, function_call_id=response.function_call_id, + image=image, ) yield message diff --git a/autochat/model.py b/autochat/model.py index e90b398..2c44b10 100644 --- a/autochat/model.py +++ b/autochat/model.py @@ -17,16 +17,24 @@ def __str__(self): class Image: def __init__(self, image: PILImage.Image): + if not isinstance(image, PILImage.Image): + raise TypeError("image must be an instance of PIL.Image.Image") self.image = image def resize(self, size: tuple[int, int]): - self.image = self.image.resize(size) + try: + self.image = self.image.resize(size) + except Exception as e: + raise ValueError(f"Failed to resize image: {e}") def to_base64(self): - buffered = BytesIO() - self.image.save(buffered, format="PNG") - img_str = base64.b64encode(buffered.getvalue()).decode() - return img_str + try: + buffered = BytesIO() + self.image.save(buffered, format=self.image.format) + img_str = base64.b64encode(buffered.getvalue()).decode() + return img_str + except Exception as e: + raise ValueError(f"Failed to convert image to base64: {e}") @property def format(self): @@ -42,7 +50,7 @@ def __init__( function_call: typing.Optional[dict] = None, id: typing.Optional[int] = None, function_call_id: typing.Optional[str] = None, - image: typing.Optional[Image] = None, + image: typing.Optional[PILImage.Image] = None, ) -> None: self.role = role self.content = content @@ -50,7 +58,7 @@ def __init__( self.function_call = function_call self.id = id self.function_call_id = function_call_id - self.image = image + self.image = Image(image) if image else None def to_openai_dict(self) -> dict: res = { @@ -96,15 +104,28 @@ def to_anthropic_dict(self) -> dict: "role": self.role if self.role in ["user", "assistant"] else "user", "content": [], } + if self.role == "function": + if self.image: + content = [ + { + "type": "image", + "source": { + "type": "base64", + "media_type": self.image.format, + "data": self.image.to_base64(), + }, + } + ] + else: + content = self.content - if self.role == "function": # result of a function call - res["content"] = [ + res["content"].append( { "type": "tool_result", "tool_use_id": self.function_call_id, - "content": self.content, + "content": content, } - ] + ) return res if self.content: diff --git a/examples/demo_image.py b/examples/demo_image.py index 86047cd..c336848 100644 --- a/examples/demo_image.py +++ b/examples/demo_image.py @@ -1,11 +1,22 @@ -from PIL import Image as PILImage +from PIL import Image import sys +import argparse sys.path.append("..") -from autochat import Autochat, Message, Image +from autochat import Autochat, Message -ai = Autochat(provider="openai") +parser = argparse.ArgumentParser(description="Describe an image using AI") +parser.add_argument( + "--provider", + type=str, + default="anthropic", + help="AI provider (e.g., 'anthropic', 'openai')", +) +args = parser.parse_args() -image = Image(PILImage.open("./image.jpg")) -response = ai.ask(Message(role="user", content="describe the image", image=image)) +ai = Autochat(provider=args.provider) + +image = Image.open("./image.jpg") +message = Message(role="user", content="describe the image", image=image) +response = ai.ask(message) print(response) diff --git a/setup.py b/setup.py index 207a8d2..b1134aa 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name="autochat", - version="0.3.3", + version="0.3.4", packages=find_packages(), install_requires=["tenacity==8.3.0", "pillow==10.4.0"], extras_require={ diff --git a/tests/test_conversation_flow.py b/tests/test_conversation_flow.py new file mode 100644 index 0000000..4a52d2c --- /dev/null +++ b/tests/test_conversation_flow.py @@ -0,0 +1,108 @@ +import pytest +from unittest.mock import patch, MagicMock +from autochat import Autochat, Message, APIProvider, Image +from PIL import Image as PILImage +import io + + +@pytest.fixture +def autochat(): + return Autochat(provider=APIProvider.OPENAI, model="gpt-4-turbo") + + +def test_conversation_flow(autochat): + # Mock the OpenAI API call + with patch("openai.OpenAI") as mock_openai: + mock_client = MagicMock() + mock_openai.return_value = mock_client + + # Mock the chat completion response + mock_response = MagicMock() + mock_response.choices[0].message.role = "assistant" + mock_response.choices[0].message.content = None + mock_response.choices[0].message.function_call = { + "name": "test_function", + "arguments": '{"arg1": "value1"}', + } + mock_response.id = "response_id" + mock_client.chat.completions.create.return_value = mock_response + + # Define a test function + def test_function(arg1: str) -> str: + """test""" + return "Function result" + + autochat.add_function(test_function) + + # Start the conversation + conversation = autochat.run_conversation("Hello") + + # Step 1: User message + user_message = next(conversation) + assert user_message.role == "user" + assert user_message.content == "Hello" + + # # Step 2: Assistant function call + assistant_message = next(conversation) + assert assistant_message.role == "assistant" + assert assistant_message.function_call["name"] == "test_function" + + # # Step 3: Function result + # function_result = next(conversation) + # assert function_result.role == "function" + # assert function_result.content == "Function result" + + # # Mock the second API call (assistant's response to function result) + # mock_response.choices[0].message.content = "Final response" + # mock_response.choices[0].message.function_call = None + + # # Step 4: Assistant final message + # final_message = next(conversation) + # assert final_message.role == "assistant" + # assert final_message.content == "Final response" + + +# def test_conversation_flow_with_image(autochat): +# with patch("openai.OpenAI") as mock_openai: +# mock_client = MagicMock() +# mock_openai.return_value = mock_client + +# mock_response = MagicMock() +# mock_response.choices[0].message.role = "assistant" +# mock_response.choices[0].message.content = None +# mock_response.choices[0].message.function_call = { +# "name": "image_function", +# "arguments": "{}", +# } +# mock_response.id = "response_id" +# mock_client.chat.completions.create.return_value = mock_response + +# def image_function(): +# img = PILImage.new("RGB", (100, 100), color="red") +# img_byte_arr = io.BytesIO() +# img.save(img_byte_arr, format="PNG") +# return img_byte_arr.getvalue() + +# autochat.add_function(image_function) + +# conversation = autochat.run_conversation("Generate an image") + +# user_message = next(conversation) +# assert user_message.role == "user" +# assert user_message.content == "Generate an image" + +# assistant_message = next(conversation) +# assert assistant_message.role == "assistant" +# assert assistant_message.function_call["name"] == "image_function" + +# function_result = next(conversation) +# assert function_result.role == "function" +# assert isinstance(function_result.image, Image) +# assert function_result.content is None + +# mock_response.choices[0].message.content = "Image generated successfully" +# mock_response.choices[0].message.function_call = None + +# final_message = next(conversation) +# assert final_message.role == "assistant" +# assert final_message.content == "Image generated successfully" diff --git a/tests/test_snap.py b/tests/test_snap.py new file mode 100644 index 0000000..ae19ecd --- /dev/null +++ b/tests/test_snap.py @@ -0,0 +1,27 @@ +from unittest.mock import patch +import pytest +from autochat import Autochat, APIProvider + + +class TestAutochat: + @pytest.mark.snapshot + def test_fetch_openai(self, snapshot): + with patch("openai.OpenAI") as mock_openai: + # Setup the mock response + mock_response = ( + mock_openai.return_value.chat.completions.create.return_value + ) + mock_response.choices[0].message.role = "assistant" + mock_response.choices[0].message.content = "Mocked response content" + mock_response.choices[0].message.function_call = None + mock_response.id = "mocked_response_id" + + # Create an instance of Autochat + autochat = Autochat(provider=APIProvider.OPENAI) + + # Call the method + result = autochat.ask("Hello, how are you?") + + print(result) + # Assert that the result matches the snapshot + snapshot.assert_match(result.to_openai_dict())