Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ROCm] Improve softmax performance. #1740

Open
wants to merge 1 commit into
base: release/2.4
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 154 additions & 35 deletions aten/src/ATen/native/cuda/SoftMax.cu
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,20 @@ struct SoftMaxBackwardEpilogue {
const AccumT sum;
};

template<typename T, typename AccumT, typename OutT>
struct SoftMaxForwardWithMulEpilogue {
__device__ __forceinline__ SoftMaxForwardWithMulEpilogue(AccumT max_input, AccumT sum)
: max_input(max_input)
, sum(sum) {}

__device__ __forceinline__ OutT operator()(T input) const {
return static_cast<OutT>(__expf(input - max_input) * sum);
}

const AccumT max_input;
const AccumT sum;
};




Expand Down Expand Up @@ -387,6 +401,19 @@ struct SumExpFloat
const AccumT max_k;
};

template<typename T, typename AccumT>
struct SumExpfFloat
{
__device__ __forceinline__ SumExpfFloat(AccumT v)
: max_k(v) {}

__device__ __forceinline__ AccumT operator()(AccumT sum, T v) const {
return sum + __expf(v - max_k);
}

const AccumT max_k;
};

template <template<typename> class Reduction, typename AccumT>
__device__ __forceinline__ AccumT
blockReduce(AccumT* smem, AccumT val,
Expand Down Expand Up @@ -449,6 +476,18 @@ T blockReduceWarp(T* smem_cache, T value, const Reduction<T>& op, T defaultVal)
return smem_cache[0];
}

template <template<typename> class Reduction, typename T>
__device__ __forceinline__
T blockReduceWarpInverse(T* smem_cache, T value, const Reduction<T>& op, T defaultVal)
{
T result = cuda_utils::BlockReduce<T, Reduction<T>>(value, op, defaultVal, smem_cache);
if (threadIdx.x == 0) {
smem_cache[0] = 1 / result;
}
__syncthreads();
return smem_cache[0];
}

template <template<typename, typename> class Reduction, int ILP, typename T, typename AccumT, typename index_t=int>
__device__ __forceinline__ AccumT
ilpReduce(index_t shift,
Expand Down Expand Up @@ -694,6 +733,67 @@ cunn_SoftMaxForward(outscalar_t *output, const scalar_t *input, int classes)
}
}

template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t,
template <typename, typename, typename> class EpilogueWithMul, typename index_t = int32_t>
__global__ void
cunn_SoftMaxForwardGmem(outscalar_t *output, const scalar_t *input, index_t classes)
{
// Each thread block processes a sample in the batch
input += static_cast<int64_t>(blockIdx.x) * classes;
output += static_cast<int64_t>(blockIdx.x) * classes;

accscalar_t threadMax = -at::numeric_limits<accscalar_t>::max();
accscalar_t threadExp = static_cast<accscalar_t>(0);

// The first smem segment is used to cache input values and the last
// segment is used for thread block reductions
extern __shared__ unsigned char smem[];
auto smem_reduction_cache = reinterpret_cast<accscalar_t*>(smem);

using LoadT = at::native::memory::aligned_vector<scalar_t, ILP>;
const LoadT* const input_vec_ptr = reinterpret_cast<const LoadT*>(input);

// Do the first step in max calculation:
MaxFloat<scalar_t, accscalar_t> maxFunc;
for (index_t offset = threadIdx.x; offset * ILP < classes; offset += blockDim.x) {
LoadT crnt_vec = input_vec_ptr[offset];
#pragma unroll
for (int i = 0; i < ILP; ++i) {
threadMax = maxFunc(threadMax, crnt_vec.val[i]);
}
}

accscalar_t max_k = blockReduceWarp<Max, accscalar_t>(smem_reduction_cache, threadMax,
Max<accscalar_t>(), -at::numeric_limits<accscalar_t>::max());

// Do the second step in sum exp calculation:
SumExpfFloat<scalar_t, accscalar_t> sumExpFunc(max_k);
for (index_t offset = threadIdx.x; offset * ILP < classes; offset += blockDim.x) {
LoadT crnt_vec = input_vec_ptr[offset];
#pragma unroll
for (int i = 0; i < ILP; ++i) {
threadExp = sumExpFunc(threadExp, crnt_vec.val[i]);
}
}

accscalar_t sumAll = blockReduceWarpInverse<Add, accscalar_t>(smem_reduction_cache, threadExp,
Add<accscalar_t>(), static_cast<accscalar_t>(0));

EpilogueWithMul<scalar_t, accscalar_t, outscalar_t> epilogue(max_k, sumAll);

using StoreT = at::native::memory::aligned_vector<outscalar_t, ILP>;
StoreT* output_vec_ptr = reinterpret_cast<StoreT*>(output);
for (index_t offset = threadIdx.x; offset * ILP < classes; offset += blockDim.x) {
LoadT crnt_vec = input_vec_ptr[offset];
StoreT out_vec;
#pragma unroll
for (int i = 0; i < ILP; ++i) {
out_vec.val[i] = epilogue(crnt_vec.val[i]);
}
output_vec_ptr[offset] = out_vec;
}
}

template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t,
template <typename, typename, typename> class Epilogue, typename index_t = int32_t>
__global__ void
Expand Down Expand Up @@ -816,7 +916,8 @@ cunn_SoftMaxBackward(scalar_t *gradInput, const outscalar_t *output, const outsc
}
}

template<template<typename, typename, typename> class Epilogue, bool is_log_softmax>
template<template<typename, typename, typename> class Epilogue,
template<typename, typename, typename> class EpilogueWithMul, bool is_log_softmax, bool use_fast_softmax>
Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_to_float, const Tensor& output){
if (half_to_float) {
TORCH_CHECK(input_.scalar_type() == ScalarType::Half, "conversion is supported for Half type only");
Expand Down Expand Up @@ -858,23 +959,30 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t
}
} else {
constexpr int ILP = sizeof(float4) / sizeof(scalar_t);
dim3 block = SoftMaxForward_getBlockSize(dim_size);
size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof(accscalar_t);
auto max_elements_per_smem = (at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock -
smem_reduction_sz) / sizeof(scalar_t);

bool can_use_smem = dim_size < max_elements_per_smem;
can_use_smem &= !(reinterpret_cast<const uintptr_t>(input_ptr) % ALIGN_BYTES);
can_use_smem &= (!(reinterpret_cast<uintptr_t>(output_ptr) % ALIGN_BYTES));
can_use_smem &= !(dim_size % ILP);

if (can_use_smem) {
size_t smem_sz = dim_size * sizeof(scalar_t) + smem_reduction_sz;
cunn_SoftMaxForwardSmem<ILP, scalar_t, accscalar_t, scalar_t, Epilogue>
<<<grid, block, smem_sz, stream>>>(output_ptr, input_ptr, dim_size);
} else {
cunn_SoftMaxForward<ILP, scalar_t, accscalar_t, scalar_t, Epilogue>
if (use_fast_softmax) {
dim3 block(512);
size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof(accscalar_t);
cunn_SoftMaxForwardGmem<ILP, scalar_t, accscalar_t, scalar_t, EpilogueWithMul>
<<<grid, block, smem_reduction_sz, stream>>>(output_ptr, input_ptr, dim_size);
} else {
dim3 block = SoftMaxForward_getBlockSize(dim_size);
size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof(accscalar_t);
auto max_elements_per_smem = (at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock -
smem_reduction_sz) / sizeof(scalar_t);

bool can_use_smem = (size_t) dim_size < max_elements_per_smem;
can_use_smem &= !(reinterpret_cast<uintptr_t>(input_ptr) % ALIGN_BYTES);
can_use_smem &= (!(reinterpret_cast<uintptr_t>(output_ptr) % ALIGN_BYTES));
can_use_smem &= !(dim_size % ILP);

if (can_use_smem) {
size_t smem_sz = dim_size * sizeof(scalar_t) + smem_reduction_sz;
cunn_SoftMaxForwardSmem<ILP, scalar_t, accscalar_t, scalar_t, Epilogue>
<<<grid, block, smem_sz, stream>>>(output_ptr, input_ptr, dim_size);
} else {
cunn_SoftMaxForward<ILP, scalar_t, accscalar_t, scalar_t, Epilogue>
<<<grid, block, smem_reduction_sz, stream>>>(output_ptr, input_ptr, dim_size);
}
}

C10_CUDA_KERNEL_LAUNCH_CHECK();
Expand All @@ -894,23 +1002,30 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t
}
} else {
constexpr int ILP = sizeof(float4) / sizeof(scalar_t);
dim3 block = SoftMaxForward_getBlockSize(dim_size);
size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof(accscalar_t);
auto max_elements_per_smem = (at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock -
smem_reduction_sz) / sizeof(scalar_t);

bool can_use_smem = dim_size < max_elements_per_smem;
can_use_smem &= !(reinterpret_cast<const uintptr_t>(input_ptr) % ALIGN_BYTES);
can_use_smem &= (!(reinterpret_cast<uintptr_t>(output_ptr) % ALIGN_BYTES));
can_use_smem &= !(dim_size % ILP);

if (can_use_smem) {
size_t smem_sz = dim_size * sizeof(scalar_t) + smem_reduction_sz;
cunn_SoftMaxForwardSmem<ILP, scalar_t, accscalar_t, accscalar_t, Epilogue>
<<<grid, block, smem_sz, stream>>>(output_ptr, input_ptr, dim_size);
} else {
cunn_SoftMaxForward<ILP, scalar_t, accscalar_t, accscalar_t, Epilogue>
if (use_fast_softmax) {
dim3 block(512);
size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof(accscalar_t);
cunn_SoftMaxForwardGmem<ILP, scalar_t, accscalar_t, accscalar_t, EpilogueWithMul>
<<<grid, block, smem_reduction_sz, stream>>>(output_ptr, input_ptr, dim_size);
} else {
dim3 block = SoftMaxForward_getBlockSize(dim_size);
size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof(accscalar_t);
auto max_elements_per_smem = (at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock -
smem_reduction_sz) / sizeof(scalar_t);

bool can_use_smem = (size_t) dim_size < max_elements_per_smem;
can_use_smem &= !(reinterpret_cast<uintptr_t>(input_ptr) % ALIGN_BYTES);
can_use_smem &= (!(reinterpret_cast<uintptr_t>(output_ptr) % ALIGN_BYTES));
can_use_smem &= !(dim_size % ILP);

if (can_use_smem) {
size_t smem_sz = dim_size * sizeof(scalar_t) + smem_reduction_sz;
cunn_SoftMaxForwardSmem<ILP, scalar_t, accscalar_t, accscalar_t, Epilogue>
<<<grid, block, smem_sz, stream>>>(output_ptr, input_ptr, dim_size);
} else {
cunn_SoftMaxForward<ILP, scalar_t, accscalar_t, accscalar_t, Epilogue>
<<<grid, block, smem_reduction_sz, stream>>>(output_ptr, input_ptr, dim_size);
}
}

C10_CUDA_KERNEL_LAUNCH_CHECK();
Expand Down Expand Up @@ -1069,7 +1184,7 @@ TORCH_IMPL_FUNC(log_softmax_cuda_out) (
const int64_t dim,
const bool half_to_float,
const Tensor &output) {
host_softmax<LogSoftMaxForwardEpilogue,true>(input, dim, half_to_float, output);
host_softmax<LogSoftMaxForwardEpilogue, LogSoftMaxForwardEpilogue, true, false>(input, dim, half_to_float, output);
}

TORCH_IMPL_FUNC(log_softmax_backward_cuda_out) (
Expand All @@ -1093,7 +1208,11 @@ TORCH_IMPL_FUNC(softmax_cuda_out) (
const int64_t dim,
const bool half_to_float,
const Tensor &output) {
host_softmax<SoftMaxForwardEpilogue,false>(input, dim, half_to_float, output);
#if defined(USE_ROCM) && defined(PYTORCH_USE_FAST_SOFTMAX)
host_softmax<SoftMaxForwardEpilogue, SoftMaxForwardWithMulEpilogue, false, true>(input, dim, half_to_float, output);
#else
host_softmax<SoftMaxForwardEpilogue, SoftMaxForwardWithMulEpilogue, false, false>(input, dim, half_to_float, output);
#endif
}

TORCH_IMPL_FUNC(softmax_backward_cuda_out)
Expand Down