diff --git a/linalg/arm64/arm64simd/arm64simd_act_f32_32n.tmpl b/linalg/arm64/arm64simd/arm64simd_act_f32_32n.tmpl index fcd04311e6..d82df4bee3 100644 --- a/linalg/arm64/arm64simd/arm64simd_act_f32_32n.tmpl +++ b/linalg/arm64/arm64simd/arm64simd_act_f32_32n.tmpl @@ -339,8 +339,41 @@ fmax v7.4s, v7.4s, v24.4s b .inner_loop -.fma: - b .unsupported +.fma: + // a <- a * b + k + // vfma a,b,c does a <- a + b * c + // mov d,a ; mov a,#k ; vfma a, b, d + + and v24.16b, v0.16b, v0.16b + and v25.16b, v1.16b, v1.16b + and v26.16b, v2.16b, v2.16b + and v27.16b, v3.16b, v3.16b + and v28.16b, v4.16b, v4.16b + and v29.16b, v5.16b, v5.16b + and v30.16b, v6.16b, v6.16b + and v31.16b, v7.16b, v7.16b + + ins v0.s[0], w3 + add x5, x5, 4 + dup v0.4s, v0.s[0] + dup v1.4s, v0.s[0] + dup v2.4s, v0.s[0] + dup v3.4s, v0.s[0] + dup v4.4s, v0.s[0] + dup v5.4s, v0.s[0] + dup v6.4s, v0.s[0] + dup v7.4s, v0.s[0] + + fmla v0.4s, v24.4s, v8.4s + fmla v1.4s, v25.4s, v9.4s + fmla v2.4s, v26.4s, v10.4s + fmla v3.4s, v27.4s, v11.4s + fmla v4.4s, v28.4s, v12.4s + fmla v5.4s, v29.4s, v13.4s + fmla v6.4s, v30.4s, v14.4s + fmla v7.4s, v31.4s, v15.4s + + b .inner_loop .if_pos_then_else: fcmge v0.4s, v0.4s, #0.0 @@ -362,7 +395,34 @@ b .inner_loop .swap_b_c: - b .unsupported +// move d <- b + and v24.16b, v8.16b , v8.16b + and v25.16b, v9.16b , v9.16b + and v26.16b, v10.16b, v10.16b + and v27.16b, v11.16b, v11.16b + and v28.16b, v12.16b, v12.16b + and v29.16b, v13.16b, v13.16b + and v30.16b, v14.16b, v14.16b + and v31.16b, v15.16b, v15.16b +// move b <- c + and v8.16b , v16.16b, v16.16b + and v9.16b , v17.16b, v17.16b + and v10.16b, v18.16b, v18.16b + and v11.16b, v19.16b, v19.16b + and v12.16b, v20.16b, v20.16b + and v13.16b, v21.16b, v21.16b + and v14.16b, v22.16b, v22.16b + and v15.16b, v23.16b, v23.16b +// move c <- d + and v16.16b, v24.16b, v24.16b + and v17.16b, v25.16b, v25.16b + and v18.16b, v26.16b, v26.16b + and v19.16b, v27.16b, v27.16b + and v20.16b, v28.16b, v28.16b + and v21.16b, v29.16b, v29.16b + and v22.16b, v30.16b, v30.16b + and v23.16b, v31.16b, v31.16b + b .inner_loop .floor: b .unsupported diff --git a/linalg/benches/activations.rs b/linalg/benches/activations.rs index e76fb2672a..9869aea59b 100644 --- a/linalg/benches/activations.rs +++ b/linalg/benches/activations.rs @@ -1,9 +1,11 @@ use criterion::{black_box, criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion}; use tract_linalg::frame::activations::{definitions, reference, ActivationKer, Program}; +const SIZES:&[i32] = &[32, 256, 1024, 8192]; + fn crit(c: &mut Criterion, name: &str, r: impl Fn(f32) -> f32, prog: &Program) { let mut group = c.benchmark_group(name); - for size in [1i32, 32, 256, 1024, 8192].iter() { + for size in SIZES { group.throughput(criterion::Throughput::Elements(*size as u64)); group.bench_with_input(BenchmarkId::new("Reference", size), size, |b, size| { b.iter_batched( @@ -14,7 +16,7 @@ fn crit(c: &mut Criterion, name: &str, r: impl Fn(f32) -> f32, prog: &Program f32, prog: &Program f32, prog: &Program() -> Program { pub fn affine(alpha: T, beta: T) -> Program { Program { #[rustfmt::skip] - ops: vec![ - MulConst(alpha), - AddConst(beta), - ], + ops: vec![ + MulConst(alpha), + AddConst(beta), + ], } } pub fn leaky_relu(alpha: T) -> Program { Program { #[rustfmt::skip] - ops: vec![ - Move(B,A), - MulConst(alpha), - Move(C,A), - Move(A,B), - IfPosTE, - ], + ops: vec![ + Move(B,A), + MulConst(alpha), + Move(C,A), + Move(A,B), + IfPosTE, + ], } } pub fn threshold_relu(alpha: T) -> Program { Program { #[rustfmt::skip] - ops: vec![ - Move(B,A), - SubConst(alpha), - Load(C, T::zero()), - IfPosTE, - ], + ops: vec![ + Move(B,A), + SubConst(alpha), + Load(C, T::zero()), + IfPosTE, + ], } } pub fn hard_sigmoid(alpha: T, beta: T) -> Program { Program { #[rustfmt::skip] - ops: vec![ - MulConst(alpha), - AddConst(beta), - MinConst(T::one()), - MaxConst(T::zero()), - ], + ops: vec![ + MulConst(alpha), + AddConst(beta), + MinConst(T::one()), + MaxConst(T::zero()), + ], } } pub fn softsign() -> Program { Program { #[rustfmt::skip] - ops: vec![ - Move(B,A), - Abs, - AddConst(T::one()), - Recip, - Mul, - ], + ops: vec![ + Move(B,A), + Abs, + AddConst(T::one()), + Recip, + Mul, + ], } } @@ -71,104 +71,88 @@ pub fn hard_swish() -> Program { let one_half = T::one() / (T::one() + T::one()); Program { #[rustfmt::skip] - ops: vec![ - Move(B, A), - MulConst(one_sixth), - AddConst(one_half), - MinConst(T::one()), - MaxConst(T::zero()), - Mul, - ], + ops: vec![ + Move(B, A), + MulConst(one_sixth), + AddConst(one_half), + MinConst(T::one()), + MaxConst(T::zero()), + Mul, + ], } } -/* -pub fn sigmoid() -> Program { +pub fn sigmoid() -> Program { Program { - #[rustfmt::skip] - ops: vec![ - MinConst(3), - MaxConst(2), - Move(B, A), // b = x - Move(C, A), // c = x - Mul, // a = x2 - Move(B, A), // b = x2 - MulConst(4), - AddConst(5), // a = x2 * a13 + a11 - FMA(6), - FMA(7), - FMA(8), - FMA(9), - FMA(10), - SwapBC, // c = x2, b = x - Mul, // a = p(x) - Move(B, C), // b = x2 - Move(C, A), // c = p(x) - Move(A, B), // a = x2 - MulConst(11), - AddConst(12), - FMA(13), - FMA(1), // a = q(x) - Recip, - Move(B,C), // b = p(x) - Mul, - AddConst(14) - ], - csts: vec![ - -18.6, // const 2 - 18.6, // const 3 - -4.433153405e-18, // const 4, also alpha_13 - 1.169974371e-14, // const 5, also a11 - -1.875289645e-11, - 4.257889523e-8, - 0.00004811817576, // const 8 - 0.008163842030, - 0.2499999971, // alpha_1 - 3.922935744e-6, // beta_6 - 0.001524872358, // const 12 - 0.1159886749, - 0.5, //beta_0 + ops: vec![ + MaxConst(-18.6), // const 2 + MinConst(18.6), // const 3 + Move(B, A), // b = x + Move(C, A), // c = x + Mul, // a = x2 + Move(B, A), // b = x2 + MulConst(-4.433153405e-18), // const 4, also alpha_13 + AddConst(1.169974371e-14), // const 5, also a11 + FMA(-1.875289645e-11), + FMA(4.257889523e-8), + FMA(0.00004811817576), // const 8 + FMA(0.008163842030), + FMA(0.2499999971), // alpha_1 + SwapBC, // c = x2, b = x + Mul, // a = p(x) + Move(B, C), // b = x2 + Move(C, A), // c = p(x) + Move(A, B), // a = x2 + MulConst(3.922935744e-6), // beta_6 + AddConst(0.001524872358), // const 12 + FMA(0.1159886749), + FMA(1.0), // a = q(x) + Recip, + Move(B, C), // b = p(x) + Mul, + AddConst(0.5), //beta_0 ], } } +/* pub fn exp2f() -> Program { - Program { - #[rustfmt::skip] - ops: vec![ - MinConst(2), - MaxConst(3), - Move(B, A), // b = x - AddConst(4), // a = x + 0.5 - Floor, // a = ipart - Move(C, A), // c = ipart - Move(A, B), // a = x - Move(B, C), // b = ipart - Sub, // a = fpart - Move(B, A), // b = fpart - Load(A, 5), // a = exp2p[0] - FMA(6), - FMA(7), - FMA(8), - FMA(9), - FMA(10), - FMA(1), // a = y - Move(B, A), - Move(A, C), - TwoPowOfInt, - Mul - ], - csts: vec![ - 127f32, - -127f32, - 0.5, - 1.535336188319500e-4, - 1.339887440266574e-3, - 9.618437357674640e-3, - 5.550332471162809e-2, - 2.402264791363012e-1, - 6.931472028550421e-1, - ], - } +Program { +#[rustfmt::skip] +ops: vec![ +MinConst(2), +MaxConst(3), +Move(B, A), // b = x +AddConst(4), // a = x + 0.5 +Floor, // a = ipart +Move(C, A), // c = ipart +Move(A, B), // a = x +Move(B, C), // b = ipart +Sub, // a = fpart +Move(B, A), // b = fpart +Load(A, 5), // a = exp2p[0] +FMA(6), +FMA(7), +FMA(8), +FMA(9), +FMA(10), +FMA(1), // a = y +Move(B, A), +Move(A, C), +TwoPowOfInt, +Mul +], +csts: vec![ +127f32, +-127f32, +0.5, +1.535336188319500e-4, +1.339887440266574e-3, +9.618437357674640e-3, +5.550332471162809e-2, +2.402264791363012e-1, +6.931472028550421e-1, +], +} } */ diff --git a/linalg/src/frame/activations/tests.rs b/linalg/src/frame/activations/tests.rs index f4331d6e55..2987a0f774 100644 --- a/linalg/src/frame/activations/tests.rs +++ b/linalg/src/frame/activations/tests.rs @@ -189,10 +189,28 @@ macro_rules! act_tests { run_kernel_test::<$ti, $ker>(&x, &[Load(RegisterId::B, 2 as _), Load(RegisterId::C, 3 as _), IfPosTE], |x| if x >= <$ti>::zero() { 2 as _ } else { 3 as _ }); - } } } + #[test] + fn swapbc_prop(x in x_strat()) { + if $cond { + run_kernel_test::<$ti, $ker>(&x, + &[Load(RegisterId::B, 2 as _), Load(RegisterId::C, 3 as _), SwapBC, IfPosTE], + |x| if x >= <$ti>::zero() { 3 as _ } else { 2 as _ }); + } + } + + #[test] + fn fma_prop(x in x_strat(), b in any::<$ti>(), k in any::<$ti>()) { + if $cond { + run_kernel_test::<$ti, $ker>(&x, + &[Load(RegisterId::B, b), FMA(k)], + |x| x * b + k); + } + } + } + #[test] fn max_const_zero() { if $cond { @@ -272,7 +290,7 @@ macro_rules! act_tests { } #[test] - fn hard_sigmoid(x in x_strat(), alpha in any::<$ti>(), beta in any::<$ti>()) { + fn hard_sigmoid_prop(x in x_strat(), alpha in any::<$ti>(), beta in any::<$ti>()) { if $cond { run_kernel_test::<$ti, $ker>( &x, @@ -283,7 +301,7 @@ macro_rules! act_tests { } #[test] - fn softsign(x in x_strat()) { + fn softsign_prop(x in x_strat()) { if $cond { run_kernel_test::<$ti, $ker>( &x, @@ -294,7 +312,7 @@ macro_rules! act_tests { } #[test] - fn hard_swish(x in x_strat()) { + fn hard_swish_prop(x in x_strat()) { if $cond { run_kernel_test::<$ti, $ker>( &x, @@ -303,6 +321,17 @@ macro_rules! act_tests { ); } } + + #[test] + fn sigmoid_prop(x in x_strat()) { + if $cond { + run_kernel_test::<$ti, $ker>( + &x, + &$crate::frame::activations::definitions::sigmoid().ops, + crate::generic::sigmoid::ssigmoid + ); + } + } } } };