Skip to content

Commit

Permalink
chore: check test
Browse files Browse the repository at this point in the history
  • Loading branch information
FL33TW00D committed Jan 21, 2024
1 parent cd31192 commit 1859f9e
Showing 1 changed file with 41 additions and 39 deletions.
80 changes: 41 additions & 39 deletions crates/ratchet-core/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -426,43 +426,45 @@ mod tests {
Ok(())
}

#[test]
fn test_pyo3() -> anyhow::Result<()> {
let cpu_device = Device::request_device(DeviceRequest::CPU)?;
let a = Tensor::randn::<f32>(shape![1024, 1024], cpu_device.clone());
let b = Tensor::randn::<f32>(shape![1024, 1024], cpu_device.clone());

let ground: anyhow::Result<Tensor> = Python::with_gil(|py| {
let prg = PyModule::from_code(
py,
r#"
import torch
def matmul(a, b):
return torch.matmul(torch.from_numpy(a), torch.from_numpy(b)).numpy()
"#,
"x.py",
"x",
)?;

let result = prg
.getattr("matmul")?
.call1((a.to_py::<f32>(&py), b.to_py::<f32>(&py)))?
.extract::<&PyArrayDyn<f32>>()?;
Ok(Tensor::from(result))
});

let gpu_device = Device::request_device(DeviceRequest::GPU)?;
let a = a.to(gpu_device.clone())?;
let b = b.to(gpu_device)?;

let c = a.matmul(&b)?;
c.resolve()?;

let our_result = c.to(cpu_device)?;
println!("\nTORCH: {:#?}", ground);
println!("\nOURS: {:#?}", our_result);

Ok(())
}
/*
#[test]
fn test_pyo3() -> anyhow::Result<()> {
let cpu_device = Device::request_device(DeviceRequest::CPU)?;
let a = Tensor::randn::<f32>(shape![1024, 1024], cpu_device.clone());
let b = Tensor::randn::<f32>(shape![1024, 1024], cpu_device.clone());
let ground: anyhow::Result<Tensor> = Python::with_gil(|py| {
let prg = PyModule::from_code(
py,
r#"
import torch
def matmul(a, b):
return torch.matmul(torch.from_numpy(a), torch.from_numpy(b)).numpy()
"#,
"x.py",
"x",
)?;
let result = prg
.getattr("matmul")?
.call1((a.to_py::<f32>(&py), b.to_py::<f32>(&py)))?
.extract::<&PyArrayDyn<f32>>()?;
Ok(Tensor::from(result))
});
let gpu_device = Device::request_device(DeviceRequest::GPU)?;
let a = a.to(gpu_device.clone())?;
let b = b.to(gpu_device)?;
let c = a.matmul(&b)?;
c.resolve()?;
let our_result = c.to(cpu_device)?;
println!("\nTORCH: {:#?}", ground);
println!("\nOURS: {:#?}", our_result);
Ok(())
}
*/
}

0 comments on commit 1859f9e

Please sign in to comment.