Skip to content

Commit

Permalink
gpu: sycl: binary: add support for remaining post ops
Browse files Browse the repository at this point in the history
  • Loading branch information
t4c1 committed May 9, 2024
1 parent 7bdb0e1 commit 170bfaf
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 17 deletions.
111 changes: 107 additions & 4 deletions src/gpu/sycl/binary_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,24 @@ struct binary_kernel_vec_t {
xpu::sycl::in_memory_arg_t &src0, xpu::sycl::in_memory_arg_t &src1,
xpu::sycl::out_memory_arg_t &dst,
xpu::sycl::in_memory_arg_t &src0_scale,
xpu::sycl::in_memory_arg_t &src1_scale, data_type_t scales_dt)
xpu::sycl::in_memory_arg_t &src1_scale, data_type_t scales_dt,
xpu::sycl::in_memory_arg_t &po1_src,
xpu::sycl::in_memory_arg_t &po2_src,
xpu::sycl::in_memory_arg_t &po3_src,
xpu::sycl::in_memory_arg_t &po4_src,
xpu::sycl::in_memory_arg_t &po5_src)
: conf_(conf)
, src0_(src0)
, src1_(src1)
, dst_(dst)
, src0_scale_(src0_scale)
, src1_scale_(src1_scale)
, scales_dt_(scales_dt) {}
, scales_dt_(scales_dt)
, po1_src_(po1_src)
, po2_src_(po2_src)
, po3_src_(po3_src)
, po4_src_(po4_src)
, po5_src_(po5_src) {}

void operator()(::sycl::nd_item<1> item) const {
auto sg = item.get_sub_group();
Expand Down Expand Up @@ -73,7 +83,7 @@ struct binary_kernel_vec_t {
any_broadcast |= conf_.broadcast_dims[i];
}
}
if (!any_broadcast
if (!any_broadcast && conf_.post_ops.get_post_op() == 0
&& sg_base_idx + (sg.get_local_range()[0] * conf_.block_size)
< conf_.wk_size) {
for (int i = 0; i < conf_.block_size / vec_len; i++) {
Expand Down Expand Up @@ -123,7 +133,8 @@ struct binary_kernel_vec_t {
if (conf_.do_scale_src1) src1 *= sm_1;

auto acc = compute_alg_n(src0, src1, conf_.alg_kind);
acc = conf_.post_ops.apply(acc, dst);
::sycl::vec<float, 16> post_po_sr = post_op_src_val(idx);
acc = conf_.post_ops.apply(acc, dst, post_po_sr);
store_float_value(
dst_md().data_type(), acc, dst_ptr(), idx);
}
Expand All @@ -146,6 +157,93 @@ struct binary_kernel_vec_t {
return static_cast<float *>(src1_scale_.get_pointer());
}

inline ::sycl::vec<float, 16> post_op_src_val(dim_t data_l_off) const {
::sycl::vec<float, 16> post_po_sr;
const auto maxPostPo = conf_.post_ops.get_post_op();

for (dim_t po_idx = 0; po_idx < maxPostPo; po_idx++) {
float res = 0.0f;
if (po_idx == 0)
res = get_post_op_val(po1_src_, po_idx, data_l_off);
else if (po_idx == 1)
res = get_post_op_val(po2_src_, po_idx, data_l_off);
else if (po_idx == 2)
res = get_post_op_val(po3_src_, po_idx, data_l_off);
else if (po_idx == 3)
res = get_post_op_val(po4_src_, po_idx, data_l_off);
else if (po_idx == 4)
res = get_post_op_val(po5_src_, po_idx, data_l_off);

post_po_sr[po_idx] = res;
}
return post_po_sr;
}

float get_post_op_val(const xpu::sycl::in_memory_arg_t &bin_src_op,
dim_t &idx, dim_t offset) const {
auto src1_desc = conf_.binary_src_arr[idx];

const auto off = get_binary_src1_off(
src1_desc, offset, dst_md().dims(), dst_md().ndims());

auto dst = load_float_value(
src1_desc.data_type(), bin_src_op.get_pointer(), off);
return dst;
}

dim_t get_binary_src1_off(const xpu::sycl::md_t &src1_md, dim_t l_offset,
const xpu::sycl::md_t::dims32_t &dst_dims,
const xpu::sycl::md_t::dim32_t &dst_ndims) const {
const dim_t mask_binary_po
= get_dims_mask(dst_dims, src1_md.dims(), dst_ndims);
return get_po_tensor_off(
src1_md, l_offset, dst_dims, dst_ndims, mask_binary_po);
}

inline dim_t get_dims_mask(const xpu::sycl::md_t::dims32_t &dims1,
const xpu::sycl::md_t::dims32_t &dims2, const dim_t &ndims,
bool skip_dim_of_one = false) const {
dim_t mask = 0;
for (dim_t d = 0; d < ndims; ++d) {
// Disable mask_bit for dimensions of `1` by request.
dim_t mask_bit = skip_dim_of_one && dims1[d] == 1 ? 0 : (1 << d);
mask += dims1[d] == dims2[d] ? mask_bit : 0;
}
return mask;
}

inline dim_t get_po_tensor_off(const xpu::sycl::md_t &tensor_md,
dim_t l_offset, const xpu::sycl::md_t::dims32_t &dst_dims,
const dim_t &dst_ndims, const dim_t &mask) const {
dims_t l_dims_po {};
get_l_dims_po(l_dims_po, l_offset, dst_dims, dst_ndims, mask);

return tensor_md.off_v(l_dims_po);
}

inline void get_l_dims_po(dims_t l_dims_po, dim_t l_offset,
const xpu::sycl::md_t::dims32_t &dst_dims, const dim_t &dst_ndims,
const dim_t &mask) const {

l_dims_by_l_offset(l_dims_po, l_offset, dst_dims, dst_ndims);
utils::apply_mask_on_dims(l_dims_po, dst_ndims, mask);
}

inline void l_dims_by_l_offset(dims_t dims_pos, dim_t l_offset,
const xpu::sycl::md_t::dims32_t &dims, const dim_t &ndims) const {
for (dim_t rd = 0; rd < ndims; ++rd) {
const dim_t d = ndims - 1 - rd;
/* switch to faster 32-bit division when possible. */
if (l_offset <= INT32_MAX && dims[d] <= INT32_MAX) {
dims_pos[d] = (int32_t)l_offset % (int32_t)dims[d];
l_offset = (int32_t)l_offset / (int32_t)dims[d];
} else {
dims_pos[d] = l_offset % dims[d];
l_offset /= dims[d];
}
}
}

template <int width>
::sycl::vec<float, width> compute_alg(::sycl::vec<float, width> src0,
::sycl::vec<float, width> src1, alg_kind_t alg) const {
Expand Down Expand Up @@ -199,6 +297,11 @@ struct binary_kernel_vec_t {
xpu::sycl::in_memory_arg_t src0_scale_;
xpu::sycl::in_memory_arg_t src1_scale_;
data_type_t scales_dt_;
xpu::sycl::in_memory_arg_t po1_src_;
xpu::sycl::in_memory_arg_t po2_src_;
xpu::sycl::in_memory_arg_t po3_src_;
xpu::sycl::in_memory_arg_t po4_src_;
xpu::sycl::in_memory_arg_t po5_src_;
};

} // namespace sycl
Expand Down
22 changes: 21 additions & 1 deletion src/gpu/sycl/ref_binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ status_t ref_binary_t::pd_t::init_conf() {

conf_.post_ops = sycl_post_ops_t(attr());

for (auto i = 0; i < conf_.post_ops.get_post_op(); ++i) {
const auto &e = attr()->post_ops_.entry_[i];
if (e.is_binary() || e.is_prelu()) {
conf_.binary_src_arr[i] = xpu::sycl::md_t(
arg_md(DNNL_ARG_ATTR_MULTIPLE_POST_OP(i) | DNNL_ARG_SRC_1));
}
}
return status::success;
}

Expand All @@ -62,6 +69,7 @@ status_t ref_binary_t::init(engine_t *engine) {
}

status_t ref_binary_t::execute(const exec_ctx_t &ctx) const {

parallel_for(ctx, kernel_, [&](::sycl::handler &cgh) {
auto src0_mem_arg = CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_SRC_0);
auto src1_mem_arg = CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_SRC_1);
Expand All @@ -76,9 +84,21 @@ status_t ref_binary_t::execute(const exec_ctx_t &ctx) const {
.data_type()
: data_type_t::dnnl_f32;

auto src_mem_po_1 = CTX_IN_SYCL_KERNEL_MEMORY(
(DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1));
auto src_mem_po_2 = CTX_IN_SYCL_KERNEL_MEMORY(
(DNNL_ARG_ATTR_MULTIPLE_POST_OP(1) | DNNL_ARG_SRC_1));
auto src_mem_po_3 = CTX_IN_SYCL_KERNEL_MEMORY(
(DNNL_ARG_ATTR_MULTIPLE_POST_OP(2) | DNNL_ARG_SRC_1));
auto src_mem_po_4 = CTX_IN_SYCL_KERNEL_MEMORY(
(DNNL_ARG_ATTR_MULTIPLE_POST_OP(3) | DNNL_ARG_SRC_1));
auto src_mem_po_5 = CTX_IN_SYCL_KERNEL_MEMORY(
(DNNL_ARG_ATTR_MULTIPLE_POST_OP(4) | DNNL_ARG_SRC_1));

binary_kernel_vec_t binary_kernel(pd()->conf_, src0_mem_arg,
src1_mem_arg, dst_mem_arg, src0_scale_mem_arg,
src1_scale_mem_arg, scales_dt);
src1_scale_mem_arg, scales_dt, src_mem_po_1, src_mem_po_2,
src_mem_po_3, src_mem_po_4, src_mem_po_5);

const int block_size = pd()->conf_.block_size;
const int wg_size = pd()->conf_.wg_size;
Expand Down
19 changes: 7 additions & 12 deletions src/gpu/sycl/ref_binary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ struct ref_binary_t : public sycl_gpu_primitive_t {
const memory_desc_wrapper dst_d(dst_md());

const bool ok = set_default_params() == status::success
&& attr_.set_default_formats(dst_md()) == status::success
&& check_data_types(src0_d, src1_d, dst_d)
&& check_formats(src0_d, src1_d, dst_d)
&& attr()->has_default_values(
Expand All @@ -72,18 +73,12 @@ struct ref_binary_t : public sycl_gpu_primitive_t {
}

bool post_ops_ok() const {
for (int i = 0; i < attr()->post_ops_.len(); i++) {
const auto &e = attr()->post_ops_.entry_[i];
if (!IMPLICATION(e.is_eltwise(),
utils::one_of(e.eltwise.alg, alg_kind::eltwise_relu,
alg_kind::eltwise_linear))) {
return false;
}
}
// Binary, prelu and dw conv post-ops are not supported.
// Dw conv post-ops are not supported.
return attr()->post_ops_.len() <= sycl_post_ops_t::max_post_ops
&& attr()->post_ops_.has_default_values(
{primitive_kind::eltwise});
{primitive_kind::eltwise, primitive_kind::binary,
primitive_kind::prelu,
primitive_kind::sum});
}

static bool check_data_types(const memory_desc_wrapper &src0,
Expand All @@ -100,7 +95,7 @@ struct ref_binary_t : public sycl_gpu_primitive_t {
}

return IMPLICATION(utils::one_of(bf16, src0_dt, src1_dt, dst_dt),
src0_dt == src1_dt == dst_dt);
src0_dt == dst_dt && src1_dt == dst_dt);
}

static bool check_formats(const memory_desc_wrapper &src0,
Expand All @@ -109,7 +104,7 @@ struct ref_binary_t : public sycl_gpu_primitive_t {
using namespace format_tag;

for (const auto &mdw : {src0, src1, dst}) {
if (mdw.matches_one_of_tag(ab, abc, abcd, abcde) == undef) {
if (mdw.matches_one_of_tag(a, ab, abc, abcd, abcde) == undef) {
return false;
}
}
Expand Down
2 changes: 2 additions & 0 deletions src/gpu/sycl/sycl_primitive_conf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ struct sycl_binary_conf_t {
int wg_size;
int wk_size;

xpu::sycl::md_t binary_src_arr[8];

sycl_post_ops_t post_ops;
};

Expand Down

0 comments on commit 170bfaf

Please sign in to comment.