Skip to content

Commit

Permalink
Merge pull request #2 from BenderV/callback_image
Browse files Browse the repository at this point in the history
Callback image
  • Loading branch information
BenderV authored Sep 18, 2024
2 parents bc02151 + f343d07 commit 9a92e86
Show file tree
Hide file tree
Showing 6 changed files with 206 additions and 18 deletions.
23 changes: 22 additions & 1 deletion autochat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
wait_random_exponential,
)

from autochat.model import Message
from autochat.model import Message, Image
from autochat.utils import csv_dumps, inspect_schema, parse_chat_template
from PIL import Image as PILImage
import io

AUTOCHAT_HOST = os.getenv("AUTOCHAT_HOST")
AUTOCHAT_MODEL = os.getenv("AUTOCHAT_MODEL")
Expand Down Expand Up @@ -187,6 +189,8 @@ def run_conversation(
function_name = response.function_call["name"]
function_arguments = response.function_call["arguments"]

image = None
content = None
try:
try:
content = self.functions[function_name](
Expand All @@ -206,6 +210,12 @@ def run_conversation(

yield response

if isinstance(content, Message):
# We support if the function returns a Message class
message = content
yield message
continue

if content is None:
# If function call returns None, we continue the conversation without adding a message
# message = None
Expand Down Expand Up @@ -234,6 +244,16 @@ def run_conversation(
content[:OUTPUT_SIZE_LIMIT]
+ f"\n... ({len(content)} characters)"
)
# Support bytes
# If it's an image; resize it
elif isinstance(content, bytes):
# Detect if it's an image
try:
image = PILImage.open(io.BytesIO(content))
content = None
except IOError:
# If it's not an image, return the original content
raise ValueError("Not an image")
else:
raise ValueError(f"Invalid content type: {type(content)}")

Expand All @@ -242,6 +262,7 @@ def run_conversation(
role="function",
content=content,
function_call_id=response.function_call_id,
image=image,
)
yield message

Expand Down
43 changes: 32 additions & 11 deletions autochat/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,24 @@ def __str__(self):

class Image:
def __init__(self, image: PILImage.Image):
if not isinstance(image, PILImage.Image):
raise TypeError("image must be an instance of PIL.Image.Image")
self.image = image

def resize(self, size: tuple[int, int]):
self.image = self.image.resize(size)
try:
self.image = self.image.resize(size)
except Exception as e:
raise ValueError(f"Failed to resize image: {e}")

def to_base64(self):
buffered = BytesIO()
self.image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
return img_str
try:
buffered = BytesIO()
self.image.save(buffered, format=self.image.format)
img_str = base64.b64encode(buffered.getvalue()).decode()
return img_str
except Exception as e:
raise ValueError(f"Failed to convert image to base64: {e}")

@property
def format(self):
Expand All @@ -42,15 +50,15 @@ def __init__(
function_call: typing.Optional[dict] = None,
id: typing.Optional[int] = None,
function_call_id: typing.Optional[str] = None,
image: typing.Optional[Image] = None,
image: typing.Optional[PILImage.Image] = None,
) -> None:
self.role = role
self.content = content
self.name = name
self.function_call = function_call
self.id = id
self.function_call_id = function_call_id
self.image = image
self.image = Image(image) if image else None

def to_openai_dict(self) -> dict:
res = {
Expand Down Expand Up @@ -96,15 +104,28 @@ def to_anthropic_dict(self) -> dict:
"role": self.role if self.role in ["user", "assistant"] else "user",
"content": [],
}
if self.role == "function":
if self.image:
content = [
{
"type": "image",
"source": {
"type": "base64",
"media_type": self.image.format,
"data": self.image.to_base64(),
},
}
]
else:
content = self.content

if self.role == "function": # result of a function call
res["content"] = [
res["content"].append(
{
"type": "tool_result",
"tool_use_id": self.function_call_id,
"content": self.content,
"content": content,
}
]
)
return res

if self.content:
Expand Down
21 changes: 16 additions & 5 deletions examples/demo_image.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
from PIL import Image as PILImage
from PIL import Image
import sys
import argparse

sys.path.append("..")
from autochat import Autochat, Message, Image
from autochat import Autochat, Message

ai = Autochat(provider="openai")
parser = argparse.ArgumentParser(description="Describe an image using AI")
parser.add_argument(
"--provider",
type=str,
default="anthropic",
help="AI provider (e.g., 'anthropic', 'openai')",
)
args = parser.parse_args()

image = Image(PILImage.open("./image.jpg"))
response = ai.ask(Message(role="user", content="describe the image", image=image))
ai = Autochat(provider=args.provider)

image = Image.open("./image.jpg")
message = Message(role="user", content="describe the image", image=image)
response = ai.ask(message)
print(response)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name="autochat",
version="0.3.3",
version="0.3.4",
packages=find_packages(),
install_requires=["tenacity==8.3.0", "pillow==10.4.0"],
extras_require={
Expand Down
108 changes: 108 additions & 0 deletions tests/test_conversation_flow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import pytest
from unittest.mock import patch, MagicMock
from autochat import Autochat, Message, APIProvider, Image
from PIL import Image as PILImage
import io


@pytest.fixture
def autochat():
return Autochat(provider=APIProvider.OPENAI, model="gpt-4-turbo")


def test_conversation_flow(autochat):
# Mock the OpenAI API call
with patch("openai.OpenAI") as mock_openai:
mock_client = MagicMock()
mock_openai.return_value = mock_client

# Mock the chat completion response
mock_response = MagicMock()
mock_response.choices[0].message.role = "assistant"
mock_response.choices[0].message.content = None
mock_response.choices[0].message.function_call = {
"name": "test_function",
"arguments": '{"arg1": "value1"}',
}
mock_response.id = "response_id"
mock_client.chat.completions.create.return_value = mock_response

# Define a test function
def test_function(arg1: str) -> str:
"""test"""
return "Function result"

autochat.add_function(test_function)

# Start the conversation
conversation = autochat.run_conversation("Hello")

# Step 1: User message
user_message = next(conversation)
assert user_message.role == "user"
assert user_message.content == "Hello"

# # Step 2: Assistant function call
assistant_message = next(conversation)
assert assistant_message.role == "assistant"
assert assistant_message.function_call["name"] == "test_function"

# # Step 3: Function result
# function_result = next(conversation)
# assert function_result.role == "function"
# assert function_result.content == "Function result"

# # Mock the second API call (assistant's response to function result)
# mock_response.choices[0].message.content = "Final response"
# mock_response.choices[0].message.function_call = None

# # Step 4: Assistant final message
# final_message = next(conversation)
# assert final_message.role == "assistant"
# assert final_message.content == "Final response"


# def test_conversation_flow_with_image(autochat):
# with patch("openai.OpenAI") as mock_openai:
# mock_client = MagicMock()
# mock_openai.return_value = mock_client

# mock_response = MagicMock()
# mock_response.choices[0].message.role = "assistant"
# mock_response.choices[0].message.content = None
# mock_response.choices[0].message.function_call = {
# "name": "image_function",
# "arguments": "{}",
# }
# mock_response.id = "response_id"
# mock_client.chat.completions.create.return_value = mock_response

# def image_function():
# img = PILImage.new("RGB", (100, 100), color="red")
# img_byte_arr = io.BytesIO()
# img.save(img_byte_arr, format="PNG")
# return img_byte_arr.getvalue()

# autochat.add_function(image_function)

# conversation = autochat.run_conversation("Generate an image")

# user_message = next(conversation)
# assert user_message.role == "user"
# assert user_message.content == "Generate an image"

# assistant_message = next(conversation)
# assert assistant_message.role == "assistant"
# assert assistant_message.function_call["name"] == "image_function"

# function_result = next(conversation)
# assert function_result.role == "function"
# assert isinstance(function_result.image, Image)
# assert function_result.content is None

# mock_response.choices[0].message.content = "Image generated successfully"
# mock_response.choices[0].message.function_call = None

# final_message = next(conversation)
# assert final_message.role == "assistant"
# assert final_message.content == "Image generated successfully"
27 changes: 27 additions & 0 deletions tests/test_snap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from unittest.mock import patch
import pytest
from autochat import Autochat, APIProvider


class TestAutochat:
@pytest.mark.snapshot
def test_fetch_openai(self, snapshot):
with patch("openai.OpenAI") as mock_openai:
# Setup the mock response
mock_response = (
mock_openai.return_value.chat.completions.create.return_value
)
mock_response.choices[0].message.role = "assistant"
mock_response.choices[0].message.content = "Mocked response content"
mock_response.choices[0].message.function_call = None
mock_response.id = "mocked_response_id"

# Create an instance of Autochat
autochat = Autochat(provider=APIProvider.OPENAI)

# Call the method
result = autochat.ask("Hello, how are you?")

print(result)
# Assert that the result matches the snapshot
snapshot.assert_match(result.to_openai_dict())

0 comments on commit 9a92e86

Please sign in to comment.