Skip to content

Commit

Permalink
test checkpointing
Browse files Browse the repository at this point in the history
  • Loading branch information
Ig-dolci committed Sep 4, 2024
1 parent cd35e51 commit be3bdad
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
2 changes: 1 addition & 1 deletion demos/with_automatic_differentiation/run_fwi_ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,4 +159,4 @@ def forward(
bounds=(1.5, 3.5),
derivative_options={"riesz_representation": 'l2'})

fire.VTKFile("c_optimised.pvd").write(c_optimised)
fire.VTKFile("c_optimised.pvd").write(c_optimised)
15 changes: 15 additions & 0 deletions test/test_gradient_ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import firedrake.adjoint as fire_ad
import spyro
from numpy.random import rand
from checkpoint_schedules import Revolve


# --- Basid setup to run a forward simulation with AD --- #
Expand Down Expand Up @@ -52,6 +53,7 @@
}
model["aut_dif"] = {
"status": True,
"checkpointing": False,
}

model["timeaxis"] = {
Expand Down Expand Up @@ -86,6 +88,12 @@ def forward(
):
if annotate:
fire_ad.continue_annotation()
if model["aut_dif"]["checkpointing"]:
total_steps = int(model["timeaxis"]["tf"] / model["timeaxis"]["dt"])
steps_store = int(total_steps / 10) # Store 10% of the steps.
tape = fire_ad.get_working_tape()
tape.progress_bar = fire.ProgressBar
tape.enable_checkpointing(Revolve(total_steps, steps_store))
# source_number based on the ensemble.ensemble_comm.rank
source_number = ensemble.ensemble_comm.rank
receiver_data, J = fwd_solver.execute_acoustic(
Expand Down Expand Up @@ -131,3 +139,10 @@ def test_taylor():
h = fire.Function(V)
h.dat.data[:] = rand(V.dim())
assert fire_ad.taylor_test(J_hat, c_guess, h) > 1.9
fire_ad.get_working_tape().clear_tape()
fire_ad.pause_annotation()


def test_taylor_checkpointing():
model["aut_dif"]["checkpointing"] = True
test_taylor()

0 comments on commit be3bdad

Please sign in to comment.