Skip to content

Commit

Permalink
work in progress
Browse files Browse the repository at this point in the history
[skip ci]
  • Loading branch information
Sbozzolo committed Oct 29, 2024
1 parent 1504d19 commit 368ac55
Show file tree
Hide file tree
Showing 5 changed files with 914 additions and 731 deletions.
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ main

- Fixed world-age issue on Julia 1.11 issue [Julia#54780](https://github.com/JuliaLang/julia/issues/54780), PR [#2034](https://github.com/CliMA/ClimaCore.jl/pull/2034).

### ![][badge-✨feature/enhancement] Various improvements to `Remapper`



v0.14.19
-------

Expand Down
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ withenv("GKSwstype" => "nul") do
"Installation and How-to Guides" => "installation_instructions.md",
"Geometry" => "geometry.md",
"Operators" => "operators.md",
"Remapping" => "remapping.md",
"MatrixFields" => "matrix_fields.md",
"API" => "api.md",
"Developer docs" => ["Performance tips" => "performance_tips.md"],
Expand Down
119 changes: 119 additions & 0 deletions docs/src/remapping.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Remapping to regular grids

`ClimaCore` horizontal domains are spectral elements.


### Non-conservative remapping

# Remapper

## Constructors

```julia
Remapper(space, target_hcoords, target_zcoords, buffer_length = 1)
Remapper(space; target_hcoords, target_zcoords, buffer_length = 1)
Remapper(space, target_hcoords; buffer_length = 1)
Remapper(space, target_zcoords; buffer_length = 1)
```

Return a `Remapper` responsible for interpolating any `Field` defined on the given `space` to the Cartesian product of `target_hcoords` with `target_zcoords`.

`target_zcoords` can be `nothing` for interpolation on horizontal spaces. Similarly, `target_hcoords` can be `nothing` for interpolation on vertical spaces.

The `Remapper` is designed to not be tied to any particular `Field`. You can use the same `Remapper` for any `Field` as long as they are all defined on the same `topology`.

`Remapper` is the main argument to the `interpolate` function.

### Keyword arguments

`buffer_length` is size of the internal buffer in the Remapper to store intermediate values for interpolation. Effectively, this controls how many fields can be remapped simultaneously in `interpolate`. When more fields than `buffer_length` are passed, the remapper will batch the work in sizes of `buffer_length`.

## Interpolation

```julia
interpolate(remapper::Remapper, fields)
interpolate!(dest, remapper::Remapper, fields)
```

Interpolate the given `field`(s) as prescribed by `remapper`.

The optimal number of fields passed is the `buffer_length` of the `remapper`. If more fields are passed, the `remapper` will batch work with size up to its `buffer_length`.

This call mutates the internal (private) state of the `remapper`.

Horizontally, interpolation is performed with the barycentric formula in [Berrut2004], equation (3.2). Vertical interpolation is linear except in the boundary elements where it is 0th order.

`interpolate!` writes the output to the given `dest`ination. `dest` is expected to be defined on the root process and to be `nothing` for the other processes.

**Note:** `interpolate` allocates new arrays and has some internal type-instability, `interpolate!` is non-allocating and type-stable.

When using `interpolate!`, the `dest`ination has to be the same array type as the device in use (e.g., `CuArray` for CUDA runs).

### Example

Given `field1`,`field2`, two `Field` defined on a cubed sphere.

```julia
longpts = range(-180.0, 180.0, 21)
latpts = range(-80.0, 80.0, 21)
zpts = range(0.0, 1000.0, 21)

hcoords = [Geometry.LatLongPoint(lat, long) for long in longpts, lat in latpts]
zcoords = [Geometry.ZPoint(z) for z in zpts]

space = axes(field1)

remapper = Remapper(space, hcoords, zcoords)

int1 = interpolate(remapper, field1)
int2 = interpolate(remapper, field2)

# Or
int12 = interpolate(remapper, [field1, field2])
# With int1 = int12[1, :, :, :]
```

## Convenience Interpolation

```julia
interpolate(field::ClimaCore.Fields;
hresolution = 180,
resolution = 50,
target_hcoords = default_target_hcoords(space; hresolution),
target_zcoords = default_target_vcoords(space; vresolution)
)
```

Interpolate the given fields on the Cartesian product of `target_hcoords` with `target_zcoords` (if not empty).

Coordinates have to be `ClimaCore.Geometry.Points`.

**Note:** do not use this method when performance is important. Instead, define a `Remapper` and call `interpolate(remapper, fields)`. Different `Field`s defined on the same `Space` can share a `Remapper`, so that interpolation can be optimized.

### Example

Given `field`, a `Field` defined on a cubed sphere.

By default, a target uniform grid is chosen (with resolution `hresolution` and `vresolution`), so remapping is simply

```julia
julia> interpolate(field, hcoords, zcoords)
```

Coordinates can be specified:

```julia
julia> longpts = range(-180.0, 180.0, 21)
julia> latpts = range(-80.0, 80.0, 21)
julia> zpts = range(0.0, 1000.0, 21)

julia> hcoords = [Geometry.LatLongPoint(lat, long) for long in longpts, lat in latpts]
julia> zcoords = [Geometry.ZPoint(z) for z in zpts]

julia> interpolate(field, hcoords, zcoords)
```
```
### Conservative remapping with `TempestRemap`
This section hasn't been written yet. You can help by writing it.
168 changes: 87 additions & 81 deletions src/Remapping/distributed_remapping.jl
Original file line number Diff line number Diff line change
Expand Up @@ -885,6 +885,9 @@ function interpolate(remapper::Remapper, fields)
if !isa_vertical_space
# For spaces with an horizontal component, reshape the output so that it is a nice grid.
_apply_mpi_bitmask!(remapper, num_fields)
else
# For purely vertical spaces, just move to _interpolated_values
remapper._interpolated_values .= remapper._local_interpolated_values
end

# Finally, we have to send all the _interpolated_values to root and sum them up to
Expand All @@ -899,12 +902,83 @@ function interpolate(remapper::Remapper, fields)
interpolated_values
end

# dest has to be allowed to be nothing because interpolation happens only on the root
# process
function interpolate!(
dest::Union{Nothing, <:AbstractArray},
remapper::Remapper,
fields,
)
only_one_field = fields isa Fields.Field
if only_one_field
fields = [fields]
end
isa_vertical_space = remapper.space isa Spaces.FiniteDifferenceSpace

if !isnothing(dest)
# !isnothing(dest) means that this is the root process, in this case, the size have
# to match (ignoring the buffer_length)
dest_size = only_one_field ? size(dest) : size(dest)[1:(end - 1)]

dest_size == size(remapper._interpolated_values)[1:(end - 1)] || error(
"Destination array is not compatible with remapper (size mismatch)",
)

expected_array_type =
ClimaComms.array_type(ClimaComms.device(remapper.comms_ctx))

found_type = nameof(typeof(dest))

dest isa expected_array_type ||
error("dest is a $found_type, expected $expected_array_type")
end
index_field_begin, index_field_end =
1, min(length(fields), remapper.buffer_length)

while true
num_fields = 1 + index_field_end - index_field_begin

# Reset interpolated_values. This is needed because we collect distributed results
# with a + reduction.
_reset_interpolated_values!(remapper)
# Perform the interpolations (horizontal and vertical)
_set_interpolated_values!(
remapper,
view(fields, index_field_begin:index_field_end),
)

if !isa_vertical_space
# For spaces with an horizontal component, reshape the output so that it is a nice grid.
_apply_mpi_bitmask!(remapper, num_fields)
else
# For purely vertical spaces, just move to _interpolated_values
remapper._interpolated_values .= remapper._local_interpolated_values
end

# Finally, we have to send all the _interpolated_values to root and sum them up to
# obtain the final answer.
_collect_interpolated_values!(
dest,
remapper,
index_field_begin,
index_field_end;
only_one_field,
)

index_field_end != length(fields) || break
index_field_begin = index_field_begin + remapper.buffer_length
index_field_end =
min(length(fields), index_field_end + remapper.buffer_length)
end
return nothing
end

"""
interpolate(field::ClimaCore.Fields;
hresolution = 180,
resolution = 50,
target_hcoords = get_target_hcoords(space; hresolution),
target_zcoords = get_target_cords(space; vresolution)
target_hcoords = default_target_hcoords(space; hresolution),
target_zcoords = default_target_vcoords(space; vresolution)
)
Interpolate the given fields on the Cartesian product of `target_hcoords` with
Expand Down Expand Up @@ -943,8 +1017,8 @@ function interpolate(
field::Fields.Field;
vresolution = 50,
hresolution = 100,
target_hcoords = get_target_hcoords(axes(field); hresolution),
target_zcoords = get_target_zcoords(axes(field); vresolution),
target_hcoords = default_target_hcoords(axes(field); hresolution),
target_zcoords = default_target_zcoords(axes(field); vresolution),
)
return interpolate(field, axes(field); hresolution, vresolution)
end
Expand All @@ -954,29 +1028,29 @@ function interpolate(field::Fields.Field, target_hcoords, target_zcoords)
return interpolate(remapper, field)
end

function get_target_hcoords(space::Spaces.AbstractSpace; hresolution)
return get_target_hcoords(Spaces.horizontal_space(space); hresolution)
function default_target_hcoords(space::Spaces.AbstractSpace; hresolution)
return default_target_hcoords(Spaces.horizontal_space(space); hresolution)
end

function get_target_hcoords(
function default_target_hcoords(
space::Spaces.SpectralElementSpace2D;
hresolution = 180,
)
topology = Spaces.topology(space)
mesh = topology.mesh
domain = Meshes.domain(mesh)
PT1 = typeof(domain.interval1.coord_min)
PT2 = typeof(domain.interval2.coord_min)
PointType1 = typeof(domain.interval1.coord_min)
PointType2 = typeof(domain.interval2.coord_min)
x1min = Geometry.component(domain.interval1.coord_min, 1)
x2min = Geometry.component(domain.interval2.coord_min, 1)
x1max = Geometry.component(domain.interval1.coord_max, 1)
x2max = Geometry.component(domain.interval2.coord_max, 1)
x1 = map(PT1, range(x1min, x1max; length = hresolution))
x2 = map(PT2, range(x2min, x2max; length = hresolution))
x1 = map(PointType1, range(x1min, x1max; length = hresolution))
x2 = map(PointType2, range(x2min, x2max; length = hresolution))
return Base.Iterators.product((x1, x2))
end

function get_target_hcoords(space::Spaces.SpectralElementSpace1D; hresolution = 180)
function default_target_hcoords(space::Spaces.SpectralElementSpace1D; hresolution = 180)
topology = Spaces.topology(space)
mesh = topology.mesh
domain = Meshes.domain(mesh)
Expand All @@ -986,76 +1060,8 @@ function get_target_hcoords(space::Spaces.SpectralElementSpace1D; hresolution =
return PointType.(range(x1min, x1max; length = hresolution))
end

function get_target_zcoords(space; vresolution = 50)
function default_target_zcoords(space; vresolution = 50)
return Geometry.ZPoint.(
range(z_min(space), z_max(space); length = vresolution)
)
end

# dest has to be allowed to be nothing because interpolation happens only on the root
# process
function interpolate!(
dest::Union{Nothing, <:AbstractArray},
remapper::Remapper,
fields,
)
only_one_field = fields isa Fields.Field
if only_one_field
fields = [fields]
end
isa_vertical_space = remapper.space isa Spaces.FiniteDifferenceSpace

if !isnothing(dest)
# !isnothing(dest) means that this is the root process, in this case, the size have
# to match (ignoring the buffer_length)
dest_size = only_one_field ? size(dest) : size(dest)[1:(end - 1)]

dest_size == size(remapper._interpolated_values)[1:(end - 1)] || error(
"Destination array is not compatible with remapper (size mismatch)",
)

expected_array_type =
ClimaComms.array_type(ClimaComms.device(remapper.comms_ctx))

found_type = nameof(typeof(dest))

dest isa expected_array_type ||
error("dest is a $found_type, expected $expected_array_type")
end
index_field_begin, index_field_end =
1, min(length(fields), remapper.buffer_length)

while true
num_fields = 1 + index_field_end - index_field_begin

# Reset interpolated_values. This is needed because we collect distributed results
# with a + reduction.
_reset_interpolated_values!(remapper)
# Perform the interpolations (horizontal and vertical)
_set_interpolated_values!(
remapper,
view(fields, index_field_begin:index_field_end),
)

if !isa_vertical_space
# For spaces with an horizontal component, reshape the output so that it is a nice grid.
_apply_mpi_bitmask!(remapper, num_fields)
end

# Finally, we have to send all the _interpolated_values to root and sum them up to
# obtain the final answer.
_collect_interpolated_values!(
dest,
remapper,
index_field_begin,
index_field_end;
only_one_field,
)

index_field_end != length(fields) || break
index_field_begin = index_field_begin + remapper.buffer_length
index_field_end =
min(length(fields), index_field_end + remapper.buffer_length)
end
return nothing
end
Loading

0 comments on commit 368ac55

Please sign in to comment.