Skip to content

Commit

Permalink
Move to forward solver class; use ensemble parallelisl
Browse files Browse the repository at this point in the history
  • Loading branch information
Ig-dolci committed Sep 3, 2024
1 parent 011f23f commit 5a76631
Show file tree
Hide file tree
Showing 5 changed files with 180 additions and 265 deletions.
83 changes: 26 additions & 57 deletions demos/with_automatic_differentiation/run_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import matplotlib.pyplot as plt
import numpy as np

import spyro.solvers

# --- Basid setup to run a forward simulation with AD --- #
model = {}

Expand All @@ -17,13 +19,10 @@

model["parallelism"] = {
# options:
# `shots_parallelism` (same number of cores for every processor. Apply only
# shots parallelism, i.e., the spatial domain is not parallelised.)
# `automatic` (same number of cores for every processor. Apply shots and
# spatial parallelism.)
# `spatial` (Only spatial parallelisation).
# `None` (No parallelisation).
# `shots_parallelism`. Shots parallelism.
# None - no shots parallelism.
"type": "shots_parallelism",
"num_spacial_cores": 1, # Number of cores to use in the spatial parallelism.
}

# Define the domain size without the ABL.
Expand Down Expand Up @@ -66,16 +65,11 @@
"fspool": 1, # how frequently to save solution to RAM
}

comm, spatial_comm = spyro.utils.mpi_init(model)
if model["parallelism"]["type"] == "shots_parallelism":
# Only shots parallelism.
mesh = UnitSquareMesh(50, 50, comm=spatial_comm)
else:
mesh = UnitSquareMesh(50, 50)

# Receiver mesh.
vom = VertexOnlyMesh(mesh, model["acquisition"]["receiver_locations"])

# Use emsemble parallelism.
M = model["parallelism"]["num_spacial_cores"]
my_ensemble = Ensemble(COMM_WORLD, M)
mesh = UnitSquareMesh(50, 50, comm=my_ensemble.comm)
element = spyro.domains.space.FE_method(
mesh, model["opts"]["method"], model["opts"]["degree"]
)
Expand All @@ -98,52 +92,27 @@ def make_vp_circle(vp_guess=False, plot_vp=False):
return vp


def run_forward(source_number):
"""Execute a acoustic wave equation.
Parameters
----------
source_number: `int`, optional
The source number defined by the user.
Notes
-----
The forward solver (`forward_AD`) is implemented in spyro using firedrake's
functions that can be annotated by the algorithimic differentiation (AD).
This is because spyro is capable of executing Full Waveform Inversion (FWI),
which needs the computation of the gradient of the objective function with
respect to the velocity model through (AD).
"""
receiver_data = spyro.solvers.forward_AD(model, mesh, comm, vp_exact,
wavelet, vom, debug=True,
source_number=source_number)
# --- Plot the receiver data --- #
data = []
for _, rec in enumerate(receiver_data):
data.append(rec.dat.data_ro[:])
spyro.plots.plot_shots(model, comm, data, vmax=1e-08, vmin=-1e-08)


# Rickers wavelet
forward_solver = spyro.solvers.forward_ad.ForwardSolver(
model, mesh
)

c_true = make_vp_circle()
# Ricker wavelet
wavelet = spyro.full_ricker_wavelet(
dt=model["timeaxis"]["dt"],
tf=model["timeaxis"]["tf"],
freq=model["acquisition"]["frequency"],
)
# True acoustic velocity model
vp_exact = make_vp_circle(plot_vp=True)

# Processor number.
rank = comm.ensemble_comm.rank
# Number of processors used in the simulation.
size = comm.ensemble_comm.size
if size == 1:

if model["parallelism"]["type"] is None:
for sn in range(len(model["acquisition"]["source_pos"])):
run_forward(sn)
elif size == len(model["acquisition"]["source_pos"]):
# Only run the forward simulation for the source number that matches the
# processor number.
run_forward(rank)
rec_data, _ = forward_solver.execute(c_true, sn, wavelet)
spyro.plots.plot_shots(
model, my_ensemble.comm, rec_data, vmax=1e-08, vmin=-1e-08)
else:
raise NotImplementedError("`size` must be 1 or equal to `num_sources`."
"Different values are not supported yet.")
# source_number based on the ensemble.ensemble_comm.rank
source_number = my_ensemble.ensemble_comm.rank
rec_data, _ = forward_solver.execute(
c_true, source_number, wavelet)
spyro.plots.plot_shots(
model, my_ensemble.comm, rec_data, vmax=1e-08, vmin=-1e-08)
20 changes: 12 additions & 8 deletions demos/with_automatic_differentiation/run_fwi.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@

model["acquisition"] = {
"source_type": "Ricker",
"source_pos": spyro.create_transect((0.3, 0.15), (0.7, 0.15), 5),
"source_pos": spyro.create_transect((0.3, 0.15), (0.7, 0.15), 1),
"frequency": 7.0,
"delay": 1.0,
"receiver_locations": spyro.create_transect((0.2, 0.8), (0.8, 0.8), 10),
Expand All @@ -67,8 +67,8 @@

model["timeaxis"] = {
"t0": 0.0, # Initial time for event
"tf": 1.0, # Final time for event (for test 7)
"dt": 0.001, # timestep size (divided by 2 in the test 4. dt for test 3 is 0.00050)
"tf": 0.8, # Final time for event (for test 7)
"dt": 0.002, # timestep size (divided by 2 in the test 4. dt for test 3 is 0.00050)
"amplitude": 1, # the Ricker has an amplitude of 1.
"nspool": 20, # (20 for dt=0.00050) how frequently to output solution to pvds
"fspool": 1, # how frequently to save solution to RAM
Expand All @@ -89,6 +89,8 @@ def make_vp_circle(vp_guess=False, plot_vp=False):
outfile = File("acoustic_cp.pvd")
outfile.write(vp)
return vp


true_receiver_data = []
iterations = 0
nt = int(model["timeaxis"]["tf"] / model["timeaxis"]["dt"]) # number of timesteps
Expand Down Expand Up @@ -123,9 +125,11 @@ def run_fwi(vp_guess_data):
continue_annotation()
tape = get_working_tape()
tape.progress_bar = ProgressBar
get_working_tape().enable_checkpointing(Revolve(nt, nt//4))
J_total += J(mesh, comm, vp_exact, wavelet, vom, sn, vp_guess)
dJ_total += compute_gradient(J_total, Control(vp_guess))
get_working_tape().enable_checkpointing(SingleMemoryStorageSchedule())
Js = J(mesh, comm, vp_exact, wavelet, vom, sn, vp_guess)
print(Js)
dJ_total += compute_gradient(Js, Control(vp_guess))
J_total += Js
get_working_tape().clear_tape()
elif size == len(model["acquisition"]["source_pos"]):
J_local = J(mesh, comm, vp_exact, wavelet, vom, rank, vp_guess)
Expand All @@ -142,9 +146,9 @@ def run_fwi(vp_guess_data):
comm, spatial_comm = spyro.utils.mpi_init(model)
if model["parallelism"]["type"] == "shots_parallelism":
# Only shots parallelism.
mesh = UnitSquareMesh(100, 100, comm=spatial_comm)
mesh = UnitSquareMesh(60, 60, comm=spatial_comm)
else:
mesh = UnitSquareMesh(100, 100)
mesh = UnitSquareMesh(60, 60)

element = spyro.domains.space.FE_method(
mesh, model["opts"]["method"], model["opts"]["degree"]
Expand Down
6 changes: 3 additions & 3 deletions spyro/solvers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from .forward import forward
from .forward_AD import forward as forward_AD
# from .forward import forward
from .forward_ad import ForwardSolver
from .gradient import gradient

__all__ = [
"forward", # forward solver adapted for discrete adjoint
"forward_AD", # forward solver adapted for Automatic Differentiation
"ForwardSolver", # forward solver adapted for Automatic Differentiation
"gradient",
]
Loading

0 comments on commit 5a76631

Please sign in to comment.