From b4e377bcac86d8ffc36c94514f5a4535b54f4141 Mon Sep 17 00:00:00 2001 From: Juncheng Date: Thu, 20 Aug 2020 12:57:24 +0800 Subject: [PATCH] Fused softmax kernel (#3496) Former-commit-id: 2ec8fc6989f78289208b679baeb03c0ef0c361e9 --- oneflow/python/test/ops/test_softmax.py | 32 ++- .../kernels/softmax_cross_entropy_kernel.h | 25 +- oneflow/user/kernels/softmax_kernel.cpp | 77 +++-- oneflow/user/kernels/softmax_kernel.cu | 272 ++++++++++++++++++ oneflow/user/kernels/softmax_kernel_util.cpp | 62 +++- oneflow/user/kernels/softmax_kernel_util.h | 11 +- .../sparse_softmax_cross_entropy_kernel.cpp | 29 +- 7 files changed, 423 insertions(+), 85 deletions(-) create mode 100644 oneflow/user/kernels/softmax_kernel.cu diff --git a/oneflow/python/test/ops/test_softmax.py b/oneflow/python/test/ops/test_softmax.py index 948dcecb15e..66b75905b3b 100644 --- a/oneflow/python/test/ops/test_softmax.py +++ b/oneflow/python/test/ops/test_softmax.py @@ -33,7 +33,6 @@ def compare_with_tensorflow(device_type, x_shape, data_type, axis): func_config = flow.FunctionConfig() if data_type == "float16": - func_config.enable_auto_mixed_precision(True) dtype = flow.float else: dtype = type_name_to_flow_type[data_type] @@ -45,10 +44,16 @@ def SoftmaxJob(): "x", shape=x_shape, dtype=dtype, - initializer=flow.random_uniform_initializer(minval=-10, maxval=10), + initializer=flow.random_uniform_initializer(minval=-0.1, maxval=0.1), trainable=True, ) - loss = flow.nn.softmax(x, axis=axis) + if data_type == "float16": + loss = flow.cast( + flow.nn.softmax(flow.cast(x, dtype=flow.float16), axis=axis), + dtype=flow.float, + ) + else: + loss = flow.nn.softmax(x, axis=axis) flow.optimizer.SGD( flow.optimizer.PiecewiseConstantScheduler([], [1e-4]), momentum=0 ).minimize(loss) @@ -71,16 +76,31 @@ def SoftmaxJob(): loss_diff = test_global_storage.Get("loss_diff") tf_x_diff = tape.gradient(tf_out, x, loss_diff) - assert np.allclose(of_out.numpy(), tf_out.numpy(), rtol=1e-5, atol=1e-5) + if data_type == "float16": + tolerance = 1e-3 + else: + tolerance = 1e-5 + assert np.allclose(of_out.numpy(), tf_out.numpy(), rtol=tolerance, atol=tolerance) assert np.allclose( - test_global_storage.Get("x_diff"), tf_x_diff.numpy(), rtol=1e-5, atol=1e-5 + test_global_storage.Get("x_diff"), + tf_x_diff.numpy(), + rtol=tolerance, + atol=tolerance, ) def test_softmax(test_case): arg_dict = OrderedDict() arg_dict["device_type"] = ["gpu", "cpu"] - arg_dict["x_shape"] = [(10, 10, 20, 30), (10, 20, 30), (10, 20)] + arg_dict["x_shape"] = [ + (10, 10, 20, 30), + (10, 20, 30), + (10, 20), + (10, 960), + (10, 4096), + (10, 8092), + (256, 1001), + ] arg_dict["data_type"] = ["float32", "double", "float16"] arg_dict["axis"] = [-1, 1, 2, 3] for arg in GenArgList(arg_dict): diff --git a/oneflow/user/kernels/softmax_cross_entropy_kernel.h b/oneflow/user/kernels/softmax_cross_entropy_kernel.h index e599a15d582..81a39c70938 100644 --- a/oneflow/user/kernels/softmax_cross_entropy_kernel.h +++ b/oneflow/user/kernels/softmax_cross_entropy_kernel.h @@ -45,8 +45,8 @@ class SoftmaxCrossEntropyKernel final : public user_op::OpKernel { const int64_t num_instances = label->shape().Count(0, num_axes - 1); const int64_t num_classes = label->shape().At(num_axes - 1); SoftmaxKernelUtil::ComputeProb( - ctx->device_ctx(), num_instances, num_classes, prediction->dptr(), out->mut_dptr(), - prob->mut_dptr(), tmp_buffer->mut_dptr(), tmp_buffer->shape().elem_cnt() * sizeof(T)); + ctx->device_ctx(), num_instances, num_classes, prediction->dptr(), prob->mut_dptr(), + tmp_buffer->mut_dptr(), tmp_buffer->shape().elem_cnt()); CrossEntropyKernelUtil::ComputeEntropy(ctx->device_ctx(), num_instances, num_classes, prob->dptr(), label->dptr(), out->mut_dptr()); @@ -54,15 +54,18 @@ class SoftmaxCrossEntropyKernel final : public user_op::OpKernel { bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; -#define REGISTER_SOFTMAX_CROSS_ENTROPY_KERNEL(device_type_v, dtype_pair) \ - REGISTER_USER_KERNEL("softmax_cross_entropy") \ - .SetCreateFn>() \ - .SetIsMatchedHob((user_op::HobDeviceTag() == device_type_v) \ - & (user_op::HobDataType("label", 0) == OF_PP_PAIR_SECOND(dtype_pair)) \ - & (user_op::HobDataType("out", 0) == OF_PP_PAIR_SECOND(dtype_pair))) \ - .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ - const Shape* prediction_shape = ctx->Shape4ArgNameAndIndex("prediction", 0); \ - return prediction_shape->elem_cnt() * sizeof(OF_PP_PAIR_FIRST(dtype_pair)); \ +#define REGISTER_SOFTMAX_CROSS_ENTROPY_KERNEL(device_type_v, dtype_pair) \ + REGISTER_USER_KERNEL("softmax_cross_entropy") \ + .SetCreateFn>() \ + .SetIsMatchedHob((user_op::HobDeviceTag() == device_type_v) \ + & (user_op::HobDataType("label", 0) == OF_PP_PAIR_SECOND(dtype_pair)) \ + & (user_op::HobDataType("out", 0) == OF_PP_PAIR_SECOND(dtype_pair))) \ + .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ + const Shape* prediction_shape = ctx->Shape4ArgNameAndIndex("prediction", 0); \ + const int64_t num_classes = prediction_shape->At(prediction_shape->NumAxes() - 1); \ + const int64_t num_instances = prediction_shape->Count(0, prediction_shape->NumAxes() - 1); \ + return SoftmaxKernelUtil:: \ + GetComputeProbTempStorageSizeInBytes(num_instances, num_classes); \ }); template diff --git a/oneflow/user/kernels/softmax_kernel.cpp b/oneflow/user/kernels/softmax_kernel.cpp index b7f69bd9012..f26346d88e0 100644 --- a/oneflow/user/kernels/softmax_kernel.cpp +++ b/oneflow/user/kernels/softmax_kernel.cpp @@ -25,36 +25,31 @@ template class SoftmaxKernel final : public user_op::OpKernel { public: SoftmaxKernel() = default; - ~SoftmaxKernel() = default; + ~SoftmaxKernel() override = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { - const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("in", 0); - user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("out", 0); - const int64_t num_classes = x->shape().At(x->shape().NumAxes() - 1); - const int64_t num_instances = x->shape().elem_cnt() / num_classes; - + const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); + user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); + const int64_t num_classes = in->shape().At(in->shape().NumAxes() - 1); + const int64_t num_instances = in->shape().Count(0, in->shape().NumAxes() - 1); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); - const size_t temp_storage_bytes = x->shape().elem_cnt() * sizeof(T); - const size_t tmp_bytes = GetCudaAlignedSize(temp_storage_bytes / num_classes); - - T* tmp_ptr = tmp_buffer->mut_dptr(); - void* temp_storage_ptr = reinterpret_cast(tmp_ptr + tmp_bytes / sizeof(T)); + const size_t temp_storage_bytes = tmp_buffer->shape().elem_cnt(); SoftmaxKernelUtil::ComputeProb(ctx->device_ctx(), num_instances, num_classes, - x->dptr(), tmp_ptr, y->mut_dptr(), - temp_storage_ptr, temp_storage_bytes); + in->dptr(), out->mut_dptr(), + tmp_buffer->mut_dptr(), temp_storage_bytes); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; -template -user_op::InferTmpSizeFn GenInferTmpSizeFn(const std::string& bn) { - return [bn](user_op::InferContext* ctx) { - const Shape* x = ctx->Shape4ArgNameAndIndex(bn, 0); - const size_t num_classes = x->dim_vec().back(); - size_t temp_storage_bytes = GetCudaAlignedSize(x->elem_cnt() * sizeof(T)); // [i][j] - size_t tmp_or_sum_vec_bytes = GetCudaAlignedSize(temp_storage_bytes / num_classes); //[i] - return tmp_or_sum_vec_bytes + temp_storage_bytes; +template +user_op::InferTmpSizeFn GenFwInferTmpSizeFn() { + return [](user_op::InferContext* ctx) { + const Shape* in_shape = ctx->Shape4ArgNameAndIndex("in", 0); + const int64_t num_classes = in_shape->At(in_shape->NumAxes() - 1); + const int64_t num_instances = in_shape->Count(0, in_shape->NumAxes() - 1); + return SoftmaxKernelUtil::GetComputeProbTempStorageSizeInBytes(num_instances, + num_classes); }; } @@ -63,21 +58,16 @@ user_op::InferTmpSizeFn GenInferTmpSizeFn(const std::string& bn) { .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceTag() == device) \ & (user_op::HobDataType("out", 0) == GetDataType::value)) \ - .SetInferTmpSizeFn(GenInferTmpSizeFn("in")); + .SetInferTmpSizeFn(GenFwInferTmpSizeFn()); REGISTER_SOFTMAX_KERNEL(DeviceType::kCPU, float) REGISTER_SOFTMAX_KERNEL(DeviceType::kCPU, double) -#ifdef WITH_CUDA -REGISTER_SOFTMAX_KERNEL(DeviceType::kGPU, float16) -REGISTER_SOFTMAX_KERNEL(DeviceType::kGPU, float) -REGISTER_SOFTMAX_KERNEL(DeviceType::kGPU, double) -#endif template class SoftmaxGradKernel final : public user_op::OpKernel { public: SoftmaxGradKernel() = default; - ~SoftmaxGradKernel() = default; + ~SoftmaxGradKernel() override = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { @@ -89,32 +79,35 @@ class SoftmaxGradKernel final : public user_op::OpKernel { const int64_t num_instances = y->shape().elem_cnt() / num_classes; user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); - const size_t temp_storage_bytes = y->shape().elem_cnt() * sizeof(T); - const size_t sum_vec_bytes = GetCudaAlignedSize(temp_storage_bytes / num_classes); - - T* sum_vec_ptr = tmp_buffer->mut_dptr(); - void* temp_storage_ptr = reinterpret_cast(sum_vec_ptr + sum_vec_bytes / sizeof(T)); - SoftmaxKernelUtil::ComputeDiff( - ctx->device_ctx(), num_instances, num_classes, dy->dptr(), y->dptr(), sum_vec_ptr, - dx->mut_dptr(), temp_storage_ptr, temp_storage_bytes); + const size_t temp_storage_bytes = tmp_buffer->shape().elem_cnt(); + + SoftmaxKernelUtil::ComputeDiff(ctx->device_ctx(), num_instances, num_classes, + dy->dptr(), y->dptr(), dx->mut_dptr(), + tmp_buffer->mut_dptr(), temp_storage_bytes); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; +template +user_op::InferTmpSizeFn GenBwInferTmpSizeFn() { + return [](user_op::InferContext* ctx) { + const Shape* dy_shape = ctx->Shape4ArgNameAndIndex("dy", 0); + const int64_t num_classes = dy_shape->At(dy_shape->NumAxes() - 1); + const int64_t num_instances = dy_shape->Count(0, dy_shape->NumAxes() - 1); + return SoftmaxKernelUtil::GetComputeDiffTempStorageSizeInBytes(num_instances, + num_classes); + }; +} + #define REGISTER_SOFTMAX_GRAD_KERNEL(device, dtype) \ REGISTER_USER_KERNEL("softmax_grad") \ .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceTag() == device) \ & (user_op::HobDataType("dx", 0) == GetDataType::value)) \ - .SetInferTmpSizeFn(GenInferTmpSizeFn("dx")); + .SetInferTmpSizeFn(GenBwInferTmpSizeFn()); REGISTER_SOFTMAX_GRAD_KERNEL(DeviceType::kCPU, float) REGISTER_SOFTMAX_GRAD_KERNEL(DeviceType::kCPU, double) -#ifdef WITH_CUDA -REGISTER_SOFTMAX_GRAD_KERNEL(DeviceType::kGPU, float16) -REGISTER_SOFTMAX_GRAD_KERNEL(DeviceType::kGPU, float) -REGISTER_SOFTMAX_GRAD_KERNEL(DeviceType::kGPU, double) -#endif } // namespace diff --git a/oneflow/user/kernels/softmax_kernel.cu b/oneflow/user/kernels/softmax_kernel.cu new file mode 100644 index 00000000000..7c521ecc6b7 --- /dev/null +++ b/oneflow/user/kernels/softmax_kernel.cu @@ -0,0 +1,272 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/framework.h" +#include "oneflow/core/kernel/new_kernel_util.h" +#include "oneflow/user/kernels/softmax_kernel_util.h" +#include + +namespace oneflow { + +namespace { + +constexpr int64_t kSoftmaxGpuBlockSize = 256; + +template +struct SoftmaxUtil { + using ComputeType = T; + __device__ static ComputeType ToComputeType(T v) { return v; } + __device__ static T FromComputeType(ComputeType v) { return v; } +}; + +template<> +struct SoftmaxUtil { + using ComputeType = float; + __device__ static ComputeType ToComputeType(half v) { return __half2float(v); } + __device__ static half FromComputeType(ComputeType v) { return __float2half(v); } +}; + +__device__ double Exp(double x) { return exp(x); } + +__device__ float Exp(float x) { return expf(x); } + +template +int GetForwardDynamicSharedMemorySize(const int num_classes) { + return num_classes * sizeof(typename SoftmaxUtil::ComputeType); +} + +template +int GetBackwardDynamicSharedMemorySize(const int num_classes) { + return 2 * num_classes * sizeof(typename SoftmaxUtil::ComputeType); +} + +int GetSoftmaxBlockSize() { return kSoftmaxGpuBlockSize; } + +int GetSoftmaxNumBlocks(const int num_instances) { + return std::min(static_cast(num_instances), kCudaMaxBlocksNum); +} + +template +int GetMinNumClasses() { + return 32; +} + +template +int GetMaxNumClasses() { + return 16 * 1024 / sizeof(T); +} + +template +__global__ void SoftmaxGpuForwardImpl(const int num_instances, const int num_classes, const T* in, + T* prob) { + using SU = SoftmaxUtil; + using ComputeType = typename SU::ComputeType; + extern __shared__ __align__(sizeof(ComputeType)) unsigned char fw_shared_buf[]; + auto* compute_buf = reinterpret_cast(fw_shared_buf); + __shared__ ComputeType row_reduce_result; + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage cub_reduce_tmp_storage; + const int tid = threadIdx.x; + for (int row = blockIdx.x; row < num_instances; row += gridDim.x) { + const int row_offset = row * num_classes; + const T* in_row = in + row_offset; + T* prob_row = prob + row_offset; + ComputeType thread_max = GetMinVal(); + for (int col = tid; col < num_classes; col += kSoftmaxGpuBlockSize) { + const ComputeType x = SU::ToComputeType(in_row[col]); + compute_buf[col] = x; + thread_max = max(thread_max, x); + } + __syncthreads(); + ComputeType block_max = BlockReduce(cub_reduce_tmp_storage).Reduce(thread_max, cub::Max()); + if (tid == 0) { row_reduce_result = block_max; } + __syncthreads(); + const ComputeType row_max_t = row_reduce_result; + ComputeType thread_sum = 0; + for (int col = tid; col < num_classes; col += kSoftmaxGpuBlockSize) { + const ComputeType exp_x = Exp(compute_buf[col] - row_max_t); + compute_buf[col] = exp_x; + thread_sum += exp_x; + } + __syncthreads(); + ComputeType block_sum = BlockReduce(cub_reduce_tmp_storage).Reduce(thread_sum, cub::Sum()); + if (tid == 0) { row_reduce_result = block_sum; } + __syncthreads(); + const ComputeType row_sum_t = row_reduce_result; + for (int col = tid; col < num_classes; col += kSoftmaxGpuBlockSize) { + prob_row[col] = SU::FromComputeType(compute_buf[col] / row_sum_t); + } + } +} + +template +void SoftmaxForwardGpu(DeviceCtx* ctx, const int num_instances, const int num_classes, const T* in, + T* prob) { + SoftmaxGpuForwardImpl<<(num_classes), ctx->cuda_stream()>>>( + num_instances, num_classes, in, prob); +} + +template<> +void SoftmaxForwardGpu(DeviceCtx* ctx, const int num_instances, const int num_classes, + const float16* in, float16* prob) { + SoftmaxForwardGpu(ctx, num_instances, num_classes, reinterpret_cast(in), + reinterpret_cast(prob)); +} + +template +__global__ void SoftmaxGpuBackwardImpl(const int num_instances, const int num_classes, const T* dy, + const T* prob, T* dx) { + using SU = SoftmaxUtil; + using ComputeType = typename SU::ComputeType; + extern __shared__ __align__(sizeof(ComputeType)) unsigned char bw_shared_buf[]; + auto* dy_buf = reinterpret_cast(bw_shared_buf); + auto* prob_buf = + reinterpret_cast(bw_shared_buf + num_classes * sizeof(ComputeType)); + __shared__ ComputeType row_reduce_result; + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage cub_reduce_tmp_storage; + const int tid = threadIdx.x; + for (int row = blockIdx.x; row < num_instances; row += gridDim.x) { + const int row_offset = row * num_classes; + const T* dy_row = dy + row_offset; + const T* prob_row = prob + row_offset; + T* dx_row = dx + row_offset; + ComputeType thread_sum = 0; + for (int col = tid; col < num_classes; col += kSoftmaxGpuBlockSize) { + const ComputeType dy_col = SU::ToComputeType(dy_row[col]); + dy_buf[col] = dy_col; + const ComputeType prob_col = SU::ToComputeType(prob_row[col]); + prob_buf[col] = prob_col; + thread_sum += (dy_col * prob_col); + } + __syncthreads(); + ComputeType block_sum = BlockReduce(cub_reduce_tmp_storage).Reduce(thread_sum, cub::Sum()); + if (tid == 0) { row_reduce_result = block_sum; } + __syncthreads(); + const ComputeType row_sum_t = row_reduce_result; + for (int col = tid; col < num_classes; col += kSoftmaxGpuBlockSize) { + dx_row[col] = SU::FromComputeType((dy_buf[col] - row_sum_t) * prob_buf[col]); + } + } +} + +template +void SoftmaxBackwardGpu(DeviceCtx* ctx, const int num_instances, const int num_classes, const T* in, + const T* prob, T* dx) { + SoftmaxGpuBackwardImpl<<(num_classes), + ctx->cuda_stream()>>>(num_instances, num_classes, in, prob, dx); +} + +template<> +void SoftmaxBackwardGpu(DeviceCtx* ctx, const int num_instances, const int num_classes, + const float16* in, const float16* prob, float16* dx) { + SoftmaxBackwardGpu(ctx, num_instances, num_classes, reinterpret_cast(in), + reinterpret_cast(prob), reinterpret_cast(dx)); +} + +template +class SoftmaxKernel final : public user_op::OpKernel { + public: + SoftmaxKernel() = default; + ~SoftmaxKernel() override = default; + + private: + void Compute(user_op::KernelComputeContext* ctx) const override { + const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); + user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); + const ShapeView& in_shape = in->shape(); + const int64_t num_classes = in_shape.At(in_shape.NumAxes() - 1); + const int64_t num_instances = in_shape.Count(0, in_shape.NumAxes() - 1); + if (num_classes >= GetMinNumClasses() && num_classes <= GetMaxNumClasses()) { + SoftmaxForwardGpu(ctx->device_ctx(), num_instances, num_classes, in->dptr(), + out->mut_dptr()); + } else { + user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); + SoftmaxKernelUtil::ComputeProb( + ctx->device_ctx(), num_instances, num_classes, in->dptr(), out->mut_dptr(), + tmp_buffer->mut_dptr(), tmp_buffer->shape().elem_cnt()); + } + } + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } +}; + +#define REGISTER_SOFTMAX_GPU_KERNEL(dtype) \ + REGISTER_USER_KERNEL("softmax") \ + .SetCreateFn>() \ + .SetIsMatchedHob((user_op::HobDeviceTag() == DeviceType::kGPU) \ + & (user_op::HobDataType("out", 0) == GetDataType::value)) \ + .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ + const Shape* in_shape = ctx->Shape4ArgNameAndIndex("in", 0); \ + const int64_t num_classes = in_shape->At(in_shape->NumAxes() - 1); \ + const int64_t num_instances = in_shape->Count(0, in_shape->NumAxes() - 1); \ + return SoftmaxKernelUtil::GetComputeProbTempStorageSizeInBytes( \ + num_instances, num_classes); \ + }); + +REGISTER_SOFTMAX_GPU_KERNEL(float16) +REGISTER_SOFTMAX_GPU_KERNEL(float) +REGISTER_SOFTMAX_GPU_KERNEL(double) +#undef REGISTER_SOFTMAX_GPU_KERNEL + +template +class SoftmaxGradKernel final : public user_op::OpKernel { + public: + SoftmaxGradKernel() = default; + ~SoftmaxGradKernel() override = default; + + private: + void Compute(user_op::KernelComputeContext* ctx) const override { + const user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); + const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); + user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); + + const int64_t num_classes = y->shape().At(y->shape().NumAxes() - 1); + const int64_t num_instances = y->shape().elem_cnt() / num_classes; + if (num_classes >= GetMinNumClasses() && num_classes <= GetMaxNumClasses()) { + SoftmaxBackwardGpu(ctx->device_ctx(), num_instances, num_classes, dy->dptr(), + y->dptr(), dx->mut_dptr()); + } else { + user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); + SoftmaxKernelUtil::ComputeDiff( + ctx->device_ctx(), num_instances, num_classes, dy->dptr(), y->dptr(), + dx->mut_dptr(), tmp_buffer->mut_dptr(), tmp_buffer->shape().elem_cnt()); + } + } + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } +}; + +#define REGISTER_SOFTMAX_GRAD_KERNEL(dtype) \ + REGISTER_USER_KERNEL("softmax_grad") \ + .SetCreateFn>() \ + .SetIsMatchedHob((user_op::HobDeviceTag() == DeviceType::kGPU) \ + & (user_op::HobDataType("dx", 0) == GetDataType::value)) \ + .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ + const Shape* dy_shape = ctx->Shape4ArgNameAndIndex("dy", 0); \ + const int64_t num_classes = dy_shape->At(dy_shape->NumAxes() - 1); \ + const int64_t num_instances = dy_shape->Count(0, dy_shape->NumAxes() - 1); \ + return SoftmaxKernelUtil::GetComputeProbTempStorageSizeInBytes( \ + num_instances, num_classes); \ + }); + +REGISTER_SOFTMAX_GRAD_KERNEL(float16) +REGISTER_SOFTMAX_GRAD_KERNEL(float) +REGISTER_SOFTMAX_GRAD_KERNEL(double) +#undef REGISTER_SOFTMAX_GRAD_KERNEL + +} // namespace + +} // namespace oneflow diff --git a/oneflow/user/kernels/softmax_kernel_util.cpp b/oneflow/user/kernels/softmax_kernel_util.cpp index 4cd30884a87..d43c0880dec 100644 --- a/oneflow/user/kernels/softmax_kernel_util.cpp +++ b/oneflow/user/kernels/softmax_kernel_util.cpp @@ -19,17 +19,56 @@ limitations under the License. namespace oneflow { +namespace { + +template +size_t GetProbTmpSize(int64_t n, int64_t w) { + return GetCudaAlignedSize(n * sizeof(T)); +} + +template +size_t GetDiffTmpSize(int64_t n, int64_t w) { + return GetCudaAlignedSize(n * sizeof(T)); +} + +template +size_t GetReduceTempStorageSize(int64_t n, int64_t w) { + return GetCudaAlignedSize(n * w * sizeof(T)); +} + +} // namespace + +template +size_t SoftmaxKernelUtil::GetComputeProbTempStorageSizeInBytes(int64_t n, + int64_t w) { + return GetProbTmpSize(n, w) + GetReduceTempStorageSize(n, w); +} + +template +size_t SoftmaxKernelUtil::GetComputeDiffTempStorageSizeInBytes(int64_t n, + int64_t w) { + return GetDiffTmpSize(n, w) + GetReduceTempStorageSize(n, w); +} + template void SoftmaxKernelUtil::ComputeProb(DeviceCtx* ctx, const int64_t n, - const int64_t w, const T* in, T* tmp, T* prob, + const int64_t w, const T* in, T* prob, void* temp_storage, const size_t temp_storage_bytes) { auto Val = NdarrayUtil::GetValNdarrayBuilder(); auto Var = NdarrayUtil::GetVarNdarrayBuilder(); + const size_t min_temp_storage_bytes = + SoftmaxKernelUtil::GetComputeProbTempStorageSizeInBytes(n, w); + CHECK_GE(temp_storage_bytes, min_temp_storage_bytes); + const size_t reduce_temp_storage_bytes = GetReduceTempStorageSize(n, w); + T* reduce_storage = reinterpret_cast(temp_storage); + auto reduce_storage_var = + Var({static_cast(reduce_temp_storage_bytes / sizeof(T))}, reduce_storage); + T* tmp = reinterpret_cast(reinterpret_cast(temp_storage) + + reduce_temp_storage_bytes); // max | tmp[i] = Max_j(in[i][j]) NdarrayUtil::ReduceMax(ctx, Var({n, 1}, tmp), Val({n, w}, in), - Var({static_cast(temp_storage_bytes / sizeof(T))}, - reinterpret_cast(temp_storage))); + reduce_storage_var); // sub | prob[i][j] = in[i][j] - tmp[i] NdarrayUtil::BroadcastSub(ctx, Var({n, w}, prob), Val({n, w}, in), Val({n, 1}, tmp)); @@ -37,8 +76,7 @@ void SoftmaxKernelUtil::ComputeProb(DeviceCtx* ctx, const int64_ NdarrayUtil::InplaceExp(ctx, Var({n, w}, prob)); // sum | tmp[i] = Sum_j(prob[i][j]) NdarrayUtil::ReduceSum(ctx, Var({n, 1}, tmp), Val({n, w}, prob), - Var({static_cast(temp_storage_bytes / sizeof(T))}, - reinterpret_cast(temp_storage))); + reduce_storage_var); // div | prob[i][j] /= tmp[i] NdarrayUtil::InplaceBroadcastDiv(ctx, Var({n, w}, prob), Val({n, 1}, tmp)); } @@ -46,17 +84,25 @@ void SoftmaxKernelUtil::ComputeProb(DeviceCtx* ctx, const int64_ template void SoftmaxKernelUtil::ComputeDiff(DeviceCtx* ctx, const int64_t n, const int64_t w, const T* dy, const T* out, - T* sum_vec, T* dx, void* temp_storage, + T* dx, void* temp_storage, const size_t temp_storage_bytes) { auto Val = NdarrayUtil::GetValNdarrayBuilder(); auto Var = NdarrayUtil::GetVarNdarrayBuilder(); + const size_t min_temp_storage_bytes = + SoftmaxKernelUtil::GetComputeProbTempStorageSizeInBytes(n, w); + CHECK_GE(temp_storage_bytes, min_temp_storage_bytes); + const size_t reduce_temp_storage_bytes = GetReduceTempStorageSize(n, w); + T* reduce_storage = reinterpret_cast(temp_storage); + auto reduce_storage_var = + Var({static_cast(reduce_temp_storage_bytes / sizeof(T))}, reduce_storage); + T* sum_vec = reinterpret_cast(reinterpret_cast(temp_storage) + + reduce_temp_storage_bytes); // it's safe to use dx as tmp // dot product | get dot product sum_vec[i] from out[i] * dy[i] T* tmp = dx; NdarrayUtil::Mul(ctx, Var({n * w}, tmp), Val({n * w}, out), Val({n * w}, dy)); NdarrayUtil::ReduceSum(ctx, Var({n, 1}, sum_vec), Val({n, w}, tmp), - Var({static_cast(temp_storage_bytes / sizeof(T))}, - reinterpret_cast(temp_storage))); + reduce_storage_var); // sub | dx[i][j] = dy[i][j] - sum_vec[i] NdarrayUtil::BroadcastSub(ctx, Var({n, w}, dx), Val({n, w}, dy), Val({n, 1}, sum_vec)); diff --git a/oneflow/user/kernels/softmax_kernel_util.h b/oneflow/user/kernels/softmax_kernel_util.h index 6b85bcd52b3..a4e06813f6c 100644 --- a/oneflow/user/kernels/softmax_kernel_util.h +++ b/oneflow/user/kernels/softmax_kernel_util.h @@ -22,11 +22,12 @@ namespace oneflow { template struct SoftmaxKernelUtil { - static void ComputeProb(DeviceCtx* ctx, const int64_t n, const int64_t w, const T* in, T* tmp, - T* prob, void* temp_storage, const size_t temp_storage_bytes); - static void ComputeDiff(DeviceCtx* ctx, const int64_t n, const int64_t w, const T* dy, - const T* out, T* sum_vec, T* dx, void* temp_storage, - const size_t temp_storage_bytes); + static size_t GetComputeProbTempStorageSizeInBytes(int64_t n, int64_t w); + static size_t GetComputeDiffTempStorageSizeInBytes(int64_t n, int64_t w); + static void ComputeProb(DeviceCtx* ctx, int64_t n, int64_t w, const T* in, T* prob, + void* temp_storage, size_t temp_storage_bytes); + static void ComputeDiff(DeviceCtx* ctx, int64_t n, int64_t w, const T* dy, const T* out, T* dx, + void* temp_storage, size_t temp_storage_bytes); }; } // namespace oneflow diff --git a/oneflow/user/kernels/sparse_softmax_cross_entropy_kernel.cpp b/oneflow/user/kernels/sparse_softmax_cross_entropy_kernel.cpp index 5c1164966a9..71b29a2665c 100644 --- a/oneflow/user/kernels/sparse_softmax_cross_entropy_kernel.cpp +++ b/oneflow/user/kernels/sparse_softmax_cross_entropy_kernel.cpp @@ -40,8 +40,8 @@ class SparseSoftmaxCrossEntropyKernel final : public user_op::OpKernel { const int64_t lower_bound = 0; const int64_t depth = ctx->Attr("depth"); SoftmaxKernelUtil::ComputeProb( - ctx->device_ctx(), num_instances, num_classes, prediction->dptr(), out->mut_dptr(), - prob->mut_dptr(), tmp_buffer->mut_dptr(), tmp_buffer->shape().elem_cnt() * sizeof(T)); + ctx->device_ctx(), num_instances, num_classes, prediction->dptr(), prob->mut_dptr(), + tmp_buffer->mut_dptr(), tmp_buffer->shape().elem_cnt()); SparseCrossEntropyKernelUtil::ComputeEntropy( ctx->device_ctx(), num_instances, num_classes, depth, lower_bound, prob->dptr(), label->dptr(), out->mut_dptr()); @@ -62,17 +62,20 @@ class SparseSoftmaxCrossEntropyMsKernel final : public user_op::OpKernel { bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; -#define REGISTER_SPARSE_SOFTMAX_CROSS_ENTROPY_KERNEL(kernel_class, kernel_name, device_type_v, \ - dtype_pair, ltype_pair) \ - REGISTER_USER_KERNEL(kernel_name) \ - .SetCreateFn>() \ - .SetIsMatchedHob((user_op::HobDeviceTag() == device_type_v) \ - & (user_op::HobDataType("label", 0) == OF_PP_PAIR_SECOND(ltype_pair)) \ - & (user_op::HobDataType("out", 0) == OF_PP_PAIR_SECOND(dtype_pair))) \ - .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ - const Shape* prediction_shape = ctx->Shape4ArgNameAndIndex("prediction", 0); \ - return prediction_shape->elem_cnt() * sizeof(OF_PP_PAIR_FIRST(dtype_pair)); \ +#define REGISTER_SPARSE_SOFTMAX_CROSS_ENTROPY_KERNEL(kernel_class, kernel_name, device_type_v, \ + dtype_pair, ltype_pair) \ + REGISTER_USER_KERNEL(kernel_name) \ + .SetCreateFn>() \ + .SetIsMatchedHob((user_op::HobDeviceTag() == device_type_v) \ + & (user_op::HobDataType("label", 0) == OF_PP_PAIR_SECOND(ltype_pair)) \ + & (user_op::HobDataType("out", 0) == OF_PP_PAIR_SECOND(dtype_pair))) \ + .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ + const Shape* prediction_shape = ctx->Shape4ArgNameAndIndex("prediction", 0); \ + const int64_t num_classes = prediction_shape->At(prediction_shape->NumAxes() - 1); \ + const int64_t num_instances = prediction_shape->Count(0, prediction_shape->NumAxes() - 1); \ + return SoftmaxKernelUtil:: \ + GetComputeProbTempStorageSizeInBytes(num_instances, num_classes); \ }); OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_SPARSE_SOFTMAX_CROSS_ENTROPY_KERNEL,