diff --git a/linalg/arm64/arm64simd/arm64simd_act_f32_32n.tmpl b/linalg/arm64/arm64simd/arm64simd_act_f32_32n.tmpl new file mode 100644 index 0000000000..2462e373b7 --- /dev/null +++ b/linalg/arm64/arm64simd/arm64simd_act_f32_32n.tmpl @@ -0,0 +1,29 @@ +// vim: ft=arm + +// C tile regs: v16 to v31, (scratch) +// - x19-x29 to preserve (but x19, x28, x29 not used) +// - d8..d15 to preserve +// - v16 to v31, no need to preserve + +.text +.align 4 + +.cpu generic+fp+simd +.global {{G}}arm64simd_act_f32_32n_{{suffix}} +{{G}}arm64simd_act_f32_32n_{{suffix}}: + + stp d8, d9, [sp, #-16]! + stp d10, d11, [sp, #-16]! + stp d12, d13, [sp, #-16]! + stp d14, d15, [sp, #-16]! + + mov x0, 0 +// b .return + +.return: + ldp d14, d15, [sp], #16 + ldp d12, d13, [sp], #16 + ldp d10, d11, [sp], #16 + ldp d8, d9, [sp], #16 + + ret diff --git a/linalg/src/arm64/arm64simd.rs b/linalg/src/arm64/arm64simd.rs index ab03e9654b..8ec9d8618e 100644 --- a/linalg/src/arm64/arm64simd.rs +++ b/linalg/src/arm64/arm64simd.rs @@ -44,3 +44,6 @@ sigmoid_impl!(f32, arm64simd_sigmoid_f32_4n, 4, 4, true); tanh_impl!(f16, arm64fp16_tanh_f16_8n, 8, 8, crate::arm64::has_fp16()); #[cfg(not(feature="no_fp16"))] sigmoid_impl!(f16, arm64fp16_sigmoid_f16_8n, 8, 8, crate::arm64::has_fp16()); + +act_impl!(f32, arm64simd_act_f32_32n, 32, 4, true); + diff --git a/linalg/src/frame/activations.rs b/linalg/src/frame/activations.rs index 8b604b4998..5586c16973 100644 --- a/linalg/src/frame/activations.rs +++ b/linalg/src/frame/activations.rs @@ -96,19 +96,20 @@ where } macro_rules! act_impl { - ($ti: ident, $func: ident, $nr: expr, $alignment_items: expr) => { + ($ti: ident, $func: ident, $nr: expr, $alignment_items: expr, $cond: expr) => { paste! { mod [] { #[allow(unused_imports)] use tract_data::prelude::f16; - extern_kernel!(fn $func(ptr: *mut $ti, count: usize) -> ()); + use crate::frame::activations::Op; + extern_kernel!(fn $func(ops: *const Op, constants: *const $ti, xs: *mut $ti, len: usize) -> usize); } #[derive(Copy, Clone, Debug)] #[allow(non_camel_case_types)] pub struct $func; - impl ActivationKer<$ti> for $func { + impl $crate::frame::activations::ActivationKer<$ti> for $func { #[inline(always)] fn name() -> &'static str { stringify!($func) @@ -126,10 +127,14 @@ macro_rules! act_impl { $alignment_items * std::mem::size_of::<$ti>() } #[inline(never)] - fn run(ops: &Op, csts:&[T], buf: &mut [$ti]) { - unsafe { []::$func(ops.as_ptr(), csts.as_ptr(), buf.as_mut_ptr(), buf.len()) } + fn run(ops: &[$crate::frame::activations::Op], csts:&[$ti], buf: &mut [$ti]) { + let err = unsafe { []::$func(ops.as_ptr(), csts.as_ptr(), buf.as_mut_ptr(), buf.len()) }; + assert_eq!(err, 0); } } + + #[cfg(test)] + act_tests!($cond, $func, $ti); } }; } diff --git a/linalg/src/frame/activations/tests.rs b/linalg/src/frame/activations/tests.rs index 197608ff2e..fb916e856c 100644 --- a/linalg/src/frame/activations/tests.rs +++ b/linalg/src/frame/activations/tests.rs @@ -1,8 +1,22 @@ -macro_rules! prop_activation { +use crate::LADatum; + +use super::{Program, Op}; +use Op::*; + +pub fn noop() -> Program { + Program { ops: vec![Done], csts: vec![] } +} + +macro_rules! prop_act_e2e { ($cond:expr, $ti: ty, $ker: ty, $name: ident ( $($param:ident),* )) => { proptest::proptest! { #[test] - fn $name(x in proptest::prelude::any::<$ti>(), repeat in 1usize..4, $($param in proptest::prelude::any::<$ti>()),*) { + fn $name( + x in proptest::prelude::any::<$ti>(), + repeat in 1usize..4, + $($param in proptest::prelude::any::<$ti>()),*) + { + use crate::frame::activations::ActivationKer; if $cond { let mut input = tract_data::prelude::Tensor::zero_aligned::<$ti>(&[<$ker>::nr() * repeat], <$ker>::alignment_bytes()).unwrap(); input.fill_t::<$ti>(x).unwrap(); @@ -18,18 +32,47 @@ macro_rules! prop_activation { } } +macro_rules! prop_act_unit { + ($cond:expr, $ti: ty, $ker: ty, $name: ident ( $($param:ident),* ), $refer: expr) => { + proptest::proptest! { + #[test] + fn $name( + x in proptest::prelude::any::<$ti>(), + repeat in 1usize..4, + $($param in proptest::prelude::any::<$ti>()),*) + { + use crate::frame::activations::ActivationKer; + if $cond { + let mut input = tract_data::prelude::Tensor::zero_aligned::<$ti>(&[<$ker>::nr() * repeat], <$ker>::alignment_bytes()).unwrap(); + input.fill_t::<$ti>(x).unwrap(); + let refer2: fn($ti) -> $ti = $refer; + let expected:Vec<$ti> = input.as_slice::<$ti>().unwrap().iter().cloned().map(refer2).collect(); + let prog = crate::frame::activations::tests::$name($($param),*); + <$ker>::run(&prog.ops, &prog.csts, &mut input.as_slice_mut::<$ti>().unwrap()); + + let expected = tract_data::prelude::tensor1(&expected); + expected.close_enough(&input, true).unwrap(); + } + } + } + } +} + #[macro_export] -macro_rules! act_frame_tests { +macro_rules! act_tests { ($cond:expr, $ker:ty, $ti:ty) => { - prop_activation!($cond, $ti, $ker, relu()); - prop_activation!($cond, $ti, $ker, affine(alpha, beta)); - prop_activation!($cond, $ti, $ker, leaky_relu(alpha)); - prop_activation!($cond, $ti, $ker, threshold_relu(alpha)); - prop_activation!($cond, $ti, $ker, softsign()); - prop_activation!($cond, $ti, $ker, hardswish()); + prop_act_unit!($cond, $ti, $ker, noop(), |x| x); + + prop_act_e2e!($cond, $ti, $ker, relu()); + prop_act_e2e!($cond, $ti, $ker, affine(alpha, beta)); + prop_act_e2e!($cond, $ti, $ker, leaky_relu(alpha)); + prop_act_e2e!($cond, $ti, $ker, threshold_relu(alpha)); + prop_act_e2e!($cond, $ti, $ker, softsign()); + prop_act_e2e!($cond, $ti, $ker, hardswish()); /* prop_activation!($cond, $ti, $ker, sigmoid()); prop_activation!($cond, $ti, $ker, exp2f()); */ }; } + diff --git a/linalg/src/generic/activations.rs b/linalg/src/generic/activations.rs index c76eace069..48fccc8cf3 100644 --- a/linalg/src/generic/activations.rs +++ b/linalg/src/generic/activations.rs @@ -94,5 +94,20 @@ impl ActivationKer for SActivations { } #[cfg(test)] -act_frame_tests!(true, SActivations, f32); +act_tests!(true, SActivations, f32); +#[cfg(test)] +mod tests { + use crate::frame::activations::Op; + use crate::frame::activations::ActivationKer; + + use super::SActivations; + + #[test] + fn act_noop() { + let mut xs = vec!(1f32; SActivations::nr()); + let expect = xs.clone(); + SActivations::run(&[Op::Done], &[], &mut *xs); + assert_eq!(expect, xs); + } +}