Skip to content

Commit

Permalink
add image support
Browse files Browse the repository at this point in the history
  • Loading branch information
BenderV committed Sep 8, 2024
1 parent dcfc5c0 commit 6f651ce
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 21 deletions.
93 changes: 73 additions & 20 deletions autochat/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import json
import typing
import base64
from io import BytesIO
from typing import Literal

from PIL import Image as PILImage


class FunctionCallParsingError(Exception):
def __init__(self, obj):
Expand All @@ -11,6 +15,24 @@ def __str__(self):
return f"Invalid function_call: {self.obj.function_call}"


class Image:
def __init__(self, image: PILImage.Image):
self.image = image

def resize(self, size: tuple[int, int]):
self.image = self.image.resize(size)

def to_base64(self):
buffered = BytesIO()
self.image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
return img_str

@property
def format(self):
return "image/" + self.image.format.lower()


class Message:
def __init__(
self,
Expand All @@ -20,19 +42,37 @@ def __init__(
function_call: typing.Optional[dict] = None,
id: typing.Optional[int] = None,
function_call_id: typing.Optional[str] = None,
image: typing.Optional[Image] = None,
) -> None:
self.role = role
self.content = content
self.name = name
self.function_call = function_call
self.id = id
self.function_call_id = function_call_id
self.image = image

def to_openai_dict(self) -> dict:
res = {
"role": self.role,
"content": self.content,
"content": [],
}
if self.content:
res["content"].append(
{
"type": "text",
"text": self.content,
}
)
if self.image:
res["content"].append(
{
"type": "image_url",
"image_url": {
"url": f"data:{self.image.format};base64,{self.image.to_base64()}"
},
}
)
if self.name:
res["name"] = self.name
if self.function_call:
Expand All @@ -43,7 +83,7 @@ def to_openai_dict(self) -> dict:
}
else:
# If user is triggering a function, we add the function call to the content
# since openai doesn't support functions for user messages
# since openai doesn't support functions from user messages
res["content"] = (
self.function_call["name"]
+ ":"
Expand All @@ -52,16 +92,29 @@ def to_openai_dict(self) -> dict:
return res

def to_anthropic_dict(self) -> dict:
res = {"role": self.role if self.role in ["user", "assistant"] else "user"}
res = {
"role": self.role if self.role in ["user", "assistant"] else "user",
"content": [],
}

if self.role == "function": # result of a function call
res["content"] = [
{
"type": "tool_result",
"tool_use_id": self.function_call_id,
"content": self.content,
}
]
return res

if self.content:
res["content"].append(
{
"type": "text",
"text": self.content,
}
)
if self.function_call:
res["content"] = []
if self.content:
res["content"].append(
{
"type": "text",
"text": self.content,
}
)
res["content"].append(
{
"type": "tool_use",
Expand All @@ -70,17 +123,17 @@ def to_anthropic_dict(self) -> dict:
"input": self.function_call["arguments"],
}
)
elif self.role == "function": # result of a function call
res["role"] = "user"
res["content"] = [
if self.image:
res["content"].append(
{
"type": "tool_result",
"tool_use_id": self.function_call_id,
"content": self.content,
"type": "image",
"source": {
"type": "base64",
"media_type": self.image.format,
"data": self.image.to_base64(),
},
}
]
else:
res["content"] = self.content
)

return res

Expand Down
11 changes: 11 additions & 0 deletions examples/demo_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from PIL import Image as PILImage
import sys

sys.path.append("..")
from autochat import Autochat, Message, Image

ai = Autochat(provider="openai")

image = Image(PILImage.open("./image.jpg"))
response = ai.ask(Message(role="user", content="describe the image", image=image))
print(response)
Binary file added examples/image.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name="autochat",
version="0.3.1",
version="0.3.2",
packages=find_packages(),
install_requires=["tenacity==8.3.0"],
extras_require={
Expand Down

0 comments on commit 6f651ce

Please sign in to comment.