diff --git a/linalg/activations/src/lib.rs b/linalg/activations/src/lib.rs index f0a55d4638..dbb0f5c44a 100644 --- a/linalg/activations/src/lib.rs +++ b/linalg/activations/src/lib.rs @@ -161,8 +161,8 @@ mod test { } mod scalar { - use proptest::prelude::*; use super::close_enough; + use proptest::prelude::*; macro_rules! prop_activation { ($name: ident ( $($param:ident),* )) => { @@ -186,8 +186,8 @@ mod test { } mod vector { - use proptest::prelude::*; use super::close_enough; + use proptest::prelude::*; macro_rules! prop_activation { ($name: ident ( $($param:ident),* )) => { diff --git a/linalg/src/frame.rs b/linalg/src/frame.rs index d4c9dbc14f..28f6761da8 100644 --- a/linalg/src/frame.rs +++ b/linalg/src/frame.rs @@ -1,4 +1,6 @@ #[macro_use] +pub mod activations; +#[macro_use] pub mod element_wise; #[macro_use] pub mod lut; diff --git a/linalg/src/frame/activations.rs b/linalg/src/frame/activations.rs new file mode 100644 index 0000000000..8b604b4998 --- /dev/null +++ b/linalg/src/frame/activations.rs @@ -0,0 +1,143 @@ +use std::fmt::Debug; +use std::marker::PhantomData; + +use tract_data::TractResult; + +use crate::LADatum; + +use super::element_wise_helper::run_over_slice_with_alignment; + +pub mod definitions; +pub mod reference; +#[macro_use] +pub mod tests; + +#[derive(Copy, Clone, Debug, PartialEq)] +#[repr(u8)] +pub enum RegisterId { + A = 0, + B = 1, + C = 2, +} + +type ConstantId = u8; + +#[repr(C, u16)] +#[derive(Copy, Clone, Debug, PartialEq)] +pub enum Op { + Done, + Move(RegisterId, RegisterId), + Load(RegisterId, ConstantId), + Abs, + Recip, + Add, + Sub, + Mul, + Min, + Max, + AddConst(ConstantId), + SubConst(ConstantId), + MulConst(ConstantId), + MinConst(ConstantId), + MaxConst(ConstantId), + FMA(ConstantId), // a <- a * b + cst + IfPosTE, + SwapBC, + Floor, + TwoPowOfInt, +} + +#[derive(Clone, Debug, PartialEq)] +pub struct Program { + pub ops: Vec, + pub csts: Vec, +} + +pub trait Activation: Send + Sync + Debug + dyn_clone::DynClone { + fn run(&self, prog: &Program, vec: &mut [T]) -> TractResult<()>; +} + +#[derive(Debug, Clone, new)] +pub struct ActivationImpl +where + T: LADatum, + K: ActivationKer + Clone, +{ + phantom: PhantomData<(K, T)>, +} + +impl Activation for ActivationImpl +where + T: LADatum, + K: ActivationKer + Clone, +{ + fn run(&self, program: &Program, vec: &mut [T]) -> TractResult<()> { + run_over_slice_with_alignment( + vec, + |slice| K::run(&program.ops, &*program.csts, slice), + K::nr(), + K::alignment_bytes(), + ) + } +} + +pub trait ActivationKer: Send + Sync + Debug + dyn_clone::DynClone + Clone + 'static +where + T: LADatum, +{ + fn name() -> &'static str; + fn alignment_bytes() -> usize; + fn alignment_items() -> usize; + fn nr() -> usize; + fn run(ops: &[Op], csts: &[T], vec: &mut [T]); + fn act() -> Box> { + Box::new(ActivationImpl::::new()) + } +} + +macro_rules! act_impl { + ($ti: ident, $func: ident, $nr: expr, $alignment_items: expr) => { + paste! { + mod [] { + #[allow(unused_imports)] + use tract_data::prelude::f16; + extern_kernel!(fn $func(ptr: *mut $ti, count: usize) -> ()); + } + + #[derive(Copy, Clone, Debug)] + #[allow(non_camel_case_types)] + pub struct $func; + + impl ActivationKer<$ti> for $func { + #[inline(always)] + fn name() -> &'static str { + stringify!($func) + } + #[inline(always)] + fn nr() -> usize { + $nr + } + #[inline(always)] + fn alignment_items() -> usize { + $alignment_items + } + #[inline(always)] + fn alignment_bytes() -> usize { + $alignment_items * std::mem::size_of::<$ti>() + } + #[inline(never)] + fn run(ops: &Op, csts:&[T], buf: &mut [$ti]) { + unsafe { []::$func(ops.as_ptr(), csts.as_ptr(), buf.as_mut_ptr(), buf.len()) } + } + } + } + }; +} + +#[cfg(test)] +mod test { + #[test] + fn size_of_op() { + assert_eq!(std::mem::size_of::(), 4); + } +} diff --git a/linalg/src/frame/activations/definitions.rs b/linalg/src/frame/activations/definitions.rs new file mode 100644 index 0000000000..e1f0ff8e59 --- /dev/null +++ b/linalg/src/frame/activations/definitions.rs @@ -0,0 +1,175 @@ +use super::Op::*; +use super::RegisterId::*; +use super::*; + +pub fn relu() -> Program { + Program { ops: vec![MaxConst(0), Done], csts: vec![T::zero()] } +} + +pub fn affine(alpha: T, beta: T) -> Program { + Program { + #[rustfmt::skip] + ops: vec![ + MulConst(0), + AddConst(1), + Done, + ], + csts: vec![alpha, beta], + } +} + +pub fn leaky_relu(alpha: T) -> Program { + Program { + #[rustfmt::skip] + ops: vec![ + Move(B,A), + MulConst(0), + Move(C,A), + Move(A,B), + IfPosTE, + Done, + ], + csts: vec![alpha], + } +} + +pub fn threshold_relu(alpha: T) -> Program { + Program { + #[rustfmt::skip] + ops: vec![ + Move(B,A), + SubConst(1), + Load(C,0), + IfPosTE, + Done, + ], + csts: vec![T::zero(), alpha], + } +} + +pub fn softsign() -> Program { + Program { + #[rustfmt::skip] + ops: vec![ + Move(B,A), + Abs, + AddConst(0), + Recip, + Mul, + Done, + ], + csts: vec![T::one()], + } +} + +pub fn hardswish() -> Program { + Program { + #[rustfmt::skip] + ops: vec![ + Move(B, A), + MulConst(2), + AddConst(3), + MinConst(1), + MaxConst(0), + Mul, + Done, + ], + csts: vec![ + T::zero(), + T::one(), + T::one() / (T::one() + T::one() + T::one() + T::one() + T::one() + T::one()), // 1/6 + T::one() / (T::one() + T::one()), // 1/2 + ], + } +} + +/* +pub fn sigmoid() -> Program { + Program { + #[rustfmt::skip] + ops: vec![ + MinConst(3), + MaxConst(2), + Move(B, A), // b = x + Move(C, A), // c = x + Mul, // a = x2 + Move(B, A), // b = x2 + MulConst(4), + AddConst(5), // a = x2 * a13 + a11 + FMA(6), + FMA(7), + FMA(8), + FMA(9), + FMA(10), + SwapBC, // c = x2, b = x + Mul, // a = p(x) + Move(B, C), // b = x2 + Move(C, A), // c = p(x) + Move(A, B), // a = x2 + MulConst(11), + AddConst(12), + FMA(13), + FMA(1), // a = q(x) + Recip, + Move(B,C), // b = p(x) + Mul, + AddConst(14) + ], + csts: vec![ + -18.6, // const 2 + 18.6, // const 3 + -4.433153405e-18, // const 4, also alpha_13 + 1.169974371e-14, // const 5, also a11 + -1.875289645e-11, + 4.257889523e-8, + 0.00004811817576, // const 8 + 0.008163842030, + 0.2499999971, // alpha_1 + 3.922935744e-6, // beta_6 + 0.001524872358, // const 12 + 0.1159886749, + 0.5, //beta_0 + ], + } +} + +pub fn exp2f() -> Program { + Program { + #[rustfmt::skip] + ops: vec![ + MinConst(2), + MaxConst(3), + Move(B, A), // b = x + AddConst(4), // a = x + 0.5 + Floor, // a = ipart + Move(C, A), // c = ipart + Move(A, B), // a = x + Move(B, C), // b = ipart + Sub, // a = fpart + Move(B, A), // b = fpart + Load(A, 5), // a = exp2p[0] + FMA(6), + FMA(7), + FMA(8), + FMA(9), + FMA(10), + FMA(1), // a = y + Move(B, A), + Move(A, C), + TwoPowOfInt, + Mul + ], + csts: vec![ + 127f32, + -127f32, + 0.5, + 1.535336188319500e-4, + 1.339887440266574e-3, + 9.618437357674640e-3, + 5.550332471162809e-2, + 2.402264791363012e-1, + 6.931472028550421e-1, + ], + } +} +*/ diff --git a/linalg/src/frame/activations/reference.rs b/linalg/src/frame/activations/reference.rs new file mode 100644 index 0000000000..525fd849f0 --- /dev/null +++ b/linalg/src/frame/activations/reference.rs @@ -0,0 +1,110 @@ + +pub fn relu(x: f32) -> f32 { + x.max(0f32) +} + +pub fn affine(x: f32, alpha: f32, beta: f32) -> f32 { + alpha * x + beta +} + +pub fn leaky_relu(x: f32, alpha: f32) -> f32 { + if x > 0f32 { + x + } else { + alpha * x + } +} + +pub fn threshold_relu(x: f32, alpha: f32) -> f32 { + if x >= alpha { + x + } else { + 0f32 + } +} + +pub fn softsign(x: f32) -> f32 { + x / (1. + x.abs()) +} + +pub fn hardswish(x: f32) -> f32 { + x * 0f32.max(1f32.min((1. / 6.) * x + 0.5)) +} + +pub fn sigmoid(x: f32) -> f32 { + ssigmoid(x) +} + +pub fn ref_exp2f(x: f32) -> f32 { + 2f32.powf(x) +} + +pub fn cm_exp2f(x: f32) -> f32 { + exp2f(x) +} + +fn ssigmoid(x: f32) -> f32 { + const LOW: f32 = -18.6; + const HIGH: f32 = -LOW; + + const ALPHA_13: f32 = -4.433153405e-18; + const ALPHA_11: f32 = 1.169974371e-14; + const ALPHA_9: f32 = -1.875289645e-11; + const ALPHA_7: f32 = 4.257889523e-8; + const ALPHA_5: f32 = 0.00004811817576; + const ALPHA_3: f32 = 0.008163842030; + const ALPHA_1: f32 = 0.2499999971; + const BETA_6: f32 = 3.922935744e-6; + const BETA_4: f32 = 0.001524872358; + const BETA_2: f32 = 0.1159886749; + const BETA_0: f32 = 1.0; + + let x = x.clamp(LOW, HIGH); + + let x2 = x * x; + + let p = ALPHA_13; + let p = x2 * p + ALPHA_11; + let p = x2 * p + ALPHA_9; + let p = x2 * p + ALPHA_7; + let p = x2 * p + ALPHA_5; + let p = x2 * p + ALPHA_3; + let p = x2 * p + ALPHA_1; + let p = p * x; + + let q = BETA_6; + let q = x2 * q + BETA_4; + let q = x2 * q + BETA_2; + let q = x2 * q + BETA_0; + + p / q + 0.5 +} + +pub fn exp2f(x: f32) -> f32 { + const EXP2P: [f32; 7] = [ + 1.535336188319500e-4, + 1.339887440266574e-3, + 9.618437357674640e-3, + 5.550332471162809e-2, + 2.402264791363012e-1, + 6.931472028550421e-1, + 1.000000000000000, + ]; + + let x = x.min(127f32).max(-127f32); + + let ipart = (x + 0.5).floor(); + let fpart = x - ipart; + + // 2^ipart + let two_pow_ipart = f32::from_bits((((ipart as i32) + 127) as u32) << 23); + + let mut y = EXP2P[0]; + y = y * fpart + EXP2P[1]; + y = y * fpart + EXP2P[2]; + y = y * fpart + EXP2P[3]; + y = y * fpart + EXP2P[4]; + y = y * fpart + EXP2P[5]; + y = y * fpart + EXP2P[6]; + y * two_pow_ipart +} diff --git a/linalg/src/frame/activations/tests.rs b/linalg/src/frame/activations/tests.rs new file mode 100644 index 0000000000..197608ff2e --- /dev/null +++ b/linalg/src/frame/activations/tests.rs @@ -0,0 +1,35 @@ +macro_rules! prop_activation { + ($cond:expr, $ti: ty, $ker: ty, $name: ident ( $($param:ident),* )) => { + proptest::proptest! { + #[test] + fn $name(x in proptest::prelude::any::<$ti>(), repeat in 1usize..4, $($param in proptest::prelude::any::<$ti>()),*) { + if $cond { + let mut input = tract_data::prelude::Tensor::zero_aligned::<$ti>(&[<$ker>::nr() * repeat], <$ker>::alignment_bytes()).unwrap(); + input.fill_t::<$ti>(x).unwrap(); + let prog = crate::frame::activations::definitions::$name($($param),*); + <$ker>::run(&prog.ops, &prog.csts, &mut input.as_slice_mut::<$ti>().unwrap()); + let expected = crate::frame::activations::reference::$name(x, $($param),*); + let mut output = tract_data::prelude::Tensor::zero_aligned::<$ti>(&[<$ker>::nr() * repeat], <$ker>::alignment_bytes()).unwrap(); + output.fill_t::<$ti>(expected).unwrap(); + output.close_enough(&input, true).unwrap(); + } + } + } + } +} + +#[macro_export] +macro_rules! act_frame_tests { + ($cond:expr, $ker:ty, $ti:ty) => { + prop_activation!($cond, $ti, $ker, relu()); + prop_activation!($cond, $ti, $ker, affine(alpha, beta)); + prop_activation!($cond, $ti, $ker, leaky_relu(alpha)); + prop_activation!($cond, $ti, $ker, threshold_relu(alpha)); + prop_activation!($cond, $ti, $ker, softsign()); + prop_activation!($cond, $ti, $ker, hardswish()); + /* + prop_activation!($cond, $ti, $ker, sigmoid()); + prop_activation!($cond, $ti, $ker, exp2f()); + */ + }; +} diff --git a/linalg/src/generic.rs b/linalg/src/generic.rs index 583ce45816..485c7d32f7 100644 --- a/linalg/src/generic.rs +++ b/linalg/src/generic.rs @@ -1,3 +1,4 @@ +pub mod activations; pub mod erf; pub mod lut; pub mod mmm; diff --git a/linalg/src/generic/activations.rs b/linalg/src/generic/activations.rs new file mode 100644 index 0000000000..c76eace069 --- /dev/null +++ b/linalg/src/generic/activations.rs @@ -0,0 +1,98 @@ +use crate::frame::activations::{ActivationKer, Op, RegisterId}; + +// TODO make the inner loop tighter +unsafe fn compute_slice(ops: *const Op, constants: *const f32, xs: *mut f32, len: usize) { + let mut a = std::slice::from_raw_parts_mut(xs, len); + let mut b = vec![0.0f32; a.len()]; + let mut c = vec![0.0f32; a.len()]; + let mut pc = ops; + loop { + match *pc { + Op::Done => break, + Op::Move(dst, src) => { + let mut regs = [&mut a, &mut *b, &mut c]; + let dst = dst as usize; + let src = src as usize; + if dst < src { + let (left, right) = regs.split_at_mut(src); + let d = &mut *left[dst]; + let s = &*right[0]; + d.copy_from_slice(s) + } else { + let (left, right) = regs.split_at_mut(dst); + let s = &*left[src]; + let d = &mut *right[0]; + d.copy_from_slice(s) + } + } + Op::Load(dst, cst) if dst == RegisterId::A => { + a.iter_mut().for_each(|x| *x = *constants.add(cst as usize)) + } + Op::Load(dst, cst) if dst == RegisterId::B => { + b.iter_mut().for_each(|x| *x = *constants.add(cst as usize)) + } + Op::Load(_dst, cst) => c.iter_mut().for_each(|x| *x = *constants.add(cst as usize)), + Op::Abs => a.iter_mut().for_each(|x| *x = x.abs()), + Op::Recip => a.iter_mut().for_each(|x| *x = x.recip()), + Op::Add => a.iter_mut().zip(&b).for_each(|(x, y)| *x += *y), + Op::Sub => a.iter_mut().zip(&b).for_each(|(x, y)| *x -= *y), + Op::Mul => a.iter_mut().zip(&b).for_each(|(x, y)| *x *= *y), + Op::Min => a.iter_mut().zip(&b).for_each(|(x, y)| *x = x.min(*y)), + Op::Max => a.iter_mut().zip(&b).for_each(|(x, y)| *x = x.max(*y)), + Op::AddConst(cst) => a.iter_mut().for_each(|x| *x += *constants.add(cst as usize)), + Op::SubConst(cst) => a.iter_mut().for_each(|x| *x -= *constants.add(cst as usize)), + Op::MulConst(cst) => a.iter_mut().for_each(|x| *x *= *constants.add(cst as usize)), + Op::MinConst(cst) => { + a.iter_mut().for_each(|x| *x = x.min(*constants.add(cst as usize))) + } + Op::MaxConst(cst) => { + a.iter_mut().for_each(|x| *x = x.max(*constants.add(cst as usize))) + } + Op::IfPosTE => a + .iter_mut() + .zip(&b) + .zip(&c) + .for_each(|((x, y), z)| *x = if *x >= 0f32 { *y } else { *z }), + Op::FMA(cst) => { + a.iter_mut().zip(&b).for_each(|(x, y)| *x = *x * *y + *constants.add(cst as usize)) + } + Op::SwapBC => b.iter_mut().zip(c.iter_mut()).for_each(|(b, c)| std::mem::swap(b, c)), + Op::Floor => a.iter_mut().for_each(|x| *x = x.floor()), + Op::TwoPowOfInt => { + a.iter_mut().for_each(|x| *x = f32::from_bits((((*x as i32) + 127) as u32) << 23)) + } + } + pc = pc.add(1); + } +} + +#[derive(Clone, Debug)] +pub struct SActivations; + +impl ActivationKer for SActivations { + fn name() -> &'static str { + "generic" + } + + fn alignment_bytes() -> usize { + 16 + } + + fn alignment_items() -> usize { + 4 + } + + fn nr() -> usize { + 4 + } + + fn run(ops: &[Op], csts: &[f32], xs: &mut [f32]) { + debug_assert!(xs.len() % Self::nr() == 0); + debug_assert!(xs.as_ptr() as usize % Self::alignment_bytes() == 0); + unsafe { compute_slice(ops.as_ptr(), csts.as_ptr(), xs.as_mut_ptr(), xs.len()) }; + } +} + +#[cfg(test)] +act_frame_tests!(true, SActivations, f32); + diff --git a/linalg/src/lib.rs b/linalg/src/lib.rs index ea66629e8e..e085baf4f1 100644 --- a/linalg/src/lib.rs +++ b/linalg/src/lib.rs @@ -156,7 +156,8 @@ pub trait LADatum: + 'static + Add + Sub - + Mul + + Mul + + Div + AddAssign + PartialOrd + Bounded