Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
BenderV committed Aug 24, 2024
1 parent a3478b4 commit effdb0b
Show file tree
Hide file tree
Showing 4 changed files with 345 additions and 115 deletions.
254 changes: 143 additions & 111 deletions autochat/chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,10 @@
import typing
from enum import Enum

from tenacity import (
retry,
retry_if_not_exception_type,
stop_after_attempt,
wait_random_exponential,
)
from tenacity import (retry, retry_if_not_exception_type, stop_after_attempt,
wait_random_exponential)

from .model import Message
from .utils import csv_dumps, inspect_schema, parse_chat_template

AUTOCHAT_HOST = os.getenv("AUTOCHAT_HOST")
Expand All @@ -19,14 +16,7 @@

class APIProvider(Enum):
OPENAI = "openai"


class FunctionCallParsingError(Exception):
def __init__(self, obj):
self.obj = obj

def __str__(self):
return f"Invalid function_call: {self.obj.function_call}"
ANTHROPIC = "anthropic"


class ContextLengthExceededError(Exception):
Expand All @@ -45,83 +35,10 @@ class InsufficientQuotaError(Exception):
pass


class Message:
def __init__(
self,
role: str,
content: str = None,
name: typing.Optional[str] = None,
function_call: typing.Optional[dict] = None,
id: typing.Optional[int] = None,
) -> None:
self.role = role
self.content = content
self.name = name
self.function_call = function_call
self.id = id

def to_openai_dict(self) -> dict:
res = {
"role": self.role,
"content": self.content,
}
if self.name:
res["name"] = self.name
if self.function_call:
if self.role == "assistant":
res["function_call"] = {
"name": self.function_call["name"],
"arguments": json.dumps(self.function_call["arguments"]),
}
else:
# If user is triggering a function, we add the function call to the content
# since openai doesn't support functions for user messages
res["content"] = (
self.function_call["name"]
+ ":"
+ json.dumps(self.function_call["arguments"])
)
return res

@classmethod
def from_openai_dict(cls, **kwargs):
obj = cls(**kwargs)
if obj.function_call:
# Parse function_call with json.loads
try:
obj.function_call = {
"name": obj.function_call.name,
"arguments": json.loads(obj.function_call.arguments),
}
except json.decoder.JSONDecodeError:
raise FunctionCallParsingError(obj)
return obj

def __repr__(self) -> str:
text = f"Message(role={self.role}, "
if self.content:
text += f'content="{self.content}", '
if self.function_call:
text += f'function_call="{self.function_call}", '
return text[:-2] + ")"

def to_markdown(self) -> str:
text = f"## {self.role}\n"
if self.content is not None:
text += self.content + "\n"
if self.function_call is not None:
# Display function_call so it look like func(arg1="value1", arg2="value2")
text += f"> {self.function_call['name']}({', '.join([f'{k}={v}' for k, v in self.function_call['arguments'].items()])})\n"
if self.content is None and self.function_call is None:
raise ValueError("Message should have content or function_call")
return text


class ChatGPT:
def __init__(
self,
instruction=None,
examples=[],
messages: list[Message],
context=None,
max_interactions: int = 100,
model=AUTOCHAT_MODEL,
Expand All @@ -139,27 +56,16 @@ def __init__(

self.model = model
self.client = None
self.pre_history: list[Message] = []
self.messages: list[Message] = messages
self.history: list[Message] = []
self.instruction: typing.Optional[str] = instruction
self.examples = examples
self.context = context
self.max_interactions = max_interactions
self.functions_schema = []
self.functions = {}

if self.instruction:
self.pre_history.append(Message(role="system", content=self.instruction))

# Simple loop
for example in self.examples:
# Herit name from message role
self.pre_history.append(
Message(
**example,
name="example_" + example["role"],
)
)
# TODO: Remove this
self.pre_history = messages
self.history = []

if self.provider == APIProvider.OPENAI:
from openai import OpenAI
Expand All @@ -177,15 +83,35 @@ def __init__(
max_retries=0, # default is 2
)
self.fetch = self.fetch_openai
elif self.provider == APIProvider.ANTHROPIC:
import anthropic

if self.model is None:
self.model = "claude-3-5-sonnet-20240620"
self.client = anthropic.Anthropic(
default_headers={"anthropic-beta": "prompt-caching-2024-07-31"}
)
self.fetch = self.fetch_anthropic
else:
raise ValueError(f"Invalid provider: {self.provider}")

@classmethod
def from_template(cls, chat_template: str, **kwargs):
instruction, examples = parse_chat_template(chat_template)
messages = parse_chat_template(chat_template)
return cls(
instruction=instruction,
examples=examples,
messages=messages,
**kwargs,
)

@classmethod
def from_instruction_and_examples(
cls, instruction: str, examples: list[dict], **kwargs
):
messages = [Message(role="system", content=instruction)] + [
Message(**example) for example in examples
]
return cls(
messages=messages,
**kwargs,
)

Expand Down Expand Up @@ -290,11 +216,10 @@ def run_conversation(

if content is None:
# If function call returns None, we continue the conversation without adding a message
message = None
continue

# If data is list of dicts, dumps to CSV
if isinstance(content, list):
# message = None
# continue
content = None
elif isinstance(content, list): # If data is list of dicts, dumps to CSV
if not content:
content = "[]"
elif isinstance(content[0], dict):
Expand Down Expand Up @@ -324,6 +249,7 @@ def run_conversation(
name=function_name,
role="function",
content=content,
function_call_id=response.function_call_id,
)
yield message

Expand Down Expand Up @@ -375,3 +301,109 @@ def fetch_openai(self):
function_call=message.function_call,
id=res.id, # We use the response id as the message id
)

@retry(
stop=stop_after_attempt(4),
wait=wait_random_exponential(multiplier=2, max=10),
# If we get a context_length_exceeded error, we stop the conversation
retry=(
retry_if_not_exception_type(ContextLengthExceededError)
& retry_if_not_exception_type(InvalidRequestError)
& retry_if_not_exception_type(InsufficientQuotaError)
),
# After 5 attempts, we throw the error
reraise=True,
)
def fetch_anthropic(self):
# Anthropic fix for empty function call result
# Iterate on messages and check if the last message is a function call, and the following is a user text
# If so, we have to add an empty result of the function call before the user text
for i in range(len(self.history) - 1, 0, -1):
if (
self.history[i - 1].role == "assistant"
and self.history[i - 1].function_call
and self.history[i].role == "user"
and not self.history[i].function_call
):
print(f"Inserting empty function result: {i}")
# Insert an empty function result
self.history.insert(
i,
Message(
role="function",
name=self.history[i - 1].function_call["name"],
content="",
function_call_id=self.history[i - 1].function_call_id,
),
)

messages = self.prepare_messages(
transform_function=lambda m: m.to_anthropic_dict()
)

# Add cache control to the last message
if (
messages
and isinstance(messages[-1]["content"], list)
and len(messages[-1]["content"]) > 1
):
messages[-1]["content"][-1]["cache_control"] = {"type": "ephemeral"}

system = messages[0]["content"]
messages = messages[1:]

def merge_messages(messages):
"""
When two messages are in the same role, we merge the following message into the previous.
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "example_19",
"content": ""
}
]
},
{
"role": "user",
"content": "Plot distribution of stations per city"
}
"""
merged_messages = []
for message in messages:
if merged_messages and merged_messages[-1]["role"] == message["role"]:
merged_messages[-1]["content"].append(
{
"type": "text",
"text": message["content"],
}
)
else:
merged_messages.append(message)
return merged_messages

messages = merge_messages(messages)

# Need to map field "parameters" to "input_schema"
tools = [
{
"name": s["name"],
"description": s["description"],
"input_schema": s["parameters"],
}
for s in self.functions_schema
]

res = self.client.messages.create(
model=self.model,
system=system,
messages=messages,
tools=tools,
max_tokens=2000,
)
res_dict = res.to_dict()
return Message.from_anthropic_dict(
role=res_dict["role"],
content=res_dict["content"],
)
Loading

0 comments on commit effdb0b

Please sign in to comment.