Skip to content

Commit

Permalink
Swap CUDA grid dimensions for some partitions
Browse files Browse the repository at this point in the history
  • Loading branch information
sriharshakandala committed Dec 13, 2024
1 parent 340603b commit b50891f
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions ext/cuda/data_layouts_threadblock.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ function is_valid_index end
Nv_thread = min(Int(fld(n_max_threads, Nij * Nij)), Nv)
Nv_blocks = cld(Nv, Nv_thread)
@assert prod((Nv_thread, Nij, Nij)) n_max_threads "threads,n_max_threads=($(prod((Nv_thread, Nij, Nij))),$n_max_threads)"
return (; threads = (Nv_thread, Nij, Nij), blocks = (Nv_blocks, Nh))
return (; threads = (Nv_thread, Nij, Nij), blocks = (Nh, Nv_blocks))
end
@inline function universal_index(::Union{DataLayouts.VIJFH, DataLayouts.VIJHF})
(tv, i, j) = CUDA.threadIdx()
Expand Down Expand Up @@ -152,7 +152,7 @@ end
Nv_thread = min(Int(fld(n_max_threads, Ni)), Nv)
Nv_blocks = cld(Nv, Nv_thread)
@assert prod((Nv_thread, Ni)) n_max_threads "threads,n_max_threads=($(prod((Nv_thread, Ni))),$n_max_threads)"
return (; threads = (Nv_thread, Ni), blocks = (Nv_blocks, Nh))
return (; threads = (Nv_thread, Ni), blocks = (Nh, Nv_blocks))
end
@inline function universal_index(::Union{DataLayouts.VIFH, DataLayouts.VIHF})
(tv, i) = CUDA.threadIdx()
Expand Down

0 comments on commit b50891f

Please sign in to comment.