diff --git a/Cargo.toml b/Cargo.toml index b7d3449e75..7469edfdf1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,7 @@ debug = true [workspace] resolver = "2" members = [ + "bin/asset-sprayer-prompts", "bin/cyclone", "bin/forklift", "bin/module-index", diff --git a/bin/asset-sprayer-prompts/BUCK b/bin/asset-sprayer-prompts/BUCK new file mode 100644 index 0000000000..c8ab04d485 --- /dev/null +++ b/bin/asset-sprayer-prompts/BUCK @@ -0,0 +1,18 @@ +load( + "@prelude-si//:macros.bzl", + "rust_binary", +) + +rust_binary( + name = "asset-sprayer-prompts", + deps = [ + "//lib/asset-sprayer:asset-sprayer", + "//third-party/rust:async-openai", + "//third-party/rust:clap", + "//third-party/rust:color-eyre", + "//third-party/rust:serde_yaml", + "//third-party/rust:tokio", + ], + srcs = glob(["src/**/*.rs"]), + env = {"CARGO_BIN_NAME": "asset-sprayer-prompts"}, +) diff --git a/bin/asset-sprayer-prompts/Cargo.toml b/bin/asset-sprayer-prompts/Cargo.toml new file mode 100644 index 0000000000..ee209b5848 --- /dev/null +++ b/bin/asset-sprayer-prompts/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "asset-sprayer-prompts" +version.workspace = true +authors.workspace = true +license.workspace = true +repository.workspace = true +edition.workspace = true +rust-version.workspace = true +publish.workspace = true + +[[bin]] +name = "asset-sprayer-prompts" +path = "src/main.rs" + +[dependencies] +asset-sprayer = { path = "../../lib/asset-sprayer" } +async-openai = { workspace = true } +clap = { workspace = true } +color-eyre = { workspace = true } +serde_yaml = { workspace = true } +tokio = { workspace = true } diff --git a/bin/asset-sprayer-prompts/src/main.rs b/bin/asset-sprayer-prompts/src/main.rs new file mode 100644 index 0000000000..49dd3486ef --- /dev/null +++ b/bin/asset-sprayer-prompts/src/main.rs @@ -0,0 +1,59 @@ +use asset_sprayer::{config::AssetSprayerConfig, prompt::Prompt, AssetSprayer}; +use clap::{Parser, ValueEnum}; +use color_eyre::Result; + +const NAME: &str = "asset-sprayer-prompts"; + +#[derive(Parser, Debug)] +#[command(name = NAME, max_term_width = 100)] +pub(crate) struct Args { + /// The action to take with the prompt. + #[arg(index = 1, value_enum)] + pub action: Action, + /// The AWS command to generate an asset schema for. + #[arg(index = 2)] + pub aws_command: String, + /// The AWS subcommand to generate an asset schema for. + #[arg(index = 3)] + pub aws_subcommand: String, + /// Directory to load prompts from. + #[arg(long)] + pub prompts_dir: Option, +} + +/// The action to take with the prompt. +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum, Debug)] +pub enum Action { + /// Show the prompt. + Show, + /// Run the prompt. + Run, +} + +#[tokio::main] +async fn main() -> Result<()> { + color_eyre::install()?; + let args = Args::parse(); + + let asset_sprayer = AssetSprayer::new( + async_openai::Client::new(), + AssetSprayerConfig { + prompts_dir: args.prompts_dir, + }, + ); + let prompt = Prompt::AwsAssetSchema { + command: args.aws_command.clone(), + subcommand: args.aws_subcommand.clone(), + }; + match args.action { + Action::Show => { + let prompt = asset_sprayer.prompt(&prompt).await?; + println!("{}", serde_yaml::to_string(&prompt)?); + } + Action::Run => { + let asset_schema = asset_sprayer.run(&prompt).await?; + println!("{}", asset_schema); + } + } + Ok(()) +} diff --git a/lib/asset-sprayer/src/lib.rs b/lib/asset-sprayer/src/lib.rs index 9ac008104a..f480caeb7e 100644 --- a/lib/asset-sprayer/src/lib.rs +++ b/lib/asset-sprayer/src/lib.rs @@ -25,14 +25,16 @@ while_true )] -use async_openai::config::OpenAIConfig; +use std::path::PathBuf; + +use async_openai::{config::OpenAIConfig, types::CreateChatCompletionRequest}; use config::AssetSprayerConfig; -use prompts::{Prompt, Prompts}; +use prompt::{Prompt, PromptKind}; use telemetry::prelude::*; use thiserror::Error; pub mod config; -pub mod prompts; +pub mod prompt; #[remain::sorted] #[derive(Debug, Error)] @@ -42,7 +44,7 @@ pub enum AssetSprayerError { #[error("I/O error: {0}")] Io(#[from] std::io::Error), #[error("Missing end {{/FETCH}} after {{FETCH}}: {0}")] - MissingEndFetch(Prompt), + MissingEndFetch(PromptKind), #[error("No choices were returned from the AI.")] NoChoices, #[error("OpenAI error: {0}")] @@ -55,12 +57,12 @@ pub enum AssetSprayerError { Unreachable, } -pub type AssetSprayerResult = Result; +pub type Result = std::result::Result; #[derive(Debug, Clone)] pub struct AssetSprayer { - openai_client: async_openai::Client, - prompts: Prompts, + pub openai_client: async_openai::Client, + pub prompts_dir: Option, } impl AssetSprayer { @@ -70,32 +72,18 @@ impl AssetSprayer { ) -> Self { Self { openai_client, - prompts: Prompts::new(config.prompts_dir.map(Into::into)), + prompts_dir: config.prompts_dir.map(Into::into), } } - pub async fn aws_asset_schema( - &self, - aws_command: &str, - aws_subcommand: &str, - ) -> AssetSprayerResult { - debug!( - "Generating asset schema for 'aws {} {}'", - aws_command, aws_subcommand - ); - self.run( - Prompt::AssetSchema, - &[ - ("{AWS_COMMAND}", aws_command), - ("{AWS_SUBCOMMAND}", aws_subcommand), - ], - ) - .await + pub async fn prompt(&self, prompt: &Prompt) -> Result { + prompt.prompt(&self.prompts_dir).await } - async fn run(&self, prompt: Prompt, replace: &[(&str, &str)]) -> AssetSprayerResult { - let request = self.prompts.create_request(prompt, replace).await?; - let response = self.openai_client.chat().create(request).await?; + pub async fn run(&self, prompt: &Prompt) -> Result { + debug!("Generating {}", prompt); + let prompt = self.prompt(prompt).await?; + let response = self.openai_client.chat().create(prompt).await?; let choice = response .choices .into_iter() @@ -111,13 +99,16 @@ impl AssetSprayer { #[ignore = "You must have OPENAI_API_KEY set to run this test"] #[tokio::test] -async fn test_do_ai() -> AssetSprayerResult<()> { +async fn test_do_ai() -> Result<()> { let asset_sprayer = AssetSprayer::new(async_openai::Client::new(), AssetSprayerConfig::default()); println!( "Done: {}", asset_sprayer - .aws_asset_schema("sqs", "create-queue") + .run(&Prompt::AwsAssetSchema { + command: "sqs".into(), + subcommand: "create-queue".into() + }) .await? ); Ok(()) diff --git a/lib/asset-sprayer/src/prompt.rs b/lib/asset-sprayer/src/prompt.rs new file mode 100644 index 0000000000..113b47437c --- /dev/null +++ b/lib/asset-sprayer/src/prompt.rs @@ -0,0 +1,139 @@ +use std::{borrow::Cow, path::PathBuf}; + +use async_openai::types::{ + ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage, + ChatCompletionRequestSystemMessageContent, ChatCompletionRequestUserMessage, + ChatCompletionRequestUserMessageContent, CreateChatCompletionRequest, +}; +use telemetry::prelude::*; + +use crate::{AssetSprayerError, Result}; + +#[derive(Debug, Clone, strum::Display, strum::EnumDiscriminants)] +#[strum_discriminants(name(PromptKind))] +#[strum_discriminants(derive(strum::Display))] +pub enum Prompt { + AwsAssetSchema { command: String, subcommand: String }, +} + +impl Prompt { + pub fn kind(&self) -> PromptKind { + self.into() + } + + pub async fn prompt( + &self, + prompts_dir: &Option, + ) -> Result { + let raw_prompt = self.kind().raw_prompt(prompts_dir).await?; + self.replace_prompt(raw_prompt).await + } + + async fn replace_prompt( + &self, + request: CreateChatCompletionRequest, + ) -> Result { + let mut request = request; + for message in request.messages.iter_mut() { + *message = self.replace_prompt_message(message.clone()).await?; + } + Ok(request) + } + + async fn replace_prompt_message( + &self, + message: ChatCompletionRequestMessage, + ) -> Result { + let mut message = message; + match &mut message { + ChatCompletionRequestMessage::User(ChatCompletionRequestUserMessage { + content: ChatCompletionRequestUserMessageContent::Text(text), + .. + }) + | ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage { + content: ChatCompletionRequestSystemMessageContent::Text(text), + .. + }) => *text = self.replace_prompt_text(text).await?, + _ => {} + }; + Ok(message) + } + + async fn replace_prompt_text(&self, text: &str) -> Result { + let text = match self { + Self::AwsAssetSchema { + command, + subcommand, + } => text + .replace("{AWS_COMMAND}", command) + .replace("{AWS_SUBCOMMAND}", subcommand), + }; + self.fetch_prompt_text(&text).await + } + + async fn fetch_prompt_text(&self, text: &str) -> Result { + // Fetch things between {FETCH} and {/FETCH} + let mut result = String::new(); + let mut text = text; + while let Some(fetch_start) = text.find("{FETCH}") { + // Copy up to {FETCH} + result.push_str(&text[..fetch_start]); + text = &text[(fetch_start + "{FETCH}".len())..]; + + if let Some(url_end) = text.find("{/FETCH}") { + // Fetch the URL between {FETCH}...{/FETCH} + result.push_str(&Self::get(&text[..url_end]).await?); + text = &text[(url_end + "{/FETCH}".len())..]; + } else { + return Err(AssetSprayerError::MissingEndFetch(self.kind())); + } + } + + // Copy the remainder of the text + result.push_str(text); + + Ok(result) + } + + async fn get(url: &str) -> reqwest::Result { + info!("Fetching: {}", url); + let client = reqwest::ClientBuilder::new() + .user_agent("Wget/1.21.2") + .build()?; + let response = client.get(url).send().await?; + response.error_for_status()?.text().await + } +} + +impl PromptKind { + pub async fn raw_prompt( + &self, + prompts_dir: &Option, + ) -> Result { + Ok(serde_yaml::from_str(&self.yaml(prompts_dir).await?)?) + } + + async fn yaml(&self, prompts_dir: &Option) -> Result> { + if let Some(ref prompts_dir) = prompts_dir { + // Read from disk if prompts_dir is available (faster dev cycle) + let path = prompts_dir.join(self.yaml_relative_path()); + info!("Loading prompt for {} from disk at {:?}", self, path); + Ok(tokio::fs::read_to_string(path).await?.into()) + } else { + info!("Loading embedded prompt for {}", self); + Ok(self.yaml_embedded().into()) + } + } + + fn yaml_relative_path(&self) -> &str { + match self { + Self::AwsAssetSchema => "aws/asset_schema.yaml", + } + } + + fn yaml_embedded(&self) -> &'static str { + match self { + Self::AwsAssetSchema => include_str!("../prompts/aws/asset_schema.yaml"), + } + } +} diff --git a/lib/asset-sprayer/src/prompts.rs b/lib/asset-sprayer/src/prompts.rs deleted file mode 100644 index 2545c9ef09..0000000000 --- a/lib/asset-sprayer/src/prompts.rs +++ /dev/null @@ -1,150 +0,0 @@ -use std::path::PathBuf; - -use async_openai::types::{ - ChatCompletionRequestMessage, ChatCompletionRequestSystemMessageContent, - ChatCompletionRequestUserMessageContent, CreateChatCompletionRequest, -}; -use telemetry::prelude::*; - -use crate::{AssetSprayerError, AssetSprayerResult}; - -#[derive(Debug, Clone)] -pub struct Prompts { - prompts_dir: Option, -} - -#[derive(Debug, Clone, Copy, strum::Display)] -pub enum Prompt { - AssetSchema, -} - -impl Prompt { - fn yaml_relative_path(&self) -> &str { - match self { - Prompt::AssetSchema => "aws/asset_schema.yaml", - } - } - - fn yaml_embedded(&self) -> &'static str { - match self { - Prompt::AssetSchema => include_str!("../prompts/aws/asset_schema.yaml"), - } - } -} - -impl Prompts { - pub fn new(prompts_dir: Option) -> Self { - Self { - prompts_dir: prompts_dir.map(Into::into), - } - } - - pub async fn create_request( - &self, - prompt: Prompt, - replace: &[(&str, &str)], - ) -> AssetSprayerResult { - let request = self.raw_request(prompt).await?; - Self::replace_prompt_request(request, replace, prompt).await - } - - async fn raw_request(&self, prompt: Prompt) -> AssetSprayerResult { - Ok(serde_yaml::from_str(&self.yaml(prompt).await?)?) - } - - async fn yaml(&self, prompt: Prompt) -> AssetSprayerResult { - if let Some(ref prompts_dir) = self.prompts_dir { - // Read from disk if prompts_dir is available (faster dev cycle) - let path = prompts_dir.join(prompt.yaml_relative_path()); - info!("Loading prompt for {} from disk at {:?}", prompt, path); - Ok(tokio::fs::read_to_string(path).await?) - } else { - info!("Loading embedded prompt for {}", prompt); - Ok(prompt.yaml_embedded().to_string()) - } - } - - async fn replace_prompt_request( - request: CreateChatCompletionRequest, - replace: &[(&str, &str)], - prompt: Prompt, - ) -> AssetSprayerResult { - let mut request = request; - for message in request.messages.iter_mut() { - *message = - Self::replace_prompt_request_message(message.clone(), replace, prompt).await?; - } - Ok(request) - } - - async fn replace_prompt_request_message( - message: ChatCompletionRequestMessage, - replace: &[(&str, &str)], - prompt: Prompt, - ) -> AssetSprayerResult { - let mut message = message; - match &mut message { - ChatCompletionRequestMessage::User(message) => { - if let ChatCompletionRequestUserMessageContent::Text(text) = &mut message.content { - *text = Self::replace_prompt_text(text.clone(), replace, prompt).await?; - } - } - ChatCompletionRequestMessage::System(message) => { - if let ChatCompletionRequestSystemMessageContent::Text(text) = &mut message.content - { - *text = Self::replace_prompt_text(text.clone(), replace, prompt).await?; - } - } - _ => (), - } - Ok(message) - } - - async fn replace_prompt_text( - text: String, - replace: &[(&str, &str)], - prompt: Prompt, - ) -> AssetSprayerResult { - let mut text = text; - - // Replace {KEY} with value - for (from, to) in replace { - text = text.replace(from, to); - } - - Self::fetch_prompt_text(&text, prompt).await - } - - async fn fetch_prompt_text(text: &str, prompt: Prompt) -> AssetSprayerResult { - // Fetch things between {FETCH} and {/FETCH} - let mut result = String::new(); - let mut text = text; - while let Some(fetch_start) = text.find("{FETCH}") { - // Copy up to {FETCH} - result.push_str(&text[..fetch_start]); - text = &text[(fetch_start + "{FETCH}".len())..]; - - if let Some(url_end) = text.find("{/FETCH}") { - // Fetch the URL between {FETCH}...{/FETCH} - result.push_str(&Self::get(&text[..url_end]).await?); - text = &text[(url_end + "{/FETCH}".len())..]; - } else { - return Err(AssetSprayerError::MissingEndFetch(prompt)); - } - } - - // Copy the remainder of the text - result.push_str(text); - - Ok(result) - } - - async fn get(url: &str) -> reqwest::Result { - info!("Fetching: {}", url); - let client = reqwest::ClientBuilder::new() - .user_agent("Wget/1.21.2") - .build()?; - let response = client.get(url).send().await?; - response.error_for_status()?.text().await - } -} diff --git a/lib/sdf-server/src/service/v2/variant/generate_aws_asset_schema.rs b/lib/sdf-server/src/service/v2/variant/generate_aws_asset_schema.rs index b4c370a47f..43a7594063 100644 --- a/lib/sdf-server/src/service/v2/variant/generate_aws_asset_schema.rs +++ b/lib/sdf-server/src/service/v2/variant/generate_aws_asset_schema.rs @@ -1,3 +1,4 @@ +use asset_sprayer::prompt::Prompt; use axum::extract::{Host, OriginalUri, Path, Query}; use dal::{ schema::variant::authoring::VariantAuthoringClient, ChangeSet, ChangeSetId, SchemaVariant, @@ -41,9 +42,11 @@ pub async fn generate_aws_asset_schema( let force_change_set_id = ChangeSet::force_new(&mut ctx).await?; // Generate the code - let code = asset_sprayer - .aws_asset_schema(&aws_command.command, &aws_command.subcommand) - .await?; + let prompt = Prompt::AwsAssetSchema { + command: aws_command.command.clone(), + subcommand: aws_command.subcommand.clone(), + }; + let code = asset_sprayer.run(&prompt).await?; // Update the function let schema_variant = SchemaVariant::get_by_id_or_error(&ctx, schema_variant_id).await?;