Skip to content

Commit

Permalink
Merge pull request #242 from huggingface/chore/linux
Browse files Browse the repository at this point in the history
chore: minor
  • Loading branch information
FL33TW00D authored Aug 12, 2024
2 parents e87830d + 57f5340 commit d07fcdf
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 42 deletions.
1 change: 1 addition & 0 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ env:
DXC_FILENAME: "dxc_2023_08_14.zip"
WASM_BINDGEN_TEST_TIMEOUT: 300 # 5 minutes
CI_BINARY_BUILD: "build18" # Corresponds to https://github.com/gfx-rs/ci-build/releases
RATCHET_FORCE_F32: 1

jobs:
check:
Expand Down
6 changes: 5 additions & 1 deletion crates/ratchet-core/src/ops/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,10 @@ def cast(a):

fn run_cast_trial(prob: CastProblem) -> anyhow::Result<()> {
let device = Device::request_device(DeviceRequest::GPU).unwrap();
let device_precision = device.compute_precision();
if matches!(device_precision, DType::F32) {
return Ok(())
}
let CastProblem { dst_dt, B, M, N } = prob;
let input = Tensor::randn::<f32>(shape![B, M, N], Device::CPU);
let ground = ground_truth(&input, dst_dt)?;
Expand All @@ -256,7 +260,7 @@ def cast(a):
}

#[proptest(cases = 256)]
fn test_type_cast(prob: CastProblem) {
fn test_cast(prob: CastProblem) {
run_cast_trial(prob).unwrap();
}
}
6 changes: 3 additions & 3 deletions crates/ratchet-core/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1152,8 +1152,8 @@ mod tests {
#[test]
fn has_nan_works() {
let device = Device::request_device(crate::DeviceRequest::GPU).unwrap();
let rand = Tensor::randn::<f16>(shape![1, 1500, 384], device.clone());
let nans = Tensor::from_data(vec![f16::NAN; 1500 * 384], shape![1, 1500, 384], device);
let rand = Tensor::randn::<f32>(shape![1, 1500, 384], device.clone());
let nans = Tensor::from_data(vec![f32::NAN; 1500 * 384], shape![1, 1500, 384], device);

let bingo = Tensor::cat(rvec![rand, nans], 2)
.unwrap()
Expand All @@ -1162,6 +1162,6 @@ mod tests {

let result = bingo.to(&Device::CPU).unwrap();
println!("RESULT: {:?}", result);
assert!(result.has_nan::<f16>());
assert!(result.has_nan::<f32>());
}
}
34 changes: 0 additions & 34 deletions crates/ratchet-loader/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,37 +141,3 @@ impl GgmlDType {
}
}

#[cfg(test)]
mod tests {
use crate::gguf::gguf::Header;
use half::f16;
use ratchet::{shape, Device, DeviceRequest, Tensor};

#[test]
fn test_q4_dequant() -> anyhow::Result<()> {
let _ = env_logger::builder().is_test(true).try_init();
let model_path = "../../fixtures/tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf";
let mut reader = std::io::BufReader::new(std::fs::File::open(model_path)?);
let device = Device::request_device(DeviceRequest::GPU)?;
let content = Header::read(&mut reader)?;
let t = content.tensor(&mut reader, "blk.0.attn_k.weight", &device)?;

let rhs = Tensor::randn::<f16>(shape![2048, 2048], device.clone());
let out = t.matmul(rhs, false, false)?.resolve()?;
let out_cpu = out.to(&Device::CPU)?;

println!("{:#?}", out_cpu);

let ground = "../../fixtures/tinyllama_blk_0_attn_k_weight_f32.npy";
let ground_t = Tensor::read_npy::<f32, _>(ground, &device)?
.half()?
.resolve()?
.to(&Device::CPU)?;

println!("{:#?}", ground_t);

//out_cpu.all_close(&ground_t, f16::from_f32(1e-3), f16::from_f32(1e-3))?;

Ok(())
}
}
5 changes: 1 addition & 4 deletions crates/ratchet-models/tests/whisper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,7 @@ async fn tiny_decoder() -> Result<(), JsValue> {

let device = Device::request_device(DeviceRequest::GPU).await.unwrap();

let audio_ctx = Tensor::from_npy_bytes::<f32>(&hs_data.to_vec(), &device)
.unwrap()
.cast(device.compute_precision())
.unwrap();
let audio_ctx = Tensor::from_npy_bytes::<f32>(&hs_data.to_vec(), &device).unwrap();
let mut decoder = WhisperDecoder::load(&header, &config, &mut reader, &device).unwrap();

let mut tokens = vec![50258, 50259, 50359];
Expand Down
12 changes: 12 additions & 0 deletions crates/ratchet-models/webdriver.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"goog:chromeOptions": {
"args": [
"--no-sandbox",
"--headless=new",
"--use-angle=vulkan",
"--enable-features=Vulkan",
"--disable-vulkan-surface",
"--enable-unsafe-webgpu"
]
}
}

0 comments on commit d07fcdf

Please sign in to comment.