diff --git a/demos/with_automatic_differentiation/run_fwi_ad.py b/demos/with_automatic_differentiation/run_fwi_ad.py index b247b246..e1540f9a 100644 --- a/demos/with_automatic_differentiation/run_fwi_ad.py +++ b/demos/with_automatic_differentiation/run_fwi_ad.py @@ -159,4 +159,4 @@ def forward( bounds=(1.5, 3.5), derivative_options={"riesz_representation": 'l2'}) -fire.VTKFile("c_optimised.pvd").write(c_optimised) \ No newline at end of file +fire.VTKFile("c_optimised.pvd").write(c_optimised) diff --git a/test/test_gradient_ad.py b/test/test_gradient_ad.py index 59a9882c..55b8b286 100644 --- a/test/test_gradient_ad.py +++ b/test/test_gradient_ad.py @@ -2,6 +2,7 @@ import firedrake.adjoint as fire_ad import spyro from numpy.random import rand +from checkpoint_schedules import Revolve # --- Basid setup to run a forward simulation with AD --- # @@ -52,6 +53,7 @@ } model["aut_dif"] = { "status": True, + "checkpointing": False, } model["timeaxis"] = { @@ -86,6 +88,12 @@ def forward( ): if annotate: fire_ad.continue_annotation() + if model["aut_dif"]["checkpointing"]: + total_steps = int(model["timeaxis"]["tf"] / model["timeaxis"]["dt"]) + steps_store = int(total_steps / 10) # Store 10% of the steps. + tape = fire_ad.get_working_tape() + tape.progress_bar = fire.ProgressBar + tape.enable_checkpointing(Revolve(total_steps, steps_store)) # source_number based on the ensemble.ensemble_comm.rank source_number = ensemble.ensemble_comm.rank receiver_data, J = fwd_solver.execute_acoustic( @@ -131,3 +139,10 @@ def test_taylor(): h = fire.Function(V) h.dat.data[:] = rand(V.dim()) assert fire_ad.taylor_test(J_hat, c_guess, h) > 1.9 + fire_ad.get_working_tape().clear_tape() + fire_ad.pause_annotation() + + +def test_taylor_checkpointing(): + model["aut_dif"]["checkpointing"] = True + test_taylor()