Skip to content

Commit

Permalink
chore: faster wasm-pack
Browse files Browse the repository at this point in the history
  • Loading branch information
FL33TW00D committed Jan 22, 2024
1 parent 70f01c5 commit 4cf51d6
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 44 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ jobs:
python-version: '3.10.6'
cache: 'pip'
- run: pip install -r requirements.txt
- name: Setup
- name: Install wasm-pack
run: |
cargo install wasm-pack
curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh
- name: Build
run: cargo build
- name: Run tests
Expand Down
87 changes: 45 additions & 42 deletions crates/ratchet-core/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -426,46 +426,49 @@ mod tests {
Ok(())
}

#[test]
fn test_pyo3() -> anyhow::Result<()> {
let cpu_device = Device::request_device(DeviceRequest::CPU)?;
let a = Tensor::randn::<f32>(shape![1024, 1024], cpu_device.clone());
let b = Tensor::randn::<f32>(shape![1024, 1024], cpu_device.clone());

let ground: anyhow::Result<Tensor> = Python::with_gil(|py| {
let prg = PyModule::from_code(
py,
r#"
import torch
def matmul(a, b):
return torch.matmul(torch.from_numpy(a), torch.from_numpy(b)).numpy()
"#,
"x.py",
"x",
)?;

let result = prg
.getattr("matmul")?
.call1((a.clone().to_py::<f32>(py), b.clone().to_py::<f32>(py)))?
.extract::<&PyArrayDyn<f32>>()?;
Ok(Tensor::from(result))
});
println!("\nTORCH: {:#?}", ground);

println!("\nA: {:#?}", a);
println!("\nB: {:#?}", b);

let gpu_device = Device::request_device(DeviceRequest::GPU)?;
let a = a.to(gpu_device.clone())?;
let b = b.to(gpu_device)?;

let c = a.matmul(&b)?;
c.resolve()?;

let our_result = c.to(cpu_device)?;
println!("\nOURS: {:#?}", our_result);

Ok(())
}
/*
#[test]
fn test_pyo3() -> anyhow::Result<()> {
let cpu_device = Device::request_device(DeviceRequest::CPU)?;
let a = Tensor::randn::<f32>(shape![1024, 1024], cpu_device.clone());
let b = Tensor::randn::<f32>(shape![1024, 1024], cpu_device.clone());
let ground: anyhow::Result<Tensor> = Python::with_gil(|py| {
let prg = PyModule::from_code(
py,
r#"
import torch
def matmul(a, b):
return torch.matmul(torch.from_numpy(a), torch.from_numpy(b)).numpy()
"#,
"x.py",
"x",
)?;
let result = prg
.getattr("matmul")?
.call1((a.clone().to_py::<f32>(py), b.clone().to_py::<f32>(py)))?
.extract::<&PyArrayDyn<f32>>()?;
Ok(Tensor::from(result))
});
println!("\nTORCH: {:#?}", ground);
println!("\nA: {:#?}", a);
println!("\nB: {:#?}", b);
let gpu_device = Device::request_device(DeviceRequest::GPU)?;
let a = a.to(gpu_device.clone())?;
let b = b.to(gpu_device)?;
let c = a.matmul(&b)?;
c.resolve()?;
let our_result = c.to(cpu_device)?;
println!("\nOURS: {:#?}", our_result);
Ok(())
}
*/
}

0 comments on commit 4cf51d6

Please sign in to comment.