Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NTT’s 6th optimization #1275

Open
wants to merge 30 commits into
base: dev/3.0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
6db55e9
add ranked expand func
Dec 3, 2024
f56e9c9
add ranked expand opt
Dec 3, 2024
361fa80
add 4d transpose benchmark cases
Dec 3, 2024
74bc603
add transpose ctest & ranked func
Dec 3, 2024
96fc976
revise bug for benchmark transpose
Dec 3, 2024
1ae8099
start ci
Dec 3, 2024
939dbe6
framework for transpose opt
Dec 3, 2024
345b773
tidy some code for common func
Dec 3, 2024
5f81188
Optimize rvv cast from bool to fp32 and update benchmark/roofline.
zhangyang2057 Dec 3, 2024
0f70676
Simplify rvv cast both float and bool and update roofline.
zhangyang2057 Dec 4, 2024
a9611fd
Fix and optimize rvv cast.
zhangyang2057 Dec 4, 2024
3bb9e3d
Update benchmark_ntt_transpose config for rvv.
zhangyang2057 Dec 5, 2024
a5407ae
add just for l1 limit
Dec 5, 2024
a3b853c
add extra warmup for test
Dec 5, 2024
0cdd3cc
change test orders
Dec 5, 2024
8f035d3
change more warm up
Dec 5, 2024
a98aa83
fallback some code
Dec 5, 2024
a390ad4
opt to memcpy again
Dec 5, 2024
b1fa9bc
test for extra warmup
Dec 5, 2024
894bcc9
Update benchmark_ntt_transpose and avx/rvv roofline.
zhangyang2057 Dec 5, 2024
809d07e
Merge branch 'dev/3.0' into feature/ntt_benchmark_roofline_6
zhangyang2057 Dec 5, 2024
882a9a3
Add 2D/3D/4D + unpacked/packed + fixed_shape/ranked_shape ctest for t…
zhangyang2057 Dec 5, 2024
7dda957
add some opt code for trans
Dec 9, 2024
d300eb6
fix build for macos
Dec 9, 2024
7227154
fallback for 1d transpose
Dec 9, 2024
6c943f1
opt some code to para pack
Dec 9, 2024
b916402
fix build error for clang
Dec 9, 2024
01d08e1
avoid inline for Debug build mode
Dec 9, 2024
6447520
Merge branch 'dev/3.0' into feature/ntt_benchmark_roofline_6
Dec 10, 2024
d77f5ae
Fix rvv transpose nchw took 0 cycle issue.
zhangyang2057 Dec 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ntt/cmake/compile_flags.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ if (MSVC)
string(REGEX REPLACE "/GR" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /GR-")
else()
add_compile_options(-Wno-multichar -Wno-unused-value -fno-common -ffunction-sections -fno-exceptions -fdata-sections -fno-unwind-tables -fno-asynchronous-unwind-tables -fno-stack-protector)
add_compile_options(-Wno-multichar -Wno-unused-value -fno-common -ffunction-sections -fno-exceptions -fdata-sections -fno-unwind-tables -fno-asynchronous-unwind-tables -fno-stack-protector -finline-functions)
zhangyang2057 marked this conversation as resolved.
Show resolved Hide resolved
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-rtti")

if (APPLE)
add_compile_options(-fno-stack-check -Wno-c++11-narrowing)
else()
add_compile_options(-Wnarrowing)
add_compile_options(-Wnarrowing -finline-limit=500)
endif()
endif()

Expand Down
76 changes: 46 additions & 30 deletions ntt/include/nncase/ntt/arch/riscv64/primitive_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -1417,38 +1417,10 @@ REGISTER_RVV_CLAMP_OP(float, clamp_float32)
return __riscv_vfcvt_f_xu_v_f32m##lmul(v, vl); \
}

#define CAST_FLOAT32_BOOL(lmul1, lmul2, mlen) \
inline vuint8m##lmul2##_t cast_float32_bool(const vfloat32m##lmul1##_t &v, \
const size_t vl) { \
auto zero = __riscv_vmv_v_x_u8m##lmul2(0, vl); \
auto mask = __riscv_vmfne_vf_f32m##lmul1##_b##mlen(v, 0.f, vl); \
return __riscv_vmerge_vxm_u8m##lmul2(zero, 1, mask, vl); \
}

inline vuint8m1_t cast_float32_bool(const vfloat32m1_t &v0,
const vfloat32m1_t &v1,
const vfloat32m1_t &v2,
const vfloat32m1_t &v3, const size_t vl) {
auto v = __riscv_vcreate_v_f32m1_f32m4(v0, v1, v2, v3);
auto zero = __riscv_vmv_v_x_u8m1(0, vl);
auto mask = __riscv_vmfne_vf_f32m4_b8(v, 0.f, vl);
return __riscv_vmerge_vxm_u8m1(zero, 1, mask, vl);
}

#define CAST_BOOL_FLOAT32(lmul1, lmul2, mlen) \
inline vfloat32m##lmul2##_t cast_bool_float32(const vuint8m##lmul1##_t &v, \
const size_t vl) { \
auto zero = __riscv_vfmv_v_f_f32m##lmul2(0.f, vl); \
auto mask = __riscv_vmsne_vx_u8m##lmul1##_b##mlen(v, 0, vl); \
return __riscv_vfmerge_vfm_f32m##lmul2(zero, 1.f, mask, vl); \
}

REGISTER_RVV_KERNEL(CAST_FLOAT32_INT32)
REGISTER_RVV_KERNEL(CAST_INT32_FLOAT32)
REGISTER_RVV_KERNEL(CAST_FLOAT32_UINT32)
REGISTER_RVV_KERNEL(CAST_UINT32_FLOAT32)
REGISTER_RVV_KERNEL_4_1(CAST_FLOAT32_BOOL)
REGISTER_RVV_KERNEL_1_4(CAST_BOOL_FLOAT32)

// register cast op
#define RVV_CAST_OP_1_1(from_dtype, to_dtype, vl, kernel) \
Expand Down Expand Up @@ -1492,7 +1464,51 @@ REGISTER_RVV_CAST_OP(float, int, cast_float32_int32)
REGISTER_RVV_CAST_OP(int, float, cast_int32_float32)
REGISTER_RVV_CAST_OP(float, unsigned int, cast_float32_uint32)
REGISTER_RVV_CAST_OP(unsigned int, float, cast_uint32_float32)
REGISTER_RVV_CAST_OP_4_1(float, bool, cast_float32_bool)
REGISTER_RVV_CAST_OP_1_4(bool, float, cast_bool_float32)

// cast float to bool
template <>
struct cast<ntt::vector<float, NTT_VL(sizeof(float) * 8, *, 1)>,
ntt::vector<bool, NTT_VL(sizeof(bool) * 8, *, 1)>> {
auto
operator()(const ntt::vector<float, NTT_VL(sizeof(float) * 8, *, 1)> &v0,
const ntt::vector<float, NTT_VL(sizeof(float) * 8, *, 1)> &v1,
const ntt::vector<float, NTT_VL(sizeof(float) * 8, *, 1)> &v2,
const ntt::vector<float, NTT_VL(sizeof(float) * 8, *, 1)> &v3)
const noexcept {
constexpr auto vl = NTT_VL(sizeof(bool) * 8, *, 1);
#if 0
auto src = __riscv_vcreate_v_f32m1_f32m4(v0, v1, v2, v3);
#else
auto src = __riscv_vundefined_f32m4();
src = __riscv_vset_v_f32m1_f32m4(src, 0, v0);
src = __riscv_vset_v_f32m1_f32m4(src, 1, v1);
src = __riscv_vset_v_f32m1_f32m4(src, 2, v2);
src = __riscv_vset_v_f32m1_f32m4(src, 3, v3);
#endif
auto zero = __riscv_vmv_v_x_u8m1(0, vl);
auto mask = __riscv_vmfne_vf_f32m4_b8(src, 0.f, vl);
return __riscv_vmerge_vxm_u8m1(zero, 1, mask, vl);
}
};

// cast bool to float
template <>
struct cast<ntt::vector<bool, NTT_VL(sizeof(bool) * 8, *, 1)>,
ntt::vector<float, NTT_VL(sizeof(float) * 8, *, 1)>> {
auto operator()(const ntt::vector<bool, NTT_VL(sizeof(bool) * 8, *, 1)> &v)
const noexcept {
constexpr auto vl = NTT_VL(sizeof(float) * 8, *, 1);
auto mask = __riscv_vreinterpret_v_u8m1_b8(v);
ntt::vector<float, 4, vl> output;
auto zero = __riscv_vfmv_v_f_f32m4(0.f, vl);
auto dst = __riscv_vfmerge_vfm_f32m4(zero, 1.f, mask, vl);
output(0) = __riscv_vget_v_f32m4_f32m1(dst, 0);
output(1) = __riscv_vget_v_f32m4_f32m1(dst, 1);
output(2) = __riscv_vget_v_f32m4_f32m1(dst, 2);
output(3) = __riscv_vget_v_f32m4_f32m1(dst, 3);
return output;
}
};

#endif
} // namespace nncase::ntt::ops
2 changes: 1 addition & 1 deletion ntt/include/nncase/ntt/arch/riscv64/ukernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ template <reduce_op Op, class T> struct u_reduce_policy<Op, T, true> {
};

// cast
template <> struct u_cast_policy<true> { static constexpr size_t unroll = 8; };
template <> struct u_cast_policy<true> { static constexpr size_t unroll = 4; };

// matmul
template <>
Expand Down
Loading
Loading