Skip to content

Commit

Permalink
[experimental][FP16] Add native __half support for sum_functor
Browse files Browse the repository at this point in the history
During cmake step set EnVar `PYTORCH_REDUCESUM_ENABLE_NATIVE_HALF=1`.
Enables experimental support of FP16 for `sum_functor`.
That is, operator+ will utilize __half types directly,
instead of using static_cast<float> of arguments.

Note: only additions are affected by these changes.
  • Loading branch information
mhalk committed Nov 29, 2024
1 parent de3e990 commit 2bfb4fc
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 2 deletions.
4 changes: 4 additions & 0 deletions aten/src/ATen/native/cuda/ReduceSumProdKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,11 @@ template <
typename GeneralDispatcher>
static void reduce_dispatch(TensorIterator& iter, GeneralDispatcher op) {
if (iter.dtype() == kHalf) {
#ifdef PYTORCH_REDUCESUM_ENABLE_NATIVE_HALF
return OpFunctor<at::Half, at::Half>{}(iter);
#else
return OpFunctor<at::Half, float>{}(iter);
#endif
} else if (iter.dtype(1) == kHalf && iter.dtype() == kFloat) {
// type promotion that does cast and reduction in a single kernel
return OpFunctor<at::Half, float, float>{}(iter);
Expand Down
5 changes: 5 additions & 0 deletions c10/util/Half-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,12 @@ inline __device__ Half __ldg(const Half* ptr) {
/// Arithmetic

inline C10_HOST_DEVICE Half operator+(const Half& a, const Half& b) {
#if (defined(__CUDACC__) || defined(__HIPCC__)) && \
defined(PYTORCH_REDUCESUM_ENABLE_NATIVE_HALF)
return __half{a} + __half{b};
#else
return static_cast<float>(a) + static_cast<float>(b);
#endif
}

inline C10_HOST_DEVICE Half operator-(const Half& a, const Half& b) {
Expand Down
19 changes: 17 additions & 2 deletions cmake/Dependencies.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1000,6 +1000,12 @@ if(USE_CUDNN)
target_include_directories(torch::cudnn INTERFACE ${CUDNN_FRONTEND_INCLUDE_DIR})
endif()

# Note: This variable also affects CUDA.
set(PYTORCH_REDUCESUM_ENABLE_NATIVE_HALF
$ENV{PYTORCH_REDUCESUM_ENABLE_NATIVE_HALF}
CACHE BOOL "Enable native support for half data type within ReduceSum." FORCE)


# ---[ HIP
if(USE_ROCM)
# This prevents linking in the libtinfo from /opt/conda/lib which conflicts with ROCm libtinfo.
Expand Down Expand Up @@ -1042,7 +1048,11 @@ if(USE_ROCM)
list(APPEND HIP_CXX_FLAGS -D__HIP_PLATFORM_AMD__=1)
list(APPEND HIP_CXX_FLAGS -DCUDA_HAS_FP16=1)
list(APPEND HIP_CXX_FLAGS -DUSE_ROCM)
list(APPEND HIP_CXX_FLAGS -D__HIP_NO_HALF_OPERATORS__=1)
if(NOT PYTORCH_REDUCESUM_ENABLE_NATIVE_HALF)
list(APPEND HIP_CXX_FLAGS -D__HIP_NO_HALF_OPERATORS__=1)
else()
add_definitions(-DPYTORCH_REDUCESUM_ENABLE_NATIVE_HALF)
endif()
list(APPEND HIP_CXX_FLAGS -D__HIP_NO_HALF_CONVERSIONS__=1)
list(APPEND HIP_CXX_FLAGS -DTORCH_HIP_VERSION=${TORCH_HIP_VERSION})
list(APPEND HIP_CXX_FLAGS -Wno-shift-count-negative)
Expand Down Expand Up @@ -1369,11 +1379,16 @@ if(NOT INTERN_BUILD_MOBILE)

message(STATUS "Found CUDA with FP16 support, compiling with torch.cuda.HalfTensor")
string(APPEND CMAKE_CUDA_FLAGS " -DCUDA_HAS_FP16=1"
" -D__CUDA_NO_HALF_OPERATORS__"
" -D__CUDA_NO_HALF_CONVERSIONS__"
" -D__CUDA_NO_HALF2_OPERATORS__"
" -D__CUDA_NO_BFLOAT16_CONVERSIONS__")

if(NOT PYTORCH_REDUCESUM_ENABLE_NATIVE_HALF)
string(APPEND CMAKE_CUDA_FLAGS " -D__CUDA_NO_HALF_OPERATORS__")
else()
add_definitions(-DPYTORCH_REDUCESUM_ENABLE_NATIVE_HALF)
endif()

string(APPEND CMAKE_C_FLAGS_RELEASE " -DNDEBUG")
string(APPEND CMAKE_CXX_FLAGS_RELEASE " -DNDEBUG")
if(NOT GENERATOR_IS_MULTI_CONFIG)
Expand Down
9 changes: 9 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,10 @@
# USE_ROCM_KERNEL_ASSERT=1
# Enable kernel assert in ROCm platform
#
# PYTORCH_REDUCESUM_ENABLE_NATIVE_HALF
# If set to '1' will enable native support for FP16 datatypes in certain functors.
# Note: Currently, this is considered experimental and will only affect reductions.
#
# Environment variables we respect (these environment variables are
# conventional and are often understood/set by other software.)
#
Expand Down Expand Up @@ -676,6 +680,11 @@ def run(self):
else:
report("-- Not using ITT")

if cmake_cache_vars["PYTORCH_ENABLE_HALF"]:
report("-- Using native FP16 support")
else:
report("-- Not using native FP16 support")

# Do not use clang to compile extensions if `-fstack-clash-protection` is defined
# in system CFLAGS
c_flags = str(os.getenv("CFLAGS", ""))
Expand Down

0 comments on commit 2bfb4fc

Please sign in to comment.