Skip to content

Commit

Permalink
split ew helper from ew
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Apr 10, 2023
1 parent 0478196 commit f981c84
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 73 deletions.
1 change: 1 addition & 0 deletions linalg/src/frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pub mod pack;
pub mod sigmoid;
#[macro_use]
pub mod tanh;
pub mod element_wise_helper;

pub use pack::Packer;
pub use pack::PackingWriter;
Expand Down
81 changes: 8 additions & 73 deletions linalg/src/frame/element_wise.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use std::alloc::*;
use std::fmt::Debug;
use std::marker::PhantomData;
use tract_data::anyhow;

use tract_data::TractResult;

use crate::LADatum;

use super::element_wise_helper::run_over_slice_with_alignment;

macro_rules! ew_impl {
($ti: ident, $func: ident, $nr: expr, $alignment_items: expr) => {
paste! {
Expand Down Expand Up @@ -44,53 +46,11 @@ macro_rules! ew_impl {
};
}

struct TempBuffer {
layout: Layout,
buffer: *mut u8,
}

impl Default for TempBuffer {
fn default() -> Self {
TempBuffer { layout: Layout::new::<()>(), buffer: std::ptr::null_mut() }
}
}

impl TempBuffer {
fn ensure(&mut self, size: usize, alignment: usize) {
unsafe {
if size > self.layout.size() || alignment > self.layout.align() {
let size = size.max(self.layout.size());
let alignment = alignment.max(self.layout.align());
if !self.buffer.is_null() {
std::alloc::dealloc(self.buffer, self.layout);
}
self.layout = Layout::from_size_align_unchecked(size, alignment);
self.buffer = std::alloc::alloc(self.layout);
assert!(!self.buffer.is_null());
}
}
}
}

impl Drop for TempBuffer {
fn drop(&mut self) {
unsafe {
if !self.buffer.is_null() {
std::alloc::dealloc(self.buffer, self.layout);
}
}
}
}

std::thread_local! {
static TMP: std::cell::RefCell<TempBuffer> = std::cell::RefCell::new(TempBuffer::default());
}

pub trait ElementWise<T>: Send + Sync + Debug + dyn_clone::DynClone
where
T: Copy + Debug + PartialEq + Send + Sync,
{
fn run(&self, vec: &mut [T]) -> anyhow::Result<()>;
fn run(&self, vec: &mut [T]) -> TractResult<()>;
}

dyn_clone::clone_trait_object!(<T> ElementWise<T> where T: Copy);
Expand All @@ -109,37 +69,12 @@ where
T: LADatum,
K: ElementWiseKer<T> + Clone,
{
fn run(&self, vec: &mut [T]) -> anyhow::Result<()> {
if vec.is_empty() {
return Ok(());
}
unsafe {
TMP.with(|buffer| {
let mut buffer = buffer.borrow_mut();
buffer.ensure(K::nr() * T::datum_type().size_of(), K::alignment_bytes());
let tmp = std::slice::from_raw_parts_mut(buffer.buffer as *mut T, K::nr());
let mut compute_via_temp_buffer = |slice: &mut [T]| {
tmp[..slice.len()].copy_from_slice(slice);
K::run(tmp);
slice.copy_from_slice(&tmp[..slice.len()])
};
let prefix_len = vec.as_ptr().align_offset(K::alignment_bytes()).min(vec.len());
if prefix_len > 0 {
compute_via_temp_buffer(&mut vec[..prefix_len]);
}
let aligned_len = (vec.len() - prefix_len) / K::nr() * K::nr();
if aligned_len > 0 {
K::run(&mut vec[prefix_len..][..aligned_len]);
}
if prefix_len + aligned_len < vec.len() {
compute_via_temp_buffer(&mut vec[prefix_len + aligned_len..]);
}
})
}
Ok(())
fn run(&self, vec: &mut [T]) -> TractResult<()> {
run_over_slice_with_alignment(vec, K::run, K::nr(), K::alignment_bytes())
}
}


pub trait ElementWiseKer<T>: Send + Sync + Debug + dyn_clone::DynClone + Clone + 'static
where
T: LADatum,
Expand Down
83 changes: 83 additions & 0 deletions linalg/src/frame/element_wise_helper.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
use crate::LADatum;
use std::alloc::*;
use tract_data::TractResult;

pub(crate) fn run_over_slice_with_alignment<T>(
vec: &mut [T],
f: impl Fn(&mut [T]),
nr: usize,
alignment_bytes: usize,
) -> TractResult<()>
where
T: LADatum,
{
if vec.is_empty() {
return Ok(());
}
unsafe {
TMP.with(|buffer| {
let mut buffer = buffer.borrow_mut();
buffer.ensure(nr * T::datum_type().size_of(), alignment_bytes);
let tmp = std::slice::from_raw_parts_mut(buffer.buffer as *mut T, nr);
let mut compute_via_temp_buffer = |slice: &mut [T]| {
tmp[..slice.len()].copy_from_slice(slice);
f(tmp);
slice.copy_from_slice(&tmp[..slice.len()])
};
let prefix_len = vec.as_ptr().align_offset(alignment_bytes).min(vec.len());
if prefix_len > 0 {
compute_via_temp_buffer(&mut vec[..prefix_len]);
}
let aligned_len = (vec.len() - prefix_len) / nr * nr;
if aligned_len > 0 {
f(&mut vec[prefix_len..][..aligned_len]);
}
if prefix_len + aligned_len < vec.len() {
compute_via_temp_buffer(&mut vec[prefix_len + aligned_len..]);
}
})
}
Ok(())
}

std::thread_local! {
static TMP: std::cell::RefCell<TempBuffer> = std::cell::RefCell::new(TempBuffer::default());
}

pub struct TempBuffer {
pub layout: Layout,
pub buffer: *mut u8,
}

impl Default for TempBuffer {
fn default() -> Self {
TempBuffer { layout: Layout::new::<()>(), buffer: std::ptr::null_mut() }
}
}

impl TempBuffer {
pub fn ensure(&mut self, size: usize, alignment: usize) {
unsafe {
if size > self.layout.size() || alignment > self.layout.align() {
let size = size.max(self.layout.size());
let alignment = alignment.max(self.layout.align());
if !self.buffer.is_null() {
std::alloc::dealloc(self.buffer, self.layout);
}
self.layout = Layout::from_size_align_unchecked(size, alignment);
self.buffer = std::alloc::alloc(self.layout);
assert!(!self.buffer.is_null());
}
}
}
}

impl Drop for TempBuffer {
fn drop(&mut self) {
unsafe {
if !self.buffer.is_null() {
std::alloc::dealloc(self.buffer, self.layout);
}
}
}
}
3 changes: 3 additions & 0 deletions linalg/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ pub fn generic() -> Ops {
tanh_f32: Box::new(|| generic::STanh4::ew()),
erf_f32: Box::new(|| generic::SErf4::ew()),
lut_u8: Box::new(|table: &[u8]| Box::new(lut::LutImpl::<generic::GenericLut8>::new(table))),
/*
activation_f32: Box::new(|microcode| generic::SActivation::new(microcode))
*/
}
}

Expand Down

0 comments on commit f981c84

Please sign in to comment.