diff --git a/crates/ratchet-core/src/tensor.rs b/crates/ratchet-core/src/tensor.rs index 12db2744..3d3ca2d3 100644 --- a/crates/ratchet-core/src/tensor.rs +++ b/crates/ratchet-core/src/tensor.rs @@ -426,43 +426,45 @@ mod tests { Ok(()) } - #[test] - fn test_pyo3() -> anyhow::Result<()> { - let cpu_device = Device::request_device(DeviceRequest::CPU)?; - let a = Tensor::randn::(shape![1024, 1024], cpu_device.clone()); - let b = Tensor::randn::(shape![1024, 1024], cpu_device.clone()); - - let ground: anyhow::Result = Python::with_gil(|py| { - let prg = PyModule::from_code( - py, - r#" -import torch - -def matmul(a, b): - return torch.matmul(torch.from_numpy(a), torch.from_numpy(b)).numpy() -"#, - "x.py", - "x", - )?; - - let result = prg - .getattr("matmul")? - .call1((a.to_py::(&py), b.to_py::(&py)))? - .extract::<&PyArrayDyn>()?; - Ok(Tensor::from(result)) - }); - - let gpu_device = Device::request_device(DeviceRequest::GPU)?; - let a = a.to(gpu_device.clone())?; - let b = b.to(gpu_device)?; - - let c = a.matmul(&b)?; - c.resolve()?; - - let our_result = c.to(cpu_device)?; - println!("\nTORCH: {:#?}", ground); - println!("\nOURS: {:#?}", our_result); - - Ok(()) - } + /* + #[test] + fn test_pyo3() -> anyhow::Result<()> { + let cpu_device = Device::request_device(DeviceRequest::CPU)?; + let a = Tensor::randn::(shape![1024, 1024], cpu_device.clone()); + let b = Tensor::randn::(shape![1024, 1024], cpu_device.clone()); + + let ground: anyhow::Result = Python::with_gil(|py| { + let prg = PyModule::from_code( + py, + r#" + import torch + + def matmul(a, b): + return torch.matmul(torch.from_numpy(a), torch.from_numpy(b)).numpy() + "#, + "x.py", + "x", + )?; + + let result = prg + .getattr("matmul")? + .call1((a.to_py::(&py), b.to_py::(&py)))? + .extract::<&PyArrayDyn>()?; + Ok(Tensor::from(result)) + }); + + let gpu_device = Device::request_device(DeviceRequest::GPU)?; + let a = a.to(gpu_device.clone())?; + let b = b.to(gpu_device)?; + + let c = a.matmul(&b)?; + c.resolve()?; + + let our_result = c.to(cpu_device)?; + println!("\nTORCH: {:#?}", ground); + println!("\nOURS: {:#?}", our_result); + + Ok(()) + } + */ }