-
Notifications
You must be signed in to change notification settings - Fork 262
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #4920 from systeminit/jkeiser/asset-sprayer-prompts
Add asset-sprayer-prompts binary to show or run prompts from the command
- Loading branch information
Showing
8 changed files
with
265 additions
and
183 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
load( | ||
"@prelude-si//:macros.bzl", | ||
"rust_binary", | ||
) | ||
|
||
rust_binary( | ||
name = "asset-sprayer-prompts", | ||
deps = [ | ||
"//lib/asset-sprayer:asset-sprayer", | ||
"//third-party/rust:async-openai", | ||
"//third-party/rust:clap", | ||
"//third-party/rust:color-eyre", | ||
"//third-party/rust:serde_yaml", | ||
"//third-party/rust:tokio", | ||
], | ||
srcs = glob(["src/**/*.rs"]), | ||
env = {"CARGO_BIN_NAME": "asset-sprayer-prompts"}, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
[package] | ||
name = "asset-sprayer-prompts" | ||
version.workspace = true | ||
authors.workspace = true | ||
license.workspace = true | ||
repository.workspace = true | ||
edition.workspace = true | ||
rust-version.workspace = true | ||
publish.workspace = true | ||
|
||
[[bin]] | ||
name = "asset-sprayer-prompts" | ||
path = "src/main.rs" | ||
|
||
[dependencies] | ||
asset-sprayer = { path = "../../lib/asset-sprayer" } | ||
async-openai = { workspace = true } | ||
clap = { workspace = true } | ||
color-eyre = { workspace = true } | ||
serde_yaml = { workspace = true } | ||
tokio = { workspace = true } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
use asset_sprayer::{config::AssetSprayerConfig, prompt::Prompt, AssetSprayer}; | ||
use clap::{Parser, ValueEnum}; | ||
use color_eyre::Result; | ||
|
||
const NAME: &str = "asset-sprayer-prompts"; | ||
|
||
#[derive(Parser, Debug)] | ||
#[command(name = NAME, max_term_width = 100)] | ||
pub(crate) struct Args { | ||
/// The action to take with the prompt. | ||
#[arg(index = 1, value_enum)] | ||
pub action: Action, | ||
/// The AWS command to generate an asset schema for. | ||
#[arg(index = 2)] | ||
pub aws_command: String, | ||
/// The AWS subcommand to generate an asset schema for. | ||
#[arg(index = 3)] | ||
pub aws_subcommand: String, | ||
/// Directory to load prompts from. | ||
#[arg(long)] | ||
pub prompts_dir: Option<String>, | ||
} | ||
|
||
/// The action to take with the prompt. | ||
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum, Debug)] | ||
pub enum Action { | ||
/// Show the prompt. | ||
Show, | ||
/// Run the prompt. | ||
Run, | ||
} | ||
|
||
#[tokio::main] | ||
async fn main() -> Result<()> { | ||
color_eyre::install()?; | ||
let args = Args::parse(); | ||
|
||
let asset_sprayer = AssetSprayer::new( | ||
async_openai::Client::new(), | ||
AssetSprayerConfig { | ||
prompts_dir: args.prompts_dir, | ||
}, | ||
); | ||
let prompt = Prompt::AwsAssetSchema { | ||
command: args.aws_command.clone(), | ||
subcommand: args.aws_subcommand.clone(), | ||
}; | ||
match args.action { | ||
Action::Show => { | ||
let prompt = asset_sprayer.prompt(&prompt).await?; | ||
println!("{}", serde_yaml::to_string(&prompt)?); | ||
} | ||
Action::Run => { | ||
let asset_schema = asset_sprayer.run(&prompt).await?; | ||
println!("{}", asset_schema); | ||
} | ||
} | ||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
use std::{borrow::Cow, path::PathBuf}; | ||
|
||
use async_openai::types::{ | ||
ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage, | ||
ChatCompletionRequestSystemMessageContent, ChatCompletionRequestUserMessage, | ||
ChatCompletionRequestUserMessageContent, CreateChatCompletionRequest, | ||
}; | ||
use telemetry::prelude::*; | ||
|
||
use crate::{AssetSprayerError, Result}; | ||
|
||
#[derive(Debug, Clone, strum::Display, strum::EnumDiscriminants)] | ||
#[strum_discriminants(name(PromptKind))] | ||
#[strum_discriminants(derive(strum::Display))] | ||
pub enum Prompt { | ||
AwsAssetSchema { command: String, subcommand: String }, | ||
} | ||
|
||
impl Prompt { | ||
pub fn kind(&self) -> PromptKind { | ||
self.into() | ||
} | ||
|
||
pub async fn prompt( | ||
&self, | ||
prompts_dir: &Option<PathBuf>, | ||
) -> Result<CreateChatCompletionRequest> { | ||
let raw_prompt = self.kind().raw_prompt(prompts_dir).await?; | ||
self.replace_prompt(raw_prompt).await | ||
} | ||
|
||
async fn replace_prompt( | ||
&self, | ||
request: CreateChatCompletionRequest, | ||
) -> Result<CreateChatCompletionRequest> { | ||
let mut request = request; | ||
for message in request.messages.iter_mut() { | ||
*message = self.replace_prompt_message(message.clone()).await?; | ||
} | ||
Ok(request) | ||
} | ||
|
||
async fn replace_prompt_message( | ||
&self, | ||
message: ChatCompletionRequestMessage, | ||
) -> Result<ChatCompletionRequestMessage> { | ||
let mut message = message; | ||
match &mut message { | ||
ChatCompletionRequestMessage::User(ChatCompletionRequestUserMessage { | ||
content: ChatCompletionRequestUserMessageContent::Text(text), | ||
.. | ||
}) | ||
| ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage { | ||
content: ChatCompletionRequestSystemMessageContent::Text(text), | ||
.. | ||
}) => *text = self.replace_prompt_text(text).await?, | ||
_ => {} | ||
}; | ||
Ok(message) | ||
} | ||
|
||
async fn replace_prompt_text(&self, text: &str) -> Result<String> { | ||
let text = match self { | ||
Self::AwsAssetSchema { | ||
command, | ||
subcommand, | ||
} => text | ||
.replace("{AWS_COMMAND}", command) | ||
.replace("{AWS_SUBCOMMAND}", subcommand), | ||
}; | ||
self.fetch_prompt_text(&text).await | ||
} | ||
|
||
async fn fetch_prompt_text(&self, text: &str) -> Result<String> { | ||
// Fetch things between {FETCH} and {/FETCH} | ||
let mut result = String::new(); | ||
let mut text = text; | ||
while let Some(fetch_start) = text.find("{FETCH}") { | ||
// Copy up to {FETCH} | ||
result.push_str(&text[..fetch_start]); | ||
text = &text[(fetch_start + "{FETCH}".len())..]; | ||
|
||
if let Some(url_end) = text.find("{/FETCH}") { | ||
// Fetch the URL between {FETCH}...{/FETCH} | ||
result.push_str(&Self::get(&text[..url_end]).await?); | ||
text = &text[(url_end + "{/FETCH}".len())..]; | ||
} else { | ||
return Err(AssetSprayerError::MissingEndFetch(self.kind())); | ||
} | ||
} | ||
|
||
// Copy the remainder of the text | ||
result.push_str(text); | ||
|
||
Ok(result) | ||
} | ||
|
||
async fn get(url: &str) -> reqwest::Result<String> { | ||
info!("Fetching: {}", url); | ||
let client = reqwest::ClientBuilder::new() | ||
.user_agent("Wget/1.21.2") | ||
.build()?; | ||
let response = client.get(url).send().await?; | ||
response.error_for_status()?.text().await | ||
} | ||
} | ||
|
||
impl PromptKind { | ||
pub async fn raw_prompt( | ||
&self, | ||
prompts_dir: &Option<PathBuf>, | ||
) -> Result<CreateChatCompletionRequest> { | ||
Ok(serde_yaml::from_str(&self.yaml(prompts_dir).await?)?) | ||
} | ||
|
||
async fn yaml(&self, prompts_dir: &Option<PathBuf>) -> Result<Cow<'static, str>> { | ||
if let Some(ref prompts_dir) = prompts_dir { | ||
// Read from disk if prompts_dir is available (faster dev cycle) | ||
let path = prompts_dir.join(self.yaml_relative_path()); | ||
info!("Loading prompt for {} from disk at {:?}", self, path); | ||
Ok(tokio::fs::read_to_string(path).await?.into()) | ||
} else { | ||
info!("Loading embedded prompt for {}", self); | ||
Ok(self.yaml_embedded().into()) | ||
} | ||
} | ||
|
||
fn yaml_relative_path(&self) -> &str { | ||
match self { | ||
Self::AwsAssetSchema => "aws/asset_schema.yaml", | ||
} | ||
} | ||
|
||
fn yaml_embedded(&self) -> &'static str { | ||
match self { | ||
Self::AwsAssetSchema => include_str!("../prompts/aws/asset_schema.yaml"), | ||
} | ||
} | ||
} |
Oops, something went wrong.