Skip to content

Commit

Permalink
Extend Remapper to FiniteDifferenceSpaces, update interpolate interface
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Oct 25, 2024
1 parent 7297f5d commit 29616e4
Show file tree
Hide file tree
Showing 3 changed files with 255 additions and 22 deletions.
230 changes: 211 additions & 19 deletions src/Remapping/distributed_remapping.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,20 +116,24 @@ function target_hcoords_pid_bitmask(target_hcoords, topology, pid)
return pid_hcoord.(target_hcoords) .== pid
end


# TODO: define an inner construct and restrict types, as was done in
# https://github.com/CliMA/RRTMGP.jl/pull/352
# to avoid potential compilation issues.
struct Remapper{
CC <: ClimaComms.AbstractCommsContext,
SPACE <: Spaces.AbstractSpace,
T1 <: AbstractArray,
T1, # <: Union{AbstractArray, Nothing},
TARG_Z <: Union{Nothing, AA1} where {AA1 <: AbstractArray},
T3 <: AbstractArray,
T4 <: Tuple,
T5 <: AbstractArray,
T3, # <: Union{AbstractArray, Nothing},
T4, # <: Union{Tuple, Nothing},
T5, # <: Union{AbstractArray, Nothing},
VERT_W <: Union{Nothing, AA2} where {AA2 <: AbstractArray},
VERT_IND <: Union{Nothing, AA3} where {AA3 <: AbstractArray},
T8 <: AbstractArray,
T9 <: AbstractArray,
T8, # <: AbstractArray,
T9, # <: AbstractArray,
T10 <: AbstractArray,
T11 <: Union{Tuple{Colon}, Tuple{Colon, Colon}, Tuple{Colon, Colon, Colon}},
T11, # <: Union{Tuple{Colon}, Tuple{Colon, Colon}, Tuple{Colon, Colon, Colon}},
}

comms_ctx::CC
Expand Down Expand Up @@ -222,11 +226,39 @@ in `interpolate`. When more fields than `buffer_length` are passed, the remapper
the work in sizes of `buffer_length`.
"""
function Remapper(
function Remapper end

Remapper(
space::Spaces.AbstractSpace;
target_hcoords::Union{AbstractArray, Nothing} = nothing,
target_zcoords::Union{AbstractArray, Nothing} = nothing,
buffer_length::Int = 1,
) = _Remapper(space; target_zcoords, buffer_length, target_hcoords)

Remapper(
space::Spaces.FiniteDifferenceSpace;
target_zcoords::AbstractArray,
buffer_length::Int = 1,
) = _Remapper(space; target_zcoords, buffer_length)

Remapper(
space::Spaces.AbstractSpace,
target_hcoords::AbstractArray,
target_zcoords::Union{AbstractArray, Nothing};
buffer_length::Int = 1,
) = _Remapper(space; target_zcoords, buffer_length, target_hcoords)

Remapper(
space::Spaces.AbstractSpace,
target_hcoords::AbstractArray;
buffer_length::Int = 1,
) = _Remapper(space; target_zcoords = nothing, target_hcoords, buffer_length)

function _Remapper(
space::Spaces.AbstractSpace;
target_zcoords::Union{AbstractArray, Nothing},
target_hcoords::AbstractArray,
buffer_length::Int = 1,
)

comms_ctx = ClimaComms.context(space)
Expand Down Expand Up @@ -367,11 +399,48 @@ function Remapper(
)
end

Remapper(
space::Spaces.AbstractSpace,
target_hcoords::AbstractArray;
function _Remapper(
space::Spaces.FiniteDifferenceSpace;
target_zcoords::AbstractArray,
buffer_length::Int = 1,
) = Remapper(space, target_hcoords, nothing; buffer_length)
)

comms_ctx = ClimaComms.context(space)
pid = ClimaComms.mypid(comms_ctx)
FT = Spaces.undertype(space)
ArrayType = ClimaComms.array_type(space)

# We represent interpolation onto an horizontal slab as an empty list of zcoords
vert_interpolation_weights =
ArrayType(vertical_interpolation_weights(space, target_zcoords))
vert_bounding_indices =
ArrayType(vertical_bounding_indices(space, target_zcoords))

# We have to add one extra dimension with respect to the bitmask/local_horiz_indices
# because we are going to store the values for the columns
local_interpolated_values =
ArrayType(zeros(FT, (length(target_zcoords), buffer_length)))
interpolated_values =
ArrayType(zeros(FT, (length(target_zcoords), buffer_length)))
colons = (:,)

return Remapper(
comms_ctx,
space,
nothing, # local_target_hcoords,
target_zcoords,
nothing, # local_target_hcoords_bitmask,
nothing, # local_horiz_interpolation_weights,
nothing, # local_horiz_indices,
vert_interpolation_weights,
vert_bounding_indices,
local_interpolated_values,
nothing, # field_values,
interpolated_values,
buffer_length,
colons,
)
end

"""
_set_interpolated_values!(remapper, field)
Expand Down Expand Up @@ -439,6 +508,37 @@ function set_interpolated_values_cpu_kernel!(
end
end

function set_interpolated_values_cpu_kernel!(
out::AbstractArray,
fields::AbstractArray{<:Fields.Field},
::Nothing,
local_horiz_indices,
vert_interpolation_weights,
vert_bounding_indices,
scratch_field_values,
)
space = axes(first(fields))
FT = Spaces.undertype(space)
for (field_index, field) in enumerate(fields)
field_values = Fields.field_values(field)

# Reading values from field_values is expensive, so we try to limit the number of reads. We can do
# this because multiple target points might be all contained in the same element.
prev_vindex = -1
@inbounds for (vindex, (A, B)) in enumerate(vert_interpolation_weights)
(v_lo, v_hi) = vert_bounding_indices[vindex]
# If we are no longer in the same element, read the field values again
if prev_vindex != vindex
out[vindex, field_index] = (
A * field_values[CartesianIndex(1, 1, 1, v_lo, 1)] +
B * field_values[CartesianIndex(1, 1, 1, v_hi, 1)]
)
prev_vindex = vindex
end
end
end
end

# CPU, 2D case
function set_interpolated_values_cpu_kernel!(
out::AbstractArray,
Expand Down Expand Up @@ -778,12 +878,16 @@ function interpolate(remapper::Remapper, fields)
view(fields, index_field_begin:index_field_end),
)
# Reshape the output so that it is a nice grid.
_apply_mpi_bitmask!(remapper, num_fields)
# Finally, we have to send all the _interpolated_values to root and sum them up to
# obtain the final answer. Only the root will contain something useful. This also
# moves the data off the GPU
ret = _collect_and_return_interpolated_values!(remapper, num_fields)
return ret
if !(remapper.space isa Spaces.FiniteDifferenceSpace)
_apply_mpi_bitmask!(remapper, num_fields)
# Finally, we have to send all the _interpolated_values to root and sum them up to
# obtain the final answer. Only the root will contain something useful. This also
# moves the data off the GPU
_collect_and_return_interpolated_values!(remapper, num_fields)
else
remapper._interpolated_values .= remapper._local_interpolated_values
end
remapper._interpolated_values
end

# Non-root processes
Expand All @@ -794,7 +898,10 @@ function interpolate(remapper::Remapper, fields)
end

"""
interpolate(field::ClimaCore.Fields, target_hcoords, target_zcoords)
interpolate(field::ClimaCore.Fields;
[target_hcoords::AbstractArray],
[target_zcoords::AbstractArray]
)
Interpolate the given fields on the Cartesian product of `target_hcoords` with
`target_zcoords` (if not empty).
Expand All @@ -819,13 +926,98 @@ hcoords = [Geometry.LatLongPoint(lat, long) for long in longpts, lat in latpts]
zcoords = [Geometry.ZPoint(z) for z in zpts]
interpolate(field, hcoords, zcoords)
The `hresolution`, `vresolution` keyword constructors will
interpolate to uniform grids at the given horizontal (`h`)
and vertical (`v`) resolutions, respectively.
```
"""
function interpolate(field::Fields.Field, target_hcoords, target_zcoords)
remapper = Remapper(axes(field), target_hcoords, target_zcoords)
return interpolate(remapper, field)
end

interpolate(
field::Fields.Field;
vresolution = 50,
hresolution = 100,
target_hcoords = get_target_hcoords(axes(field); hresolution),
target_zcoords = get_target_zcoords(axes(field); vresolution),
) = interpolate(field, axes(field); hresolution, vresolution)

get_target_hcoords(space::AbstractSpace; hresolution) =
get_target_hcoords(Spaces.horizontal_space(space), hresolution)

function get_target_hcoords(space::SpectralElementSpace2D; hresolution)
topology = Spaces.topology(space)
mesh = topology.mesh
domain = Meshes.domain(mesh)
PT1 = typeof(domain.interval1.coord_min)
PT2 = typeof(domain.interval2.coord_min)
x1min = Geometry.component(domain.interval1.coord_min, 1)
x2min = Geometry.component(domain.interval2.coord_min, 1)
x1max = Geometry.component(domain.interval1.coord_max, 1)
x2max = Geometry.component(domain.interval2.coord_max, 1)
x1 = map(PT1, range(x1min, x1max; length = hresolution))
x2 = map(PT2, range(x2min, x2max; length = hresolution))
return Base.Iterators.product((x1, x2))
end

function get_target_hcoords(space::SpectralElementSpace1D; hresolution)
topology = Spaces.topology(space)
mesh = topology.mesh
domain = Meshes.domain(mesh)
PT1 = typeof(domain.interval1.coord_min)
x1min = Geometry.component(domain.interval1.coord_min, 1)
x1max = Geometry.component(domain.interval1.coord_max, 1)
x1 = map(PT1, range(x1min, x1max; length = hresolution))
return x1
end

get_target_zcoords(space; vresolution = 50) = map(
Geometry.ZPoint,
range(z_min(space), z_max(space); length = vresolution),
)

function interpolate(
field::Fields.Field,
space::AbstractSpectralElementSpace;
hresolution,
vresolution,
)
remapper = Remapper(
space;
target_hcoords = get_target_hcoords(space),
target_zcoords = nothing,
)
return interpolate(remapper, field)
end

function interpolate(
field::Fields.Field,
space::AbstractSpace;
hresolution,
vresolution,
)
target_zcoords = get_target_zcoords(space; vresolution)
target_hcoords = get_target_hcoords(space; hresolution)
remapper = Remapper(space; target_hcoords, target_zcoords)
return interpolate(remapper, field)
end

function interpolate(
field::Fields.Field,
::FiniteDifferenceSpace;
hresolution,
vresolution,
)
remapper = Remapper(
axes(field);
target_zcoords = get_target_zcoords(space; vresolution),
)
return interpolate(remapper, field)
end

# dest has to be allowed to be nothing because interpolation happens only on the root
# process
function interpolate!(
Expand Down
12 changes: 9 additions & 3 deletions src/Remapping/remapping_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,13 @@ element in a column, no interpolation is performed and the value at the cell cen
returned. Effectively, this means that the interpolation is first-order accurate across the
column, but zeroth-order accurate close to the boundaries.
"""
function vertical_bounding_indices(space, zcoords) end
function vertical_bounding_indices end

function vertical_bounding_indices(
space::Spaces.FaceExtrudedFiniteDifferenceSpace,
space::Union{
Spaces.FaceExtrudedFiniteDifferenceSpace,
Spaces.FaceFiniteDifferenceSpace,
},
zcoords,
)
vert_topology = Spaces.vertical_topology(space)
Expand All @@ -64,7 +67,10 @@ function vertical_bounding_indices(
end

function vertical_bounding_indices(
space::Spaces.CenterExtrudedFiniteDifferenceSpace,
space::Union{
Spaces.CenterExtrudedFiniteDifferenceSpace,
Spaces.CenterFiniteDifferenceSpace,
},
zcoords,
)
vert_topology = Spaces.vertical_topology(space)
Expand Down
35 changes: 35 additions & 0 deletions test/Remapping/distributed_remapping.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
#=
julia --project
using Revise; include("test/Remapping/distributed_remapping.jl")
=#
using Logging
using Test
using IntervalSets
Expand Down Expand Up @@ -657,3 +661,34 @@ end
@test interp_sin_long dest[:, :, 3]
end
end

@testset "Purely vertical space" begin
vertdomain = Domains.IntervalDomain(
Geometry.ZPoint(0.0),
Geometry.ZPoint(1000.0);
boundary_names = (:bottom, :top),
)

vertmesh = Meshes.IntervalMesh(vertdomain, nelems = 30)
verttopo = Topologies.IntervalTopology(
ClimaComms.SingletonCommsContext(ClimaComms.device()),
vertmesh,
)
cspace = Spaces.CenterFiniteDifferenceSpace(verttopo)
space = Spaces.FaceFiniteDifferenceSpace(cspace)

zpts = range(0.0, 1000.0, 21)
zcoords = [Geometry.ZPoint(z) for z in zpts]
remapper =
Remapping.Remapper(space; target_zcoords = zcoords, buffer_length = 2)

coords = Fields.coordinate_field(space)

interp_z = Remapping.interpolate(remapper, coords.z)
expected_z = zpts
if ClimaComms.iamroot(context)
@test interp_z == collect(expected_z)
end

# TODO: anything more to exercise?
end

0 comments on commit 29616e4

Please sign in to comment.