diff --git a/oneflow/core/kernel/kernel_util.cpp b/oneflow/core/kernel/kernel_util.cpp index ad82c448ebc..88d04d74548 100644 --- a/oneflow/core/kernel/kernel_util.cpp +++ b/oneflow/core/kernel/kernel_util.cpp @@ -557,6 +557,10 @@ KU_INTEGRAL_METHOD InitializeWithConf(DeviceCtx* ctx, const InitializerConf& ini } } +KU_INTEGRAL_METHOD Mul(DeviceCtx* ctx, const int64_t n, const T* x, const T* y, T* z) { + for (int64_t i = 0; i < n; ++i) { z[i] = x[i] * y[i]; } +} + #define INSTANTIATE_KERNEL_UTIL(type_cpp, type_proto) \ template struct CpuKernelUtilIf>; \ template struct KernelUtil; diff --git a/oneflow/core/kernel/kernel_util.cu b/oneflow/core/kernel/kernel_util.cu index c5046a1c7dd..7e42c1620b8 100644 --- a/oneflow/core/kernel/kernel_util.cu +++ b/oneflow/core/kernel/kernel_util.cu @@ -649,6 +649,11 @@ KU_INTEGRAL_METHOD Axpy(DeviceCtx* ctx, const int n, const T alpha, const T* x, n, alpha, x, incx, y, incy); } +KU_INTEGRAL_METHOD Mul(DeviceCtx* ctx, const int64_t n, const T* x, const T* y, T* z) { + MulGpu + <<cuda_stream()>>>(n, x, y, z); +} + #define INSTANTIATE_KERNEL_UTIL(type_cpp, type_proto) \ template struct GpuKernelUtilIf>; \ template struct KernelUtil; diff --git a/oneflow/core/kernel/kernel_util.h b/oneflow/core/kernel/kernel_util.h index c83ee39b7ab..541311bd568 100644 --- a/oneflow/core/kernel/kernel_util.h +++ b/oneflow/core/kernel/kernel_util.h @@ -203,6 +203,7 @@ struct KernelUtil::va const int incy); static void InitializeWithConf(DeviceCtx* ctx, const InitializerConf& initializer_conf, uint32_t random_seed, Blob* blob); + static void Mul(DeviceCtx* ctx, const int64_t n, const T* x, const T* y, T* z); }; // GPU, Integral, Floating @@ -305,6 +306,7 @@ struct KernelUtil::va public GpuKernelUtilIf> { static void Axpy(DeviceCtx* ctx, const int n, const T alpha, const T* x, const int incx, T* y, const int incy); + static void Mul(DeviceCtx* ctx, const int64_t n, const T* x, const T* y, T* z); }; using CopyBlobFieldMthd = void (Blob::*)(DeviceCtx*, const Blob*); diff --git a/oneflow/core/kernel/multiply_kernel.cpp b/oneflow/core/kernel/multiply_kernel.cpp index 8b1ccb6187f..9426e702e39 100644 --- a/oneflow/core/kernel/multiply_kernel.cpp +++ b/oneflow/core/kernel/multiply_kernel.cpp @@ -33,6 +33,6 @@ const PbMessage& MultiplyKernel::GetCustomizedOpConf() const { return this->op_conf().multiply_conf(); } -ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kMultiplyConf, MultiplyKernel, FLOATING_DATA_TYPE_SEQ); +ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kMultiplyConf, MultiplyKernel, ARITHMETIC_DATA_TYPE_SEQ); } // namespace oneflow diff --git a/oneflow/customized/kernels/multiply_kernel.cpp b/oneflow/customized/kernels/multiply_kernel.cpp index c7092db5cde..e134766ff04 100644 --- a/oneflow/customized/kernels/multiply_kernel.cpp +++ b/oneflow/customized/kernels/multiply_kernel.cpp @@ -39,7 +39,8 @@ class MultiplyKernel final : public user_op::OpKernel { return Maybe::Ok(); \ }); -OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_MULTIPLY_KERNEL, DEVICE_TYPE_SEQ, FLOATING_DATA_TYPE_SEQ) +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_MULTIPLY_KERNEL, DEVICE_TYPE_SEQ, + ARITHMETIC_DATA_TYPE_SEQ) #undef REGISTER_MULTIPLY_KERNEL } // namespace oneflow diff --git a/oneflow/python/test/ops/test_multiply.py b/oneflow/python/test/ops/test_multiply.py index 76023249d4e..501e3a4b790 100644 --- a/oneflow/python/test/ops/test_multiply.py +++ b/oneflow/python/test/ops/test_multiply.py @@ -45,23 +45,26 @@ def _test_element_wise_mul_fw_bw(test_case, device, shape, type_name): @flow.global_function(func_config) def test_element_wise_mul_job( - x=flow.FixedTensorDef(shape, dtype=flow_type), - y=flow.FixedTensorDef(shape, dtype=flow_type), + x=flow.FixedTensorDef(shape, dtype=flow.float), + y=flow.FixedTensorDef(shape, dtype=flow.float), ): with flow.fixed_placement(device, "0:0"): x += flow.get_variable( name="vx", shape=(1,), - dtype=flow_type, + dtype=flow.float, initializer=flow.zeros_initializer(), ) y += flow.get_variable( name="vy", shape=(1,), - dtype=flow_type, + dtype=flow.float, initializer=flow.zeros_initializer(), ) + x = flow.cast(x, dtype=flow_type) + y = flow.cast(y, dtype=flow_type) out = flow.math.multiply(x, y) + out = flow.cast(out, dtype=flow.float) flow.losses.add_loss(out) flow.watch(x, test_global_storage.Setter("x")) @@ -74,9 +77,9 @@ def test_element_wise_mul_job( check_point = flow.train.CheckPoint() check_point.init() - test_element_wise_mul_job( - np.random.rand(*shape).astype(np_type), np.random.rand(*shape).astype(np_type) - ).get() + x = np.random.randint(low=0, high=10, size=shape).astype(np.float32) + y = np.random.randint(low=0, high=10, size=shape).astype(np.float32) + test_element_wise_mul_job(x, y).get() test_case.assertTrue( np.allclose( test_global_storage.Get("x") * test_global_storage.Get("y"), @@ -101,6 +104,6 @@ def test_element_wise_mul_fw_bw(test_case): arg_dict = OrderedDict() arg_dict["device"] = ["gpu", "cpu"] arg_dict["shape"] = [(96, 96)] - arg_dict["type_name"] = ["float32", "double"] + arg_dict["type_name"] = ["float32", "double", "int8", "int32", "int64"] for arg in GenArgDict(arg_dict): _test_element_wise_mul_fw_bw(test_case, **arg)