Skip to content

Commit

Permalink
debugging
Browse files Browse the repository at this point in the history
  • Loading branch information
Olender committed Jul 15, 2024
1 parent 70cddee commit d124341
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 43 deletions.
2 changes: 1 addition & 1 deletion spyro/io/basicio.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def switch_serial_shot(wave, propagation_id):
for array_i, array in enumerate(stacked_shot_arrays):
wave.forward_solution[array_i].dat.data[:] = array
wave.forward_solution_receivers = np.load(f"tmp_rec{propagation_id}_comm{spatialcomm}"+id_str+".npy")

wave.receivers_output = wave.forward_solution_receivers

def ensemble_gradient(func):
"""Decorator for gradient to distribute shots for ensemble parallelism"""
Expand Down
82 changes: 47 additions & 35 deletions temp_test_serialshots_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@
from firedrake import File
import firedrake as fire
import spyro
import warnings


def check_gradient(Wave_obj_guess, dJ, rec_out_exact, Jm, plot=False):
warnings.filterwarnings("ignore")

def check_gradient(Wave_obj_guess, dJ, rec_out_exact_list, Jm_list, plot=False):
steps = [1e-3, 1e-4, 1e-5] # step length

errors = []
Expand All @@ -25,14 +28,16 @@ def check_gradient(Wave_obj_guess, dJ, rec_out_exact, Jm, plot=False):

for step in steps:

Wave_obj_guess.reset_pressure()
c_guess = fire.Constant(2.0) + step*dm
Wave_obj_guess.initial_velocity_model = c_guess
Wave_obj_guess.forward_solve()
misfit_plusdm = rec_out_exact - Wave_obj_guess.receivers_output
J_plusdm = spyro.utils.compute_functional(Wave_obj_guess, misfit_plusdm)
grad_fd = 0.0
for snum in range(Wave_obj_guess.number_of_sources):
Wave_obj_guess.reset_pressure()
c_guess = fire.Constant(2.0) + step*dm
Wave_obj_guess.initial_velocity_model = c_guess
Wave_obj_guess.forward_solve()
misfit_plusdm = rec_out_exact_list[snum] - Wave_obj_guess.receivers_output
J_plusdm = spyro.utils.compute_functional(Wave_obj_guess, misfit_plusdm)

grad_fd = (J_plusdm - Jm) / (step)
grad_fd += (J_plusdm - Jm_list[snum]) / (step)
projnorm = fire.assemble(dJ * dm * fire.dx(scheme=Wave_obj_guess.quadrature_rule))

error = 100 * ((grad_fd - projnorm) / projnorm)
Expand Down Expand Up @@ -122,48 +127,55 @@ def check_gradient(Wave_obj_guess, dJ, rec_out_exact, Jm, plot=False):
}


def get_forward_model(load_true=False):
if load_true is False:
Wave_obj_exact = spyro.AcousticWave(dictionary=dictionary)
Wave_obj_exact.set_mesh(mesh_parameters={"dx": 0.1})
# Wave_obj_exact.set_initial_velocity_model(constant=3.0)
cond = fire.conditional(Wave_obj_exact.mesh_z > -1.5, 1.5, 3.5)
Wave_obj_exact.set_initial_velocity_model(
conditional=cond,
# output=True
)
# spyro.plots.plot_model(Wave_obj_exact, abc_points=[(-1, 1), (-2, 1), (-2, 4), (-1, 2)])
Wave_obj_exact.forward_solve()
# forward_solution_exact = Wave_obj_exact.forward_solution
rec_out_exact = Wave_obj_exact.receivers_output
# np.save("rec_out_exact", rec_out_exact)

else:
rec_out_exact = np.load("rec_out_exact.npy")
def get_forward_model():

print(f"Calculating exact", flush=True)
Wave_obj_exact = spyro.AcousticWave(dictionary=dictionary)
Wave_obj_exact.set_mesh(mesh_parameters={"dx": 0.1})

cond = fire.conditional(Wave_obj_exact.mesh_z > -1.5, 1.5, 3.5)
Wave_obj_exact.set_initial_velocity_model(
conditional=cond,
)

Wave_obj_exact.forward_solve()

print(f"Calculating guess", flush=True)
Wave_obj_guess = spyro.AcousticWave(dictionary=dictionary)
Wave_obj_guess.set_mesh(mesh_parameters={"dx": 0.1})
Wave_obj_guess.set_initial_velocity_model(constant=2.0)
Wave_obj_guess.forward_solve()
rec_out_guess = Wave_obj_guess.receivers_output

return rec_out_exact, rec_out_guess, Wave_obj_guess
rec_exact_list = []
rec_guess_list = []
print(f"Sending shot records and guess object", flush=True)
for propagation_id in range(Wave_obj_exact.number_of_sources):
spyro.io.switch_serial_shot(Wave_obj_exact, propagation_id)
rec_exact_list.append(Wave_obj_exact.receivers_output)
rec_guess_list.append(Wave_obj_guess.receivers_output)

return rec_exact_list, rec_guess_list, Wave_obj_guess

def test_gradient_supershot():
rec_out_exact, rec_out_guess, Wave_obj_guess = get_forward_model(load_true=False)

misfit = rec_out_exact - rec_out_guess
def test_gradient_serialshots():
print(f"Starting", flush=True)
rec_exact_list, rec_guess_list, Wave_obj_guess = get_forward_model()

Jm = spyro.utils.compute_functional(Wave_obj_guess, misfit)
print(f"Cost functional : {Jm}", flush=True)
Jm_list = []
print(f"Saving cost functionals", flush=True)
for propagation_id in range(Wave_obj_guess.number_of_sources):
misfit = rec_exact_list[propagation_id] - rec_guess_list[propagation_id]
Jm = spyro.utils.compute_functional(Wave_obj_guess, misfit)
print(f"Cost functional : {Jm}", flush=True)
Jm_list.append(Jm)

# compute the gradient of the control (to be verified)
print(f"Gradient calculation", flush=True)
dJ = Wave_obj_guess.gradient_solve()
File("gradient.pvd").write(dJ)

check_gradient(Wave_obj_guess, dJ, rec_out_exact, Jm, plot=True)
check_gradient(Wave_obj_guess, dJ, rec_exact_list, Jm_list, plot=True)


if __name__ == "__main__":
test_gradient_supershot()
test_gradient_serialshots()
14 changes: 7 additions & 7 deletions temp_test_supershot_grady.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from mpi4py.MPI import COMM_WORLD
import debugpy
debugpy.listen(3000 + COMM_WORLD.rank)
debugpy.wait_for_client()
# from mpi4py.MPI import COMM_WORLD
# import debugpy
# debugpy.listen(3000 + COMM_WORLD.rank)
# debugpy.wait_for_client()
import numpy as np
import math
import matplotlib.pyplot as plt
Expand All @@ -17,10 +17,10 @@ def check_gradient(Wave_obj_guess, dJ, rec_out_exact, Jm, plot=False):
errors = []
V_c = Wave_obj_guess.function_space
dm = fire.Function(V_c)
# size, = np.shape(dm.dat.data[:])
# dm_data = np.random.rand(size)
size, = np.shape(dm.dat.data[:])
dm_data = np.random.rand(size)
# np.save(f"dmdata{COMM_WORLD.rank}", dm_data)
dm_data = np.load(f"dmdata{COMM_WORLD.rank}.npy")
# dm_data = np.load(f"dmdata{COMM_WORLD.rank}.npy")
dm.dat.data[:] = dm_data
# dm.assign(dJ)

Expand Down

0 comments on commit d124341

Please sign in to comment.