diff --git a/aten/src/ATen/native/cuda/SoftMax.cu b/aten/src/ATen/native/cuda/SoftMax.cu index 4aca753a510b8..f1f542a12417f 100644 --- a/aten/src/ATen/native/cuda/SoftMax.cu +++ b/aten/src/ATen/native/cuda/SoftMax.cu @@ -89,6 +89,20 @@ struct SoftMaxBackwardEpilogue { const AccumT sum; }; +template +struct SoftMaxForwardWithMulEpilogue { + __device__ __forceinline__ SoftMaxForwardWithMulEpilogue(AccumT max_input, AccumT sum) + : max_input(max_input) + , sum(sum) {} + + __device__ __forceinline__ OutT operator()(T input) const { + return static_cast(__expf(input - max_input) * sum); + } + + const AccumT max_input; + const AccumT sum; +}; + @@ -387,6 +401,19 @@ struct SumExpFloat const AccumT max_k; }; +template +struct SumExpfFloat +{ + __device__ __forceinline__ SumExpfFloat(AccumT v) + : max_k(v) {} + + __device__ __forceinline__ AccumT operator()(AccumT sum, T v) const { + return sum + __expf(v - max_k); + } + + const AccumT max_k; +}; + template class Reduction, typename AccumT> __device__ __forceinline__ AccumT blockReduce(AccumT* smem, AccumT val, @@ -449,6 +476,18 @@ T blockReduceWarp(T* smem_cache, T value, const Reduction& op, T defaultVal) return smem_cache[0]; } +template class Reduction, typename T> +__device__ __forceinline__ +T blockReduceWarpInverse(T* smem_cache, T value, const Reduction& op, T defaultVal) +{ + T result = cuda_utils::BlockReduce>(value, op, defaultVal, smem_cache); + if (threadIdx.x == 0) { + smem_cache[0] = 1 / result; + } + __syncthreads(); + return smem_cache[0]; +} + template class Reduction, int ILP, typename T, typename AccumT, typename index_t=int> __device__ __forceinline__ AccumT ilpReduce(index_t shift, @@ -694,6 +733,67 @@ cunn_SoftMaxForward(outscalar_t *output, const scalar_t *input, int classes) } } +template class EpilogueWithMul, typename index_t = int32_t> +__global__ void +cunn_SoftMaxForwardGmem(outscalar_t *output, const scalar_t *input, index_t classes) +{ + // Each thread block processes a sample in the batch + input += static_cast(blockIdx.x) * classes; + output += static_cast(blockIdx.x) * classes; + + accscalar_t threadMax = -at::numeric_limits::max(); + accscalar_t threadExp = static_cast(0); + + // The first smem segment is used to cache input values and the last + // segment is used for thread block reductions + extern __shared__ unsigned char smem[]; + auto smem_reduction_cache = reinterpret_cast(smem); + + using LoadT = at::native::memory::aligned_vector; + const LoadT* const input_vec_ptr = reinterpret_cast(input); + + // Do the first step in max calculation: + MaxFloat maxFunc; + for (index_t offset = threadIdx.x; offset * ILP < classes; offset += blockDim.x) { + LoadT crnt_vec = input_vec_ptr[offset]; + #pragma unroll + for (int i = 0; i < ILP; ++i) { + threadMax = maxFunc(threadMax, crnt_vec.val[i]); + } + } + + accscalar_t max_k = blockReduceWarp(smem_reduction_cache, threadMax, + Max(), -at::numeric_limits::max()); + + // Do the second step in sum exp calculation: + SumExpfFloat sumExpFunc(max_k); + for (index_t offset = threadIdx.x; offset * ILP < classes; offset += blockDim.x) { + LoadT crnt_vec = input_vec_ptr[offset]; + #pragma unroll + for (int i = 0; i < ILP; ++i) { + threadExp = sumExpFunc(threadExp, crnt_vec.val[i]); + } + } + + accscalar_t sumAll = blockReduceWarpInverse(smem_reduction_cache, threadExp, + Add(), static_cast(0)); + + EpilogueWithMul epilogue(max_k, sumAll); + + using StoreT = at::native::memory::aligned_vector; + StoreT* output_vec_ptr = reinterpret_cast(output); + for (index_t offset = threadIdx.x; offset * ILP < classes; offset += blockDim.x) { + LoadT crnt_vec = input_vec_ptr[offset]; + StoreT out_vec; + #pragma unroll + for (int i = 0; i < ILP; ++i) { + out_vec.val[i] = epilogue(crnt_vec.val[i]); + } + output_vec_ptr[offset] = out_vec; + } +} + template class Epilogue, typename index_t = int32_t> __global__ void @@ -816,7 +916,8 @@ cunn_SoftMaxBackward(scalar_t *gradInput, const outscalar_t *output, const outsc } } -template class Epilogue, bool is_log_softmax> +template class Epilogue, + template class EpilogueWithMul, bool is_log_softmax, bool use_fast_softmax> Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_to_float, const Tensor& output){ if (half_to_float) { TORCH_CHECK(input_.scalar_type() == ScalarType::Half, "conversion is supported for Half type only"); @@ -858,23 +959,30 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t } } else { constexpr int ILP = sizeof(float4) / sizeof(scalar_t); - dim3 block = SoftMaxForward_getBlockSize(dim_size); - size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof(accscalar_t); - auto max_elements_per_smem = (at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock - - smem_reduction_sz) / sizeof(scalar_t); - - bool can_use_smem = dim_size < max_elements_per_smem; - can_use_smem &= !(reinterpret_cast(input_ptr) % ALIGN_BYTES); - can_use_smem &= (!(reinterpret_cast(output_ptr) % ALIGN_BYTES)); - can_use_smem &= !(dim_size % ILP); - - if (can_use_smem) { - size_t smem_sz = dim_size * sizeof(scalar_t) + smem_reduction_sz; - cunn_SoftMaxForwardSmem - <<>>(output_ptr, input_ptr, dim_size); - } else { - cunn_SoftMaxForward + if (use_fast_softmax) { + dim3 block(512); + size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof(accscalar_t); + cunn_SoftMaxForwardGmem <<>>(output_ptr, input_ptr, dim_size); + } else { + dim3 block = SoftMaxForward_getBlockSize(dim_size); + size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof(accscalar_t); + auto max_elements_per_smem = (at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock - + smem_reduction_sz) / sizeof(scalar_t); + + bool can_use_smem = (size_t) dim_size < max_elements_per_smem; + can_use_smem &= !(reinterpret_cast(input_ptr) % ALIGN_BYTES); + can_use_smem &= (!(reinterpret_cast(output_ptr) % ALIGN_BYTES)); + can_use_smem &= !(dim_size % ILP); + + if (can_use_smem) { + size_t smem_sz = dim_size * sizeof(scalar_t) + smem_reduction_sz; + cunn_SoftMaxForwardSmem + <<>>(output_ptr, input_ptr, dim_size); + } else { + cunn_SoftMaxForward + <<>>(output_ptr, input_ptr, dim_size); + } } C10_CUDA_KERNEL_LAUNCH_CHECK(); @@ -894,23 +1002,30 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t } } else { constexpr int ILP = sizeof(float4) / sizeof(scalar_t); - dim3 block = SoftMaxForward_getBlockSize(dim_size); - size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof(accscalar_t); - auto max_elements_per_smem = (at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock - - smem_reduction_sz) / sizeof(scalar_t); - - bool can_use_smem = dim_size < max_elements_per_smem; - can_use_smem &= !(reinterpret_cast(input_ptr) % ALIGN_BYTES); - can_use_smem &= (!(reinterpret_cast(output_ptr) % ALIGN_BYTES)); - can_use_smem &= !(dim_size % ILP); - - if (can_use_smem) { - size_t smem_sz = dim_size * sizeof(scalar_t) + smem_reduction_sz; - cunn_SoftMaxForwardSmem - <<>>(output_ptr, input_ptr, dim_size); - } else { - cunn_SoftMaxForward + if (use_fast_softmax) { + dim3 block(512); + size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof(accscalar_t); + cunn_SoftMaxForwardGmem <<>>(output_ptr, input_ptr, dim_size); + } else { + dim3 block = SoftMaxForward_getBlockSize(dim_size); + size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof(accscalar_t); + auto max_elements_per_smem = (at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock - + smem_reduction_sz) / sizeof(scalar_t); + + bool can_use_smem = (size_t) dim_size < max_elements_per_smem; + can_use_smem &= !(reinterpret_cast(input_ptr) % ALIGN_BYTES); + can_use_smem &= (!(reinterpret_cast(output_ptr) % ALIGN_BYTES)); + can_use_smem &= !(dim_size % ILP); + + if (can_use_smem) { + size_t smem_sz = dim_size * sizeof(scalar_t) + smem_reduction_sz; + cunn_SoftMaxForwardSmem + <<>>(output_ptr, input_ptr, dim_size); + } else { + cunn_SoftMaxForward + <<>>(output_ptr, input_ptr, dim_size); + } } C10_CUDA_KERNEL_LAUNCH_CHECK(); @@ -1069,7 +1184,7 @@ TORCH_IMPL_FUNC(log_softmax_cuda_out) ( const int64_t dim, const bool half_to_float, const Tensor &output) { - host_softmax(input, dim, half_to_float, output); + host_softmax(input, dim, half_to_float, output); } TORCH_IMPL_FUNC(log_softmax_backward_cuda_out) ( @@ -1093,7 +1208,11 @@ TORCH_IMPL_FUNC(softmax_cuda_out) ( const int64_t dim, const bool half_to_float, const Tensor &output) { - host_softmax(input, dim, half_to_float, output); +#if defined(USE_ROCM) && defined(PYTORCH_USE_FAST_SOFTMAX) + host_softmax(input, dim, half_to_float, output); +#else + host_softmax(input, dim, half_to_float, output); +#endif } TORCH_IMPL_FUNC(softmax_backward_cuda_out)