diff --git a/ext/cuda/data_layouts_threadblock.jl b/ext/cuda/data_layouts_threadblock.jl index 6ff4967855..01a9099f79 100644 --- a/ext/cuda/data_layouts_threadblock.jl +++ b/ext/cuda/data_layouts_threadblock.jl @@ -55,7 +55,7 @@ function is_valid_index end Nv_thread = min(Int(fld(n_max_threads, Nij * Nij)), Nv) Nv_blocks = cld(Nv, Nv_thread) @assert prod((Nv_thread, Nij, Nij)) ≤ n_max_threads "threads,n_max_threads=($(prod((Nv_thread, Nij, Nij))),$n_max_threads)" - return (; threads = (Nv_thread, Nij, Nij), blocks = (Nv_blocks, Nh)) + return (; threads = (Nv_thread, Nij, Nij), blocks = (Nh, Nv_blocks)) end @inline function universal_index(::Union{DataLayouts.VIJFH, DataLayouts.VIJHF}) (tv, i, j) = CUDA.threadIdx() @@ -152,7 +152,7 @@ end Nv_thread = min(Int(fld(n_max_threads, Ni)), Nv) Nv_blocks = cld(Nv, Nv_thread) @assert prod((Nv_thread, Ni)) ≤ n_max_threads "threads,n_max_threads=($(prod((Nv_thread, Ni))),$n_max_threads)" - return (; threads = (Nv_thread, Ni), blocks = (Nv_blocks, Nh)) + return (; threads = (Nv_thread, Ni), blocks = (Nh, Nv_blocks)) end @inline function universal_index(::Union{DataLayouts.VIFH, DataLayouts.VIHF}) (tv, i) = CUDA.threadIdx()