Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] CUTLASS CollectiveMMA for cublasDx #233

Open
osayamenja opened this issue Nov 7, 2024 · 1 comment
Open

[RFC] CUTLASS CollectiveMMA for cublasDx #233

osayamenja opened this issue Nov 7, 2024 · 1 comment
Labels

Comments

@osayamenja
Copy link

osayamenja commented Nov 7, 2024

Hello!

What was the motivation for using cooperative_gemm for the internals of cublasDx? I ask, because due to the points below, CollectiveMMA seems to be a much more compelling option for general-pupose, large GEMMs, which I believe to be the goal of cublasDx.

TL;DR

✅ = winner for that category, doubly is a tie
🥹 = loser for that category
c_gemm = cooperative_gemm

Metrics CollectiveMMA cublasDx (c_gemm)
Performance 🥹
Shared Memory Demand 🥹
Register Pressure
Ease-of-use

Comparison

All experiments run on an A100 80GB GPU.

Performance At Scale

M = 4096, N = 4096, K = 64 (small K due to cublasDX limits), precision = tf32, tf32, float, bM = 128, bN = 64, bK = 64.

  • multiblock_gemm: 1.1510 ms
  • CUTLASS sgemm_80 with TiledMMA and bK=64: 0.1836 ms
  • CUTLASS with TiledMMA and bK = 8: 0.094ms

Diffs

Here is my diff for the CUTLASS code, in case you want to replicate these results.

diff
// inside gemm_tn
 auto bM = Int<128>{};
-  auto bN = Int<128>{};
+  auto bN = Int<64>{};
 auto bK = Int<  8>{};
 auto cta_tiler = make_shape(bM, bN, bK);                   // (BLK_M, BLK_N, BLK_K)
-  auto bP = Int<3>{};  // Pipeline
+  auto bP = Int<1>{};  // Pipeline

 // Define the smem layouts (static)
 auto sA_atom                  = make_layout(make_shape (      bM,          bK),
                                             make_stride(Int<1>{}, bM+Int<1>{})); // (m,k) -> smem_idx; padded m-major
 [[maybe_unused]] auto sB_atom = make_layout(make_shape (      bN,          bK),
                                             make_stride(Int<1>{}, bN+Int<1>{})); // (n,k) -> smem_idx; padded n-major
-  auto sA = tile_to_shape(sA_atom, make_shape(bM, bK, bP));
-  auto sB = tile_to_shape(sA_atom, make_shape(bN, bK, bP));
+  auto sA = make_layout(make_shape(bM, bK, bP), LayoutRight{});
+  auto sB = make_layout(make_shape(bN, bK, bP), LayoutRight{});
 auto sC = make_layout(make_shape(bM, bN));                        // (m,n) -> smem_idx

 // Define the thread layouts (static)

 TiledCopy copyA = make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<TA>, TA>{},
-                                    Layout<Shape<_32,_8>,Stride<_8,_1>>{}, // Thr layout 32x8 k-major
+                                    Layout<Shape<_16,_8>,Stride<_8,_1>>{}, // Thr layout 16x8 k-major
                                   Layout<Shape< _1,_1>>{});              // Val layout  1x1
 TiledCopy copyB = make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<TB>, TB>{},
-                                    Layout<Shape<_32,_8>,Stride<_8,_1>>{}, // Thr layout 32x8 k-major
+                                    Layout<Shape<_16,_8>,Stride<_8,_1>>{}, // Thr layout 16x8 k-major
                                   Layout<Shape< _1,_1>>{});              // Val layout  1x1

-  TiledMMA mmaC = make_tiled_mma(UniversalFMA<TC,TA,TB>{},
-                                 Layout<Shape<_16,_16,_1>>{});  // 16x16x1 TiledMMA

+  TiledMMA mmaC = TiledMMA<
+      MMA_Atom<SM80_16x8x8_F32TF32TF32F32_TN>,
+      Layout<Shape<_2, _2, _1>,
+      Tile<_32, _32, _8>
+  >
+    >{};

Below is the diff for cublasDx. I had to make changes to enable different types.

diff to support tf32, tf32, float
diff --git a/MathDx/cuBLASDx/multiblock_gemm.cu b/MathDx/cuBLASDx/multiblock_gemm.cu
index ad13ce1..0c2e0ea 100644
--- a/MathDx/cuBLASDx/multiblock_gemm.cu
+++ b/MathDx/cuBLASDx/multiblock_gemm.cu
@@ -36,7 +36,7 @@ constexpr auto get_stride_from_arrangement() {
     }
 }
 
-template <int GridX, int GridY, class BlockMM, class MatrixA, class MatrixB, class MatrixC, class ValueType = typename example::uniform_value_type_t<BlockMM>>
+template <int GridX, int GridY, class BlockMM, class MatrixA, class MatrixB, class MatrixC, class ValueType>
 __launch_bounds__(BlockMM::max_threads_per_block) __global__
 void block_mm_kernel(MatrixA gA, MatrixB gB, MatrixC gC, ValueType alpha, ValueType beta) {
     using value_type = ValueType;
@@ -102,7 +102,10 @@ int benchmark_multiblock_gemm(const cudaStream_t& stream, bool verbose = false)
 
     constexpr bool set_block_size{BlockSize > 0};
     using block_mm_type = std::conditional_t<set_block_size, decltype(BlockMM() + BlockDim<BlockSize>()), BlockMM>;
-    using value_type = typename example::uniform_value_type_t<block_mm_type>;
+    //using value_type = typename example::uniform_value_type_t<block_mm_type>;
+    using typeA = typename BlockMM::a_value_type;
+    using typeB = typename BlockMM::b_value_type;
+    using typeC = typename BlockMM::c_value_type;
 
     constexpr auto M = cublasdx::size_of<GlobalMM>::m;
     constexpr auto N = cublasdx::size_of<GlobalMM>::n;
@@ -112,26 +115,27 @@ int benchmark_multiblock_gemm(const cudaStream_t& stream, bool verbose = false)
     constexpr auto c_size = GlobalMM::c_size;
 
     // Allocate device memory for A, B, C.
-    value_type* inputs;
+    cuda::std::byte* inputs;
 
-    auto inputs_size       = a_size + b_size + c_size;
-    auto inputs_size_bytes = inputs_size * sizeof(value_type);
+    //auto inputs_size       = a_size + b_size + c_size;
+    constexpr auto abSize = a_size * sizeof(typeA) + b_size * sizeof(typeB);
+    constexpr auto inputs_size_bytes = abSize + c_size*sizeof(typeC);
     CUDA_CHECK_AND_EXIT(cudaMalloc(&inputs, inputs_size_bytes));
 
-    value_type* a     = inputs;
-    value_type* b     = a + a_size;
-    value_type* c     = b + b_size;
+    auto* a     = static_cast<typeA*>(static_cast<void*>(inputs));
+    auto* b     = static_cast<typeB*>(static_cast<void*>(inputs + a_size * sizeof(typeA)));
+    auto* c     = static_cast<typeC*>(static_cast<void*>(inputs + abSize));
 
-    value_type  alpha = example::make_value<value_type>(1.f);
-    value_type  beta  = example::make_value<value_type>(0.f);
+    auto  alpha = example::make_value<typeC>(1.f);
+    auto  beta  = example::make_value<typeC>(0.f);
 
     // Fill the A, B, C matrices with random values.
-    auto host_a = example::get_random_data<value_type>(0.1, 1.0, a_size);
-    auto host_b = example::get_random_data<value_type>(0.1, 1.0, b_size);
-    auto host_c = example::get_random_data<value_type>(0.1, 1.0, c_size);
-    CUDA_CHECK_AND_EXIT(cudaMemcpy(a, host_a.data(), a_size * sizeof(value_type), cudaMemcpyHostToDevice));
-    CUDA_CHECK_AND_EXIT(cudaMemcpy(b, host_b.data(), b_size * sizeof(value_type), cudaMemcpyHostToDevice));
-    CUDA_CHECK_AND_EXIT(cudaMemcpy(c, host_c.data(), c_size * sizeof(value_type), cudaMemcpyHostToDevice));
+    auto host_a = example::get_random_data<typeA>(0.1, 1.0, a_size);
+    auto host_b = example::get_random_data<typeB>(0.1, 1.0, b_size);
+    auto host_c = example::get_random_data<typeC>(0.1, 1.0, c_size);
+    CUDA_CHECK_AND_EXIT(cudaMemcpy(a, host_a.data(), a_size * sizeof(typeA), cudaMemcpyHostToDevice));
+    CUDA_CHECK_AND_EXIT(cudaMemcpy(b, host_b.data(), b_size * sizeof(typeB), cudaMemcpyHostToDevice));
+    CUDA_CHECK_AND_EXIT(cudaMemcpy(c, host_c.data(), c_size * sizeof(typeC), cudaMemcpyHostToDevice));
     CUDA_CHECK_AND_EXIT(cudaDeviceSynchronize());
 
     constexpr auto a_arrangement = cublasdx::arrangement_of<BlockMM>::a;
@@ -149,7 +153,7 @@ int benchmark_multiblock_gemm(const cudaStream_t& stream, bool verbose = false)
     constexpr dim3 grid{M / cublasdx::size_of<block_mm_type>::m, N / cublasdx::size_of<block_mm_type>::n};
 
     // Increase max dynamic shared memory for the kernel if needed.
-    auto kernel = block_mm_kernel<grid.x, grid.y, block_mm_type, a_type, b_type, c_type>;
+    auto kernel = block_mm_kernel<grid.x, grid.y, block_mm_type, a_type, b_type, c_type, typeC>;
 
     CUDA_CHECK_AND_EXIT(
         cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, block_mm_type::shared_memory_size));
@@ -169,12 +173,12 @@ int benchmark_multiblock_gemm(const cudaStream_t& stream, bool verbose = false)
 
     double avg_time = time / kernel_repeats;
 
-    double gflops = example::gemm_flops<value_type>(M, N, K) / avg_time / 1000000.;
+    double gflops = example::gemm_flops<typeC>(M, N, K) / avg_time / 1000000.;
 
     // Copy results back to host.
-    std::vector<value_type> host_output(c_size);
+    std::vector<typeC> host_output(c_size);
     CUDA_CHECK_AND_EXIT(
-        cudaMemcpy(host_output.data(), c, c_size * sizeof(value_type), cudaMemcpyDeviceToHost));
+        cudaMemcpy(host_output.data(), c, c_size * sizeof(typeC), cudaMemcpyDeviceToHost));
     CUDA_CHECK_AND_EXIT(cudaDeviceSynchronize());
 
     // Free device memory.
@@ -185,8 +189,8 @@ int benchmark_multiblock_gemm(const cudaStream_t& stream, bool verbose = false)
                   << std::endl;
         std::cout << "Block m, n, k: " << size_of<block_mm_type>::m << ", " << size_of<block_mm_type>::n << ", " << size_of<block_mm_type>::k
                   << std::endl;
-        std::cout << "Type: " << example::type_string<value_type>() << std::endl;
-        std::cout << "Precision: " << example::precision_string<value_type>() << std::endl;
+        std::cout << "Type C: " << example::type_string<typeC>() << std::endl;
+        std::cout << "Precision C: " << example::precision_string<typeC>() << std::endl;
         std::cout << "Block size: " << block_mm_type::block_dim.x << std::endl;
         std::cout << "Grid dimensions: " << grid.x << ", " << grid.y << ", " << grid.z << std::endl;
         std::cout << "Shared memory: " << block_mm_type::shared_memory_size << std::endl;
@@ -196,7 +200,7 @@ int benchmark_multiblock_gemm(const cudaStream_t& stream, bool verbose = false)
         std::cout << "Performance [GFLOPS]: " << gflops << std::endl;
     } else {
         std::cout << "(" << size_of<GlobalMM>::m << ", " << size_of<GlobalMM>::n << ", " << size_of<GlobalMM>::k << ") "
-                  << example::precision_string<value_type>() << " precision " <<  example::type_string<value_type>()
+                  << example::precision_string<typeC>() << " precision " <<  example::type_string<typeC>()
                   << ": " << std::fixed << std::setprecision(4) << gflops << " GFLOPS, " << avg_time << " ms." << std::endl;
     }
     // Calculate reference solution.
@@ -233,12 +237,12 @@ int multiblock_gemm() {
     // Parameters M, N, K define the dimensions of matrices A, B, and C. In this example, we choose K to be small enough
     // so that a local block fits into shared memory since we don't partition A and B along the k-mode. Such "tall and
     // skinny" matrices can represent the query, key, or value matrices in an attention operation, as an example.
-    constexpr unsigned int M = 2048;
-    constexpr unsigned int N = 2048;
+    constexpr unsigned int M = 4096;
+    constexpr unsigned int N = 4096;
     constexpr unsigned int K = 64;
 
     // Parameters m, n, k define the dimensions of the local blocks of matrices A, B, and C.
-    constexpr unsigned int m = 64;
+    constexpr unsigned int m = 128;
     constexpr unsigned int n = 64;
     constexpr unsigned int k = K;
 
@@ -248,29 +252,34 @@ int multiblock_gemm() {
     constexpr unsigned int BlockSize = 0;
 
     // Choose precision (__half, float, double) and type (real or complex).
-    using precision = __half;
-    constexpr auto type = cublasdx::type::complex;
+    using precision = float;
+    constexpr auto type = cublasdx::type::real;
 
     // Choose arrangement for A and B: row-major or column-major.
     // Conjugation or other element-wise transformations can be passed to the execute function.
     constexpr auto a_arrangement = cublasdx::col_major;
     constexpr auto b_arrangement = cublasdx::row_major;
 
+    using aValueType = tfloat32_t;
+    using bValueType = aValueType;
+    using outValueType = float;
+
     // Define the local matrix multiplication operation.
     using BlockMM  = decltype(
         cublasdx::Size<m, n, k>() +
-        cublasdx::Precision<precision>() +
+        cublasdx::Precision<aValueType, bValueType, outValueType>() +
         cublasdx::Type<type>() +
         cublasdx::Function<cublasdx::function::MM>() +
         cublasdx::Arrangement<a_arrangement, b_arrangement>() +
         cublasdx::Block() +
+        cublasdx::BlockDim<128>() +
         cublasdx::SM<Arch>()
     );
     // The global matrix multiplication operation provides a convenient way to encapsulate data of interest: (M, N, K),
     // matrix sizes, leading dimensions, etc. It cannot be executed since the problem is too large to fit into shared memory.
     using GlobalMM = decltype(
         cublasdx::Size<M, N, K>() +
-        cublasdx::Precision<precision>() +
+        cublasdx::Precision<aValueType, bValueType, outValueType>() +
         cublasdx::Type<type>() +
         cublasdx::Function<cublasdx::function::MM>() +
         cublasdx::Arrangement<a_arrangement, b_arrangement>() +

Results

Also here is the result summary of running multi_block.

cublasDx result summary
Global M, N, K: 4096, 4096, 64
Block m, n, k: 128, 64, 64
Type C: real
Precision C: float
Block size: 128
Grid dimensions: 32, 64, 1
Shared memory: 81920
Avg time [ms]: 1.1510
Time (all) [ms]: 5.7549
Performance [GFLOPS]: 1865.7936
The results are verified to be correct.

CUTLASS is below

CUTLASS result summary
M = 4096
N = 4096
K = 64
C = A^T B^N
CUTE_GEMM:     [22875.9]GFlop/s  (0.0939)ms

Prologue

It is clear that the CollectiveMMA is an order of magnitude faster here (bK = 8). On the one hand, this is not necessarily an apples to apples comparison, given the invocation overhead of cublasDx and that the underlying cooperative_gemm does more work (albeit unrequested), namely a register to shared copy, so the question becomes do these alone account for the severe performance gap?

I do not think so, I hope to elaborate why in the following points. Note that CollectiveMMA has a more richer space of architecture-specfic parameters like gmem tiled copy layouts, pipelining, via register double-buffering (Volta) or shared memory (Ampere and above), and much more sophisticated configurations for Hopper.

Ease-of-use

Given the vast parameter space for configuring ColletiveMMA templates, it is definitely much less trivial to get it right compared to the simpler API of the cooperative_gemm. For example, CollectiveMMA demands that you specify the CopyAtom and GmemTiledCopy while cooperative_gemm doesn't.

However, in my opinion, the above constraint applies to the lay user not the experienced engineers developing cublasDx. This is the case, because they have already tackled this complexity through their suggest_* APIs, which translate a high-level GEMM description to optimal, architecture-specific template parameters.

Thus, neither implementation seems to be more complex for the expert CUDA developer.

Shared Memory Demand

This is where CollectiveMMA really shines.

For example, let's use these block dimensions bM=128, bN=128, bK=8 and floating-point precision and blockDim of 128 threads.

With these parameters, the CollectiveMMA with no shared memory pipelining only demands 8K = $(128\cdot 8 \cdot 4 \cdot 2)$ shared memory per block to store A and B matrices. Alternatively, the memory amount scales with the pipeline stages, so at 2 stages (minimum amount for Ampere), shared memory would go up to 16K.

On the other hand, cublasDx, in addition stores the C matrix in shared memory; thus, requires 8K + ($128\cdot 128 \cdot 4$) = 72 K, a 9X increase.

This may not seem like a lot, but the implications for block scheduling and occupancy are worth considering.

Continuing with this example, let's use the Volta architecture, where maximum shared memory per SM is 96K. In this architecture, only one block would be scheduled per SM, when I run a GEMM using cublasDx.

On the other hand, with the CollectiveMMA, you can fit 5 blocks per Volta SM. Note, you cannot fit 96/8 = 12 blocks because of register pressure, which I explain next.

But the point is clear, more active blocks per SM translates to better performance.

I believe this is the major cause for the poor performance of cublasDx vs collectiveMMA, as both APIs have largely identical underlying code, invoking cute::gemm underneath.

K Limitations

More importantly, this limitation constrains cublasDx currently to only matrix sizes, where k <= 196.

Any larger k requires the user to write their own k-reduction logic, defeating the purpose of using a library as cublasDx, because most likely than not, writing high-performance GEMM CUDA code is a futile exercise for non-experts.

Registers

Both collectiveMMA and cooperative_gemm require the same amount of registers, specifically, the registers used as A, B, and C operands for the same underlying gemm function, namely using CuTe's notation, tCrA, tCrB, and tCrC. For the CollectiveMMA done in the performance comparison above, register demand (size(tCrA) + size(tCrB) + size(tCrC)) is about 96 per thread (about 128 for no spill loads), hence why we can only fit 5 blocks per Volta SM.

Key

The key difference is that after the GEMM, collectiveMMA leaves matrix C in register memory, this can allow for more efficient epilogue operations as the roundtrip to and from shared memory is bypassed. On the other hand, cooperative_gemm, does only one axpby and then stores the result in shared memory forcing cublasDx to only support only one element-wise epilogue on register memory, while subsequent epilogues must be done on shared memory by the user.

For further motivation, note from here that the shared memory for the C matrix is actually used only for the epilogue axpby and not the core gemm. So, we can completely eliminate that shared memory allocation and instead operate directly on registers, or if needed reuse the shared memory for A and B, since the gemm would be complete by the time the epilogue begins.

Next Steps

To reiterate the initial point, if cublasDx aims to be a library for executing high-performance, general-pupose, large GEMMs, then it seems that CollectiveMMA is the way to go; otherwise, then the existing setup is sufficient. Also, we must keep in mind that cublasDx is still in early access mode, so there may already be ongoing efforts revisiting this design choice.

Finally, I am aware that cublasDx is not necessarily open-source, but as a CUDA geek :), I am more than happy to contribute to making this happen, if the cublasDx team decides in that regard.

@llukas
Copy link
Contributor

llukas commented Nov 18, 2024

Hi, we're working on optimizations that should help with addressing issue you raised. Feel free to drop me an e-mail on [email protected]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants