Skip to content

Commit

Permalink
works
Browse files Browse the repository at this point in the history
  • Loading branch information
PankajBhojwani committed Jun 7, 2024
1 parent b44216b commit e78e4d0
Show file tree
Hide file tree
Showing 7 changed files with 265 additions and 166 deletions.
176 changes: 174 additions & 2 deletions src/cascadia/QueryExtension/AzureLLMProvider.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,22 @@
#include "LibraryResources.h"

#include "AzureLLMProvider.g.cpp"
#include "AzureResponse.g.cpp"

using namespace winrt::Windows::Foundation;
using namespace winrt::Windows::Foundation::Collections;
using namespace winrt::Windows::UI::Core;
using namespace winrt::Windows::UI::Xaml;
using namespace winrt::Windows::UI::Xaml::Controls;
using namespace winrt::Windows::System;
namespace WWH = ::winrt::Windows::Web::Http;
namespace WSS = ::winrt::Windows::Storage::Streams;
namespace WDJ = ::winrt::Windows::Data::Json;

static constexpr std::wstring_view acceptedModel{ L"gpt-35-turbo" };
static constexpr std::wstring_view acceptedSeverityLevel{ L"safe" };

const std::wregex azureOpenAIEndpointRegex{ LR"(^https.*openai\.azure\.com)" };

namespace winrt::Microsoft::Terminal::Query::Extension::implementation
{
Expand All @@ -19,8 +35,164 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation
_httpClient.DefaultRequestHeaders().Append(L"api-key", _AIKey);
}

void AzureLLMProvider::Initialize()
void AzureLLMProvider::ClearMessageHistory()
{
_jsonMessages.Clear();
}

void AzureLLMProvider::SetSystemPrompt(const winrt::hstring& systemPrompt)
{
WDJ::JsonObject systemMessageObject;
winrt::hstring systemMessageContent{ systemPrompt };
systemMessageObject.Insert(L"role", WDJ::JsonValue::CreateStringValue(L"system"));
systemMessageObject.Insert(L"content", WDJ::JsonValue::CreateStringValue(systemMessageContent));
_jsonMessages.Append(systemMessageObject);
}

void AzureLLMProvider::SetContext(Extension::IContext context)
{
_context = context;
}

winrt::Windows::Foundation::IAsyncOperation<Extension::IResponse> AzureLLMProvider::GetResponseAsync(const winrt::hstring& userPrompt)
{
// Use a flag for whether the response the user receives is an error message
// we pass this flag back to the caller so they can handle it appropriately (specifically, ExtensionPalette will send the correct telemetry event)
// there is only one case downstream from here that sets this flag to false, so start with it being true
bool isError{ true };
hstring message{};

// If the AI key and endpoint is still empty, tell the user to fill them out in settings
if (_AIKey.empty() || _AIEndpoint.empty())
{
message = RS_(L"CouldNotFindKeyErrorMessage");
}
else if (!std::regex_search(_AIEndpoint.c_str(), azureOpenAIEndpointRegex))
{
message = RS_(L"InvalidEndpointMessage");
}

// If we don't have a message string, that means the endpoint exists and matches the regex
// that we allow - now we can actually make the http request
if (message.empty())
{
// Make a copy of the prompt because we are switching threads
const auto promptCopy{ userPrompt };

// Make sure we are on the background thread for the http request
co_await winrt::resume_background();

WWH::HttpRequestMessage request{ WWH::HttpMethod::Post(), Uri{ _AIEndpoint } };
request.Headers().Accept().TryParseAdd(L"application/json");

WDJ::JsonObject jsonContent;
WDJ::JsonObject messageObject;

// _ActiveCommandline should be set already, we request for it the moment we become visible
winrt::hstring engineeredPrompt{ promptCopy };
if (_context && !_context.ActiveCommandline().empty())
{
engineeredPrompt = promptCopy + L". The shell I am running is " + _context.ActiveCommandline();
}
messageObject.Insert(L"role", WDJ::JsonValue::CreateStringValue(L"user"));
messageObject.Insert(L"content", WDJ::JsonValue::CreateStringValue(engineeredPrompt));
_jsonMessages.Append(messageObject);
jsonContent.SetNamedValue(L"messages", _jsonMessages);
jsonContent.SetNamedValue(L"max_tokens", WDJ::JsonValue::CreateNumberValue(800));
jsonContent.SetNamedValue(L"temperature", WDJ::JsonValue::CreateNumberValue(0.7));
jsonContent.SetNamedValue(L"frequency_penalty", WDJ::JsonValue::CreateNumberValue(0));
jsonContent.SetNamedValue(L"presence_penalty", WDJ::JsonValue::CreateNumberValue(0));
jsonContent.SetNamedValue(L"top_p", WDJ::JsonValue::CreateNumberValue(0.95));
jsonContent.SetNamedValue(L"stop", WDJ::JsonValue::CreateStringValue(L"None"));
const auto stringContent = jsonContent.ToString();
WWH::HttpStringContent requestContent{
stringContent,
WSS::UnicodeEncoding::Utf8,
L"application/json"
};

request.Content(requestContent);

// Send the request
try
{
const auto response = _httpClient.SendRequestAsync(request).get();
// Parse out the suggestion from the response
const auto string{ response.Content().ReadAsStringAsync().get() };
const auto jsonResult{ WDJ::JsonObject::Parse(string) };
if (jsonResult.HasKey(L"error"))
{
const auto errorObject = jsonResult.GetNamedObject(L"error");
message = errorObject.GetNamedString(L"message");
}
else
{
if (_verifyModelIsValidHelper(jsonResult))
{
const auto choices = jsonResult.GetNamedArray(L"choices");
const auto firstChoice = choices.GetAt(0).GetObject();
const auto messageObject = firstChoice.GetNamedObject(L"message");
message = messageObject.GetNamedString(L"content");
isError = false;
}
else
{
message = RS_(L"InvalidModelMessage");
}
}
}
catch (...)
{
message = RS_(L"UnknownErrorMessage");
}
}

// Also make a new entry in our jsonMessages list, so the AI knows the full conversation so far
WDJ::JsonObject responseMessageObject;
responseMessageObject.Insert(L"role", WDJ::JsonValue::CreateStringValue(L"assistant"));
responseMessageObject.Insert(L"content", WDJ::JsonValue::CreateStringValue(message));
_jsonMessages.Append(responseMessageObject);

co_return winrt::make<AzureResponse>(message, isError);
}

bool AzureLLMProvider::_verifyModelIsValidHelper(const WDJ::JsonObject jsonResponse)
{
_Thing = L"ayy lmao";
if (jsonResponse.GetNamedString(L"model") != acceptedModel)
{
return false;
}
WDJ::JsonObject contentFiltersObject;
// For some reason, sometimes the content filter results are in a key called "prompt_filter_results"
// and sometimes they are in a key called "prompt_annotations". Check for either.
if (jsonResponse.HasKey(L"prompt_filter_results"))
{
contentFiltersObject = jsonResponse.GetNamedArray(L"prompt_filter_results").GetObjectAt(0);
}
else if (jsonResponse.HasKey(L"prompt_annotations"))
{
contentFiltersObject = jsonResponse.GetNamedArray(L"prompt_annotations").GetObjectAt(0);
}
else
{
return false;
}
const auto contentFilters = contentFiltersObject.GetNamedObject(L"content_filter_results");
if (Feature_TerminalChatJailbreakFilter::IsEnabled() && !contentFilters.HasKey(L"jailbreak"))
{
return false;
}
for (const auto filterPair : contentFilters)
{
const auto filterLevel = filterPair.Value().GetObjectW();
if (filterLevel.HasKey(L"severity"))
{
if (filterLevel.GetNamedString(L"severity") != acceptedSeverityLevel)
{
return false;
}
}
}
return true;
}
}
26 changes: 24 additions & 2 deletions src/cascadia/QueryExtension/AzureLLMProvider.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,48 @@
#pragma once

#include "AzureLLMProvider.g.h"
#include "AzureResponse.g.h"

namespace winrt::Microsoft::Terminal::Query::Extension::implementation
{
struct AzureLLMProvider : AzureLLMProviderT<AzureLLMProvider>
{
AzureLLMProvider(winrt::hstring endpoint, winrt::hstring key);
void Initialize();

WINRT_PROPERTY(winrt::hstring, Thing);
void ClearMessageHistory();
void SetSystemPrompt(const winrt::hstring& systemPrompt);
void SetContext(Extension::IContext context);

winrt::Windows::Foundation::IAsyncOperation<Extension::IResponse> GetResponseAsync(const winrt::hstring& userPrompt);

private:
winrt::hstring _AIEndpoint;
winrt::hstring _AIKey;
winrt::Windows::Web::Http::HttpClient _httpClient{ nullptr };

Extension::IContext _context;

winrt::Windows::Data::Json::JsonArray _jsonMessages;

bool _verifyModelIsValidHelper(const Windows::Data::Json::JsonObject jsonResponse);
};

struct AzureResponse : AzureResponseT<AzureResponse>
{
AzureResponse(winrt::hstring message, bool isError) :
_message{ message },
_isError{ isError } {}
winrt::hstring Message() { return _message; };
bool IsError() { return _isError; };

private:
winrt::hstring _message;
bool _isError;
};
}

namespace winrt::Microsoft::Terminal::Query::Extension::factory_implementation
{
BASIC_FACTORY(AzureLLMProvider);
BASIC_FACTORY(AzureResponse);
}
7 changes: 5 additions & 2 deletions src/cascadia/QueryExtension/AzureLLMProvider.idl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@ namespace Microsoft.Terminal.Query.Extension
{
[default_interface] runtimeclass AzureLLMProvider : ILLMProvider

Check failure

Code scanning / check-spelling

Unrecognized Spelling Error

ILLM is not a recognized word. (unrecognized-spelling)
{
AzureLLMProvider(String endpt, String key);
AzureLLMProvider(String endpoint, String key);
}

String Thing();
[default_interface] runtimeclass AzureResponse : IResponse
{
AzureResponse(String message, Boolean isError);
}
}
Loading

1 comment on commit e78e4d0

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@check-spelling-bot Report

🔴 Please review

See the 📜action log or 📝 job summary for details.

Unrecognized words (2)

ILLM
llm

Previously acknowledged words that are now absent CRLFs Redir wcsicmp 🫥
To accept these unrecognized words as correct and remove the previously acknowledged and now absent words, you could run the following commands

... in a clone of the [email protected]:microsoft/terminal.git repository
on the dev/pabhoj/llm_provider_interface branch (ℹ️ how do I use this?):

curl -s -S -L 'https://raw.githubusercontent.com/check-spelling/check-spelling/v0.0.22/apply.pl' |
perl - 'https://github.com/microsoft/terminal/actions/runs/9421656986/attempts/1'
Available 📚 dictionaries could cover words (expected and unrecognized) not in the 📘 dictionary

This includes both expected items (2212) from .github/actions/spelling/expect/04cdb9b77d6827c0202f51acd4205b017015bfff.txt
.github/actions/spelling/expect/alphabet.txt
.github/actions/spelling/expect/expect.txt
.github/actions/spelling/expect/web.txt and unrecognized words (2)

Dictionary Entries Covers Uniquely
cspell:cpp/src/lang-jargon.txt 11 1 1
cspell:swift/src/swift.txt 53 1 1
cspell:gaming-terms/dict/gaming-terms.txt 59 1 1
cspell:monkeyc/src/monkeyc_keywords.txt 123 1 1
cspell:cryptocurrencies/cryptocurrencies.txt 125 1 1

Consider adding them (in .github/workflows/spelling2.yml) for uses: check-spelling/[email protected] in its with:

      with:
        extra_dictionaries:
          cspell:cpp/src/lang-jargon.txt
          cspell:swift/src/swift.txt
          cspell:gaming-terms/dict/gaming-terms.txt
          cspell:monkeyc/src/monkeyc_keywords.txt
          cspell:cryptocurrencies/cryptocurrencies.txt

To stop checking additional dictionaries, add (in .github/workflows/spelling2.yml) for uses: check-spelling/[email protected] in its with:

check_extra_dictionaries: ''
Errors (2)

See the 📜action log or 📝 job summary for details.

❌ Errors Count
❌ check-file-path 1
❌ ignored-expect-variant 3

See ❌ Event descriptions for more information.

✏️ Contributor please read this

By default the command suggestion will generate a file named based on your commit. That's generally ok as long as you add the file to your commit. Someone can reorganize it later.

If the listed items are:

  • ... misspelled, then please correct them instead of using the command.
  • ... names, please add them to .github/actions/spelling/allow/names.txt.
  • ... APIs, you can add them to a file in .github/actions/spelling/allow/.
  • ... just things you're using, please add them to an appropriate file in .github/actions/spelling/expect/.
  • ... tokens you only need in one place and shouldn't generally be used, you can add an item in an appropriate file in .github/actions/spelling/patterns/.

See the README.md in each directory for more information.

🔬 You can test your commits without appending to a PR by creating a new branch with that extra change and pushing it to your fork. The check-spelling action will run in response to your push -- it doesn't require an open pull request. By using such a branch, you can limit the number of typos your peers see you make. 😉

If the flagged items are 🤯 false positives

If items relate to a ...

  • binary file (or some other file you wouldn't want to check at all).

    Please add a file path to the excludes.txt file matching the containing file.

    File paths are Perl 5 Regular Expressions - you can test yours before committing to verify it will match your files.

    ^ refers to the file's path from the root of the repository, so ^README\.md$ would exclude README.md (on whichever branch you're using).

  • well-formed pattern.

    If you can write a pattern that would match it,
    try adding it to the patterns.txt file.

    Patterns are Perl 5 Regular Expressions - you can test yours before committing to verify it will match your lines.

    Note that patterns can't match multiline strings.

Please sign in to comment.