diff --git a/torchao/csrc/cuda/sparse_marlin/marlin_kernel_nm.cu b/torchao/csrc/cuda/sparse_marlin/marlin_kernel_nm.cu index 4e9e757ff..380d69130 100644 --- a/torchao/csrc/cuda/sparse_marlin/marlin_kernel_nm.cu +++ b/torchao/csrc/cuda/sparse_marlin/marlin_kernel_nm.cu @@ -852,8 +852,6 @@ __global__ void Marlin_24( } } -#endif - #define CALL_IF_2_4(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ THREAD_K_BLOCKS, GROUP_BLOCKS) \ else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \ @@ -1119,6 +1117,8 @@ torch::Tensor marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, return c; } +#endif + TORCH_LIBRARY_IMPL(torchao, CUDA, m) { m.impl("torchao::marlin_24_gemm", &marlin_24_gemm); }