-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2 from BenderV/callback_image
Callback image
- Loading branch information
Showing
6 changed files
with
206 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,22 @@ | ||
from PIL import Image as PILImage | ||
from PIL import Image | ||
import sys | ||
import argparse | ||
|
||
sys.path.append("..") | ||
from autochat import Autochat, Message, Image | ||
from autochat import Autochat, Message | ||
|
||
ai = Autochat(provider="openai") | ||
parser = argparse.ArgumentParser(description="Describe an image using AI") | ||
parser.add_argument( | ||
"--provider", | ||
type=str, | ||
default="anthropic", | ||
help="AI provider (e.g., 'anthropic', 'openai')", | ||
) | ||
args = parser.parse_args() | ||
|
||
image = Image(PILImage.open("./image.jpg")) | ||
response = ai.ask(Message(role="user", content="describe the image", image=image)) | ||
ai = Autochat(provider=args.provider) | ||
|
||
image = Image.open("./image.jpg") | ||
message = Message(role="user", content="describe the image", image=image) | ||
response = ai.ask(message) | ||
print(response) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
import pytest | ||
from unittest.mock import patch, MagicMock | ||
from autochat import Autochat, Message, APIProvider, Image | ||
from PIL import Image as PILImage | ||
import io | ||
|
||
|
||
@pytest.fixture | ||
def autochat(): | ||
return Autochat(provider=APIProvider.OPENAI, model="gpt-4-turbo") | ||
|
||
|
||
def test_conversation_flow(autochat): | ||
# Mock the OpenAI API call | ||
with patch("openai.OpenAI") as mock_openai: | ||
mock_client = MagicMock() | ||
mock_openai.return_value = mock_client | ||
|
||
# Mock the chat completion response | ||
mock_response = MagicMock() | ||
mock_response.choices[0].message.role = "assistant" | ||
mock_response.choices[0].message.content = None | ||
mock_response.choices[0].message.function_call = { | ||
"name": "test_function", | ||
"arguments": '{"arg1": "value1"}', | ||
} | ||
mock_response.id = "response_id" | ||
mock_client.chat.completions.create.return_value = mock_response | ||
|
||
# Define a test function | ||
def test_function(arg1: str) -> str: | ||
"""test""" | ||
return "Function result" | ||
|
||
autochat.add_function(test_function) | ||
|
||
# Start the conversation | ||
conversation = autochat.run_conversation("Hello") | ||
|
||
# Step 1: User message | ||
user_message = next(conversation) | ||
assert user_message.role == "user" | ||
assert user_message.content == "Hello" | ||
|
||
# # Step 2: Assistant function call | ||
assistant_message = next(conversation) | ||
assert assistant_message.role == "assistant" | ||
assert assistant_message.function_call["name"] == "test_function" | ||
|
||
# # Step 3: Function result | ||
# function_result = next(conversation) | ||
# assert function_result.role == "function" | ||
# assert function_result.content == "Function result" | ||
|
||
# # Mock the second API call (assistant's response to function result) | ||
# mock_response.choices[0].message.content = "Final response" | ||
# mock_response.choices[0].message.function_call = None | ||
|
||
# # Step 4: Assistant final message | ||
# final_message = next(conversation) | ||
# assert final_message.role == "assistant" | ||
# assert final_message.content == "Final response" | ||
|
||
|
||
# def test_conversation_flow_with_image(autochat): | ||
# with patch("openai.OpenAI") as mock_openai: | ||
# mock_client = MagicMock() | ||
# mock_openai.return_value = mock_client | ||
|
||
# mock_response = MagicMock() | ||
# mock_response.choices[0].message.role = "assistant" | ||
# mock_response.choices[0].message.content = None | ||
# mock_response.choices[0].message.function_call = { | ||
# "name": "image_function", | ||
# "arguments": "{}", | ||
# } | ||
# mock_response.id = "response_id" | ||
# mock_client.chat.completions.create.return_value = mock_response | ||
|
||
# def image_function(): | ||
# img = PILImage.new("RGB", (100, 100), color="red") | ||
# img_byte_arr = io.BytesIO() | ||
# img.save(img_byte_arr, format="PNG") | ||
# return img_byte_arr.getvalue() | ||
|
||
# autochat.add_function(image_function) | ||
|
||
# conversation = autochat.run_conversation("Generate an image") | ||
|
||
# user_message = next(conversation) | ||
# assert user_message.role == "user" | ||
# assert user_message.content == "Generate an image" | ||
|
||
# assistant_message = next(conversation) | ||
# assert assistant_message.role == "assistant" | ||
# assert assistant_message.function_call["name"] == "image_function" | ||
|
||
# function_result = next(conversation) | ||
# assert function_result.role == "function" | ||
# assert isinstance(function_result.image, Image) | ||
# assert function_result.content is None | ||
|
||
# mock_response.choices[0].message.content = "Image generated successfully" | ||
# mock_response.choices[0].message.function_call = None | ||
|
||
# final_message = next(conversation) | ||
# assert final_message.role == "assistant" | ||
# assert final_message.content == "Image generated successfully" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from unittest.mock import patch | ||
import pytest | ||
from autochat import Autochat, APIProvider | ||
|
||
|
||
class TestAutochat: | ||
@pytest.mark.snapshot | ||
def test_fetch_openai(self, snapshot): | ||
with patch("openai.OpenAI") as mock_openai: | ||
# Setup the mock response | ||
mock_response = ( | ||
mock_openai.return_value.chat.completions.create.return_value | ||
) | ||
mock_response.choices[0].message.role = "assistant" | ||
mock_response.choices[0].message.content = "Mocked response content" | ||
mock_response.choices[0].message.function_call = None | ||
mock_response.id = "mocked_response_id" | ||
|
||
# Create an instance of Autochat | ||
autochat = Autochat(provider=APIProvider.OPENAI) | ||
|
||
# Call the method | ||
result = autochat.ask("Hello, how are you?") | ||
|
||
print(result) | ||
# Assert that the result matches the snapshot | ||
snapshot.assert_match(result.to_openai_dict()) |