Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forward solver adapted for automatic differentiation. #124

Merged
merged 13 commits into from
Sep 15, 2024
10 changes: 5 additions & 5 deletions .github/workflows/python-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,28 @@ jobs:
- uses: actions/checkout@v3
- name: Running serial tests
run: |
source /home/olender/firedrakes/2024_07_19/firedrake/bin/activate
source /home/olender/firedrakes/2024_09_11/firedrake/bin/activate
pytest --cov-report=xml --cov=spyro test/
- name: Running parallel 3D forward test
run: |
source /home/olender/firedrakes/2024_07_19/firedrake/bin/activate
source /home/olender/firedrakes/2024_09_11/firedrake/bin/activate
mpiexec -n 6 pytest test_3d/test_hexahedral_convergence.py
mpiexec -n 6 pytest test_parallel/test_forward.py
mpiexec -n 6 pytest test_parallel/test_fwi.py
- name: Covering parallel 3D forward test
continue-on-error: true
run: |
source /home/olender/firedrakes/2024_07_19/firedrake/bin/activate
source /home/olender/firedrakes/2024_09_11/firedrake/bin/activate
mpiexec -n 6 pytest --cov-report=xml --cov-append --cov=spyro test_3d/test_hexahedral_convergence.py
- name: Covering parallel forward test
continue-on-error: true
run: |
source /home/olender/firedrakes/2024_07_19/firedrake/bin/activate
source /home/olender/firedrakes/2024_09_11/firedrake/bin/activate
mpiexec -n 6 pytest --cov-report=xml --cov-append --cov=spyro test_parallel/test_forward.py
- name: Covering parallel fwi test
continue-on-error: true
run: |
source /home/olender/firedrakes/2024_07_19/firedrake/bin/activate
source /home/olender/firedrakes/2024_09_11/firedrake/bin/activate
mpiexec -n 6 pytest --cov-report=xml --cov-append --cov=spyro test_parallel/test_fwi.py
- name: Uploading coverage to Codecov
run: export CODECOV_TOKEN="6cd21147-54f7-4b77-94ad-4b138053401d" && bash <(curl -s https://codecov.io/bash)
Expand Down
File renamed without changes.
46 changes: 46 additions & 0 deletions demos/with_automatic_differentiation/run_forward_ad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import firedrake as fire
import spyro
from demos.with_automatic_differentiation.utils import \
model_settings, make_c_camembert
import os
os.environ["OMP_NUM_THREADS"] = "1"

# --- Basid setup to run a forward simulation with AD --- #

model = model_settings()

# Use emsemble parallelism.
M = model["parallelism"]["num_spacial_cores"]
my_ensemble = fire.Ensemble(fire.COMM_WORLD, M)
mesh = fire.UnitSquareMesh(50, 50, comm=my_ensemble.comm)
element = fire.FiniteElement(
model["opts"]["method"], mesh.ufl_cell(), degree=model["opts"]["degree"],
variant=model["opts"]["quadrature"]
)
V = fire.FunctionSpace(mesh, element)


forward_solver = spyro.solvers.forward_ad.ForwardSolver(model, mesh, V)

c_true = make_c_camembert(mesh, V)
# Ricker wavelet
wavelet = spyro.full_ricker_wavelet(
model["timeaxis"]["dt"], model["timeaxis"]["tf"],
model["acquisition"]["frequency"],
)

if model["parallelism"]["type"] is None:
outfile = fire.VTKFile("solution.pvd")
for sn in range(len(model["acquisition"]["source_pos"])):
rec_data, _ = forward_solver.execute(c_true, sn, wavelet)
sol = forward_solver.solution
outfile.write(sol)
else:
# source_number based on the ensemble.ensemble_comm.rank
source_number = my_ensemble.ensemble_comm.rank
rec_data, _ = forward_solver.execute_acoustic(
c_true, source_number, wavelet)
sol = forward_solver.solution
fire.VTKFile(
"solution_" + str(source_number) + ".pvd", comm=my_ensemble.comm
).write(sol)
112 changes: 112 additions & 0 deletions demos/with_automatic_differentiation/run_fwi_ad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import firedrake as fire
import firedrake.adjoint as fire_ad
from checkpoint_schedules import Revolve
import spyro
from demos.with_automatic_differentiation import utils
import os
os.environ["OMP_NUM_THREADS"] = "1"

# --- Basid setup to run a FWI --- #
model = utils.model_settings()


def forward(
c, compute_functional=False, true_data_receivers=None, annotate=False
):
"""Time-stepping acoustic forward solver.

The time integration is done using a central difference scheme.

Parameters
----------
c : firedrake.Function
Velocity field.
compute_functional : bool, optional
Whether to compute the functional. If True, the true receiver
data must be provided.
true_data_receivers : list, optional
True receiver data. This is used to compute the functional.
annotate : bool, optional
If True, the forward model is annotated for automatic differentiation.

Returns
-------
(receiver_data : list, J_val : float)
Receiver data and functional value.
"""
if annotate:
fire_ad.continue_annotation()
if model["aut_dif"]["checkpointing"]:
total_steps = int(model["timeaxis"]["tf"] / model["timeaxis"]["dt"])
steps_store = int(total_steps / 10) # Store 10% of the steps.
tape = fire_ad.get_working_tape()
tape.progress_bar = fire.ProgressBar
tape.enable_checkpointing(Revolve(total_steps, steps_store))

if model["parallelism"]["type"] is None:
outfile = fire.VTKFile("solution.pvd")
receiver_data = []
J = 0.0
for sn in range(len(model["acquisition"]["source_pos"])):
rec_data, J_val = forward_solver.execute_acoustic(c, sn, wavelet)
receiver_data.append(rec_data)
J += J_val
sol = forward_solver.solution
outfile.write(sol)

else:
# source_number based on the ensemble.ensemble_comm.rank
source_number = my_ensemble.ensemble_comm.rank
receiver_data, J = forward_solver.execute_acoustic(
c, source_number, wavelet,
compute_functional=compute_functional,
true_data_receivers=true_data_receivers
)
sol = forward_solver.solution
fire.VTKFile(
"solution_" + str(source_number) + ".pvd", comm=my_ensemble.comm
).write(sol)

return receiver_data, J


# Use emsemble parallelism.
M = model["parallelism"]["num_spacial_cores"]
my_ensemble = fire.Ensemble(fire.COMM_WORLD, M)
mesh = fire.UnitSquareMesh(50, 50, comm=my_ensemble.comm)
element = fire.FiniteElement(
model["opts"]["method"], mesh.ufl_cell(), degree=model["opts"]["degree"],
variant=model["opts"]["quadrature"]
)
V = fire.FunctionSpace(mesh, element)


forward_solver = spyro.solvers.forward_ad.ForwardSolver(model, mesh, V)
# Camembert model.
c_true = utils.make_c_camembert(mesh, V)
# Ricker wavelet
wavelet = spyro.full_ricker_wavelet(
model["timeaxis"]["dt"], model["timeaxis"]["tf"],
model["acquisition"]["frequency"],
)

true_rec, _ = forward(c_true)

# --- FWI with AD --- #
c_guess = utils.make_c_camembert(mesh, V, c_guess=True)
guess_rec, J = forward(
c_guess, compute_functional=True, true_data_receivers=true_rec,
annotate=True
)

# :class:`~.EnsembleReducedFunctional` is employed to recompute in
# parallel the functional and its gradient associated with the multiple sources
# (3 in this case).
J_hat = fire_ad.EnsembleReducedFunctional(
J, fire_ad.Control(c_guess), my_ensemble)
c_optimised = fire_ad.minimize(J_hat, method="L-BFGS-B",
options={"disp": True, "maxiter": 10},
bounds=(1.5, 3.5),
derivative_options={"riesz_representation": 'l2'})

fire.VTKFile("c_optimised.pvd").write(c_optimised)
104 changes: 104 additions & 0 deletions demos/with_automatic_differentiation/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# --- Basid setup to run a forward simulation with AD --- #
import firedrake as fire
import spyro

def model_settings():
"""Model settings for forward and Full Waveform Inversion (FWI)
simulations.

Returns
-------
model : dict
Dictionary containing the model settings.
"""

model = {}

model["opts"] = {
"method": "KMV", # either CG or mass_lumped_triangle
"quadrature": "KMV", # Equi or mass_lumped_triangle
"degree": 1, # p order
"dimension": 2, # dimension
"regularization": False, # regularization is on?
"gamma": 1e-5, # regularization parameter
}

model["parallelism"] = {
# options:
# `shots_parallelism`. Shots parallelism.
# None - no shots parallelism.
"type": "shots_parallelism",
"num_spacial_cores": 1, # Number of cores to use in the spatial
# parallelism.
}

# Define the domain size without the ABL.
model["mesh"] = {
"Lz": 1.0, # depth in km - always positive
"Lx": 1.0, # width in km - always positive
"Ly": 0.0, # thickness in km - always positive
"meshfile": "not_used.msh",
"initmodel": "not_used.hdf5",
"truemodel": "not_used.hdf5",
}

# Specify a 250-m Absorbing Boundary Layer (ABL) on the three sides of the domain to damp outgoing waves.
model["BCs"] = {
"status": False, # True or False, used to turn on any type of BC
"outer_bc": "non-reflective", # none or non-reflective (outer boundary condition)
"abl_bc": "none", # none, gaussian-taper, or alid
"lz": 0.0, # thickness of the ABL in the z-direction (km) - always positive
"lx": 0.0, # thickness of the ABL in the x-direction (km) - always positive
"ly": 0.0, # thickness of the ABL in the y-direction (km) - always positive
}

model["acquisition"] = {
"source_type": "Ricker",
"source_pos": spyro.create_transect((0.2, 0.15), (0.8, 0.15), 3),
"frequency": 7.0,
"delay": 1.0,
"receiver_locations": spyro.create_transect((0.2, 0.2), (0.8, 0.2), 10),
}
model["aut_dif"] = {
"status": True,
"checkpointing": True,
}

model["timeaxis"] = {
"t0": 0.0, # Initial time for event
"tf": 0.8, # Final time for event (for test 7)
"dt": 0.001, # timestep size (divided by 2 in the test 4. dt for test 3 is 0.00050)
"amplitude": 1, # the Ricker has an amplitude of 1.
"nspool": 20, # (20 for dt=0.00050) how frequently to output solution to pvds
"fspool": 1, # how frequently to save solution to RAM
}

return model


def make_c_camembert(mesh, function_space, c_guess=False, plot_c=False):
"""Acoustic velocity model.

Parameters
----------
mesh : firedrake.Mesh
Mesh.
function_space : firedrake.FunctionSpace
Function space.
c_guess : bool, optional
If True, the initial guess for the velocity field is returned.
plot_c : bool, optional
If True, the velocity field is saved to a VTK file.
"""
x, z = fire.SpatialCoordinate(mesh)
if c_guess:
c = fire.Function(function_space).interpolate(1.5 + 0.0 * x)
else:
c = fire.Function(function_space).interpolate(
2.5
+ 1 * fire.tanh(100 * (0.125 - fire.sqrt((x - 0.5) ** 2 + (z - 0.5) ** 2)))
)
if plot_c:
outfile = fire.VTKFile("acoustic_cp.pvd")
outfile.write(c)
return c
2 changes: 2 additions & 0 deletions spyro/solvers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from .acoustic_wave import AcousticWave
from .mms_acoustic import AcousticWaveMMS
from .inversion import FullWaveformInversion
from .forward_ad import ForwardSolver

__all__ = [
"Wave",
"AcousticWave",
"AcousticWaveMMS",
"FullWaveformInversion",
"ForwardSolver",
]
Loading
Loading