Skip to content

Commit

Permalink
Standardizing environment for completions and embeddings (#46)
Browse files Browse the repository at this point in the history
* standardizing environment for completions and embeddings

* pr feedback

* build pipeline fixes

* renaming default_values module

* renaming tests
  • Loading branch information
rlundeen2 authored Feb 16, 2024
1 parent 6491292 commit b4beb27
Show file tree
Hide file tree
Showing 12 changed files with 217 additions and 80 deletions.
42 changes: 11 additions & 31 deletions examples/code/azure_embeddings_and_completions.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,15 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 1,
"metadata": {
"ExecuteTime": {
"end_time": "2023-07-27T18:17:24.179820Z",
"start_time": "2023-07-27T18:17:24.142442Z"
},
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
}
],
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
Expand All @@ -66,34 +57,23 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 3,
"metadata": {
"ExecuteTime": {
"end_time": "2023-07-27T20:00:49.932378Z",
"start_time": "2023-07-27T20:00:48.444786Z"
},
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Completion(id='cmpl-8XxIbg1skkdrHYQEJYDWrA8jE9VyI', choices=[CompletionChoice(finish_reason='length', index=0, logprobs=None, text=' mineral compositionum granite this stone is composed of feldspar,quartz and')], created=1703103461, model='davinci-002', object='text_completion', system_fingerprint=None, usage=CompletionUsage(completion_tokens=16, prompt_tokens=3, total_tokens=19))\n"
]
}
],
"outputs": [],
"source": [
"from pprint import pprint\n",
"from pyrit.completion.azure_completions import AzureCompletion\n",
"\n",
"\n",
"prompt = \"hello world!\"\n",
"\n",
"davinci_engine = AzureCompletion(\n",
" api_key=api_key,\n",
" api_base=api_base,\n",
" model=deployment_name)\n",
"davinci_engine = AzureCompletion()\n",
"text_response = davinci_engine.complete_text(text=prompt)"
]
},
Expand All @@ -109,7 +89,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 4,
"metadata": {
"ExecuteTime": {
"end_time": "2023-07-27T20:00:50.600130Z",
Expand All @@ -122,7 +102,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"PromptResponse(completion=' mineral compositionum granite this stone is composed of feldspar,quartz and', prompt='hello world!', id='cmpl-8XxIbg1skkdrHYQEJYDWrA8jE9VyI', completion_tokens=16, prompt_tokens=3, total_tokens=19, model='davinci-002', object='text_completion', created_at=0, logprobs=False, index=0, finish_reason='', api_request_time_to_complete_ns=0, metadata={})\n"
"PromptResponse(completion=' i’m kai the bass slayer. It’s been 16 years of', prompt='hello world!', id='cmpl-8saOvxBdORlWpGjDTCsJZl5HKsYTB', completion_tokens=16, prompt_tokens=3, total_tokens=19, model='davinci-002', object='text_completion', created_at=0, logprobs=False, index=0, finish_reason='', api_request_time_to_complete_ns=0, metadata={})\n"
]
}
],
Expand Down Expand Up @@ -158,8 +138,8 @@
"input_text = \"hello\"\n",
"ada_embedding_engine = AzureTextEmbedding(\n",
" api_key=api_key,\n",
" api_base=api_base,\n",
" model=os.environ.get(\"AZURE_OPENAI_EMBEDDING_DEPLOYMENT\"))\n",
" endpoint=api_base,\n",
" deployment=os.environ.get(\"AZURE_OPENAI_EMBEDDING_DEPLOYMENT\"))\n",
"embedding_response = ada_embedding_engine.generate_text_embedding(text=input_text)"
]
},
Expand Down Expand Up @@ -302,9 +282,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "pyrit_kernel",
"display_name": "pyrit-dev",
"language": "python",
"name": "pyrit_kernel"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand Down
10 changes: 6 additions & 4 deletions pyrit/chat/aml_online_endpoint_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import asyncio
import logging

from pyrit.common import environment_variables, net_utility
from pyrit.common import default_values, net_utility
from pyrit.interfaces import ChatSupport
from pyrit.models import ChatMessage

Expand Down Expand Up @@ -34,10 +34,12 @@ def __init__(
endpoint_uri: AML online endpoint URI.
api_key: api key for the endpoint
"""
self.endpoint_uri: str = environment_variables.get_required_value(
self.ENDPOINT_URI_ENVIRONMENT_VARIABLE, endpoint_uri
self.endpoint_uri: str = default_values.get_required_value(
env_var_name=self.ENDPOINT_URI_ENVIRONMENT_VARIABLE, passed_value=endpoint_uri
)
self.api_key: str = default_values.get_required_value(
env_var_name=self.API_KEY_ENVIRONMENT_VARIABLE, passed_value=api_key
)
self.api_key: str = environment_variables.get_required_value(self.API_KEY_ENVIRONMENT_VARIABLE, api_key)

def complete_chat(
self,
Expand Down
10 changes: 7 additions & 3 deletions pyrit/chat/azure_openai_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Licensed under the MIT license.
from openai import AsyncAzureOpenAI, AzureOpenAI
from openai.types.chat import ChatCompletion
from pyrit.common import environment_variables
from pyrit.common import default_values

from pyrit.interfaces import ChatSupport
from pyrit.models import ChatMessage
Expand All @@ -22,8 +22,12 @@ def __init__(
) -> None:
self._deployment_name = deployment_name

endpoint = environment_variables.get_required_value(self.ENDPOINT_URI_ENVIRONMENT_VARIABLE, endpoint)
api_key = environment_variables.get_required_value(self.API_KEY_ENVIRONMENT_VARIABLE, api_key)
endpoint = default_values.get_required_value(
env_var_name=self.ENDPOINT_URI_ENVIRONMENT_VARIABLE, passed_value=endpoint
)
api_key = default_values.get_required_value(
env_var_name=self.API_KEY_ENVIRONMENT_VARIABLE, passed_value=api_key
)

self._client = AzureOpenAI(
api_key=api_key,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os


def get_required_value(environment_variable_name: str, passed_value: str) -> str:
def get_required_value(*, env_var_name: str, passed_value: str) -> str:
"""
Gets a required value from an environment variable or a passed value,
prefering the passed value
Expand All @@ -22,8 +22,8 @@ def get_required_value(environment_variable_name: str, passed_value: str) -> str
if passed_value:
return passed_value

value = os.environ.get(environment_variable_name)
value = os.environ.get(env_var_name)
if value:
return value

raise ValueError(f"Environment variable {environment_variable_name} is required")
raise ValueError(f"Environment variable {env_var_name} is required")
36 changes: 31 additions & 5 deletions pyrit/completion/azure_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,46 @@

from openai import AzureOpenAI

from pyrit.common import default_values
from pyrit.interfaces import CompletionSupport
from pyrit.models import PromptResponse


class AzureCompletion(CompletionSupport):
def __init__(self, api_key: str, api_base: str, model: str, api_version: str = "2023-05-15"):
self._model = model
self._api_version = api_version
self._api_base = api_base
API_KEY_ENVIRONMENT_VARIABLE: str = "AZURE_OPENAI_COMPLETION_KEY"
ENDPOINT_URI_ENVIRONMENT_VARIABLE: str = "AZURE_OPENAI_COMPLETION_ENDPOINT"
DEPLOYMENT_ENVIRONMENT_VARIABLE: str = "AZURE_OPENAI_COMPLETION_DEPLOYMENT"

def __init__(
self, api_key: str = None, endpoint: str = None, deployment: str = None, api_version: str = "2023-05-15"
):
"""
Initializes an instance of the AzureCompletions class.
Args:
api_key (str, optional): The API key for accessing the Azure OpenAI service.
Defaults to environment variable.
endpoint (str, optional): The endpoint URL for the Azure OpenAI service.
Defaults to environment variable.
deployment (str, optional): The deployment name for the Azure OpenAI service.
Defaults to environment variable.
api_version (str, optional): The API version for the Azure OpenAI service. Defaults to "2023-05-15".
"""

api_key = default_values.get_required_value(
env_var_name=self.API_KEY_ENVIRONMENT_VARIABLE, passed_value=api_key
)
endpoint = default_values.get_required_value(
env_var_name=self.ENDPOINT_URI_ENVIRONMENT_VARIABLE, passed_value=endpoint
)
self._model = default_values.get_required_value(
env_var_name=self.DEPLOYMENT_ENVIRONMENT_VARIABLE, passed_value=deployment
)

self._client = AzureOpenAI(
api_key=api_key,
api_version=api_version,
azure_endpoint=api_base,
azure_endpoint=endpoint,
)

def complete_text(self, text: str, **kwargs) -> PromptResponse:
Expand Down
35 changes: 27 additions & 8 deletions pyrit/embedding/azure_text_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,43 @@

from openai import AzureOpenAI

from pyrit.common import default_values
from pyrit.embedding._text_embedding import _TextEmbedding


class AzureTextEmbedding(_TextEmbedding):
def __init__(self, *, api_key: str, api_base: str, api_version: str = "2023-05-15", model: str) -> None:
API_KEY_ENVIRONMENT_VARIABLE: str = "AZURE_OPENAI_EMBEDDING_KEY"
ENDPOINT_URI_ENVIRONMENT_VARIABLE: str = "AZURE_OPENAI_EMBEDDING_ENDPOINT"
DEPLOYMENT_ENVIRONMENT_VARIABLE: str = "AZURE_OPENAI_EMBEDDING_DEPLOYMENT"

def __init__(
self, *, api_key: str = None, endpoint: str = None, deployment: str = None, api_version: str = "2023-05-15"
) -> None:
"""Generate embedding using the Azure API
Args:
api_key: The API key to use
api_base: The API base to use
model: The engine to use, usually name of the deployment
api_version: The API version to use
api_key: The API key to use. Defaults to environment variable.
endpoint: The API base to use, sometimes referred to as the api_base. Defaults to environment variable.
deployment: The engine to use, in AOAI referred to as deployment, in some APIs referred to as model.
Defaults to environment variable.
api_version: The API version to use. Defaults to "2023-05-15".
"""

api_key = default_values.get_required_value(
env_var_name=self.API_KEY_ENVIRONMENT_VARIABLE, passed_value=api_key
)
endpoint = default_values.get_required_value(
env_var_name=self.ENDPOINT_URI_ENVIRONMENT_VARIABLE, passed_value=endpoint
)
deployment = default_values.get_required_value(
env_var_name=self.DEPLOYMENT_ENVIRONMENT_VARIABLE, passed_value=deployment
)

self._client = AzureOpenAI(
api_key=api_key,
api_version=api_version,
azure_endpoint=api_base,
azure_deployment=model,
azure_endpoint=endpoint,
azure_deployment=deployment,
)
self._model = model
self._model = deployment
super().__init__()
2 changes: 1 addition & 1 deletion pyrit/memory/memory_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def default_memory_embedding_factory(embedding_model: EmbeddingSupport = None) -
api_base = os.environ.get("AZURE_OPENAI_EMBEDDING_ENDPOINT")
deployment = os.environ.get("AZURE_OPENAI_EMBEDDING_DEPLOYMENT")
if api_key and api_base and deployment:
model = AzureTextEmbedding(api_key=api_key, api_base=api_base, model=deployment)
model = AzureTextEmbedding(api_key=api_key, endpoint=api_base, deployment=deployment)
return MemoryEmbedding(embedding_model=model)
else:
return None
4 changes: 2 additions & 2 deletions tests/memory/test_file_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def test_file_memory_labels_included(memory: FileMemory):


def test_explicit_embedding_model_set():
embedding = AzureTextEmbedding(api_key="testkey", api_base="testbase", model="deployment")
embedding = AzureTextEmbedding(api_key="testkey", endpoint="testbase", deployment="deployment")

with NamedTemporaryFile(suffix=".json.memory") as tmp:
memory = FileMemory(filepath=tmp.name, embedding_model=embedding)
Expand All @@ -227,7 +227,7 @@ def test_default_embedding_model_set_none():


def test_default_embedding_model_set_correctly():
embedding = AzureTextEmbedding(api_key="testkey", api_base="testbase", model="deployment")
embedding = AzureTextEmbedding(api_key="testkey", endpoint="testbase", deployment="deployment")

with (
NamedTemporaryFile(suffix=".json.memory") as tmp,
Expand Down
53 changes: 53 additions & 0 deletions tests/test_azure_completion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import os
import pytest

from pyrit.completion import AzureCompletion


def test_valid_init():
os.environ[AzureCompletion.API_KEY_ENVIRONMENT_VARIABLE] = ""
completion = AzureCompletion(api_key="xxxxx", endpoint="https://mock.azure.com/", deployment="gpt-4")

assert completion is not None


def test_valid_init_env():
os.environ[AzureCompletion.API_KEY_ENVIRONMENT_VARIABLE] = "xxxxx"
os.environ[AzureCompletion.ENDPOINT_URI_ENVIRONMENT_VARIABLE] = "https://testcompletionendpoint"
os.environ[AzureCompletion.DEPLOYMENT_ENVIRONMENT_VARIABLE] = "testcompletiondeployment"

completion = AzureCompletion()
assert completion is not None


def test_invalid_key_raises():
os.environ[AzureCompletion.API_KEY_ENVIRONMENT_VARIABLE] = ""
with pytest.raises(ValueError):
AzureCompletion(
api_key="",
endpoint="https://mock.azure.com/",
deployment="gpt-4",
api_version="some_version",
)


def test_invalid_endpoint_raises():
os.environ[AzureCompletion.ENDPOINT_URI_ENVIRONMENT_VARIABLE] = ""
with pytest.raises(ValueError):
AzureCompletion(
api_key="xxxxxx",
deployment="gpt-4",
api_version="some_version",
)


def test_invalid_deployment_raises():
os.environ[AzureCompletion.DEPLOYMENT_ENVIRONMENT_VARIABLE] = ""
with pytest.raises(ValueError):
AzureCompletion(
api_key="",
endpoint="https://mock.azure.com/",
)
Loading

0 comments on commit b4beb27

Please sign in to comment.