Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Broadcast.flatten on master #1782

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ext/ClimaCoreCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module ClimaCoreCUDAExt

import NVTX
import ClimaComms
import ClimaCore: broadcast_flatten
import ClimaCore: DataLayouts, Grids, Spaces, Fields
import ClimaCore: Geometry
import ClimaCore.Geometry: AxisTensor
Expand Down
12 changes: 8 additions & 4 deletions ext/cuda/data_layouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,9 @@ end

function Base.copyto!(
dest::IJFH{S, Nij},
bc::Union{IJFH{S, Nij, A}, Base.Broadcast.Broadcasted{IJFHStyle{Nij, A}}},
bc::Union{IJFH{S, Nij, A}, Base.Broadcast.Broadcasted{IJFHStyle{Nij, A}}},
) where {S, Nij, A <: CUDA.CuArray}
bc = broadcast_flatten(bc′)
_, _, _, _, Nh = size(bc)
if Nh > 0
auto_launch!(
Expand Down Expand Up @@ -99,11 +100,12 @@ end

function Base.copyto!(
dest::VIJFH{S, Nv, Nij},
bc::Union{
bc::Union{
VIJFH{S, Nv, Nij, A},
Base.Broadcast.Broadcasted{VIJFHStyle{Nv, Nij, A}},
},
) where {S, Nv, Nij, A <: CUDA.CuArray}
bc = broadcast_flatten(bc′)
_, _, _, _, Nh = size(bc)
if Nv > 0 && Nh > 0
Nv_per_block = min(Nv, fld(256, Nij * Nij))
Expand Down Expand Up @@ -140,8 +142,9 @@ end

function Base.copyto!(
dest::VF{S, Nv},
bc::Union{VF{S, Nv, A}, Base.Broadcast.Broadcasted{VFStyle{Nv, A}}},
bc::Union{VF{S, Nv, A}, Base.Broadcast.Broadcasted{VFStyle{Nv, A}}},
) where {S, Nv, A <: CUDA.CuArray}
bc = broadcast_flatten(bc′)
_, _, _, _, Nh = size(dest)
if Nv > 0 && Nh > 0
auto_launch!(
Expand Down Expand Up @@ -170,8 +173,9 @@ end

function Base.copyto!(
dest::DataF{S},
bc::Union{DataF{S, A}, Base.Broadcast.Broadcasted{DataFStyle{A}}},
bc::Union{DataF{S, A}, Base.Broadcast.Broadcasted{DataFStyle{A}}},
) where {S, A <: CUDA.CuArray}
bc = broadcast_flatten(bc′)
auto_launch!(
knl_copyto!,
(dest, bc),
Expand Down
3 changes: 2 additions & 1 deletion ext/cuda/operators_finite_difference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@ AbstractStencilStyle(::ClimaComms.CUDADevice) = CUDAColumnStencilStyle

function Base.copyto!(
out::Field,
bc::Union{
bc::Union{
StencilBroadcasted{CUDAColumnStencilStyle},
Broadcasted{CUDAColumnStencilStyle},
},
)
bc = broadcast_flatten(bc′)
space = axes(out)
if space isa Spaces.ExtrudedFiniteDifferenceSpace
QS = Spaces.quadrature_style(space)
Expand Down
3 changes: 2 additions & 1 deletion ext/cuda/operators_spectral_element.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,12 @@ end

function Base.copyto!(
out::Field,
sbc::Union{
sbc::Union{
SpectralBroadcasted{CUDASpectralStyle},
Broadcasted{CUDASpectralStyle},
},
)
sbc = broadcast_flatten(sbc′)
space = axes(out)
QS = Spaces.quadrature_style(space)
Nq = Quadratures.degrees_of_freedom(QS)
Expand Down
1 change: 1 addition & 0 deletions src/ClimaCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using PkgVersion
const VERSION = PkgVersion.@Version
import ClimaComms

include("upstream.jl")
include("interface.jl")
include("devices.jl")
include("Utilities/Utilities.jl")
Expand Down
1 change: 1 addition & 0 deletions src/DataLayouts/DataLayouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import ClimaComms
import MultiBroadcastFusion as MBF
import Adapt

import ..broadcast_flatten
import ..slab, ..slab_args, ..column, ..column_args, ..level
export slab, column, level, IJFH, IJF, IFH, IF, VF, VIJFH, VIFH, DataF

Expand Down
29 changes: 19 additions & 10 deletions src/DataLayouts/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -467,11 +467,12 @@ end
# Performance optimization for the common identity scalar case: dest .= val
function Base.copyto!(
dest::AbstractData,
bc::Base.Broadcast.Broadcasted{Style},
bc::Base.Broadcast.Broadcasted{Style},
) where {
Style <:
Union{Base.Broadcast.AbstractArrayStyle{0}, Base.Broadcast.Style{Tuple}},
}
bc = broadcast_flatten(bc′)
bc = Base.Broadcast.instantiate(
Base.Broadcast.Broadcasted{Style}(bc.f, bc.args, ()),
)
Expand All @@ -481,16 +482,18 @@ end

function Base.copyto!(
dest::DataF{S},
bc::Union{DataF{S, A}, Base.Broadcast.Broadcasted{DataFStyle{A}}},
bc::Union{DataF{S, A}, Base.Broadcast.Broadcasted{DataFStyle{A}}},
) where {S, A}
bc = broadcast_flatten(bc′)
@inbounds dest[] = convert(S, bc[])
return dest
end

function Base.copyto!(
dest::IJFH{S, Nij},
bc::Union{IJFH{S, Nij}, Base.Broadcast.Broadcasted{<:IJFHStyle{Nij}}},
bc::Union{IJFH{S, Nij}, Base.Broadcast.Broadcasted{<:IJFHStyle{Nij}}},
) where {S, Nij}
bc = broadcast_flatten(bc′)
_, _, _, _, Nh = size(bc)
@inbounds for h in 1:Nh
slab_dest = slab(dest, h)
Expand All @@ -502,8 +505,9 @@ end

function Base.copyto!(
dest::IFH{S, Ni},
bc::Union{IFH{S, Ni}, Base.Broadcast.Broadcasted{<:IFHStyle{Ni}}},
bc::Union{IFH{S, Ni}, Base.Broadcast.Broadcasted{<:IFHStyle{Ni}}},
) where {S, Ni}
bc = broadcast_flatten(bc′)
_, _, _, _, Nh = size(bc)
@inbounds for h in 1:Nh
slab_dest = slab(dest, h)
Expand All @@ -516,8 +520,9 @@ end
# inline inner slab(::DataSlab2D) copy
function Base.copyto!(
dest::IJF{S, Nij},
bc::Union{IJF{S, Nij, A}, Base.Broadcast.Broadcasted{IJFStyle{Nij, A}}},
bc::Union{IJF{S, Nij, A}, Base.Broadcast.Broadcasted{IJFStyle{Nij, A}}},
) where {S, Nij, A}
bc = broadcast_flatten(bc′)
@inbounds for j in 1:Nij, i in 1:Nij
idx = CartesianIndex(i, j, 1, 1, 1)
dest[idx] = convert(S, bc[idx])
Expand All @@ -528,8 +533,9 @@ end
# inline inner slab(::DataSlab1D) copy
function Base.copyto!(
dest::IF{S, Ni},
bc::Base.Broadcast.Broadcasted{IFStyle{Ni, A}},
bc::Base.Broadcast.Broadcasted{IFStyle{Ni, A}},
) where {S, Ni, A}
bc = broadcast_flatten(bc′)
@inbounds for i in 1:Ni
idx = CartesianIndex(i, 1, 1, 1, 1)
dest[idx] = convert(S, bc[idx])
Expand All @@ -540,8 +546,9 @@ end
# inline inner column(::DataColumn) copy
function Base.copyto!(
dest::VF{S, Nv},
bc::Union{VF{S, Nv, A}, Base.Broadcast.Broadcasted{VFStyle{Nv, A}}},
bc::Union{VF{S, Nv, A}, Base.Broadcast.Broadcasted{VFStyle{Nv, A}}},
) where {S, Nv, A}
bc = broadcast_flatten(bc′)
@inbounds for v in 1:Nv
idx = CartesianIndex(1, 1, 1, v, 1)
dest[idx] = convert(S, bc[idx])
Expand Down Expand Up @@ -594,8 +601,9 @@ end

function Base.copyto!(
dest::VIFH{S, Nv, Ni},
bc::Base.Broadcast.Broadcasted{VIFHStyle{Nv, Ni, A}},
bc::Base.Broadcast.Broadcasted{VIFHStyle{Nv, Ni, A}},
) where {S, Nv, Ni, A}
bc = broadcast_flatten(bc′)
return _serial_copyto!(dest, bc)
end

Expand Down Expand Up @@ -644,8 +652,9 @@ end

function Base.copyto!(
dest::VIJFH{S, Nv, Nij},
bc::Base.Broadcast.Broadcasted{VIJFHStyle{Nv, Nij, A}},
bc::Base.Broadcast.Broadcasted{VIJFHStyle{Nv, Nij, A}},
) where {S, Nv, Nij, A}
bc = broadcast_flatten(bc′)
return _serial_copyto!(dest, bc)
end

Expand Down Expand Up @@ -674,7 +683,7 @@ function Base.copyto!(
else
bc
end
Pair(pair.first, bc′)
Pair(pair.first, broadcast_flatten(bc′))
end,
)
# check_fused_broadcast_axes(fmbc) # we should already have checked the axes
Expand Down
1 change: 1 addition & 0 deletions src/Fields/Fields.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module Fields

import ClimaComms
import MultiBroadcastFusion as MBF
import ..broadcast_flatten
import ..slab, ..slab_args, ..column, ..column_args, ..level
import ..DataLayouts:
DataLayouts,
Expand Down
6 changes: 4 additions & 2 deletions src/Fields/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,9 @@ end

@inline function Base.copyto!(
dest::Field,
bc::Base.Broadcast.Broadcasted{<:AbstractFieldStyle},
bc::Base.Broadcast.Broadcasted{<:AbstractFieldStyle},
)
bc = broadcast_flatten(bc′)
copyto!(field_values(dest), Base.Broadcast.instantiate(todata(bc)))
return dest
end
Expand All @@ -156,7 +157,8 @@ function Base.copyto!(
) where {N, T <: NTuple{N, Pair{<:Field, <:Any}}}
fmb_data = FusedMultiBroadcast(
map(fmbc.pairs) do pair
bc = Base.Broadcast.instantiate(todata(pair.second))
bc′ = Base.Broadcast.instantiate(todata(pair.second))
bc = broadcast_flatten(bc′)
Pair(field_values(pair.first), bc)
end,
)
Expand Down
6 changes: 4 additions & 2 deletions src/Fields/fieldvector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,9 @@ end

@inline function Base.copyto!(
dest::FieldVector,
bc::Base.Broadcast.Broadcasted{FieldVectorStyle},
bc::Base.Broadcast.Broadcasted{FieldVectorStyle},
)
bc = broadcast_flatten(bc′)
map(propertynames(dest)) do symb
Base.@_inline_meta
p = parent(getfield(_values(dest), symb))
Expand All @@ -304,8 +305,9 @@ end

@inline function Base.copyto!(
dest::FieldVector,
bc::Base.Broadcast.Broadcasted{<:Base.Broadcast.AbstractArrayStyle{0}},
bc::Base.Broadcast.Broadcasted{<:Base.Broadcast.AbstractArrayStyle{0}},
)
bc = broadcast_flatten(bc′)
map(propertynames(dest)) do symb
Base.@_inline_meta
p = parent(getfield(_values(dest), symb))
Expand Down
1 change: 1 addition & 0 deletions src/Operators/Operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using StaticArrays

import Base.Broadcast: Broadcasted

import ..broadcast_flatten
import ..slab, ..slab_args, ..column, ..column_args
import ClimaComms
import ..DataLayouts: DataLayouts, Data2D, DataSlab2D
Expand Down
3 changes: 2 additions & 1 deletion src/Operators/finitedifference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3354,11 +3354,12 @@ end

function Base.copyto!(
field_out::Field,
bc::Union{
bc::Union{
StencilBroadcasted{ColumnStencilStyle},
Broadcasted{ColumnStencilStyle},
},
)
bc = broadcast_flatten(bc′)
space = axes(bc)
local_geometry = Spaces.local_geometry_data(space)
(Ni, Nj, _, _, Nh) = size(local_geometry)
Expand Down
3 changes: 2 additions & 1 deletion src/Operators/spectralelement.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,12 @@ end
# Functions for SlabBlockSpectralStyle
function Base.copyto!(
out::Field,
sbc::Union{
sbc::Union{
SpectralBroadcasted{SlabBlockSpectralStyle},
Broadcasted{SlabBlockSpectralStyle},
},
)
sbc = broadcast_flatten(sbc′)
Fields.byslab(axes(out)) do slabidx
Base.@_inline_meta
@inbounds copyto_slab!(out, sbc, slabidx)
Expand Down
69 changes: 69 additions & 0 deletions src/upstream.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# https://github.com/JuliaArrays/StaticArrays.jl/pull/1186

if VERSION >= v"1.11.0-DEV.103"
const broadcast_flatten = Broadcast.flatten
else
using Base: tail
using Base.Broadcast: isflat, Broadcasted

maybeconstructor(f) = f
maybeconstructor(::Type{F}) where {F} =
(args...; kwargs...) -> F(args...; kwargs...)

broadcast_flatten(bc) = bc
function broadcast_flatten(bc::Broadcasted{Style}) where {Style}
isflat(bc) && return bc
args = cat_nested(bc)
len = Val{length(args)}()
makeargs = make_makeargs(bc.args, len, ntuple(_ -> true, len))
f = maybeconstructor(bc.f)
@inline newf(args...) = f(prepare_args(makeargs, args)...)
return Broadcasted{Style}(newf, args, bc.axes)
end

cat_nested(bc::Broadcasted) = cat_nested_args(bc.args)
cat_nested_args(::Tuple{}) = ()
cat_nested_args(t::Tuple) =
(cat_nested(t[1])..., cat_nested_args(tail(t))...)
cat_nested(@nospecialize(a)) = (a,)

function make_makeargs(args::Tuple, len, flags)
makeargs, r = _make_makeargs(args, len, flags)
r isa Tuple{} || error("Internal error. Please file a bug")
return makeargs
end

# We build `makeargs` by traversing the broadcast nodes recursively.
# note: `len` isa `Val` indicates the length of whole flattened argument list.
# `flags` is a tuple of `Bool` with the same length of the rest arguments.
@inline function _make_makeargs(args::Tuple, len::Val, flags::Tuple)
head, flags′ = _make_makeargs1(args[1], len, flags)
rest, flags″ = _make_makeargs(tail(args), len, flags′)
(head, rest...), flags″
end
_make_makeargs(::Tuple{}, ::Val, x::Tuple) = (), x

# For flat nodes:
# 1. we just consume one argument, and return the "pick" function
@inline function _make_makeargs1(
@nospecialize(a),
::Val{N},
flags::Tuple,
) where {N}
pickargs(::Val{N}) where {N} = (@nospecialize(x::Tuple)) -> x[N]
return pickargs(Val{N - length(flags) + 1}()), tail(flags)
end

# For nested nodes, we form the `makeargs1` based on the child `makeargs` (n += length(cat_nested(bc)))
@inline function _make_makeargs1(bc::Broadcasted, len::Val, flags::Tuple)
makeargs, flags′ = _make_makeargs(bc.args, len, flags)
f = maybeconstructor(bc.f)
@inline makeargs1(@nospecialize(args::Tuple)) =
f(prepare_args(makeargs, args)...)
makeargs1, flags′
end

prepare_args(::Tuple{}, @nospecialize(::Tuple)) = ()
@inline prepare_args(makeargs::Tuple, @nospecialize(x::Tuple)) =
(makeargs[1](x), prepare_args(tail(makeargs), x)...)
end
Loading