Skip to content

Commit

Permalink
started adding gradient
Browse files Browse the repository at this point in the history
  • Loading branch information
Olender committed Jul 9, 2024
1 parent 0fdcf3c commit 70cddee
Show file tree
Hide file tree
Showing 4 changed files with 361 additions and 5 deletions.
23 changes: 18 additions & 5 deletions spyro/io/basicio.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,9 @@ def save_serial_data(wave, propagation_id):
arrays_list = [obj.dat.data[:] for obj in wave.forward_solution]
stacked_arrays = np.stack(arrays_list, axis=0)
spatialcomm = wave.comm.comm.rank
np.save(f'tmp_shot{propagation_id}_comm{spatialcomm}.npy', stacked_arrays)
np.save(f"tmp_rec{propagation_id}_comm{spatialcomm}.npy", wave.forward_solution_receivers)
id_str = wave.random_id_string
np.save(f'tmp_shot{propagation_id}_comm{spatialcomm}'+id_str+'.npy', stacked_arrays)
np.save(f"tmp_rec{propagation_id}_comm{spatialcomm}"+id_str+".npy", wave.forward_solution_receivers)


def switch_serial_shot(wave, propagation_id):
Expand All @@ -85,19 +86,23 @@ def switch_serial_shot(wave, propagation_id):
None
"""
spatialcomm = wave.comm.comm.rank
stacked_shot_arrays = np.load(f'tmp_shot{propagation_id}_comm{spatialcomm}.npy')
id_str = wave.random_id_string
stacked_shot_arrays = np.load(f'tmp_shot{propagation_id}_comm{spatialcomm}'+id_str+'.npy')
if len(wave.forward_solution) == 0:
n_dts, n_dofs = np.shape(stacked_shot_arrays)
rebuild_empty_forward_solution(wave, n_dts)
for array_i, array in enumerate(stacked_shot_arrays):
wave.forward_solution[array_i].dat.data[:] = array
wave.forward_solution_receivers = np.load(f"tmp_rec{propagation_id}_comm{spatialcomm}.npy")
wave.forward_solution_receivers = np.load(f"tmp_rec{propagation_id}_comm{spatialcomm}"+id_str+".npy")


def ensemble_gradient(func):
"""Decorator for gradient to distribute shots for ensemble parallelism"""

def wrapper(*args, **kwargs):
comm = args[0].comm
if args[0].parallelism_type != "spatial" or args[0].number_of_sources == 1:
shot_ids_per_propagation_list = args[0].shot_ids_per_propagation
comm = args[0].comm
for propagation_id, shot_ids_in_propagation in enumerate(shot_ids_per_propagation_list):
if is_owner(comm, propagation_id):
grad = func(*args, **kwargs)
Expand All @@ -120,6 +125,7 @@ def wrapper(*args, **kwargs):
grad_total += grad

grad_total /= num
comm.comm.barrier()

return grad_total

Expand Down Expand Up @@ -228,6 +234,13 @@ def save_shots(Wave_obj, file_name="shots/shot_record_", shot_ids=0):
return None


def rebuild_empty_forward_solution(wave, time_steps):
wave.forward_solution = []
for i in range(time_steps):
wave.forward_solution.append(fire.Function(wave.function_space))



@ensemble_save_or_load
def load_shots(Wave_obj, file_name=None, shot_ids=0):
"""Load a `pickle` to a `numpy.ndarray`.
Expand Down
2 changes: 2 additions & 0 deletions spyro/io/model_parameters.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import uuid
from mpi4py import MPI
from firedrake import COMM_WORLD
import warnings
Expand Down Expand Up @@ -328,6 +329,7 @@ def __init__(self, dictionary=None, comm=None):

# Sanitize output files
self._sanitize_output()
self.random_id_string = str(uuid.uuid4())[:10]

# default_dictionary["absorving_boundary_conditions"] = {
# "status": False, # True or false
Expand Down
169 changes: 169 additions & 0 deletions temp_test_serialshots_grad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# from mpi4py.MPI import COMM_WORLD
# import debugpy
# debugpy.listen(3000 + COMM_WORLD.rank)
# debugpy.wait_for_client()
import numpy as np
import math
import matplotlib.pyplot as plt
from copy import deepcopy
from firedrake import File
import firedrake as fire
import spyro


def check_gradient(Wave_obj_guess, dJ, rec_out_exact, Jm, plot=False):
steps = [1e-3, 1e-4, 1e-5] # step length

errors = []
V_c = Wave_obj_guess.function_space
dm = fire.Function(V_c)
size, = np.shape(dm.dat.data[:])
dm_data = np.random.rand(size)
# np.save(f"dmdata{COMM_WORLD.rank}", dm_data)
# dm_data = np.load(f"dmdata{COMM_WORLD.rank}.npy")
dm.dat.data[:] = dm_data

for step in steps:

Wave_obj_guess.reset_pressure()
c_guess = fire.Constant(2.0) + step*dm
Wave_obj_guess.initial_velocity_model = c_guess
Wave_obj_guess.forward_solve()
misfit_plusdm = rec_out_exact - Wave_obj_guess.receivers_output
J_plusdm = spyro.utils.compute_functional(Wave_obj_guess, misfit_plusdm)

grad_fd = (J_plusdm - Jm) / (step)
projnorm = fire.assemble(dJ * dm * fire.dx(scheme=Wave_obj_guess.quadrature_rule))

error = 100 * ((grad_fd - projnorm) / projnorm)

errors.append(error)

errors = np.array(errors)

# Checking if error is first order in step
theory = [t for t in steps]
theory = [errors[0] * th / theory[0] for th in theory]
if plot:
plt.close()
plt.plot(steps, errors, label="Error")
plt.plot(steps, theory, "--", label="first order")
plt.legend()
plt.title(" Adjoint gradient versus finite difference gradient")
plt.xlabel("Step")
plt.ylabel("Error %")
plt.savefig("gradient_error_verification.png")
plt.close()

# Checking if every error is less than 1 percent

test1 = abs(errors[-1]) < 1
print(f"Last gradient error less than 1 percent: {test1}")

# Checking if error follows expected finite difference error convergence
test2 = math.isclose(np.log(abs(theory[-1])), np.log(abs(errors[-1])), rel_tol=1e-1)

print(f"Gradient error behaved as expected: {test2}")

assert all([test1, test2])


final_time = 1.0

dictionary = {}
dictionary["options"] = {
"cell_type": "Q", # simplexes such as triangles or tetrahedra (T) or quadrilaterals (Q)
"variant": "lumped", # lumped, equispaced or DG, default is lumped
"degree": 4, # p order
"dimension": 2, # dimension
}

dictionary["parallelism"] = {
"type": "spatial", # options: automatic (same number of cores for evey processor) or spatial
"shot_ids_per_propagation": [[0], [1]],
}

dictionary["mesh"] = {
"Lz": 3.0, # depth in km - always positive # Como ver isso sem ler a malha?
"Lx": 3.0, # width in km - always positive
"Ly": 0.0, # thickness in km - always positive
"mesh_file": None,
"mesh_type": "firedrake_mesh",
}

dictionary["acquisition"] = {
"source_type": "ricker",
"source_locations": [(-1.1, 1.3), (-1.1, 1.7)],
"frequency": 5.0,
"delay": 1.5,
"delay_type": "multiples_of_minimun",
"receiver_locations": spyro.create_transect((-1.8, 1.2), (-1.8, 1.8), 10),
}

dictionary["time_axis"] = {
"initial_time": 0.0, # Initial time for event
"final_time": final_time, # Final time for event
"dt": 0.0005, # timestep size
"amplitude": 1, # the Ricker has an amplitude of 1.
"output_frequency": 100, # how frequently to output solution to pvds - Perguntar Daiane ''post_processing_frequnecy'
"gradient_sampling_frequency": 1, # how frequently to save solution to RAM - Perguntar Daiane 'gradient_sampling_frequency'
}

dictionary["visualization"] = {
"forward_output": True,
"forward_output_filename": "results/forward_true.pvd",
"fwi_velocity_model_output": False,
"velocity_model_filename": None,
"gradient_output": False,
"gradient_filename": "results/Gradient.pvd",
"adjoint_output": False,
"adjoint_filename": None,
"debug_output": False,
}


def get_forward_model(load_true=False):
if load_true is False:
Wave_obj_exact = spyro.AcousticWave(dictionary=dictionary)
Wave_obj_exact.set_mesh(mesh_parameters={"dx": 0.1})
# Wave_obj_exact.set_initial_velocity_model(constant=3.0)
cond = fire.conditional(Wave_obj_exact.mesh_z > -1.5, 1.5, 3.5)
Wave_obj_exact.set_initial_velocity_model(
conditional=cond,
# output=True
)
# spyro.plots.plot_model(Wave_obj_exact, abc_points=[(-1, 1), (-2, 1), (-2, 4), (-1, 2)])
Wave_obj_exact.forward_solve()
# forward_solution_exact = Wave_obj_exact.forward_solution
rec_out_exact = Wave_obj_exact.receivers_output
# np.save("rec_out_exact", rec_out_exact)

else:
rec_out_exact = np.load("rec_out_exact.npy")

Wave_obj_guess = spyro.AcousticWave(dictionary=dictionary)
Wave_obj_guess.set_mesh(mesh_parameters={"dx": 0.1})
Wave_obj_guess.set_initial_velocity_model(constant=2.0)
Wave_obj_guess.forward_solve()
rec_out_guess = Wave_obj_guess.receivers_output

return rec_out_exact, rec_out_guess, Wave_obj_guess


def test_gradient_supershot():
rec_out_exact, rec_out_guess, Wave_obj_guess = get_forward_model(load_true=False)

misfit = rec_out_exact - rec_out_guess

Jm = spyro.utils.compute_functional(Wave_obj_guess, misfit)
print(f"Cost functional : {Jm}", flush=True)

# compute the gradient of the control (to be verified)
dJ = Wave_obj_guess.gradient_solve()
File("gradient.pvd").write(dJ)

check_gradient(Wave_obj_guess, dJ, rec_out_exact, Jm, plot=True)


if __name__ == "__main__":
test_gradient_supershot()
Loading

0 comments on commit 70cddee

Please sign in to comment.