Skip to content

Commit

Permalink
chore: first pass pytorch integration
Browse files Browse the repository at this point in the history
  • Loading branch information
FL33TW00D committed Jan 21, 2024
1 parent 50f7b5b commit 1e3d725
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 3 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ Cargo.lock

# MSVC Windows builds of rustc generate these, which store debugging information
*.pdb
.python-version
12 changes: 11 additions & 1 deletion crates/ratchet-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ version = "0.1.0"
edition = "2021"

[features]
default = ["rand"]
default = ["rand", "pyo3"]
pyo3 = ["dep:pyo3", "dep:numpy", "dep:ndarray"]
gpu_profiling = []
rand = ["dep:rand", "dep:rand_distr"]

Expand Down Expand Up @@ -36,5 +37,14 @@ rand_distr = { version = "0.4.3", optional = true }
rand = { version = "0.8.4", optional = true }
lazy_static = "1.4.0"

# Python bindings
pyo3 = { version = "0.20.2", features=["auto-initialize"], optional = true }
numpy = { version = "0.20.0", optional = true }
ndarray = { version = "0.15.6", optional = true }

[dev-dependencies]
rand = "0.8.4"
pyo3 = { version = "0.20.2", features=["auto-initialize"] }
numpy = { version = "0.20.0" }
ndarray = { version = "0.15.6" }

10 changes: 10 additions & 0 deletions crates/ratchet-core/src/storage/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,16 @@ impl Storage {
}
}

pub fn try_cpu(&self) -> Result<&RawCPUBuffer, DeviceError> {
match self.raw.as_ref() {
Some(RawStorage::CPU(raw)) => Ok(raw),
_ => Err(DeviceError::DeviceMismatch(
"CPU".to_string(),
"GPU".to_string(),
)),
}
}

pub fn dump(&self, dtype: DType, full: bool) -> String {
self.raw
.as_ref()
Expand Down
113 changes: 111 additions & 2 deletions crates/ratchet-core/src/tensor.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::gpu::{CpuUniform, WgpuDevice};
use crate::{
ops::*, CompiledOp, DType, Device, DeviceStorage, Executable, Operation, OperationError,
RawStorage, Shape, Storage, Strides, TensorDType, TensorId,
RawCPUBuffer, RawStorage, Shape, Storage, Strides, TensorDType, TensorId,
};
use crate::{BinaryOp, LazyOp};

Expand All @@ -12,6 +12,12 @@ use std::sync::Arc;
#[cfg(feature = "rand")]
use {rand::prelude::*, rand_distr::StandardNormal};

#[cfg(feature = "pyo3")]
use {
ndarray::{ArrayD, ArrayViewD},
numpy::PyArrayDyn,
};

// thiserror error for Tensor
#[derive(thiserror::Error, Debug)]
pub enum TensorError {
Expand Down Expand Up @@ -295,16 +301,87 @@ impl Tensor {
_ => Ok(self.clone()),
}
}

#[cfg(feature = "pyo3")]
pub fn into_ndarray<T: TensorDType>(&self) -> ArrayD<T> {
assert!(self.device().is_cpu());
let storage = self.storage().try_read().unwrap();
let raw_cpu = storage.try_cpu().unwrap();
let shape = self.shape().to_vec();
if self.num_bytes() != 0 {
let ptr = raw_cpu.inner().0 as *const T;
unsafe { ArrayViewD::from_shape_ptr(shape, ptr).to_owned() }
} else {
ArrayViewD::from_shape(shape, &[]).unwrap().to_owned()
}
}

#[cfg(feature = "pyo3")]
pub fn to_py<'s, 'p: 's, T: TensorDType + numpy::Element>(
&'s self,
py: &'p pyo3::Python<'p>,
) -> &PyArrayDyn<T> {
use numpy::PyArray;
PyArray::from_owned_array(*py, self.clone().into_ndarray::<T>())
}
}

#[cfg(feature = "pyo3")]
impl<T: TensorDType> From<ArrayD<T>> for Tensor {
fn from(it: ArrayD<T>) -> Self {
if it.as_slice().is_some() {
let layout = std::alloc::Layout::from_size_align(
it.len() * std::mem::size_of::<T>(),
std::mem::align_of::<T>(),
)
.unwrap();
let shape = it.shape().to_vec().into();
let strides = Strides::from(&shape);
let vec = it.into_raw_vec().into_boxed_slice();
let ptr = Box::into_raw(vec) as *mut u8;

let raw_buf = RawCPUBuffer::new(ptr, layout);
let storage = Storage::from(RawStorage::CPU(raw_buf));
let meta = StorageView::new(shape, T::dt(), strides);
Tensor::new(LazyOp::Const, meta, storage, Device::CPU)
} else {
panic!("Cannot convert numpy array with non-contiguous memory layout to tensor");
}
}
}

#[cfg(feature = "pyo3")]
impl<T: TensorDType + numpy::Element> From<&PyArrayDyn<T>> for Tensor {
fn from(array: &PyArrayDyn<T>) -> Self {
Self::from(array.to_owned_array())
}
}

#[cfg(test)]
mod tests {
use pyo3::{types::PyModule, Python};

use crate::{shape, DeviceRequest};

use super::*;

#[test]
fn test_cfg() -> anyhow::Result<()> {
fn test_matmul() -> anyhow::Result<()> {
let device = Device::request_device(DeviceRequest::GPU)?;
let a = Tensor::randn::<f32>(shape![1024, 1024], device.clone());
let b = Tensor::randn::<f32>(shape![1024, 1024], device.clone());
let c = a.matmul(&b)?;
c.resolve()?;
println!("\nA: {:#?}", a);
println!("\nB: {:#?}", b);
println!("\nC: {:#?}", c);
let d = c.to(Device::CPU)?;
println!("\nD: {:#?}", d);
Ok(())
}

#[test]
fn test_pyo3() -> anyhow::Result<()> {
let device = Device::request_device(DeviceRequest::GPU)?;
let a = Tensor::randn::<f32>(shape![1024, 1024], device.clone());
let b = Tensor::randn::<f32>(shape![1024, 1024], device.clone());
Expand All @@ -315,6 +392,38 @@ mod tests {
println!("\nC: {:#?}", c);
let d = c.to(Device::CPU)?;
println!("\nD: {:#?}", d);

let a = a.to(Device::CPU)?;
let b = b.to(Device::CPU)?;
let c = Python::with_gil(|py| {
let npy_a = a.to_py::<f32>(&py);
let npy_b = b.to_py::<f32>(&py);

let activators = PyModule::from_code(
py,
r#"
import numpy as np
import torch
def matmul(a, b):
return torch.matmul(torch.from_numpy(a), torch.from_numpy(b)).numpy()
"#,
"x.py",
"x",
)
.unwrap();

let result = activators
.getattr("matmul")
.unwrap()
.call1((npy_a, npy_b))
.unwrap()
.extract::<&PyArrayDyn<f32>>()
.unwrap();
Tensor::from(result)
});
println!("\nC: {:#?}", c);

Ok(())
}
}
6 changes: 6 additions & 0 deletions justfile
Original file line number Diff line number Diff line change
@@ -1,2 +1,8 @@
line-count:
cd ./crates/ratchet-core && scc -irs --exclude-file kernels
install-pyo3:
env PYTHON_CONFIGURE_OPTS="--enable-shared" pyenv install --verbose 3.10.6
echo "Please PYO3_PYTHON to your .bashrc or .zshrc"
wasm CRATE:
RUSTFLAGS=--cfg=web_sys_unstable_apis wasm-pack build --target web -d `pwd`/target/pkg/{{CRATE}} --out-name {{CRATE}} ./crates/{{CRATE}} --release

0 comments on commit 1e3d725

Please sign in to comment.