diff --git a/.gitignore b/.gitignore index 7a9fc77..5da5d44 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ autochat/__pycache__/ build/ dist/ .DS_Store +draft diff --git a/README.md b/README.md index fc30fe4..ada5f27 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ # AutoChat -AutoChat is an assistant interface to OpenAI and alternative providers, to simplify the process of creating interactive agents. +AutoChat is an assistant library, that supports OpenAI/Anthropic, to simplify the process of creating interactive agents. -- **ChatGPT Class**: Conversation wrapper to store instruction, context and messages histories. +- **Autochat Class**: Conversation wrapper to store instruction, context and messages histories. - **Message Class**: Message wrapper to handle format/parsing automatically. - **Function Calls**: Capability to handle function calls within the conversation, allowing complex interactions and responses. - **Template System**: A straightforward text-based template system for defining the behavior of the chatbot, making it easy to customize its responses and actions. @@ -20,14 +20,45 @@ Please note that this package requires Python 3.6 or later. ## Simple Example ```python -> from autochat import ChatGPT -> chat = ChatGPT(instruction="You are a parot") +> from autochat import Autochat +> chat = Autochat(instruction="You are a parot") > chat.ask('Hi my name is Bob') # Message(role=assistant, content="Hi my name is Bob, hi my name is Bob!") > chat.ask('Can you tell me my name?') # Message(role=assistant, content="Your name is Bob, your name is Bob!") ``` +## Function Calls Handling + +The library supports function calls, handling the back-and-forth between the system and the assistant. + +```python +from autochat import Autochat, Message +import json + +def label_item(category: str, from_response: Message): + # TODO: Implement function + raise NotImplementedError() + +with open("./examples/function_label.json") as f: + FUNCTION_LABEL_ITEM = json.load(f) + +classifierGPT = Autochat.from_template("./examples/classify_template.txt") +classifierGPT.add_function(label_item, FUNCTION_LABEL_ITEM) + +text = "The new iPhone is out" +for message in classifierGPT.run_conversation(text): + print(message.to_markdown()) + +# > ## assistant +# > It's about \"Technology\" since it's about a new iPhone. +# > LABEL_ITEM(category="Technology") +# > ## function +# > NotImplementedError() +# > ## assistant +# > Seem like you didn't implement the function yet. +``` + ## Template System We provide a simple template system for defining the behavior of the chatbot, using markdown-like syntax. @@ -52,52 +83,31 @@ Your name is Bob, your name is Bob! You can then load the template file using the `from_template` method: ```python -parrotGPT = ChatGPT.from_template("./parrot_template.txt") +parrotGPT = Autochat.from_template("./parrot_template.txt") ``` -The template system also supports function calls. Check out the [examples/classify.py](examples/classify.py) for a complete example. - -## Function Calls Handling +The template system also supports function calls. Check out the [examples/demo_label.py](examples/demo_label.py) for a complete example. -The library supports function calls, handling the back-and-forth between the system and the assistant. - -```python -from autochat import ChatGPT, Message -import json +## Use different API providers (only anthropic and openai are supported for now) -def label_item(category: str, from_response: Message): - # TODO: Implement function - raise NotImplementedError() +Default provider is openai. -with open("./examples/function_label.json") as f: - FUNCTION_LABEL_ITEM = json.load(f) +Anthropic: -classifierGPT = ChatGPT.from_template("./examples/classify_template.txt") -classifierGPT.add_function(label_item, FUNCTION_LABEL_ITEM) - -text = "The new iPhone is out" -for message in classifierGPT.run_conversation(text): - print(message.to_markdown()) - -# > ## assistant -# > It's about \"Technology\" since it's about a new iPhone. -# > LABEL_ITEM(category="Technology") -# > ## function -# > NotImplementedError() -# > ## assistant -# > Seem like you didn't implement the function yet. +```python +chat = Autochat(provider="anthropic") ``` ## Environment Variables -The `AUTOCHAT_DEFAULT_MODEL` environment variable specifies the model to use. If not set, it defaults to "gpt-4-turbo". +The `AUTOCHAT_MODEL` environment variable specifies the model to use. If not set, it defaults to "gpt-4-turbo". ```bash export AUTOCHAT_MODEL="gpt-4-turbo" export OPENAI_API_KEY= ``` -Use `AUTOCHAT_HOST` to use alternative provider that are openai compatible (openpipe, llama_cpp, ...) +Use `AUTOCHAT_HOST` to use alternative provider (openai, anthropic, openpipe, llama_cpp, ...) ## Support diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/autochat/__init__.py b/autochat/__init__.py index f352f5c..e9940d3 100644 --- a/autochat/__init__.py +++ b/autochat/__init__.py @@ -1 +1,412 @@ -from .chatgpt import * +import json +import os +import typing +from enum import Enum + +from tenacity import ( + retry, + retry_if_not_exception_type, + stop_after_attempt, + wait_random_exponential, +) + +from autochat.model import Message +from autochat.utils import csv_dumps, inspect_schema, parse_chat_template + +AUTOCHAT_HOST = os.getenv("AUTOCHAT_HOST") +AUTOCHAT_MODEL = os.getenv("AUTOCHAT_MODEL") +OUTPUT_SIZE_LIMIT = int(os.getenv("AUTOCHAT_OUTPUT_SIZE_LIMIT", 4000)) + + +class APIProvider(Enum): + OPENAI = "openai" + ANTHROPIC = "anthropic" + + +class ContextLengthExceededError(Exception): + pass + + +class InvalidRequestError(Exception): + pass + + +class StopLoopException(Exception): + pass + + +class InsufficientQuotaError(Exception): + pass + + +class Autochat: + def __init__( + self, + instruction: str = None, + examples: list[Message] = [], + messages: list[Message] = [], + context: str = None, + max_interactions: int = 100, + model=AUTOCHAT_MODEL, + provider=APIProvider.OPENAI, + ) -> None: + if isinstance(provider, APIProvider): + self.provider = provider + elif isinstance(provider, str): + try: + self.provider = APIProvider(provider) + except ValueError: + raise ValueError(f"Provider {provider} is not a valid provider") + else: + raise ValueError(f"Invalid provider: {provider}") + + self.model = model + self.client = None + self.instruction = instruction + self.examples = examples + self.context = context + self.messages: list[Message] = messages + self.max_interactions = max_interactions + self.functions_schema = [] + self.functions = {} + + if self.provider == APIProvider.OPENAI: + from openai import OpenAI + + if self.model is None: + # Default to gpt-4-turbo + self.model = "gpt-4-turbo" + self.client = OpenAI( + base_url=( + f"{AUTOCHAT_HOST}/v1" + if AUTOCHAT_HOST + else "https://api.openai.com/v1" + ), + # We override because we have our own retry logic + max_retries=0, # default is 2 + ) + self.fetch = self.fetch_openai + elif self.provider == APIProvider.ANTHROPIC: + import anthropic + + if self.model is None: + self.model = "claude-3-5-sonnet-20240620" + self.client = anthropic.Anthropic( + default_headers={"anthropic-beta": "prompt-caching-2024-07-31"} + ) + self.fetch = self.fetch_anthropic + else: + raise ValueError(f"Invalid provider: {self.provider}") + + @classmethod + def from_template(cls, chat_template: str, **kwargs): + instruction, examples = parse_chat_template(chat_template) + return cls( + instruction=instruction, + examples=examples, + **kwargs, + ) + + @property + def last_message(self): + if not self.messages: + return None + return self.messages[-1].content + + def load_messages(self, messages: list[Message]): + # Check order of messages (based on createdAt) + # Oldest first (createdAt ASC) + # messages = sorted(messages, key=lambda x: x.createdAt) + self.messages = messages # [message for message in messages] + + def add_function( + self, + function: typing.Callable, + function_schema: typing.Optional[dict] = None, + ): + if function_schema is None: + # We try to infer the function schema from the function + function_schema = inspect_schema(function) + + self.functions_schema.append(function_schema) + self.functions[function_schema["name"]] = function + + def prepare_messages( + self, + transform_function: typing.Callable, + transform_list_function: typing.Callable, + ) -> list[dict]: + """Prepare messages for API requests using a transformation function.""" + first_message = self.messages[0] + if self.context: + # Add context to the first message + first_message.content = self.context + "\n" + first_message.content + messages = self.examples + [first_message] + self.messages[1:] + transform_list_function(messages) + return [transform_function(m) for m in messages] + + def ask( + self, + message: typing.Union[Message, str, None] = None, + ) -> Message: + if message: + if isinstance(message, str): + # If message is instance of string, then convert to Message + message = Message( + role="user", + content=message, + ) + self.messages.append(message) # Add the question to the history + + response = self.fetch() + self.messages.append(response) + return response + + def run_conversation( + self, question: str = None + ) -> typing.Generator[Message, None, None]: + if question: + message = Message( + role="user", + content=question, + ) + yield message + else: + message = None + + for _ in range(self.max_interactions): + # TODO: Check if the user has stopped the query + response = self.ask(message) + + if not response.function_call: + # We stop the conversation if the response is not a function call + yield response + return + + function_name = response.function_call["name"] + function_arguments = response.function_call["arguments"] + + try: + try: + content = self.functions[function_name]( + **function_arguments, + from_response=response, + ) + except TypeError: + # If the function does not accept 'from_response', call it without that argument + content = self.functions[function_name](**function_arguments) + except Exception as e: + if isinstance(e, StopLoopException): + yield response + return + # If function call failed, return the error message + # Flatten the error message + content = e.__repr__() + + yield response + + if content is None: + # If function call returns None, we continue the conversation without adding a message + # message = None + # continue + content = None + elif isinstance(content, list): # If data is list of dicts, dumps to CSV + if not content: + content = "[]" + elif isinstance(content[0], dict): + try: + content = csv_dumps(content, OUTPUT_SIZE_LIMIT) + except Exception as e: + print(e) + else: + content = "\n".join(content) + elif isinstance(content, dict): + content = json.dumps(content) + if len(content) > OUTPUT_SIZE_LIMIT: + content = ( + content[:OUTPUT_SIZE_LIMIT] + + f"\n... ({len(content)} characters)" + ) + elif isinstance(content, str): + if len(content) > OUTPUT_SIZE_LIMIT: + content = ( + content[:OUTPUT_SIZE_LIMIT] + + f"\n... ({len(content)} characters)" + ) + else: + raise ValueError(f"Invalid content type: {type(content)}") + + message = Message( + name=function_name, + role="function", + content=content, + function_call_id=response.function_call_id, + ) + yield message + + @retry( + stop=stop_after_attempt(4), + wait=wait_random_exponential(multiplier=2, max=10), + # If we get a context_length_exceeded error, we stop the conversation + retry=( + retry_if_not_exception_type(ContextLengthExceededError) + & retry_if_not_exception_type(InvalidRequestError) + & retry_if_not_exception_type(InsufficientQuotaError) + ), + # After 5 attempts, we throw the error + reraise=True, + ) + def fetch_openai(self): + import openai + + messages = self.prepare_messages(transform_function=Message.to_openai_dict) + + try: + if self.functions_schema: + res = self.client.chat.completions.create( + model=self.model, + messages=messages, + functions=self.functions_schema, + ) + else: + res = self.client.chat.completions.create( + model=self.model, messages=messages + ) + except openai.BadRequestError as e: + if e.code == "context_length_exceeded": + raise ContextLengthExceededError(e) + if e.code == "invalid_request_error": + raise InvalidRequestError(e) + raise + except openai.RateLimitError as e: + if e.code == "insufficient_quota": + raise InsufficientQuotaError(e) + raise + except openai.APIError as e: + raise e + + message = res.choices[0].message + return Message.from_openai_dict( + role=message.role, + content=message.content, + function_call=message.function_call, + id=res.id, # We use the response id as the message id + ) + + @retry( + stop=stop_after_attempt(4), + wait=wait_random_exponential(multiplier=2, max=10), + # If we get a context_length_exceeded error, we stop the conversation + retry=( + retry_if_not_exception_type(ContextLengthExceededError) + & retry_if_not_exception_type(InvalidRequestError) + & retry_if_not_exception_type(InsufficientQuotaError) + ), + # After 5 attempts, we throw the error + reraise=True, + ) + def fetch_anthropic(self): + def add_empty_function_result(messages): + # Anthropic fix for empty function call result + # Iterate on messages and check if the last message is a function call, and the following is a user text + # If so, we have to add an empty result of the function call before the user text + for i in range(len(messages) - 1, 0, -1): + if ( + messages[i - 1].role == "assistant" + and messages[i - 1].function_call + and not messages[i].role == "function" + ): + # Insert an empty function result + messages.insert( + i, + Message( + role="function", + name=messages[i - 1].function_call["name"], + content="", + function_call_id=messages[i - 1].function_call_id, + ), + ) + + messages = self.prepare_messages( + transform_function=lambda m: m.to_anthropic_dict(), + transform_list_function=add_empty_function_result, + ) + + # Add cache control to the last message + if ( + messages + and isinstance(messages[-1]["content"], list) + and len(messages[-1]["content"]) > 1 + ): + messages[-1]["content"][-1]["cache_control"] = {"type": "ephemeral"} + + # Hacky way to handle system message + if self.examples and self.examples[0].role == "system": + system = messages[0]["content"] + messages = messages[1:] + else: + system = None + + def merge_messages(messages): + """ + When two messages are in the same role, we merge the following message into the previous. + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "example_19", + "content": "" + } + ] + }, + { + "role": "user", + "content": "Plot distribution of stations per city" + } + """ + merged_messages = [] + for message in messages: + if merged_messages and merged_messages[-1]["role"] == message["role"]: + merged_messages[-1]["content"].append( + { + "type": "text", + "text": message["content"], + } + ) + else: + merged_messages.append(message) + return merged_messages + + messages = merge_messages(messages) + + # Need to map field "parameters" to "input_schema" + tools = [ + { + "name": s["name"], + "description": s["description"], + "input_schema": s["parameters"], + } + for s in self.functions_schema + ] + # Add description to the function is their description is empty + for tool in tools: + if not tool["description"]: + tool["description"] = "No description provided" + + kwargs = {} + if system is not None: + kwargs["system"] = system + + res = self.client.messages.create( + model=self.model, + messages=messages, + tools=tools, + max_tokens=2000, + **kwargs, + ) + res_dict = res.to_dict() + return Message.from_anthropic_dict( + role=res_dict["role"], + content=res_dict["content"], + ) diff --git a/autochat/chatgpt.py b/autochat/chatgpt.py deleted file mode 100644 index d6acbd1..0000000 --- a/autochat/chatgpt.py +++ /dev/null @@ -1,409 +0,0 @@ -import json -import os -import typing -from enum import Enum - -from tenacity import (retry, retry_if_not_exception_type, stop_after_attempt, - wait_random_exponential) - -from .model import Message -from .utils import csv_dumps, inspect_schema, parse_chat_template - -AUTOCHAT_HOST = os.getenv("AUTOCHAT_HOST") -AUTOCHAT_MODEL = os.getenv("AUTOCHAT_MODEL") -OUTPUT_SIZE_LIMIT = int(os.getenv("AUTOCHAT_OUTPUT_SIZE_LIMIT", 4000)) - - -class APIProvider(Enum): - OPENAI = "openai" - ANTHROPIC = "anthropic" - - -class ContextLengthExceededError(Exception): - pass - - -class InvalidRequestError(Exception): - pass - - -class StopLoopException(Exception): - pass - - -class InsufficientQuotaError(Exception): - pass - - -class ChatGPT: - def __init__( - self, - messages: list[Message], - context=None, - max_interactions: int = 100, - model=AUTOCHAT_MODEL, - provider=APIProvider.OPENAI, - ) -> None: - if isinstance(provider, APIProvider): - self.provider = provider - elif isinstance(provider, str): - try: - self.provider = APIProvider(provider) - except ValueError: - raise ValueError(f"Provider {provider} is not a valid provider") - else: - raise ValueError(f"Invalid provider: {provider}") - - self.model = model - self.client = None - self.messages: list[Message] = messages - self.history: list[Message] = [] - self.context = context - self.max_interactions = max_interactions - self.functions_schema = [] - self.functions = {} - - # TODO: Remove this - self.pre_history = messages - self.history = [] - - if self.provider == APIProvider.OPENAI: - from openai import OpenAI - - if self.model is None: - # Default to gpt-4-turbo - self.model = "gpt-4-turbo" - self.client = OpenAI( - base_url=( - f"{AUTOCHAT_HOST}/v1" - if AUTOCHAT_HOST - else "https://api.openai.com/v1" - ), - # We override because we have our own retry logic - max_retries=0, # default is 2 - ) - self.fetch = self.fetch_openai - elif self.provider == APIProvider.ANTHROPIC: - import anthropic - - if self.model is None: - self.model = "claude-3-5-sonnet-20240620" - self.client = anthropic.Anthropic( - default_headers={"anthropic-beta": "prompt-caching-2024-07-31"} - ) - self.fetch = self.fetch_anthropic - else: - raise ValueError(f"Invalid provider: {self.provider}") - - @classmethod - def from_template(cls, chat_template: str, **kwargs): - messages = parse_chat_template(chat_template) - return cls( - messages=messages, - **kwargs, - ) - - @classmethod - def from_instruction_and_examples( - cls, instruction: str, examples: list[dict], **kwargs - ): - messages = [Message(role="system", content=instruction)] + [ - Message(**example) for example in examples - ] - return cls( - messages=messages, - **kwargs, - ) - - @property - def last_message(self): - if not self.history: - return None - return self.history[-1].content - - def reset_history(self): - self.history: list[Message] = [] - - def load_history(self, messages: list[Message]): - # Check order of messages (based on createdAt) - # Oldest first (createdAt ASC) - # messages = sorted(messages, key=lambda x: x.createdAt) - self.history = messages # [message for message in messages] - - def add_function( - self, - function: typing.Callable, - function_schema: typing.Optional[dict] = None, - ): - if function_schema is None: - # We try to infer the function schema from the function - function_schema = inspect_schema(function) - - self.functions_schema.append(function_schema) - self.functions[function_schema["name"]] = function - - def compress_history(self): - """Try to make a summary of the history""" - # TODO: Implement - - def prepare_messages(self, transform_function) -> list[dict]: - """Prepare messages for API requests using a transformation function.""" - first_message = self.history[0] - if self.context: - first_message.content = self.context + "\n" + first_message.content - messages = self.pre_history + [first_message] + self.history[1:] - return [transform_function(m) for m in messages] - - def ask( - self, - message: typing.Union[Message, str, None] = None, - ) -> Message: - if message: - if isinstance(message, str): - # If message is instance of string, then convert to Message - message = Message( - role="user", - content=message, - ) - self.history.append(message) # Add the question to the history - - response = self.fetch() - self.history.append(response) - return response - - def run_conversation( - self, question: str = None - ) -> typing.Generator[Message, None, None]: - if question: - message = Message( - role="user", - content=question, - ) - yield message - else: - message = None - - for _ in range(self.max_interactions): - # TODO: Check if the user has stopped the query - response = self.ask(message) - - if not response.function_call: - # We stop the conversation if the response is not a function call - yield response - return - - function_name = response.function_call["name"] - function_arguments = response.function_call["arguments"] - - try: - try: - content = self.functions[function_name]( - **function_arguments, - from_response=response, - ) - except TypeError: - # If the function does not accept 'from_response', call it without that argument - content = self.functions[function_name](**function_arguments) - except Exception as e: - if isinstance(e, StopLoopException): - yield response - return - # If function call failed, return the error message - # Flatten the error message - content = e.__repr__() - - yield response - - if content is None: - # If function call returns None, we continue the conversation without adding a message - # message = None - # continue - content = None - elif isinstance(content, list): # If data is list of dicts, dumps to CSV - if not content: - content = "[]" - elif isinstance(content[0], dict): - try: - content = csv_dumps(content, OUTPUT_SIZE_LIMIT) - except Exception as e: - print(e) - else: - content = "\n".join(content) - elif isinstance(content, dict): - content = json.dumps(content) - if len(content) > OUTPUT_SIZE_LIMIT: - content = ( - content[:OUTPUT_SIZE_LIMIT] - + f"\n... ({len(content)} characters)" - ) - elif isinstance(content, str): - if len(content) > OUTPUT_SIZE_LIMIT: - content = ( - content[:OUTPUT_SIZE_LIMIT] - + f"\n... ({len(content)} characters)" - ) - else: - raise ValueError(f"Invalid content type: {type(content)}") - - message = Message( - name=function_name, - role="function", - content=content, - function_call_id=response.function_call_id, - ) - yield message - - @retry( - stop=stop_after_attempt(4), - wait=wait_random_exponential(multiplier=2, max=10), - # If we get a context_length_exceeded error, we stop the conversation - retry=( - retry_if_not_exception_type(ContextLengthExceededError) - & retry_if_not_exception_type(InvalidRequestError) - & retry_if_not_exception_type(InsufficientQuotaError) - ), - # After 5 attempts, we throw the error - reraise=True, - ) - def fetch_openai(self): - import openai - - messages = self.prepare_messages(transform_function=Message.to_openai_dict) - - try: - if self.functions_schema: - res = self.client.chat.completions.create( - model=self.model, - messages=messages, - functions=self.functions_schema, - ) - else: - res = self.client.chat.completions.create( - model=self.model, messages=messages - ) - except openai.BadRequestError as e: - if e.code == "context_length_exceeded": - raise ContextLengthExceededError(e) - if e.code == "invalid_request_error": - raise InvalidRequestError(e) - raise - except openai.RateLimitError as e: - if e.code == "insufficient_quota": - raise InsufficientQuotaError(e) - raise - except openai.APIError as e: - raise e - - message = res.choices[0].message - return Message.from_openai_dict( - role=message.role, - content=message.content, - function_call=message.function_call, - id=res.id, # We use the response id as the message id - ) - - @retry( - stop=stop_after_attempt(4), - wait=wait_random_exponential(multiplier=2, max=10), - # If we get a context_length_exceeded error, we stop the conversation - retry=( - retry_if_not_exception_type(ContextLengthExceededError) - & retry_if_not_exception_type(InvalidRequestError) - & retry_if_not_exception_type(InsufficientQuotaError) - ), - # After 5 attempts, we throw the error - reraise=True, - ) - def fetch_anthropic(self): - # Anthropic fix for empty function call result - # Iterate on messages and check if the last message is a function call, and the following is a user text - # If so, we have to add an empty result of the function call before the user text - for i in range(len(self.history) - 1, 0, -1): - if ( - self.history[i - 1].role == "assistant" - and self.history[i - 1].function_call - and self.history[i].role == "user" - and not self.history[i].function_call - ): - print(f"Inserting empty function result: {i}") - # Insert an empty function result - self.history.insert( - i, - Message( - role="function", - name=self.history[i - 1].function_call["name"], - content="", - function_call_id=self.history[i - 1].function_call_id, - ), - ) - - messages = self.prepare_messages( - transform_function=lambda m: m.to_anthropic_dict() - ) - - # Add cache control to the last message - if ( - messages - and isinstance(messages[-1]["content"], list) - and len(messages[-1]["content"]) > 1 - ): - messages[-1]["content"][-1]["cache_control"] = {"type": "ephemeral"} - - system = messages[0]["content"] - messages = messages[1:] - - def merge_messages(messages): - """ - When two messages are in the same role, we merge the following message into the previous. - { - "role": "user", - "content": [ - { - "type": "tool_result", - "tool_use_id": "example_19", - "content": "" - } - ] - }, - { - "role": "user", - "content": "Plot distribution of stations per city" - } - """ - merged_messages = [] - for message in messages: - if merged_messages and merged_messages[-1]["role"] == message["role"]: - merged_messages[-1]["content"].append( - { - "type": "text", - "text": message["content"], - } - ) - else: - merged_messages.append(message) - return merged_messages - - messages = merge_messages(messages) - - # Need to map field "parameters" to "input_schema" - tools = [ - { - "name": s["name"], - "description": s["description"], - "input_schema": s["parameters"], - } - for s in self.functions_schema - ] - - res = self.client.messages.create( - model=self.model, - system=system, - messages=messages, - tools=tools, - max_tokens=2000, - ) - res_dict = res.to_dict() - return Message.from_anthropic_dict( - role=res_dict["role"], - content=res_dict["content"], - ) diff --git a/autochat/test.py b/autochat/test.py new file mode 100644 index 0000000..d61dcda --- /dev/null +++ b/autochat/test.py @@ -0,0 +1,66 @@ +import unittest +from unittest.mock import patch + +from autochat import APIProvider, Autochat, ContextLengthExceededError, Message + + +class TestAutochat(unittest.TestCase): + def test_autochat_initialization(self): + chat = Autochat(instruction="Test instruction", provider="openai") + self.assertEqual(chat.instruction, "Test instruction") + self.assertEqual(chat.provider, APIProvider.OPENAI) + self.assertEqual(chat.model, "gpt-4-turbo") + + def test_autochat_invalid_provider(self): + with self.assertRaises(ValueError): + Autochat(provider="invalid_provider") + + def test_add_function(self): + chat = Autochat() + + def test_function(arg1: str, arg2: int) -> str: + return f"Received {arg1} and {arg2}" + + chat.add_function(test_function) + self.assertEqual(len(chat.functions_schema), 1) + self.assertIn("test_function", chat.functions) + + @patch.object(Autochat, "fetch_openai") + def test_ask(self, mock_fetch_openai): + mock_fetch_openai.return_value = Message( + role="assistant", content="Test response" + ) + chat = Autochat(provider="openai") + + response = chat.ask("Test question") + self.assertEqual(response.role, "assistant") + self.assertEqual(response.content, "Test response") + self.assertEqual(len(chat.messages), 2) + + @patch.object(Autochat, "fetch_openai") + def test_run_conversation(self, mock_fetch_openai): + mock_fetch_openai.return_value = Message( + role="assistant", content="Final response" + ) + chat = Autochat(provider="openai") + + responses = list(chat.run_conversation("Test question")) + self.assertEqual(len(responses), 2) + self.assertEqual(responses[0].role, "user") + self.assertEqual(responses[0].content, "Test question") + self.assertEqual(responses[1].role, "assistant") + self.assertEqual(responses[1].content, "Final response") + + @patch.object(Autochat, "fetch_openai") + def test_context_length_exceeded(self, mock_fetch_openai): + mock_fetch_openai.side_effect = ContextLengthExceededError( + "Context length exceeded" + ) + chat = Autochat(provider="openai") + + with self.assertRaises(ContextLengthExceededError): + chat.ask("Test question") + + +if __name__ == "__main__": + unittest.main() diff --git a/autochat/utils.py b/autochat/utils.py index ea241f0..b79ccf7 100644 --- a/autochat/utils.py +++ b/autochat/utils.py @@ -159,21 +159,21 @@ def parse_chat_template(filename) -> list[Message]: pairs = [pair.split("\n", 1) for pair in pairs] # create a list of tuples - messages = [(pair[0], pair[1].strip()) for pair in pairs] + examples_pairs_str = [(pair[0], pair[1].strip()) for pair in pairs] - examples = [] + parsed_examples = [] instruction = None - for ind, message in enumerate(messages): + for ind, example in enumerate(examples_pairs_str): # If first message role is a system message, extract the example - if ind == 0 and message[0] == "system": - instruction = message[1] + if ind == 0 and example[0] == "system": + instruction = example[1] else: - role = message[0].strip().lower() - message = message[1] + role = example[0].strip().lower() + message = example[1] content, function_call_str = split_message(message) if function_call_str: - examples.append( + parsed_examples.append( { "role": role, "content": content if content else None, @@ -181,19 +181,15 @@ def parse_chat_template(filename) -> list[Message]: } ) else: - examples.append( + parsed_examples.append( { "role": role, "content": message, } ) - messages = [] - if instruction: - messages.append(Message(role="system", content=instruction)) - - # Simple loop - for ind, example in enumerate(examples): + examples: list[Message] = [] + for ind, example in enumerate(parsed_examples): # Herit name from message role message = Message( **example, @@ -204,9 +200,9 @@ def parse_chat_template(filename) -> list[Message]: message.function_call_id = "example_" + str(ind) if message.role == "function": message.function_call_id = "example_" + str(ind - 1) - messages.append(message) + examples.append(message) - return messages + return instruction, examples def inspect_schema(f): diff --git a/examples/demo_label.py b/examples/demo_label.py index cce687e..e14ac79 100644 --- a/examples/demo_label.py +++ b/examples/demo_label.py @@ -3,7 +3,7 @@ sys.path.append("..") -from autochat import ChatGPT, Message +from autochat import Autochat, Message def label_item(category: str, from_response: Message): @@ -14,7 +14,7 @@ def label_item(category: str, from_response: Message): with open("./function_label.json") as f: FUNCTION_LABEL_ITEM = json.load(f) -classifierGPT = ChatGPT.from_template("./classify_template.txt") +classifierGPT = Autochat.from_template("./classify_template.txt") classifierGPT.add_function(label_item, FUNCTION_LABEL_ITEM) diff --git a/setup.py b/setup.py index b82d8d0..a750bf2 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name="autochat", - version="0.2.0", + version="0.3.0", packages=find_packages(), install_requires=["tenacity==8.3.0"], extras_require={ @@ -12,7 +12,7 @@ }, author="Benjamin Derville", author_email="benderville@gmail.com", - description="Small ChatGPT library to support chat templates, and function calls", + description="Small OpenAI/Anthropic library to support chat templates, and function calls", long_description=open("README.md").read(), long_description_content_type="text/markdown", url="https://github.com/benderv/autochat",