Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Bedrock Cloud Model provider #21092

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
47 changes: 47 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 5 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ members = [
"crates/assistant_tool",
"crates/audio",
"crates/auto_update",
"crates/bedrock",
"crates/auto_update_ui",
"crates/breadcrumbs",
"crates/call",
Expand Down Expand Up @@ -170,7 +171,6 @@ members = [
#
# Tooling
#

"tooling/xtask",
]
default-members = ["crates/zed"]
Expand All @@ -191,6 +191,7 @@ assistant_tool = { path = "crates/assistant_tool" }
audio = { path = "crates/audio" }
auto_update = { path = "crates/auto_update" }
auto_update_ui = { path = "crates/auto_update_ui" }
bedrock = { path = "crates/bedrock" }
breadcrumbs = { path = "crates/breadcrumbs" }
call = { path = "crates/call" }
channel = { path = "crates/channel" }
Expand Down Expand Up @@ -335,6 +336,9 @@ async-trait = "0.1"
async-tungstenite = "0.28"
async-watch = "0.3.1"
async_zip = { version = "0.0.17", features = ["deflate", "deflate64"] }
aws-credential-types = { version = "1.2.1", features = ["hardcoded-credentials"] }
aws-config = { version = "1.1.7", features = ["behavior-version-latest"] }
aws-sdk-bedrockruntime = { version = "1.57.0", features = ["behavior-version-latest"]}
base64 = "0.22"
bitflags = "2.6.0"
blade-graphics = { git = "https://github.com/kvark/blade", rev = "e142a3a5e678eb6a13e642ad8401b1f3aa38e969" }
Expand Down
6 changes: 6 additions & 0 deletions crates/assistant/src/assistant_settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ pub enum AssistantProviderContentV1 {
default_model: Option<OllamaModel>,
api_url: Option<String>,
},
#[serde(rename = "bedrock")]
Bedrock {
default_model: Option<CloudModel>,
region: Option<String>,
},
}

#[derive(Debug, Default)]
Expand Down Expand Up @@ -415,6 +420,7 @@ pub struct LanguageModelSelection {
fn providers_schema(_: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema {
schemars::schema::SchemaObject {
enum_values: Some(vec![
"bedrock".into(),
"anthropic".into(),
"google".into(),
"ollama".into(),
Expand Down
33 changes: 33 additions & 0 deletions crates/bedrock/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
[package]
name = "bedrock"
version = "0.1.0"
edition = "2021"
publish = false
license = "AGPL-3.0-or-later"

[features]
default = []
schemars = ["dep:schemars"]

[lints]
workspace = true

[lib]
path = "src/bedrock.rs"

[dependencies]
anyhow.workspace = true
chrono.workspace = true
futures.workspace = true
http_client.workspace = true
schemars = { workspace = true, optional = true }
serde.workspace = true
serde_json.workspace = true
strum.workspace = true
thiserror.workspace = true
util.workspace = true
aws-sdk-bedrockruntime = { workspace = true, features = ["behavior-version-latest"]}
aws-config = {workspace = true, features = ["behavior-version-latest"]}

[dev-dependencies]
tokio.workspace = true
133 changes: 133 additions & 0 deletions crates/bedrock/src/bedrock.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
mod models;

use std::{str::FromStr};
use std::any::Any;
use anyhow::{Context, Result};
use aws_sdk_bedrockruntime::types::{ContentBlockDeltaEvent, ContentBlockStartEvent, ConverseStreamMetadataEvent, ConverseStreamOutput, Message, MessageStartEvent, MessageStopEvent};
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt, Stream, StreamExt};
use serde::{Deserialize, Serialize};
use thiserror::Error;

use aws_sdk_bedrockruntime as bedrock;
pub use aws_sdk_bedrockruntime as bedrock_client;
pub use bedrock::operation::converse_stream::ConverseStreamInput as BedrockStreamingRequest;
pub use bedrock::types::ContentBlock as BedrockRequestContent;
use bedrock::types::ConverseOutput as Response;
pub use bedrock::types::Message as BedrockMessage;
pub use bedrock::types::ConversationRole as BedrockRole;
pub use bedrock::types::ResponseStream as BedrockResponseStream;

//TODO: Re-export the Bedrock stuff
// https://doc.rust-lang.org/rustdoc/write-documentation/re-exports.html

pub use models::*;

pub async fn complete(
client: &bedrock::Client,
request: Request,
) -> Result<Response, BedrockError> {
let mut response = bedrock::Client::converse(client)
.model_id(request.model.clone())
.set_messages(request.messages.into())
.send().await.context("Failed to send request to Bedrock");

match response {
Ok(output) => {
Ok(output.output.unwrap())
}
Err(err) => {
Err(BedrockError::Other(err))
}
}
}

pub async fn stream_completion(
client: &bedrock::Client,
request: Request,
) -> Result<BoxStream<'static, Result<BedrockEvent, BedrockError>>, BedrockError> { // There is no generic Bedrock event Type?

let response = bedrock::Client::converse_stream(client)
.model_id(request.model)
.set_messages(request.messages.into()).send().await;


let mut stream = match response {
Ok(output) => Ok(output.stream),
Err(e) => Err(
BedrockError::SdkError(e.as_service_error().unwrap())
),
}?;

loop {
let token = stream.recv().await;
match token {
Ok(Some(text)) => {
let next = get_converse_output_text(text)?;
print!("{}", next);
Ok(())
}
Ok(None) => break,
Err(e) => Err(e
.as_service_error()
.map(BedrockConverseStreamError::from)
.unwrap_or(BedrockConverseStreamError(
"Unknown error receiving stream".into(),
))),
}?
}
}

fn get_converse_output_text(
output: ConverseStreamOutput,
) -> Result<String, BedrockError> {
Ok(match output {
ConverseStreamOutput::ContentBlockDelta(c) => {
match c.delta() {
Some(delta) => delta.as_text().cloned().unwrap_or_else(|_| "".into()),
None => "".into(),
}
}
_ => {
String::from("")
}
})
}
//TODO: A LOT of these types need to re-export the Bedrock types instead of making custom ones

#[derive(Debug, Serialize, Deserialize)]
pub struct Request {
pub model: String,
pub max_tokens: u32,
pub messages: Vec<Message>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub system: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub metadata: Option<Metadata>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub stop_sequences: Vec<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub top_k: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct Metadata {
pub user_id: Option<String>,
}

#[derive(Error, Debug)]
pub enum BedrockError {
SdkError(bedrock::Error),
Other(anyhow::Error)
}

pub enum BedrockEvent {
ContentBlockDelta(ContentBlockDeltaEvent),
ContentBlockStart(ContentBlockStartEvent),
MessageStart(MessageStartEvent),
MessageStop(MessageStopEvent),
Metadata(ConverseStreamMetadataEvent),
}
Loading