Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Callback image #2

Merged
merged 2 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion autochat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
wait_random_exponential,
)

from autochat.model import Message
from autochat.model import Message, Image
from autochat.utils import csv_dumps, inspect_schema, parse_chat_template
from PIL import Image as PILImage
import io

AUTOCHAT_HOST = os.getenv("AUTOCHAT_HOST")
AUTOCHAT_MODEL = os.getenv("AUTOCHAT_MODEL")
Expand Down Expand Up @@ -187,6 +189,8 @@ def run_conversation(
function_name = response.function_call["name"]
function_arguments = response.function_call["arguments"]

image = None
content = None
try:
try:
content = self.functions[function_name](
Expand All @@ -206,6 +210,12 @@ def run_conversation(

yield response

if isinstance(content, Message):
# We support if the function returns a Message class
message = content
yield message
continue

if content is None:
# If function call returns None, we continue the conversation without adding a message
# message = None
Expand Down Expand Up @@ -234,6 +244,16 @@ def run_conversation(
content[:OUTPUT_SIZE_LIMIT]
+ f"\n... ({len(content)} characters)"
)
# Support bytes
# If it's an image; resize it
elif isinstance(content, bytes):
# Detect if it's an image
try:
image = PILImage.open(io.BytesIO(content))
content = None
except IOError:
# If it's not an image, return the original content
raise ValueError("Not an image")
else:
raise ValueError(f"Invalid content type: {type(content)}")

Expand All @@ -242,6 +262,7 @@ def run_conversation(
role="function",
content=content,
function_call_id=response.function_call_id,
image=image,
)
yield message

Expand Down
43 changes: 32 additions & 11 deletions autochat/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,24 @@ def __str__(self):

class Image:
def __init__(self, image: PILImage.Image):
if not isinstance(image, PILImage.Image):
raise TypeError("image must be an instance of PIL.Image.Image")
self.image = image

def resize(self, size: tuple[int, int]):
self.image = self.image.resize(size)
try:
self.image = self.image.resize(size)
except Exception as e:
raise ValueError(f"Failed to resize image: {e}")

def to_base64(self):
buffered = BytesIO()
self.image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
return img_str
try:
buffered = BytesIO()
self.image.save(buffered, format=self.image.format)
img_str = base64.b64encode(buffered.getvalue()).decode()
return img_str
except Exception as e:
raise ValueError(f"Failed to convert image to base64: {e}")

@property
def format(self):
Expand All @@ -42,15 +50,15 @@ def __init__(
function_call: typing.Optional[dict] = None,
id: typing.Optional[int] = None,
function_call_id: typing.Optional[str] = None,
image: typing.Optional[Image] = None,
image: typing.Optional[PILImage.Image] = None,
) -> None:
self.role = role
self.content = content
self.name = name
self.function_call = function_call
self.id = id
self.function_call_id = function_call_id
self.image = image
self.image = Image(image) if image else None

def to_openai_dict(self) -> dict:
res = {
Expand Down Expand Up @@ -96,15 +104,28 @@ def to_anthropic_dict(self) -> dict:
"role": self.role if self.role in ["user", "assistant"] else "user",
"content": [],
}
if self.role == "function":
if self.image:
content = [
{
"type": "image",
"source": {
"type": "base64",
"media_type": self.image.format,
"data": self.image.to_base64(),
},
}
]
else:
content = self.content

if self.role == "function": # result of a function call
res["content"] = [
res["content"].append(
{
"type": "tool_result",
"tool_use_id": self.function_call_id,
"content": self.content,
"content": content,
}
]
)
return res

if self.content:
Expand Down
21 changes: 16 additions & 5 deletions examples/demo_image.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
from PIL import Image as PILImage
from PIL import Image
import sys
import argparse

sys.path.append("..")
from autochat import Autochat, Message, Image
from autochat import Autochat, Message

ai = Autochat(provider="openai")
parser = argparse.ArgumentParser(description="Describe an image using AI")
parser.add_argument(
"--provider",
type=str,
default="anthropic",
help="AI provider (e.g., 'anthropic', 'openai')",
)
args = parser.parse_args()

image = Image(PILImage.open("./image.jpg"))
response = ai.ask(Message(role="user", content="describe the image", image=image))
ai = Autochat(provider=args.provider)

image = Image.open("./image.jpg")
message = Message(role="user", content="describe the image", image=image)
response = ai.ask(message)
print(response)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name="autochat",
version="0.3.3",
version="0.3.4",
packages=find_packages(),
install_requires=["tenacity==8.3.0", "pillow==10.4.0"],
extras_require={
Expand Down
108 changes: 108 additions & 0 deletions tests/test_conversation_flow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import pytest
from unittest.mock import patch, MagicMock
from autochat import Autochat, Message, APIProvider, Image
from PIL import Image as PILImage
import io


@pytest.fixture
def autochat():
return Autochat(provider=APIProvider.OPENAI, model="gpt-4-turbo")


def test_conversation_flow(autochat):
# Mock the OpenAI API call
with patch("openai.OpenAI") as mock_openai:
mock_client = MagicMock()
mock_openai.return_value = mock_client

# Mock the chat completion response
mock_response = MagicMock()
mock_response.choices[0].message.role = "assistant"
mock_response.choices[0].message.content = None
mock_response.choices[0].message.function_call = {
"name": "test_function",
"arguments": '{"arg1": "value1"}',
}
mock_response.id = "response_id"
mock_client.chat.completions.create.return_value = mock_response

# Define a test function
def test_function(arg1: str) -> str:
"""test"""
return "Function result"

autochat.add_function(test_function)

# Start the conversation
conversation = autochat.run_conversation("Hello")

# Step 1: User message
user_message = next(conversation)
assert user_message.role == "user"
assert user_message.content == "Hello"

# # Step 2: Assistant function call
assistant_message = next(conversation)
assert assistant_message.role == "assistant"
assert assistant_message.function_call["name"] == "test_function"

# # Step 3: Function result
# function_result = next(conversation)
# assert function_result.role == "function"
# assert function_result.content == "Function result"

# # Mock the second API call (assistant's response to function result)
# mock_response.choices[0].message.content = "Final response"
# mock_response.choices[0].message.function_call = None

# # Step 4: Assistant final message
# final_message = next(conversation)
# assert final_message.role == "assistant"
# assert final_message.content == "Final response"


# def test_conversation_flow_with_image(autochat):
# with patch("openai.OpenAI") as mock_openai:
# mock_client = MagicMock()
# mock_openai.return_value = mock_client

# mock_response = MagicMock()
# mock_response.choices[0].message.role = "assistant"
# mock_response.choices[0].message.content = None
# mock_response.choices[0].message.function_call = {
# "name": "image_function",
# "arguments": "{}",
# }
# mock_response.id = "response_id"
# mock_client.chat.completions.create.return_value = mock_response

# def image_function():
# img = PILImage.new("RGB", (100, 100), color="red")
# img_byte_arr = io.BytesIO()
# img.save(img_byte_arr, format="PNG")
# return img_byte_arr.getvalue()

# autochat.add_function(image_function)

# conversation = autochat.run_conversation("Generate an image")

# user_message = next(conversation)
# assert user_message.role == "user"
# assert user_message.content == "Generate an image"

# assistant_message = next(conversation)
# assert assistant_message.role == "assistant"
# assert assistant_message.function_call["name"] == "image_function"

# function_result = next(conversation)
# assert function_result.role == "function"
# assert isinstance(function_result.image, Image)
# assert function_result.content is None

# mock_response.choices[0].message.content = "Image generated successfully"
# mock_response.choices[0].message.function_call = None

# final_message = next(conversation)
# assert final_message.role == "assistant"
# assert final_message.content == "Image generated successfully"
27 changes: 27 additions & 0 deletions tests/test_snap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from unittest.mock import patch
import pytest
from autochat import Autochat, APIProvider


class TestAutochat:
@pytest.mark.snapshot
def test_fetch_openai(self, snapshot):
with patch("openai.OpenAI") as mock_openai:
# Setup the mock response
mock_response = (
mock_openai.return_value.chat.completions.create.return_value
)
mock_response.choices[0].message.role = "assistant"
mock_response.choices[0].message.content = "Mocked response content"
mock_response.choices[0].message.function_call = None
mock_response.id = "mocked_response_id"

# Create an instance of Autochat
autochat = Autochat(provider=APIProvider.OPENAI)

# Call the method
result = autochat.ask("Hello, how are you?")

print(result)
# Assert that the result matches the snapshot
snapshot.assert_match(result.to_openai_dict())