From 0afb27d46ffe2ddb2521ea8c7fc69a5f1dada6b2 Mon Sep 17 00:00:00 2001 From: FL33TW00D Date: Sun, 21 Jan 2024 11:51:15 +0000 Subject: [PATCH] chore: cleaning --- crates/ratchet-core/src/tensor.rs | 26 +++++++++----------------- 1 file changed, 9 insertions(+), 17 deletions(-) diff --git a/crates/ratchet-core/src/tensor.rs b/crates/ratchet-core/src/tensor.rs index da2e77ea..97f57639 100644 --- a/crates/ratchet-core/src/tensor.rs +++ b/crates/ratchet-core/src/tensor.rs @@ -395,14 +395,10 @@ mod tests { let a = a.to(Device::CPU)?; let b = b.to(Device::CPU)?; - let c = Python::with_gil(|py| { - let npy_a = a.to_py::(&py); - let npy_b = b.to_py::(&py); - - let activators = PyModule::from_code( + let c: anyhow::Result = Python::with_gil(|py| { + let prg = PyModule::from_code( py, r#" -import numpy as np import torch def matmul(a, b): @@ -410,19 +406,15 @@ def matmul(a, b): "#, "x.py", "x", - ) - .unwrap(); + )?; - let result = activators - .getattr("matmul") - .unwrap() - .call1((npy_a, npy_b)) - .unwrap() - .extract::<&PyArrayDyn>() - .unwrap(); - Tensor::from(result) + let result = prg + .getattr("matmul")? + .call1((a.to_py::(&py), b.to_py::(&py)))? + .extract::<&PyArrayDyn>()?; + Ok(Tensor::from(result)) }); - println!("\nC: {:#?}", c); + println!("\nTORCH: {:#?}", c); Ok(()) }