From ad84ee760a7eff7c4aac80ea7e7d99d5d54ff4c2 Mon Sep 17 00:00:00 2001 From: Charles Kawczynski Date: Fri, 25 Oct 2024 14:55:25 -0400 Subject: [PATCH] Extend Remapper to FiniteDifferenceSpaces, update interpolate interface --- src/Remapping/distributed_remapping.jl | 237 ++++++++++++++++++++++-- src/Remapping/remapping_utils.jl | 12 +- test/Remapping/distributed_remapping.jl | 35 ++++ 3 files changed, 262 insertions(+), 22 deletions(-) diff --git a/src/Remapping/distributed_remapping.jl b/src/Remapping/distributed_remapping.jl index 5f22c9636d..efe96b36c0 100644 --- a/src/Remapping/distributed_remapping.jl +++ b/src/Remapping/distributed_remapping.jl @@ -116,20 +116,24 @@ function target_hcoords_pid_bitmask(target_hcoords, topology, pid) return pid_hcoord.(target_hcoords) .== pid end + +# TODO: define an inner construct and restrict types, as was done in +# https://github.com/CliMA/RRTMGP.jl/pull/352 +# to avoid potential compilation issues. struct Remapper{ CC <: ClimaComms.AbstractCommsContext, SPACE <: Spaces.AbstractSpace, - T1 <: AbstractArray, + T1, # <: Union{AbstractArray, Nothing}, TARG_Z <: Union{Nothing, AA1} where {AA1 <: AbstractArray}, - T3 <: AbstractArray, - T4 <: Tuple, - T5 <: AbstractArray, + T3, # <: Union{AbstractArray, Nothing}, + T4, # <: Union{Tuple, Nothing}, + T5, # <: Union{AbstractArray, Nothing}, VERT_W <: Union{Nothing, AA2} where {AA2 <: AbstractArray}, VERT_IND <: Union{Nothing, AA3} where {AA3 <: AbstractArray}, - T8 <: AbstractArray, - T9 <: AbstractArray, + T8, # <: AbstractArray, + T9, # <: AbstractArray, T10 <: AbstractArray, - T11 <: Union{Tuple{Colon}, Tuple{Colon, Colon}, Tuple{Colon, Colon, Colon}}, + T11, # <: Union{Tuple{Colon}, Tuple{Colon, Colon}, Tuple{Colon, Colon, Colon}}, } comms_ctx::CC @@ -222,11 +226,46 @@ in `interpolate`. When more fields than `buffer_length` are passed, the remapper the work in sizes of `buffer_length`. """ -function Remapper( +function Remapper end + +Remapper( + space::Spaces.AbstractSpace; + target_hcoords::Union{AbstractArray, Nothing} = nothing, + target_zcoords::Union{AbstractArray, Nothing} = nothing, + buffer_length::Int = 1, +) = _Remapper(space; target_zcoords, buffer_length, target_hcoords) + +Remapper( + space::Spaces.FiniteDifferenceSpace; + target_zcoords::AbstractArray, + buffer_length::Int = 1, +) = _Remapper(space; target_zcoords, buffer_length) + +Remapper( space::Spaces.AbstractSpace, target_hcoords::AbstractArray, target_zcoords::Union{AbstractArray, Nothing}; buffer_length::Int = 1, +) = _Remapper(space; target_zcoords, buffer_length, target_hcoords) + +Remapper( + space::Spaces.AbstractSpace, + target_hcoords::AbstractArray; + buffer_length::Int = 1, +) = _Remapper(space; target_zcoords = nothing, target_hcoords, buffer_length) + +# function _Remapper( +# space::Spaces.FiniteDifferenceSpace; +# target_zcoords::AbstractArray, +# target_hcoords::AbstractArray = [], +# buffer_length::Int = 1, +# ) + +function _Remapper( + space::Spaces.AbstractSpace; + target_zcoords::Union{AbstractArray, Nothing}, + target_hcoords::AbstractArray, + buffer_length::Int = 1, ) comms_ctx = ClimaComms.context(space) @@ -367,11 +406,48 @@ function Remapper( ) end -Remapper( - space::Spaces.AbstractSpace, - target_hcoords::AbstractArray; +function _Remapper( + space::Spaces.FiniteDifferenceSpace; + target_zcoords::AbstractArray, buffer_length::Int = 1, -) = Remapper(space, target_hcoords, nothing; buffer_length) +) + + comms_ctx = ClimaComms.context(space) + pid = ClimaComms.mypid(comms_ctx) + FT = Spaces.undertype(space) + ArrayType = ClimaComms.array_type(space) + + # We represent interpolation onto an horizontal slab as an empty list of zcoords + vert_interpolation_weights = + ArrayType(vertical_interpolation_weights(space, target_zcoords)) + vert_bounding_indices = + ArrayType(vertical_bounding_indices(space, target_zcoords)) + + # We have to add one extra dimension with respect to the bitmask/local_horiz_indices + # because we are going to store the values for the columns + local_interpolated_values = + ArrayType(zeros(FT, (length(target_zcoords), buffer_length))) + interpolated_values = + ArrayType(zeros(FT, (length(target_zcoords), buffer_length))) + colons = (:,) + + return Remapper( + comms_ctx, + space, + nothing, # local_target_hcoords, + target_zcoords, + nothing, # local_target_hcoords_bitmask, + nothing, # local_horiz_interpolation_weights, + nothing, # local_horiz_indices, + vert_interpolation_weights, + vert_bounding_indices, + local_interpolated_values, + nothing, # field_values, + interpolated_values, + buffer_length, + colons, + ) +end """ _set_interpolated_values!(remapper, field) @@ -439,6 +515,37 @@ function set_interpolated_values_cpu_kernel!( end end +function set_interpolated_values_cpu_kernel!( + out::AbstractArray, + fields::AbstractArray{<:Fields.Field}, + ::Nothing, + local_horiz_indices, + vert_interpolation_weights, + vert_bounding_indices, + scratch_field_values, +) + space = axes(first(fields)) + FT = Spaces.undertype(space) + for (field_index, field) in enumerate(fields) + field_values = Fields.field_values(field) + + # Reading values from field_values is expensive, so we try to limit the number of reads. We can do + # this because multiple target points might be all contained in the same element. + prev_vindex = -1 + @inbounds for (vindex, (A, B)) in enumerate(vert_interpolation_weights) + (v_lo, v_hi) = vert_bounding_indices[vindex] + # If we are no longer in the same element, read the field values again + if prev_vindex != vindex + out[vindex, field_index] = ( + A * field_values[CartesianIndex(1, 1, 1, v_lo, 1)] + + B * field_values[CartesianIndex(1, 1, 1, v_hi, 1)] + ) + prev_vindex = vindex + end + end + end +end + # CPU, 2D case function set_interpolated_values_cpu_kernel!( out::AbstractArray, @@ -778,12 +885,16 @@ function interpolate(remapper::Remapper, fields) view(fields, index_field_begin:index_field_end), ) # Reshape the output so that it is a nice grid. - _apply_mpi_bitmask!(remapper, num_fields) - # Finally, we have to send all the _interpolated_values to root and sum them up to - # obtain the final answer. Only the root will contain something useful. This also - # moves the data off the GPU - ret = _collect_and_return_interpolated_values!(remapper, num_fields) - return ret + if !(remapper.space isa Spaces.FiniteDifferenceSpace) + _apply_mpi_bitmask!(remapper, num_fields) + # Finally, we have to send all the _interpolated_values to root and sum them up to + # obtain the final answer. Only the root will contain something useful. This also + # moves the data off the GPU + _collect_and_return_interpolated_values!(remapper, num_fields) + else + remapper._interpolated_values .= remapper._local_interpolated_values + end + remapper._interpolated_values end # Non-root processes @@ -794,7 +905,10 @@ function interpolate(remapper::Remapper, fields) end """ - interpolate(field::ClimaCore.Fields, target_hcoords, target_zcoords) + interpolate(field::ClimaCore.Fields; + [target_hcoords::AbstractArray], + [target_zcoords::AbstractArray] + ) Interpolate the given fields on the Cartesian product of `target_hcoords` with `target_zcoords` (if not empty). @@ -819,6 +933,10 @@ hcoords = [Geometry.LatLongPoint(lat, long) for long in longpts, lat in latpts] zcoords = [Geometry.ZPoint(z) for z in zpts] interpolate(field, hcoords, zcoords) + +The `hresolution`, `vresolution` keyword constructors will +interpolate to uniform grids at the given horizontal (`h`) +and vertical (`v`) resolutions, respectively. ``` """ function interpolate(field::Fields.Field, target_hcoords, target_zcoords) @@ -826,6 +944,87 @@ function interpolate(field::Fields.Field, target_hcoords, target_zcoords) return interpolate(remapper, field) end +interpolate( + field::Fields.Field; + vresolution = 50, + hresolution = 100, + target_hcoords = get_target_hcoords(axes(field), hresolution), + target_zcoords = get_target_zcoords(axes(field), vresolution), +) = interpolate(field, axes(field); hresolution, vresolution) + +get_target_hcoords(space::AbstractSpace; hresolution) = + get_target_hcoords(Spaces.horizontal_space(space), hresolution) + +function get_target_hcoords(space::SpectralElementSpace2D; hresolution) + topology = Spaces.topology(space) + mesh = topology.mesh + domain = Meshes.domain(mesh) + PT1 = typeof(domain.interval1.coord_min) + PT2 = typeof(domain.interval2.coord_min) + x1min = Geometry.component(domain.interval1.coord_min, 1) + x2min = Geometry.component(domain.interval2.coord_min, 1) + x1max = Geometry.component(domain.interval1.coord_max, 1) + x2max = Geometry.component(domain.interval2.coord_max, 1) + x1 = map(PT1, range(x1min, x1max; length = hresolution)) + x2 = map(PT2, range(x2min, x2max; length = hresolution)) + return Base.Iterators.product((x1, x2)) +end + +function get_target_hcoords(space::SpectralElementSpace1D; hresolution) + topology = Spaces.topology(space) + mesh = topology.mesh + domain = Meshes.domain(mesh) + PT1 = typeof(domain.interval1.coord_min) + x1min = Geometry.component(domain.interval1.coord_min, 1) + x1max = Geometry.component(domain.interval1.coord_max, 1) + x1 = map(PT1, range(x1min, x1max; length = hresolution)) + return x1 +end + +get_target_zcoords(space; vresolution = 50) = map( + Geometry.ZPoint, + range(z_min(space), z_max(space); length = vresolution), +) + +function interpolate( + field::Fields.Field, + space::AbstractSpectralElementSpace; + hresolution, + vresolution, +) + remapper = Remapper( + space; + target_hcoords = get_target_hcoords(space), + target_zcoords = nothing, + ) + return interpolate(remapper, field) +end + +function interpolate( + field::Fields.Field, + space::AbstractSpace; + hresolution, + vresolution, +) + target_zcoords = get_target_zcoords(space; vresolution) + target_hcoords = get_target_hcoords(space; hresolution) + remapper = Remapper(space; target_hcoords, target_zcoords) + return interpolate(remapper, field) +end + +function interpolate( + field::Fields.Field, + ::FiniteDifferenceSpace; + hresolution, + vresolution, +) + remapper = Remapper( + axes(field); + target_zcoords = get_target_zcoords(space; vresolution), + ) + return interpolate(remapper, field) +end + # dest has to be allowed to be nothing because interpolation happens only on the root # process function interpolate!( diff --git a/src/Remapping/remapping_utils.jl b/src/Remapping/remapping_utils.jl index 84122836d5..ec88c3019f 100644 --- a/src/Remapping/remapping_utils.jl +++ b/src/Remapping/remapping_utils.jl @@ -51,10 +51,13 @@ element in a column, no interpolation is performed and the value at the cell cen returned. Effectively, this means that the interpolation is first-order accurate across the column, but zeroth-order accurate close to the boundaries. """ -function vertical_bounding_indices(space, zcoords) end +function vertical_bounding_indices end function vertical_bounding_indices( - space::Spaces.FaceExtrudedFiniteDifferenceSpace, + space::Union{ + Spaces.FaceExtrudedFiniteDifferenceSpace, + Spaces.FaceFiniteDifferenceSpace, + }, zcoords, ) vert_topology = Spaces.vertical_topology(space) @@ -64,7 +67,10 @@ function vertical_bounding_indices( end function vertical_bounding_indices( - space::Spaces.CenterExtrudedFiniteDifferenceSpace, + space::Union{ + Spaces.CenterExtrudedFiniteDifferenceSpace, + Spaces.CenterFiniteDifferenceSpace, + }, zcoords, ) vert_topology = Spaces.vertical_topology(space) diff --git a/test/Remapping/distributed_remapping.jl b/test/Remapping/distributed_remapping.jl index 0748934dc4..cf8a58809b 100644 --- a/test/Remapping/distributed_remapping.jl +++ b/test/Remapping/distributed_remapping.jl @@ -1,3 +1,7 @@ +#= +julia --project +using Revise; include("test/Remapping/distributed_remapping.jl") +=# using Logging using Test using IntervalSets @@ -657,3 +661,34 @@ end @test interp_sin_long ≈ dest[:, :, 3] end end + +@testset "Purely vertical space" begin + vertdomain = Domains.IntervalDomain( + Geometry.ZPoint(0.0), + Geometry.ZPoint(1000.0); + boundary_names = (:bottom, :top), + ) + + vertmesh = Meshes.IntervalMesh(vertdomain, nelems = 30) + verttopo = Topologies.IntervalTopology( + ClimaComms.SingletonCommsContext(ClimaComms.device()), + vertmesh, + ) + cspace = Spaces.CenterFiniteDifferenceSpace(verttopo) + space = Spaces.FaceFiniteDifferenceSpace(cspace) + + zpts = range(0.0, 1000.0, 21) + zcoords = [Geometry.ZPoint(z) for z in zpts] + remapper = + Remapping.Remapper(space; target_zcoords = zcoords, buffer_length = 2) + + coords = Fields.coordinate_field(space) + + interp_z = Remapping.interpolate(remapper, coords.z) + expected_z = zpts + if ClimaComms.iamroot(context) + @test interp_z == collect(expected_z) + end + + # TODO: anything more to exercise? +end