Skip to content

Commit

Permalink
moving everything inside tract
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Apr 10, 2023
1 parent f981c84 commit ad67de7
Show file tree
Hide file tree
Showing 9 changed files with 568 additions and 3 deletions.
4 changes: 2 additions & 2 deletions linalg/activations/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,8 @@ mod test {
}

mod scalar {
use proptest::prelude::*;
use super::close_enough;
use proptest::prelude::*;

macro_rules! prop_activation {
($name: ident ( $($param:ident),* )) => {
Expand All @@ -186,8 +186,8 @@ mod test {
}

mod vector {
use proptest::prelude::*;
use super::close_enough;
use proptest::prelude::*;

macro_rules! prop_activation {
($name: ident ( $($param:ident),* )) => {
Expand Down
2 changes: 2 additions & 0 deletions linalg/src/frame.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#[macro_use]
pub mod activations;
#[macro_use]
pub mod element_wise;
#[macro_use]
pub mod lut;
Expand Down
143 changes: 143 additions & 0 deletions linalg/src/frame/activations.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
use std::fmt::Debug;
use std::marker::PhantomData;

use tract_data::TractResult;

use crate::LADatum;

use super::element_wise_helper::run_over_slice_with_alignment;

pub mod definitions;
pub mod reference;
#[macro_use]
pub mod tests;

#[derive(Copy, Clone, Debug, PartialEq)]
#[repr(u8)]
pub enum RegisterId {
A = 0,
B = 1,
C = 2,
}

type ConstantId = u8;

#[repr(C, u16)]
#[derive(Copy, Clone, Debug, PartialEq)]
pub enum Op {
Done,
Move(RegisterId, RegisterId),
Load(RegisterId, ConstantId),
Abs,
Recip,
Add,
Sub,
Mul,
Min,
Max,
AddConst(ConstantId),
SubConst(ConstantId),
MulConst(ConstantId),
MinConst(ConstantId),
MaxConst(ConstantId),
FMA(ConstantId), // a <- a * b + cst
IfPosTE,
SwapBC,
Floor,
TwoPowOfInt,
}

#[derive(Clone, Debug, PartialEq)]
pub struct Program<T: LADatum> {
pub ops: Vec<Op>,
pub csts: Vec<T>,
}

pub trait Activation<T: LADatum>: Send + Sync + Debug + dyn_clone::DynClone {
fn run(&self, prog: &Program<T>, vec: &mut [T]) -> TractResult<()>;
}

#[derive(Debug, Clone, new)]
pub struct ActivationImpl<K, T>
where
T: LADatum,
K: ActivationKer<T> + Clone,
{
phantom: PhantomData<(K, T)>,
}

impl<K, T> Activation<T> for ActivationImpl<K, T>
where
T: LADatum,
K: ActivationKer<T> + Clone,
{
fn run(&self, program: &Program<T>, vec: &mut [T]) -> TractResult<()> {
run_over_slice_with_alignment(
vec,
|slice| K::run(&program.ops, &*program.csts, slice),
K::nr(),
K::alignment_bytes(),
)
}
}

pub trait ActivationKer<T>: Send + Sync + Debug + dyn_clone::DynClone + Clone + 'static
where
T: LADatum,
{
fn name() -> &'static str;
fn alignment_bytes() -> usize;
fn alignment_items() -> usize;
fn nr() -> usize;
fn run(ops: &[Op], csts: &[T], vec: &mut [T]);
fn act() -> Box<dyn Activation<T>> {
Box::new(ActivationImpl::<Self, T>::new())
}
}

macro_rules! act_impl {
($ti: ident, $func: ident, $nr: expr, $alignment_items: expr) => {
paste! {
mod [<sys_ $func>] {
#[allow(unused_imports)]
use tract_data::prelude::f16;
extern_kernel!(fn $func(ptr: *mut $ti, count: usize) -> ());
}

#[derive(Copy, Clone, Debug)]
#[allow(non_camel_case_types)]
pub struct $func;

impl ActivationKer<$ti> for $func {
#[inline(always)]
fn name() -> &'static str {
stringify!($func)
}
#[inline(always)]
fn nr() -> usize {
$nr
}
#[inline(always)]
fn alignment_items() -> usize {
$alignment_items
}
#[inline(always)]
fn alignment_bytes() -> usize {
$alignment_items * std::mem::size_of::<$ti>()
}
#[inline(never)]
fn run(ops: &Op, csts:&[T], buf: &mut [$ti]) {
unsafe { [<sys_ $func>]::$func(ops.as_ptr(), csts.as_ptr(), buf.as_mut_ptr(), buf.len()) }
}
}
}
};
}

#[cfg(test)]
mod test {
#[test]
fn size_of_op() {
assert_eq!(std::mem::size_of::<super::Op>(), 4);
}
}
175 changes: 175 additions & 0 deletions linalg/src/frame/activations/definitions.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
use super::Op::*;
use super::RegisterId::*;
use super::*;

pub fn relu<T: LADatum>() -> Program<T> {
Program { ops: vec![MaxConst(0), Done], csts: vec![T::zero()] }
}

pub fn affine<T: LADatum>(alpha: T, beta: T) -> Program<T> {
Program {
#[rustfmt::skip]
ops: vec![
MulConst(0),
AddConst(1),
Done,
],
csts: vec![alpha, beta],
}
}

pub fn leaky_relu<T: LADatum>(alpha: T) -> Program<T> {
Program {
#[rustfmt::skip]
ops: vec![
Move(B,A),
MulConst(0),
Move(C,A),
Move(A,B),
IfPosTE,
Done,
],
csts: vec![alpha],
}
}

pub fn threshold_relu<T: LADatum>(alpha: T) -> Program<T> {
Program {
#[rustfmt::skip]
ops: vec![
Move(B,A),
SubConst(1),
Load(C,0),
IfPosTE,
Done,
],
csts: vec![T::zero(), alpha],
}
}

pub fn softsign<T: LADatum>() -> Program<T> {
Program {
#[rustfmt::skip]
ops: vec![
Move(B,A),
Abs,
AddConst(0),
Recip,
Mul,
Done,
],
csts: vec![T::one()],
}
}

pub fn hardswish<T: LADatum>() -> Program<T> {
Program {
#[rustfmt::skip]
ops: vec![
Move(B, A),
MulConst(2),
AddConst(3),
MinConst(1),
MaxConst(0),
Mul,
Done,
],
csts: vec![
T::zero(),
T::one(),
T::one() / (T::one() + T::one() + T::one() + T::one() + T::one() + T::one()), // 1/6
T::one() / (T::one() + T::one()), // 1/2
],
}
}

/*
pub fn sigmoid() -> Program {
Program {
#[rustfmt::skip]
ops: vec![
MinConst(3),
MaxConst(2),
Move(B, A), // b = x
Move(C, A), // c = x
Mul, // a = x2
Move(B, A), // b = x2
MulConst(4),
AddConst(5), // a = x2 * a13 + a11
FMA(6),
FMA(7),
FMA(8),
FMA(9),
FMA(10),
SwapBC, // c = x2, b = x
Mul, // a = p(x)
Move(B, C), // b = x2
Move(C, A), // c = p(x)
Move(A, B), // a = x2
MulConst(11),
AddConst(12),
FMA(13),
FMA(1), // a = q(x)
Recip,
Move(B,C), // b = p(x)
Mul,
AddConst(14)
],
csts: vec![
-18.6, // const 2
18.6, // const 3
-4.433153405e-18, // const 4, also alpha_13
1.169974371e-14, // const 5, also a11
-1.875289645e-11,
4.257889523e-8,
0.00004811817576, // const 8
0.008163842030,
0.2499999971, // alpha_1
3.922935744e-6, // beta_6
0.001524872358, // const 12
0.1159886749,
0.5, //beta_0
],
}
}
pub fn exp2f() -> Program {
Program {
#[rustfmt::skip]
ops: vec![
MinConst(2),
MaxConst(3),
Move(B, A), // b = x
AddConst(4), // a = x + 0.5
Floor, // a = ipart
Move(C, A), // c = ipart
Move(A, B), // a = x
Move(B, C), // b = ipart
Sub, // a = fpart
Move(B, A), // b = fpart
Load(A, 5), // a = exp2p[0]
FMA(6),
FMA(7),
FMA(8),
FMA(9),
FMA(10),
FMA(1), // a = y
Move(B, A),
Move(A, C),
TwoPowOfInt,
Mul
],
csts: vec![
127f32,
-127f32,
0.5,
1.535336188319500e-4,
1.339887440266574e-3,
9.618437357674640e-3,
5.550332471162809e-2,
2.402264791363012e-1,
6.931472028550421e-1,
],
}
}
*/
Loading

0 comments on commit ad67de7

Please sign in to comment.