Skip to content

Commit

Permalink
Merge pull request #4920 from systeminit/jkeiser/asset-sprayer-prompts
Browse files Browse the repository at this point in the history
Add asset-sprayer-prompts binary to show or run prompts from the command
  • Loading branch information
jkeiser authored Nov 5, 2024
2 parents c62f022 + a84ef3a commit 3c67e5b
Show file tree
Hide file tree
Showing 8 changed files with 265 additions and 183 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ debug = true
[workspace]
resolver = "2"
members = [
"bin/asset-sprayer-prompts",
"bin/cyclone",
"bin/forklift",
"bin/module-index",
Expand Down
18 changes: 18 additions & 0 deletions bin/asset-sprayer-prompts/BUCK
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
load(
"@prelude-si//:macros.bzl",
"rust_binary",
)

rust_binary(
name = "asset-sprayer-prompts",
deps = [
"//lib/asset-sprayer:asset-sprayer",
"//third-party/rust:async-openai",
"//third-party/rust:clap",
"//third-party/rust:color-eyre",
"//third-party/rust:serde_yaml",
"//third-party/rust:tokio",
],
srcs = glob(["src/**/*.rs"]),
env = {"CARGO_BIN_NAME": "asset-sprayer-prompts"},
)
21 changes: 21 additions & 0 deletions bin/asset-sprayer-prompts/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
[package]
name = "asset-sprayer-prompts"
version.workspace = true
authors.workspace = true
license.workspace = true
repository.workspace = true
edition.workspace = true
rust-version.workspace = true
publish.workspace = true

[[bin]]
name = "asset-sprayer-prompts"
path = "src/main.rs"

[dependencies]
asset-sprayer = { path = "../../lib/asset-sprayer" }
async-openai = { workspace = true }
clap = { workspace = true }
color-eyre = { workspace = true }
serde_yaml = { workspace = true }
tokio = { workspace = true }
59 changes: 59 additions & 0 deletions bin/asset-sprayer-prompts/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
use asset_sprayer::{config::AssetSprayerConfig, prompt::Prompt, AssetSprayer};
use clap::{Parser, ValueEnum};
use color_eyre::Result;

const NAME: &str = "asset-sprayer-prompts";

#[derive(Parser, Debug)]
#[command(name = NAME, max_term_width = 100)]
pub(crate) struct Args {
/// The action to take with the prompt.
#[arg(index = 1, value_enum)]
pub action: Action,
/// The AWS command to generate an asset schema for.
#[arg(index = 2)]
pub aws_command: String,
/// The AWS subcommand to generate an asset schema for.
#[arg(index = 3)]
pub aws_subcommand: String,
/// Directory to load prompts from.
#[arg(long)]
pub prompts_dir: Option<String>,
}

/// The action to take with the prompt.
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum, Debug)]
pub enum Action {
/// Show the prompt.
Show,
/// Run the prompt.
Run,
}

#[tokio::main]
async fn main() -> Result<()> {
color_eyre::install()?;
let args = Args::parse();

let asset_sprayer = AssetSprayer::new(
async_openai::Client::new(),
AssetSprayerConfig {
prompts_dir: args.prompts_dir,
},
);
let prompt = Prompt::AwsAssetSchema {
command: args.aws_command.clone(),
subcommand: args.aws_subcommand.clone(),
};
match args.action {
Action::Show => {
let prompt = asset_sprayer.prompt(&prompt).await?;
println!("{}", serde_yaml::to_string(&prompt)?);
}
Action::Run => {
let asset_schema = asset_sprayer.run(&prompt).await?;
println!("{}", asset_schema);
}
}
Ok(())
}
51 changes: 21 additions & 30 deletions lib/asset-sprayer/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,16 @@
while_true
)]

use async_openai::config::OpenAIConfig;
use std::path::PathBuf;

use async_openai::{config::OpenAIConfig, types::CreateChatCompletionRequest};
use config::AssetSprayerConfig;
use prompts::{Prompt, Prompts};
use prompt::{Prompt, PromptKind};
use telemetry::prelude::*;
use thiserror::Error;

pub mod config;
pub mod prompts;
pub mod prompt;

#[remain::sorted]
#[derive(Debug, Error)]
Expand All @@ -42,7 +44,7 @@ pub enum AssetSprayerError {
#[error("I/O error: {0}")]
Io(#[from] std::io::Error),
#[error("Missing end {{/FETCH}} after {{FETCH}}: {0}")]
MissingEndFetch(Prompt),
MissingEndFetch(PromptKind),
#[error("No choices were returned from the AI.")]
NoChoices,
#[error("OpenAI error: {0}")]
Expand All @@ -55,12 +57,12 @@ pub enum AssetSprayerError {
Unreachable,
}

pub type AssetSprayerResult<T> = Result<T, AssetSprayerError>;
pub type Result<T> = std::result::Result<T, AssetSprayerError>;

#[derive(Debug, Clone)]
pub struct AssetSprayer {
openai_client: async_openai::Client<OpenAIConfig>,
prompts: Prompts,
pub openai_client: async_openai::Client<OpenAIConfig>,
pub prompts_dir: Option<PathBuf>,
}

impl AssetSprayer {
Expand All @@ -70,32 +72,18 @@ impl AssetSprayer {
) -> Self {
Self {
openai_client,
prompts: Prompts::new(config.prompts_dir.map(Into::into)),
prompts_dir: config.prompts_dir.map(Into::into),
}
}

pub async fn aws_asset_schema(
&self,
aws_command: &str,
aws_subcommand: &str,
) -> AssetSprayerResult<String> {
debug!(
"Generating asset schema for 'aws {} {}'",
aws_command, aws_subcommand
);
self.run(
Prompt::AssetSchema,
&[
("{AWS_COMMAND}", aws_command),
("{AWS_SUBCOMMAND}", aws_subcommand),
],
)
.await
pub async fn prompt(&self, prompt: &Prompt) -> Result<CreateChatCompletionRequest> {
prompt.prompt(&self.prompts_dir).await
}

async fn run(&self, prompt: Prompt, replace: &[(&str, &str)]) -> AssetSprayerResult<String> {
let request = self.prompts.create_request(prompt, replace).await?;
let response = self.openai_client.chat().create(request).await?;
pub async fn run(&self, prompt: &Prompt) -> Result<String> {
debug!("Generating {}", prompt);
let prompt = self.prompt(prompt).await?;
let response = self.openai_client.chat().create(prompt).await?;
let choice = response
.choices
.into_iter()
Expand All @@ -111,13 +99,16 @@ impl AssetSprayer {

#[ignore = "You must have OPENAI_API_KEY set to run this test"]
#[tokio::test]
async fn test_do_ai() -> AssetSprayerResult<()> {
async fn test_do_ai() -> Result<()> {
let asset_sprayer =
AssetSprayer::new(async_openai::Client::new(), AssetSprayerConfig::default());
println!(
"Done: {}",
asset_sprayer
.aws_asset_schema("sqs", "create-queue")
.run(&Prompt::AwsAssetSchema {
command: "sqs".into(),
subcommand: "create-queue".into()
})
.await?
);
Ok(())
Expand Down
139 changes: 139 additions & 0 deletions lib/asset-sprayer/src/prompt.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
use std::{borrow::Cow, path::PathBuf};

use async_openai::types::{
ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage,
ChatCompletionRequestSystemMessageContent, ChatCompletionRequestUserMessage,
ChatCompletionRequestUserMessageContent, CreateChatCompletionRequest,
};
use telemetry::prelude::*;

use crate::{AssetSprayerError, Result};

#[derive(Debug, Clone, strum::Display, strum::EnumDiscriminants)]
#[strum_discriminants(name(PromptKind))]
#[strum_discriminants(derive(strum::Display))]
pub enum Prompt {
AwsAssetSchema { command: String, subcommand: String },
}

impl Prompt {
pub fn kind(&self) -> PromptKind {
self.into()
}

pub async fn prompt(
&self,
prompts_dir: &Option<PathBuf>,
) -> Result<CreateChatCompletionRequest> {
let raw_prompt = self.kind().raw_prompt(prompts_dir).await?;
self.replace_prompt(raw_prompt).await
}

async fn replace_prompt(
&self,
request: CreateChatCompletionRequest,
) -> Result<CreateChatCompletionRequest> {
let mut request = request;
for message in request.messages.iter_mut() {
*message = self.replace_prompt_message(message.clone()).await?;
}
Ok(request)
}

async fn replace_prompt_message(
&self,
message: ChatCompletionRequestMessage,
) -> Result<ChatCompletionRequestMessage> {
let mut message = message;
match &mut message {
ChatCompletionRequestMessage::User(ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Text(text),
..
})
| ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage {
content: ChatCompletionRequestSystemMessageContent::Text(text),
..
}) => *text = self.replace_prompt_text(text).await?,
_ => {}
};
Ok(message)
}

async fn replace_prompt_text(&self, text: &str) -> Result<String> {
let text = match self {
Self::AwsAssetSchema {
command,
subcommand,
} => text
.replace("{AWS_COMMAND}", command)
.replace("{AWS_SUBCOMMAND}", subcommand),
};
self.fetch_prompt_text(&text).await
}

async fn fetch_prompt_text(&self, text: &str) -> Result<String> {
// Fetch things between {FETCH} and {/FETCH}
let mut result = String::new();
let mut text = text;
while let Some(fetch_start) = text.find("{FETCH}") {
// Copy up to {FETCH}
result.push_str(&text[..fetch_start]);
text = &text[(fetch_start + "{FETCH}".len())..];

if let Some(url_end) = text.find("{/FETCH}") {
// Fetch the URL between {FETCH}...{/FETCH}
result.push_str(&Self::get(&text[..url_end]).await?);
text = &text[(url_end + "{/FETCH}".len())..];
} else {
return Err(AssetSprayerError::MissingEndFetch(self.kind()));
}
}

// Copy the remainder of the text
result.push_str(text);

Ok(result)
}

async fn get(url: &str) -> reqwest::Result<String> {
info!("Fetching: {}", url);
let client = reqwest::ClientBuilder::new()
.user_agent("Wget/1.21.2")
.build()?;
let response = client.get(url).send().await?;
response.error_for_status()?.text().await
}
}

impl PromptKind {
pub async fn raw_prompt(
&self,
prompts_dir: &Option<PathBuf>,
) -> Result<CreateChatCompletionRequest> {
Ok(serde_yaml::from_str(&self.yaml(prompts_dir).await?)?)
}

async fn yaml(&self, prompts_dir: &Option<PathBuf>) -> Result<Cow<'static, str>> {
if let Some(ref prompts_dir) = prompts_dir {
// Read from disk if prompts_dir is available (faster dev cycle)
let path = prompts_dir.join(self.yaml_relative_path());
info!("Loading prompt for {} from disk at {:?}", self, path);
Ok(tokio::fs::read_to_string(path).await?.into())
} else {
info!("Loading embedded prompt for {}", self);
Ok(self.yaml_embedded().into())
}
}

fn yaml_relative_path(&self) -> &str {
match self {
Self::AwsAssetSchema => "aws/asset_schema.yaml",
}
}

fn yaml_embedded(&self) -> &'static str {
match self {
Self::AwsAssetSchema => include_str!("../prompts/aws/asset_schema.yaml"),
}
}
}
Loading

0 comments on commit 3c67e5b

Please sign in to comment.