From 33f53e48ca6629d3e186f61e114b6299287fe90b Mon Sep 17 00:00:00 2001 From: olender Date: Wed, 17 Jul 2024 12:37:04 -0300 Subject: [PATCH] added fwi with serialshots --- spyro/io/__init__.py | 4 + spyro/io/basicio.py | 48 +++++++++++- spyro/solvers/inversion.py | 29 ++++++-- spyro/utils/utils.py | 8 +- temp_serialshot_fwi.py | 138 ++++++++++++++++++++++++++++++++++ test/test_serialshot_fwi.py | 143 ++++++++++++++++++++++++++++++++++++ 6 files changed, 359 insertions(+), 11 deletions(-) create mode 100644 temp_serialshot_fwi.py create mode 100644 test/test_serialshot_fwi.py diff --git a/spyro/io/__init__.py b/spyro/io/__init__.py index 42276e6f..37d58b45 100644 --- a/spyro/io/__init__.py +++ b/spyro/io/__init__.py @@ -15,6 +15,8 @@ switch_serial_shot, ensemble_save_or_load, delete_tmp_files, + ensemble_shot_record, + ensemble_functional ) from .model_parameters import Model_parameters from .backwards_compatibility_io import Dictionary_conversion @@ -44,4 +46,6 @@ "switch_serial_shot", "ensemble_save_or_load", "delete_tmp_files", + "ensemble_shot_record", + "ensemble_functional", ] diff --git a/spyro/io/basicio.py b/spyro/io/basicio.py index 5297620c..b1c928de 100644 --- a/spyro/io/basicio.py +++ b/spyro/io/basicio.py @@ -1,7 +1,7 @@ from __future__ import with_statement import pickle - +from mpi4py import MPI import firedrake as fire import h5py import numpy as np @@ -19,6 +19,21 @@ def delete_tmp_files(wave): os.remove(file) +def ensemble_shot_record(func): + """Decorator for read and write shots for ensemble parallelism""" + + def wrapper(*args, **kwargs): + if args[0].parallelism_type == "spatial" and args[0].number_of_sources > 1: + output_list = [] + for snum in range(args[0].number_of_sources): + switch_serial_shot(args[0], snum) + output_list.append(func(*args, **kwargs)) + + return output_list + + return wrapper + + def ensemble_save_or_load(func): """Decorator for read and write shots for ensemble parallelism""" @@ -106,6 +121,37 @@ def switch_serial_shot(wave, propagation_id): wave.receivers_output = wave.forward_solution_receivers +def ensemble_functional(func): + """Decorator for gradient to distribute shots for ensemble parallelism""" + + def wrapper(*args, **kwargs): + comm = args[0].comm + if args[0].parallelism_type != "spatial" or args[0].number_of_sources == 1: + J = func(*args, **kwargs) + J_total = np.zeros((1)) + J_total[0] += J + J_total = fire.COMM_WORLD.allreduce(J_total, op=MPI.SUM) + J_total[0] /= comm.comm.size + + elif args[0].parallelism_type == "spatial" and args[0].number_of_sources > 1: + num = args[0].number_of_sources + residual_list = args[1] + J_total = np.zeros((1)) + + for snum in range(args[0].number_of_sources): + switch_serial_shot(args[0], snum) + current_residual = residual_list[snum] + J = func(args[0], current_residual) + J_total += J + J_total[0] /= comm.comm.size + + comm.comm.barrier() + + return J_total[0] + + return wrapper + + def ensemble_gradient(func): """Decorator for gradient to distribute shots for ensemble parallelism""" diff --git a/spyro/solvers/inversion.py b/spyro/solvers/inversion.py index 0ae7bcd2..c3c38e71 100644 --- a/spyro/solvers/inversion.py +++ b/spyro/solvers/inversion.py @@ -8,6 +8,9 @@ from ..utils import compute_functional from ..utils import Gradient_mask_for_pml, Mask from ..plots import plot_model as spyro_plot_model +from ..io.basicio import ensemble_shot_record +from ..io.basicio import switch_serial_shot + try: from ROL.firedrake_vector import FiredrakeVector as FireVector @@ -181,10 +184,19 @@ def calculate_misfit(self, c=None): self.forward_solve() output = fire.File("control_" + str(self.current_iteration)+".pvd") output.write(self.c) - self.guess_shot_record = self.forward_solution_receivers - self.guess_forward_solution = self.forward_solution - - self.misfit = self.real_shot_record - self.guess_shot_record + if self.parallelism_type == "spatial" and self.number_of_sources > 1: + misfit_list = [] + guess_shot_record_list = [] + for snum in range (self.number_of_sources): + switch_serial_shot(self, snum) + guess_shot_record_list.append(self.forward_solution_receivers) + misfit_list.append(self.real_shot_record[snum] - self.forward_solution_receivers) + self.guess_shot_record = guess_shot_record_list + self.misfit = misfit_list + else: + self.guess_shot_record = self.forward_solution_receivers + self.guess_forward_solution = self.forward_solution + self.misfit = self.real_shot_record - self.guess_shot_record return self.misfit def generate_real_shot_record(self, plot_model=False, filename=None, abc_points=None): @@ -572,4 +584,11 @@ def __init__(self, dictionary=None, comm=None): def forward_solve(self): super().forward_solve() - self.real_shot_record = self.receivers_output + if self.parallelism_type == "spatial" and self.number_of_sources > 1: + real_shot_record_list = [] + for snum in range (self.number_of_sources): + switch_serial_shot(self, snum) + real_shot_record_list.append(self.receivers_output) + self.real_shot_record = real_shot_record_list + else: + self.real_shot_record = self.receivers_output diff --git a/spyro/utils/utils.py b/spyro/utils/utils.py index 6d385b4d..6a572350 100644 --- a/spyro/utils/utils.py +++ b/spyro/utils/utils.py @@ -4,6 +4,7 @@ from mpi4py import MPI from scipy.signal import butter, filtfilt import warnings +from ..io import ensemble_functional def butter_lowpass_filter(shot, cutoff, fs, order=2): @@ -37,6 +38,7 @@ def butter_lowpass_filter(shot, cutoff, fs, order=2): return filtered_shot +@ensemble_functional def compute_functional(Wave_object, residual): """Compute the functional to be optimized. Accepts the velocity optionally and uses @@ -52,11 +54,7 @@ def compute_functional(Wave_object, residual): J *= 0.5 - J_total = np.zeros((1)) - J_total[0] += J - J_total = COMM_WORLD.allreduce(J_total, op=MPI.SUM) - J_total[0] /= comm.comm.size - return J_total[0] + return J def evaluate_misfit(model, guess, exact): diff --git a/temp_serialshot_fwi.py b/temp_serialshot_fwi.py new file mode 100644 index 00000000..277aebc7 --- /dev/null +++ b/temp_serialshot_fwi.py @@ -0,0 +1,138 @@ +import numpy as np +import firedrake as fire +import spyro +import pytest +import warnings + + +warnings.filterwarnings("ignore") + + +def is_rol_installed(): + try: + import ROL + return True + except ImportError: + return False + + +final_time = 0.9 + +dictionary = {} +dictionary["options"] = { + "cell_type": "T", # simplexes such as triangles or tetrahedra (T) or quadrilaterals (Q) + "variant": "lumped", # lumped, equispaced or DG, default is lumped + "degree": 4, # p order + "dimension": 2, # dimension +} +dictionary["parallelism"] = { + "type": "spatial", # options: automatic (same number of cores for evey processor) or spatial +} +dictionary["mesh"] = { + "Lz": 2.0, # depth in km - always positive # Como ver isso sem ler a malha? + "Lx": 2.0, # width in km - always positive + "Ly": 0.0, # thickness in km - always positive + "mesh_file": None, + "mesh_type": "firedrake_mesh", +} +dictionary["acquisition"] = { + "source_type": "ricker", + "source_locations": spyro.create_transect((-0.55, 0.7), (-0.55, 1.3), 6), + "frequency": 5.0, + "delay": 0.2, + "delay_type": "time", + "receiver_locations": spyro.create_transect((-1.45, 0.7), (-1.45, 1.3), 200), +} +dictionary["time_axis"] = { + "initial_time": 0.0, # Initial time for event + "final_time": final_time, # Final time for event + "dt": 0.001, # timestep size + "amplitude": 1, # the Ricker has an amplitude of 1. + "output_frequency": 100, # how frequently to output solution to pvds - Perguntar Daiane ''post_processing_frequnecy' + "gradient_sampling_frequency": 1, # how frequently to save solution to RAM - Perguntar Daiane 'gradient_sampling_frequency' +} +dictionary["visualization"] = { + "forward_output": False, + "forward_output_filename": "results/forward_output.pvd", + "fwi_velocity_model_output": False, + "velocity_model_filename": None, + "gradient_output": False, + "gradient_filename": "results/Gradient.pvd", + "adjoint_output": False, + "adjoint_filename": None, + "debug_output": False, +} +dictionary["inversion"] = { + "perform_fwi": True, # switch to true to make a FWI + "initial_guess_model_file": None, + "shot_record_file": None, +} + + +def test_fwi(load_real_shot=False, use_rol=False): + """ + Run the Full Waveform Inversion (FWI) test. + + Parameters + ---------- + load_real_shot (bool, optional): Whether to load a real shot record or not. Defaults to False. + """ + + # Setting up to run synthetic real problem + FWI_obj = spyro.FullWaveformInversion(dictionary=dictionary) + + FWI_obj.set_real_mesh(mesh_parameters={"dx": 0.1}) + center_z = -1.0 + center_x = 1.0 + mesh_z = FWI_obj.mesh_z + mesh_x = FWI_obj.mesh_x + cond = fire.conditional((mesh_z-center_z)**2 + (mesh_x-center_x)**2 < .2**2, 3.0, 2.5) + + FWI_obj.set_real_velocity_model(conditional=cond, output=True, dg_velocity_model=False) + FWI_obj.generate_real_shot_record( + plot_model=True, + filename="True_experiment.png", + abc_points=[(-0.5, 0.5), (-1.5, 0.5), (-1.5, 1.5), (-0.5, 1.5)] + ) + np.save("real_shot_record", FWI_obj.real_shot_record) + + # Setting up initial guess problem + FWI_obj.set_guess_mesh(mesh_parameters={"dx": 0.1}) + FWI_obj.set_guess_velocity_model(constant=2.5) + mask_boundaries = { + "z_min": -1.3, + "z_max": -0.7, + "x_min": 0.7, + "x_max": 1.3, + } + FWI_obj.set_gradient_mask(boundaries=mask_boundaries) + if use_rol: + FWI_obj.run_fwi_rol(vmin=2.5, vmax=3.0, maxiter=2) + else: + FWI_obj.run_fwi(vmin=2.5, vmax=3.0, maxiter=5) + + # simple mask test + grad_test = FWI_obj.gradient + test0 = np.isclose(grad_test.at((-0.1, 0.1)), 0.0) + print(f"PML looks masked: {test0}", flush=True) + test1 = np.abs(grad_test.at((-1.0, 1.0))) > 1e-5 + print(f"Center looks unmasked: {test1}", flush=True) + + # quick look at functional and if it reduced + test2 = FWI_obj.functional < 1e-3 + print(f"Last functional small: {test2}", flush=True) + test3 = FWI_obj.functional_history[-1]/FWI_obj.functional_history[0] < 1e-2 + print(f"Considerable functional reduction during test: {test3}", flush=True) + + print("END", flush=True) + assert all([test0, test1, test2, test3]) + + +@pytest.mark.skipif(not is_rol_installed(), reason="ROL is not installed") +def test_fwi_with_rol(load_real_shot=False, use_rol=True): + test_fwi(load_real_shot=load_real_shot, use_rol=use_rol) + + +if __name__ == "__main__": + test_fwi(load_real_shot=False) + test_fwi_with_rol() diff --git a/test/test_serialshot_fwi.py b/test/test_serialshot_fwi.py new file mode 100644 index 00000000..963d248c --- /dev/null +++ b/test/test_serialshot_fwi.py @@ -0,0 +1,143 @@ +import numpy as np +import firedrake as fire +import spyro +import pytest +import warnings + + +warnings.filterwarnings("ignore") + + +def is_rol_installed(): + try: + import ROL + return True + except ImportError: + return False + + +final_time = 0.9 + +dictionary = {} +dictionary["options"] = { + "cell_type": "T", # simplexes such as triangles or tetrahedra (T) or quadrilaterals (Q) + "variant": "lumped", # lumped, equispaced or DG, default is lumped + "degree": 4, # p order + "dimension": 2, # dimension +} +dictionary["parallelism"] = { + "type": "spatial", # options: automatic (same number of cores for evey processor) or spatial +} +dictionary["mesh"] = { + "Lz": 2.0, # depth in km - always positive # Como ver isso sem ler a malha? + "Lx": 2.0, # width in km - always positive + "Ly": 0.0, # thickness in km - always positive + "mesh_file": None, + "mesh_type": "firedrake_mesh", +} +dictionary["acquisition"] = { + "source_type": "ricker", + "source_locations": spyro.create_transect((-0.55, 0.7), (-0.55, 1.3), 6), + "frequency": 5.0, + "delay": 0.2, + "delay_type": "time", + "receiver_locations": spyro.create_transect((-1.45, 0.7), (-1.45, 1.3), 200), +} +dictionary["time_axis"] = { + "initial_time": 0.0, # Initial time for event + "final_time": final_time, # Final time for event + "dt": 0.001, # timestep size + "amplitude": 1, # the Ricker has an amplitude of 1. + "output_frequency": 100, # how frequently to output solution to pvds - Perguntar Daiane ''post_processing_frequnecy' + "gradient_sampling_frequency": 1, # how frequently to save solution to RAM - Perguntar Daiane 'gradient_sampling_frequency' +} +dictionary["visualization"] = { + "forward_output": False, + "forward_output_filename": "results/forward_output.pvd", + "fwi_velocity_model_output": False, + "velocity_model_filename": None, + "gradient_output": False, + "gradient_filename": "results/Gradient.pvd", + "adjoint_output": False, + "adjoint_filename": None, + "debug_output": False, +} +dictionary["inversion"] = { + "perform_fwi": True, # switch to true to make a FWI + "initial_guess_model_file": None, + "shot_record_file": None, +} + + +def test_fwi(load_real_shot=False, use_rol=False): + """ + Run the Full Waveform Inversion (FWI) test. + + Parameters + ---------- + load_real_shot (bool, optional): Whether to load a real shot record or not. Defaults to False. + """ + + # Setting up to run synthetic real problem + if load_real_shot is False: + FWI_obj = spyro.FullWaveformInversion(dictionary=dictionary) + + FWI_obj.set_real_mesh(mesh_parameters={"dx": 0.1}) + center_z = -1.0 + center_x = 1.0 + mesh_z = FWI_obj.mesh_z + mesh_x = FWI_obj.mesh_x + cond = fire.conditional((mesh_z-center_z)**2 + (mesh_x-center_x)**2 < .2**2, 3.0, 2.5) + + FWI_obj.set_real_velocity_model(conditional=cond, output=True, dg_velocity_model=False) + FWI_obj.generate_real_shot_record( + plot_model=True, + filename="True_experiment.png", + abc_points=[(-0.5, 0.5), (-1.5, 0.5), (-1.5, 1.5), (-0.5, 1.5)] + ) + np.save("real_shot_record", FWI_obj.real_shot_record) + + else: + dictionary["inversion"]["shot_record_file"] = "real_shot_record.npy" + FWI_obj = spyro.FullWaveformInversion(dictionary=dictionary) + + # Setting up initial guess problem + FWI_obj.set_guess_mesh(mesh_parameters={"dx": 0.1}) + FWI_obj.set_guess_velocity_model(constant=2.5) + mask_boundaries = { + "z_min": -1.3, + "z_max": -0.7, + "x_min": 0.7, + "x_max": 1.3, + } + FWI_obj.set_gradient_mask(boundaries=mask_boundaries) + if use_rol: + FWI_obj.run_fwi_rol(vmin=2.5, vmax=3.0, maxiter=2) + else: + FWI_obj.run_fwi(vmin=2.5, vmax=3.0, maxiter=5) + + # simple mask test + grad_test = FWI_obj.gradient + test0 = np.isclose(grad_test.at((-0.1, 0.1)), 0.0) + print(f"PML looks masked: {test0}", flush=True) + test1 = np.abs(grad_test.at((-1.0, 1.0))) > 1e-5 + print(f"Center looks unmasked: {test1}", flush=True) + + # quick look at functional and if it reduced + test2 = FWI_obj.functional < 1e-3 + print(f"Last functional small: {test2}", flush=True) + test3 = FWI_obj.functional_history[-1]/FWI_obj.functional_history[0] < 1e-2 + print(f"Considerable functional reduction during test: {test3}", flush=True) + + print("END", flush=True) + assert all([test0, test1, test2, test3]) + + +@pytest.mark.skipif(not is_rol_installed(), reason="ROL is not installed") +def test_fwi_with_rol(load_real_shot=False, use_rol=True): + test_fwi(load_real_shot=load_real_shot, use_rol=use_rol) + + +if __name__ == "__main__": + test_fwi(load_real_shot=False) + test_fwi_with_rol()