diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index f7978bab..13045e30 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -25,6 +25,11 @@ jobs: mpiexec -n 6 pytest test_3d/test_hexahedral_convergence.py mpiexec -n 6 pytest test_parallel/test_forward.py mpiexec -n 6 pytest test_parallel/test_fwi.py + mpiexec -n 6 pytest test_parallel/test_forward_supershot.py + mpiexec -n 2 pytest test_parallel/test_parallel_io.py + mpiexec -n 3 pytest test_parallel/test_supershot_grad.py + mpiexec -n 2 pytest test_parallel/test_forward_multiple_serial_shots.py + mpiexec -n 2 pytest test_parallel/test_gradient_serialshots.py - name: Covering parallel 3D forward test continue-on-error: true run: | @@ -40,6 +45,31 @@ jobs: run: | source /home/olender/firedrakes/2024_07_19/firedrake/bin/activate mpiexec -n 6 pytest --cov-report=xml --cov-append --cov=spyro test_parallel/test_fwi.py + - name: Covering parallel supershot test + continue-on-error: true + run: | + source /home/olender/firedrakes/2024_07_19/firedrake/bin/activate + mpiexec -n 6 pytest --cov-report=xml --cov-append --cov=spyro test_parallel/test_forward_supershot.py + - name: Covering parallel io test + continue-on-error: true + run: | + source /home/olender/firedrakes/2024_07_19/firedrake/bin/activate + mpiexec -n 2 pytest --cov-report=xml --cov-append --cov=spyro test_parallel/test_parallel_io.py + - name: Covering parallel supershot grad test + continue-on-error: true + run: | + source /home/olender/firedrakes/2024_07_19/firedrake/bin/activate + mpiexec -n 3 pytest --cov-report=xml --cov-append --cov=spyro test_parallel/test_supershot_grad.py + - name: Covering spatially parallelized shots in serial + continue-on-error: true + run: | + source /home/olender/firedrakes/2024_07_19/firedrake/bin/activate + mpiexec -n 2 pytest --cov-report=xml --cov-append --cov=spyro test_parallel/test_forward_multiple_serial_shots.py + - name: Covering spatially parallelized shots in serial + continue-on-error: true + run: | + source /home/olender/firedrakes/2024_07_19/firedrake/bin/activate + mpiexec -n 2 pytest --cov-report=xml --cov-append --cov=spyro test_parallel/test_test_gradient_serialshots.py - name: Uploading coverage to Codecov run: export CODECOV_TOKEN="057ec853-d7ea-4277-819b-0c5ea2f9ff57" && bash <(curl -s https://codecov.io/bash) diff --git a/.vscode/launch.json b/.vscode/launch.json index 99d93959..376f4a95 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -4,20 +4,20 @@ // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 "version": "0.2.0", "configurations": [ - // { - // "name": "Python Attach 0", - // "type": "python", - // "request": "attach", - // "port": 3000, - // "host": "localhost", - // }, - // { - // "name": "Python Attach 1", - // "type": "python", - // "request": "attach", - // "port": 3001, - // "host": "localhost" - // }, + { + "name": "Python Attach 0", + "type": "python", + "request": "attach", + "port": 3000, + "host": "localhost", + }, + { + "name": "Python Attach 1", + "type": "python", + "request": "attach", + "port": 3001, + "host": "localhost" + }, { "name": "Python Debugger: Current File", "type": "debugpy", diff --git a/cleanup.sh b/cleanup.sh index 88b83f1b..0ef3f65d 100755 --- a/cleanup.sh +++ b/cleanup.sh @@ -5,9 +5,13 @@ rm *.png rm *.vtu rm *.pvtu rm *.pvd +rm *.npy +rm *.pdf +rm *.dat rm results/*.vtu rm results/*.pvd rm results/*.pvtu +rm shots/*.dat rm -rf results/shot* rm -rf results/gradient rm -rf results/adjoint_shot diff --git a/spyro/io/__init__.py b/spyro/io/__init__.py index 58116832..37d58b45 100644 --- a/spyro/io/__init__.py +++ b/spyro/io/__init__.py @@ -6,14 +6,17 @@ load_shots, read_mesh, interpolate, - # ensemble_forward, # ensemble_forward_ad, # ensemble_forward_elastic_waves, ensemble_gradient, # ensemble_gradient_elastic_waves, - ensemble_plot, parallel_print, saving_source_and_receiver_location_in_csv, + switch_serial_shot, + ensemble_save_or_load, + delete_tmp_files, + ensemble_shot_record, + ensemble_functional ) from .model_parameters import Model_parameters from .backwards_compatibility_io import Dictionary_conversion @@ -28,7 +31,6 @@ "load_shots", "read_mesh", "interpolate", - # "ensemble_forward", # "ensemble_forward_ad", # "ensemble_forward_elastic_waves", "ensemble_gradient", @@ -41,4 +43,9 @@ "dictionaryio", "boundary_layer_io", "saving_source_and_receiver_location_in_csv", + "switch_serial_shot", + "ensemble_save_or_load", + "delete_tmp_files", + "ensemble_shot_record", + "ensemble_functional", ] diff --git a/spyro/io/basicio.py b/spyro/io/basicio.py index 04df9c31..6fec78d2 100644 --- a/spyro/io/basicio.py +++ b/spyro/io/basicio.py @@ -1,164 +1,199 @@ from __future__ import with_statement import pickle - +from mpi4py import MPI import firedrake as fire import h5py import numpy as np from scipy.interpolate import RegularGridInterpolator from scipy.interpolate import griddata import segyio +import glob +import os -def ensemble_save_or_load(func): +def delete_tmp_files(wave): + str_id = "*" + wave.random_id_string + ".npy" + temp_files = glob.glob(str_id) + for file in temp_files: + os.remove(file) + + +def ensemble_shot_record(func): """Decorator for read and write shots for ensemble parallelism""" def wrapper(*args, **kwargs): - num = args[0].number_of_sources - comm = args[0].comm - custom_file_name = kwargs.get("file_name") - for snum in range(num): - if is_owner(comm, snum) and comm.comm.rank == 0: - if custom_file_name is None: - func( - *args, - **dict( - kwargs, - source_id=snum, - file_name="shots/shot_record_" - + str(snum + 1) - + ".dat", - ) - ) - else: - func( - *args, - **dict( - kwargs, - source_id=snum, - file_name="shots/" - + custom_file_name - + str(snum + 1) - + ".dat", - ) - ) + if args[0].parallelism_type == "spatial" and args[0].number_of_sources > 1: + output_list = [] + for snum in range(args[0].number_of_sources): + switch_serial_shot(args[0], snum) + output_list.append(func(*args, **kwargs)) + + return output_list return wrapper -def ensemble_plot(func): - """Decorator for `plot_shots` to distribute shots for - ensemble parallelism""" +def ensemble_save_or_load(func): + """Decorator for read and write shots for ensemble parallelism""" def wrapper(*args, **kwargs): - num = args[0].number_of_sources _comm = args[0].comm - for snum in range(num): - if is_owner(_comm, snum) and _comm.comm.rank == 0: - func(*args, **dict(kwargs, file_name=str(snum + 1))) - - return wrapper - - -# def ensemble_forward(func): -# """Decorator for forward to distribute shots for ensemble parallelism""" -# def wrapper(*args, **kwargs): -# acq = args[0].get("acquisition") -# num = len(acq["source_pos"]) -# _comm = args[2] -# for snum in range(num): -# if is_owner(_comm, snum): -# u, u_r = func(*args, **dict(kwargs, source_num=snum)) -# return u, u_r + if args[0].parallelism_type != "spatial" or args[0].number_of_sources == 1: + shot_ids_per_propagation_list = args[0].shot_ids_per_propagation + for propagation_id, shot_ids_in_propagation in enumerate(shot_ids_per_propagation_list): + if is_owner(_comm, propagation_id) and _comm.comm.rank == 0: + func(*args, **dict(kwargs, shot_ids=shot_ids_in_propagation)) + elif args[0].parallelism_type == "spatial" and args[0].number_of_sources > 1: + for snum in range(args[0].number_of_sources): + switch_serial_shot(args[0], snum) + if _comm.comm.rank == 0: + func(*args, **dict(kwargs, shot_ids=[snum])) -# return wrapper + return wrapper def ensemble_propagator(func): """Decorator for forward to distribute shots for ensemble parallelism""" def wrapper(*args, **kwargs): - num = args[0].number_of_sources - _comm = args[0].comm - for snum in range(num): - if is_owner(_comm, snum): - u, u_r = func(*args, **dict(kwargs, source_num=snum)) - return u, u_r + if args[0].parallelism_type != "spatial" or args[0].number_of_sources == 1: + shot_ids_per_propagation_list = args[0].shot_ids_per_propagation + _comm = args[0].comm + for propagation_id, shot_ids_in_propagation in enumerate(shot_ids_per_propagation_list): + if is_owner(_comm, propagation_id): + u, u_r = func(*args, **dict(kwargs, source_nums=shot_ids_in_propagation)) + return u, u_r + elif args[0].parallelism_type == "spatial" and args[0].number_of_sources > 1: + num = args[0].number_of_sources + starting_time = args[0].current_time + for snum in range(num): + args[0].reset_pressure() + args[0].current_time = starting_time + u, u_r = func(*args, **dict(kwargs, source_nums=[snum])) + save_serial_data(args[0], snum) + + return u, u_r return wrapper -# def ensemble_forward_ad(func): -# """Decorator for forward to distribute shots for ensemble parallelism""" - -# def wrapper(*args, **kwargs): -# acq = args[0].get("acquisition") -# num = len(acq["source_pos"]) -# fwi = kwargs.get("fwi") -# _comm = args[2] -# for snum in range(num): -# if is_owner(_comm, snum): -# if fwi: -# u_r, J = func(*args, **dict(kwargs, source_num=snum)) -# return u_r, J -# else: -# u_r = func(*args, **dict(kwargs, source_num=snum)) - -# return wrapper +def save_serial_data(wave, propagation_id): + """ + Save serial data to numpy files. + Args: + wave (Wave): The wave object containing the forward solution. + propagation_id (int): The propagation ID. -# def ensemble_forward_elastic_waves(func): -# """Decorator for forward elastic waves to distribute shots for -# ensemble parallelism""" + Returns: + None + """ + arrays_list = [obj.dat.data[:] for obj in wave.forward_solution] + stacked_arrays = np.stack(arrays_list, axis=0) + spatialcomm = wave.comm.comm.rank + id_str = wave.random_id_string + np.save(f'tmp_shot{propagation_id}_comm{spatialcomm}'+id_str+'.npy', stacked_arrays) + np.save(f"tmp_rec{propagation_id}_comm{spatialcomm}"+id_str+".npy", wave.forward_solution_receivers) -# def wrapper(*args, **kwargs): -# acq = args[0].get("acquisition") -# num = len(acq["source_pos"]) -# _comm = args[2] -# for snum in range(num): -# if is_owner(_comm, snum): -# u, uz_r, ux_r, uy_r = func( -# *args, **dict(kwargs, source_num=snum) -# ) -# return u, uz_r, ux_r, uy_r -# return wrapper +def switch_serial_shot(wave, propagation_id): + """ + Switches the current serial shot for a given wave to shot identified with propagation ID. + Args: + wave (Wave): The wave object. + propagation_id (int): The propagation ID. -def ensemble_gradient(func): + Returns: + None + """ + spatialcomm = wave.comm.comm.rank + id_str = wave.random_id_string + stacked_shot_arrays = np.load(f'tmp_shot{propagation_id}_comm{spatialcomm}'+id_str+'.npy') + if len(wave.forward_solution) == 0: + n_dts, n_dofs = np.shape(stacked_shot_arrays) + rebuild_empty_forward_solution(wave, n_dts) + for array_i, array in enumerate(stacked_shot_arrays): + wave.forward_solution[array_i].dat.data[:] = array + wave.forward_solution_receivers = np.load(f"tmp_rec{propagation_id}_comm{spatialcomm}"+id_str+".npy") + wave.receivers_output = wave.forward_solution_receivers + + +def ensemble_functional(func): """Decorator for gradient to distribute shots for ensemble parallelism""" def wrapper(*args, **kwargs): - num = args[0].number_of_sources - _comm = args[0].comm - for snum in range(num): - if is_owner(_comm, snum): - grad = func(*args, **kwargs) - return grad + comm = args[0].comm + if args[0].parallelism_type != "spatial" or args[0].number_of_sources == 1: + J = func(*args, **kwargs) + J_total = np.zeros((1)) + J_total[0] += J + J_total = fire.COMM_WORLD.allreduce(J_total, op=MPI.SUM) + J_total[0] /= comm.comm.size + + elif args[0].parallelism_type == "spatial" and args[0].number_of_sources > 1: + num = args[0].number_of_sources + residual_list = args[1] + J_total = np.zeros((1)) + + for snum in range(args[0].number_of_sources): + switch_serial_shot(args[0], snum) + current_residual = residual_list[snum] + J = func(args[0], current_residual) + J_total += J + J_total[0] /= comm.comm.size + + comm.comm.barrier() + + return J_total[0] return wrapper -# def ensemble_gradient_elastic_waves(func): -# """Decorator for gradient (elastic waves) to distribute shots -# for ensemble parallelism""" +def ensemble_gradient(func): + """Decorator for gradient to distribute shots for ensemble parallelism""" -# def wrapper(*args, **kwargs): -# acq = args[0].get("acquisition") -# save_adjoint = kwargs.get("save_adjoint") -# num = len(acq["source_pos"]) -# _comm = args[2] -# for snum in range(num): -# if is_owner(_comm, snum): -# if save_adjoint: -# grad_lambda, grad_mu, u_adj = func(*args, **kwargs) -# return grad_lambda, grad_mu, u_adj -# else: -# grad_lambda, grad_mu = func(*args, **kwargs) -# return grad_lambda, grad_mu + def wrapper(*args, **kwargs): + comm = args[0].comm + if args[0].parallelism_type != "spatial" or args[0].number_of_sources == 1: + shot_ids_per_propagation_list = args[0].shot_ids_per_propagation + for propagation_id, shot_ids_in_propagation in enumerate(shot_ids_per_propagation_list): + if is_owner(comm, propagation_id): + grad = func(*args, **kwargs) + grad_total = fire.Function(args[0].function_space) + + comm.comm.barrier() + grad_total = comm.allreduce(grad, grad_total) + grad_total /= comm.ensemble_comm.size + + return grad_total + elif args[0].parallelism_type == "spatial" and args[0].number_of_sources > 1: + num = args[0].number_of_sources + starting_time = args[0].current_time + grad_total = fire.Function(args[0].function_space) + misfit_list = kwargs.get("misfit") + + for snum in range(num): + switch_serial_shot(args[0], snum) + current_misfit = misfit_list[snum] + args[0].reset_pressure() + args[0].current_time = starting_time + grad = func(*args, + **dict( + kwargs, + misfit=current_misfit, + ) + ) + grad_total += grad -# return wrapper + grad_total /= num + comm.comm.barrier() + + return grad_total + + return wrapper def write_function_to_grid(function, V, grid_spacing): @@ -241,7 +276,7 @@ def create_segy(function, V, grid_spacing, filename): @ensemble_save_or_load -def save_shots(Wave_obj, source_id=0, file_name=None): +def save_shots(Wave_obj, file_name="shots/shot_record_", shot_ids=0): """Save a the shot record from last forward solve to a `pickle`. Parameters @@ -258,13 +293,21 @@ def save_shots(Wave_obj, source_id=0, file_name=None): None """ + file_name = file_name + str(shot_ids) + ".dat" with open(file_name, "wb") as f: - pickle.dump(Wave_obj.forward_solution_receivers[:, source_id], f) + pickle.dump(Wave_obj.forward_solution_receivers, f) return None +def rebuild_empty_forward_solution(wave, time_steps): + wave.forward_solution = [] + for i in range(time_steps): + wave.forward_solution.append(fire.Function(wave.function_space)) + + + @ensemble_save_or_load -def load_shots(Wave_obj, source_id=0, file_name="shots/shot_record_"): +def load_shots(Wave_obj, file_name=None, shot_ids=0): """Load a `pickle` to a `numpy.ndarray`. Parameters @@ -283,6 +326,7 @@ def load_shots(Wave_obj, source_id=0, file_name="shots/shot_record_"): """ array = np.zeros(()) + file_name = file_name + str(shot_ids) + ".dat" with open(file_name, "rb") as f: array = np.asarray(pickle.load(f), dtype=float) @@ -306,7 +350,8 @@ def is_owner(ens_comm, rank): `True` if `rank` owns this shot """ - return ens_comm.ensemble_comm.rank == (rank % ens_comm.ensemble_comm.size) + owner = ens_comm.ensemble_comm.rank == (rank % ens_comm.ensemble_comm.size) + return owner def _check_units(c): diff --git a/spyro/io/model_parameters.py b/spyro/io/model_parameters.py index 1904dc55..e2bd5708 100644 --- a/spyro/io/model_parameters.py +++ b/spyro/io/model_parameters.py @@ -1,4 +1,7 @@ import numpy as np +import uuid +from mpi4py import MPI +from firedrake import COMM_WORLD import warnings from .. import io from .. import utils @@ -326,6 +329,7 @@ def __init__(self, dictionary=None, comm=None): # Sanitize output files self._sanitize_output() + self.random_id_string = str(uuid.uuid4())[:10] # default_dictionary["absorving_boundary_conditions"] = { # "status": False, # True or false @@ -543,8 +547,13 @@ def _sanitize_comm(self, comm): warnings.warn("No paralellism type listed. Assuming automatic") self.parallelism_type = "automatic" - if self.source_type == "MMS": - self.parallelism_type = "spatial" + if self.parallelism_type == "custom": + self.shot_ids_per_propagation = dictionary["parallelism"]["shot_ids_per_propagation"] + elif self.parallelism_type == "automatic": + available_cores = COMM_WORLD.size + self.shot_ids_per_propagation = [[i] for i in range(0, available_cores)] + elif self.parallelism_type == "spatial": + self.shot_ids_per_propagation = [[i] for i in range(0, self.number_of_sources)] if comm is None: self.comm = utils.mpi_init(self) diff --git a/spyro/plots/plots.py b/spyro/plots/plots.py index 68742439..b0d6c3a6 100644 --- a/spyro/plots/plots.py +++ b/spyro/plots/plots.py @@ -6,16 +6,17 @@ import numpy as np import firedrake import copy -from ..io import ensemble_plot +from ..io import ensemble_save_or_load __all__ = ["plot_shots"] -@ensemble_plot +@ensemble_save_or_load def plot_shots( Wave_object, show=False, - file_name="1", + file_name="plot_of_shot", + shot_ids=[0], vmin=-1e-5, vmax=1e-5, contour_lines=700, @@ -52,7 +53,7 @@ def plot_shots( ------- None """ - + file_name = file_name + str(shot_ids) + "." + file_format num_recvs = Wave_object.number_of_receivers dt = Wave_object.dt @@ -79,7 +80,7 @@ def plot_shots( plt.xlim(start_index, end_index) plt.ylim(tf, 0) plt.subplots_adjust(left=0.18, right=0.95, bottom=0.14, top=0.95) - plt.savefig(file_name + "." + file_format, format=file_format) + plt.savefig(file_name, format=file_format) # plt.axis("image") if show: plt.show() diff --git a/spyro/solvers/acoustic_wave.py b/spyro/solvers/acoustic_wave.py index 3ec51d44..e5b6d6bd 100644 --- a/spyro/solvers/acoustic_wave.py +++ b/spyro/solvers/acoustic_wave.py @@ -44,6 +44,7 @@ def forward_solve(self): self.c = self.initial_velocity_model self.matrix_building() self.wave_propagator() + self.comm.comm.barrier() def force_rebuild_function_space(self): if self.mesh is None: @@ -84,7 +85,7 @@ def matrix_building(self): construct_solver_or_matrix_with_pml(self) @ensemble_propagator - def wave_propagator(self, dt=None, final_time=None, source_num=0): + def wave_propagator(self, dt=None, final_time=None, source_nums=[0]): """Propagates the wave forward in time. Currently uses central differences. @@ -109,8 +110,8 @@ def wave_propagator(self, dt=None, final_time=None, source_num=0): if dt is not None: self.dt = dt - self.current_source = source_num - usol, usol_recv = time_integrator(self, source_id=source_num) + self.current_sources = source_nums + usol, usol_recv = time_integrator(self, source_ids=source_nums) return usol, usol_recv @@ -150,4 +151,3 @@ def reset_pressure(self): self.X_nm1.assign(0.0) except: warnings.warn("No mixed space pressure to reset") - diff --git a/spyro/solvers/inversion.py b/spyro/solvers/inversion.py index cd5589e5..d505ef0d 100644 --- a/spyro/solvers/inversion.py +++ b/spyro/solvers/inversion.py @@ -8,6 +8,9 @@ from ..utils import compute_functional from ..utils import Gradient_mask_for_pml, Mask from ..plots import plot_model as spyro_plot_model +from ..io.basicio import ensemble_shot_record +from ..io.basicio import switch_serial_shot + try: from ROL.firedrake_vector import FiredrakeVector as FireVector @@ -181,10 +184,19 @@ def calculate_misfit(self, c=None): self.forward_solve() output = fire.File("control_" + str(self.current_iteration)+".pvd") output.write(self.c) - self.guess_shot_record = self.forward_solution_receivers - self.guess_forward_solution = self.forward_solution - - self.misfit = self.real_shot_record - self.guess_shot_record + if self.parallelism_type == "spatial" and self.number_of_sources > 1: + misfit_list = [] + guess_shot_record_list = [] + for snum in range (self.number_of_sources): + switch_serial_shot(self, snum) + guess_shot_record_list.append(self.forward_solution_receivers) + misfit_list.append(self.real_shot_record[snum] - self.forward_solution_receivers) + self.guess_shot_record = guess_shot_record_list + self.misfit = misfit_list + else: + self.guess_shot_record = self.forward_solution_receivers + self.guess_forward_solution = self.forward_solution + self.misfit = self.real_shot_record - self.guess_shot_record return self.misfit def generate_real_shot_record(self, plot_model=False, filename="model.png", abc_points=None): @@ -382,20 +394,13 @@ def get_gradient(self, c=None, save=True, calculate_functional=True): if calculate_functional: self.get_functional(c=c) comm.comm.barrier() - dJ = self.gradient_solve(misfit=self.misfit, forward_solution=self.guess_forward_solution) - dJ_total = fire.Function(self.function_space) - comm.comm.barrier() - dJ_total = comm.allreduce(dJ, dJ_total) - dJ_total /= comm.ensemble_comm.size - if comm.comm.size > 1: - dJ_total /= comm.comm.size - self.gradient = dJ_total + self.gradient = self.gradient_solve(misfit=self.misfit, forward_solution=self.guess_forward_solution) self._apply_gradient_mask() if save and comm.comm.rank == 0: # self.gradient_out.write(dJ_total) output = fire.File("gradient_" + str(self.current_iteration)+".pvd") - output.write(dJ_total) - print("DEBUG") + output.write(self.gradient) + self.current_iteration += 1 comm.comm.barrier() @@ -579,4 +584,11 @@ def __init__(self, dictionary=None, comm=None): def forward_solve(self): super().forward_solve() - self.real_shot_record = self.receivers_output + if self.parallelism_type == "spatial" and self.number_of_sources > 1: + real_shot_record_list = [] + for snum in range (self.number_of_sources): + switch_serial_shot(self, snum) + real_shot_record_list.append(self.receivers_output) + self.real_shot_record = real_shot_record_list + else: + self.real_shot_record = self.receivers_output diff --git a/spyro/solvers/time_integration.py b/spyro/solvers/time_integration.py index 889fd834..ab94b93d 100644 --- a/spyro/solvers/time_integration.py +++ b/spyro/solvers/time_integration.py @@ -3,24 +3,24 @@ from .time_integration_central_difference import central_difference_MMS -def time_integrator(Wave_object, source_id=0): +def time_integrator(Wave_object, source_ids=[0]): if Wave_object.source_type == "ricker": - return time_integrator_ricker(Wave_object, source_id=source_id) + return time_integrator_ricker(Wave_object, source_ids=source_ids) elif Wave_object.source_type == "MMS": - return time_integrator_mms(Wave_object, source_id=source_id) + return time_integrator_mms(Wave_object, source_ids=source_ids) -def time_integrator_ricker(Wave_object, source_id=0): +def time_integrator_ricker(Wave_object, source_ids=[0]): if Wave_object.time_integrator == "central_difference": - return central_difference(Wave_object, source_id=source_id) + return central_difference(Wave_object, source_ids=source_ids) elif Wave_object.time_integrator == "mixed_space_central_difference": - return mixed_space_central_difference(Wave_object, source_id=source_id) + return mixed_space_central_difference(Wave_object, source_ids=source_ids) else: raise ValueError("The time integrator specified is not implemented yet") -def time_integrator_mms(Wave_object, source_id=0): +def time_integrator_mms(Wave_object, source_ids=[0]): if Wave_object.time_integrator == "central_difference": - return central_difference_MMS(Wave_object, source_id=source_id) + return central_difference_MMS(Wave_object, source_ids=source_ids) else: raise ValueError("The time integrator specified is not implemented yet") diff --git a/spyro/solvers/time_integration_central_difference.py b/spyro/solvers/time_integration_central_difference.py index 2c6bf696..aab006d9 100644 --- a/spyro/solvers/time_integration_central_difference.py +++ b/spyro/solvers/time_integration_central_difference.py @@ -7,7 +7,7 @@ from .. import utils -def central_difference(Wave_object, source_id=0): +def central_difference(Wave_object, source_ids=[0]): """ Perform central difference time integration for wave propagation. @@ -15,8 +15,8 @@ def central_difference(Wave_object, source_id=0): ----------- Wave_object: Spyro object The Wave object containing the necessary data and parameters. - source_id: int (optional) - The ID of the source being propagated. Defaults to 0. + source_ids: list of ints (optional) + The ID of the sources being propagated. Defaults to [0]. Returns: -------- @@ -24,13 +24,13 @@ def central_difference(Wave_object, source_id=0): A tuple containing the forward solution and the receiver output. """ excitations = Wave_object.sources - excitations.current_source = source_id + excitations.current_sources = source_ids receivers = Wave_object.receivers comm = Wave_object.comm temp_filename = Wave_object.forward_output_file filename, file_extension = temp_filename.split(".") - output_filename = filename + "sn" + str(source_id) + "." + file_extension + output_filename = filename + "sn" + str(source_ids) + "." + file_extension if Wave_object.forward_output: parallel_print(f"Saving output in: {output_filename}", Wave_object.comm) @@ -107,7 +107,7 @@ def central_difference(Wave_object, source_id=0): return usol, usol_recv -def mixed_space_central_difference(Wave_object, source_id=0): +def mixed_space_central_difference(Wave_object, source_ids=[0]): """ Performs central difference time integration for wave propagation. Solves for a mixed space formulation, for function X. For correctly @@ -118,8 +118,8 @@ def mixed_space_central_difference(Wave_object, source_id=0): ----------- Wave_object: Spyro object The Wave object containing the necessary data and parameters. - source_id: int (optional) - The ID of the source being propagated. Defaults to 0. + source_ids: list of int (optional) + The ID of the source being propagated. Defaults to [0]. Returns: -------- @@ -127,12 +127,12 @@ def mixed_space_central_difference(Wave_object, source_id=0): A tuple containing the forward solution and the receiver output. """ excitations = Wave_object.sources - excitations.current_source = source_id + excitations.current_sources = source_ids receivers = Wave_object.receivers comm = Wave_object.comm temp_filename = Wave_object.forward_output_file filename, file_extension = temp_filename.split(".") - output_filename = filename + "sn" + str(source_id) + "." + file_extension + output_filename = filename + "sn" + str(source_ids) + "." + file_extension if Wave_object.forward_output: parallel_print(f"Saving output in: {output_filename}", Wave_object.comm) @@ -209,7 +209,7 @@ def mixed_space_central_difference(Wave_object, source_id=0): return usol, usol_recv -def central_difference_MMS(Wave_object, source_id=0): +def central_difference_MMS(Wave_object, source_ids=[0]): """Propagates the wave forward in time. Currently uses central differences. diff --git a/spyro/sources/Sources.py b/spyro/sources/Sources.py index 5955e376..bbbf8309 100644 --- a/spyro/sources/Sources.py +++ b/spyro/sources/Sources.py @@ -66,7 +66,7 @@ def __init__(self, wave_object): self.point_locations = wave_object.source_locations self.number_of_points = wave_object.number_of_sources self.is_local = [0] * self.number_of_points - self.current_source = None + self.current_sources = None self.build_maps() @@ -86,7 +86,7 @@ def apply_source(self, rhs_forcing, value): The right hand side of the wave equation with the source applied """ for source_id in range(self.number_of_points): - if self.is_local[source_id] and source_id == self.current_source: + if self.is_local[source_id] and source_id in self.current_sources: for i in range(len(self.cellNodeMaps[source_id])): rhs_forcing.dat.data_with_halos[ int(self.cellNodeMaps[source_id][i]) diff --git a/spyro/utils/utils.py b/spyro/utils/utils.py index 1e6a8a2f..6a572350 100644 --- a/spyro/utils/utils.py +++ b/spyro/utils/utils.py @@ -4,6 +4,7 @@ from mpi4py import MPI from scipy.signal import butter, filtfilt import warnings +from ..io import ensemble_functional def butter_lowpass_filter(shot, cutoff, fs, order=2): @@ -37,6 +38,7 @@ def butter_lowpass_filter(shot, cutoff, fs, order=2): return filtered_shot +@ensemble_functional def compute_functional(Wave_object, residual): """Compute the functional to be optimized. Accepts the velocity optionally and uses @@ -52,11 +54,7 @@ def compute_functional(Wave_object, residual): J *= 0.5 - J_total = np.zeros((1)) - J_total[0] += J - J_total = COMM_WORLD.allreduce(J_total, op=MPI.SUM) - J_total[0] /= comm.comm.size - return J_total[0] + return J def evaluate_misfit(model, guess, exact): @@ -88,17 +86,20 @@ def mpi_init(model): available_cores = COMM_WORLD.size # noqa: F405 print(f"Parallelism type: {model.parallelism_type}", flush=True) if model.parallelism_type == "automatic": - num_cores_per_shot = available_cores / model.number_of_sources + num_cores_per_propagation = available_cores / model.number_of_sources if available_cores % model.number_of_sources != 0: raise ValueError( "Available cores cannot be divided between sources equally." ) elif model.parallelism_type == "spatial": - num_cores_per_shot = available_cores + num_cores_per_propagation = available_cores elif model.parallelism_type == "custom": - raise ValueError("Custom parallelism not yet implemented") + shot_ids_per_propagation = model.shot_ids_per_propagation + num_max_shots_per_core = max(len(sublist) for sublist in shot_ids_per_propagation) + num_propagations = len(shot_ids_per_propagation) + num_cores_per_propagation = available_cores / num_propagations - comm_ens = Ensemble(COMM_WORLD, num_cores_per_shot) # noqa: F405 + comm_ens = Ensemble(COMM_WORLD, num_cores_per_propagation) # noqa: F405 return comm_ens diff --git a/test/test_serialshot_fwi.py b/test/test_serialshot_fwi.py new file mode 100644 index 00000000..963d248c --- /dev/null +++ b/test/test_serialshot_fwi.py @@ -0,0 +1,143 @@ +import numpy as np +import firedrake as fire +import spyro +import pytest +import warnings + + +warnings.filterwarnings("ignore") + + +def is_rol_installed(): + try: + import ROL + return True + except ImportError: + return False + + +final_time = 0.9 + +dictionary = {} +dictionary["options"] = { + "cell_type": "T", # simplexes such as triangles or tetrahedra (T) or quadrilaterals (Q) + "variant": "lumped", # lumped, equispaced or DG, default is lumped + "degree": 4, # p order + "dimension": 2, # dimension +} +dictionary["parallelism"] = { + "type": "spatial", # options: automatic (same number of cores for evey processor) or spatial +} +dictionary["mesh"] = { + "Lz": 2.0, # depth in km - always positive # Como ver isso sem ler a malha? + "Lx": 2.0, # width in km - always positive + "Ly": 0.0, # thickness in km - always positive + "mesh_file": None, + "mesh_type": "firedrake_mesh", +} +dictionary["acquisition"] = { + "source_type": "ricker", + "source_locations": spyro.create_transect((-0.55, 0.7), (-0.55, 1.3), 6), + "frequency": 5.0, + "delay": 0.2, + "delay_type": "time", + "receiver_locations": spyro.create_transect((-1.45, 0.7), (-1.45, 1.3), 200), +} +dictionary["time_axis"] = { + "initial_time": 0.0, # Initial time for event + "final_time": final_time, # Final time for event + "dt": 0.001, # timestep size + "amplitude": 1, # the Ricker has an amplitude of 1. + "output_frequency": 100, # how frequently to output solution to pvds - Perguntar Daiane ''post_processing_frequnecy' + "gradient_sampling_frequency": 1, # how frequently to save solution to RAM - Perguntar Daiane 'gradient_sampling_frequency' +} +dictionary["visualization"] = { + "forward_output": False, + "forward_output_filename": "results/forward_output.pvd", + "fwi_velocity_model_output": False, + "velocity_model_filename": None, + "gradient_output": False, + "gradient_filename": "results/Gradient.pvd", + "adjoint_output": False, + "adjoint_filename": None, + "debug_output": False, +} +dictionary["inversion"] = { + "perform_fwi": True, # switch to true to make a FWI + "initial_guess_model_file": None, + "shot_record_file": None, +} + + +def test_fwi(load_real_shot=False, use_rol=False): + """ + Run the Full Waveform Inversion (FWI) test. + + Parameters + ---------- + load_real_shot (bool, optional): Whether to load a real shot record or not. Defaults to False. + """ + + # Setting up to run synthetic real problem + if load_real_shot is False: + FWI_obj = spyro.FullWaveformInversion(dictionary=dictionary) + + FWI_obj.set_real_mesh(mesh_parameters={"dx": 0.1}) + center_z = -1.0 + center_x = 1.0 + mesh_z = FWI_obj.mesh_z + mesh_x = FWI_obj.mesh_x + cond = fire.conditional((mesh_z-center_z)**2 + (mesh_x-center_x)**2 < .2**2, 3.0, 2.5) + + FWI_obj.set_real_velocity_model(conditional=cond, output=True, dg_velocity_model=False) + FWI_obj.generate_real_shot_record( + plot_model=True, + filename="True_experiment.png", + abc_points=[(-0.5, 0.5), (-1.5, 0.5), (-1.5, 1.5), (-0.5, 1.5)] + ) + np.save("real_shot_record", FWI_obj.real_shot_record) + + else: + dictionary["inversion"]["shot_record_file"] = "real_shot_record.npy" + FWI_obj = spyro.FullWaveformInversion(dictionary=dictionary) + + # Setting up initial guess problem + FWI_obj.set_guess_mesh(mesh_parameters={"dx": 0.1}) + FWI_obj.set_guess_velocity_model(constant=2.5) + mask_boundaries = { + "z_min": -1.3, + "z_max": -0.7, + "x_min": 0.7, + "x_max": 1.3, + } + FWI_obj.set_gradient_mask(boundaries=mask_boundaries) + if use_rol: + FWI_obj.run_fwi_rol(vmin=2.5, vmax=3.0, maxiter=2) + else: + FWI_obj.run_fwi(vmin=2.5, vmax=3.0, maxiter=5) + + # simple mask test + grad_test = FWI_obj.gradient + test0 = np.isclose(grad_test.at((-0.1, 0.1)), 0.0) + print(f"PML looks masked: {test0}", flush=True) + test1 = np.abs(grad_test.at((-1.0, 1.0))) > 1e-5 + print(f"Center looks unmasked: {test1}", flush=True) + + # quick look at functional and if it reduced + test2 = FWI_obj.functional < 1e-3 + print(f"Last functional small: {test2}", flush=True) + test3 = FWI_obj.functional_history[-1]/FWI_obj.functional_history[0] < 1e-2 + print(f"Considerable functional reduction during test: {test3}", flush=True) + + print("END", flush=True) + assert all([test0, test1, test2, test3]) + + +@pytest.mark.skipif(not is_rol_installed(), reason="ROL is not installed") +def test_fwi_with_rol(load_real_shot=False, use_rol=True): + test_fwi(load_real_shot=load_real_shot, use_rol=use_rol) + + +if __name__ == "__main__": + test_fwi(load_real_shot=False) + test_fwi_with_rol() diff --git a/test_parallel/test_forward.py b/test_parallel/test_forward.py index 91287542..1760dd52 100644 --- a/test_parallel/test_forward.py +++ b/test_parallel/test_forward.py @@ -91,7 +91,7 @@ def test_forward_3_shots(): error = error_calc(arr0[:430], analytical_p[:430], 430) if comm.comm.rank == 0: - print(f"Error for shot {Wave_obj.current_source} is {error} and test has passed equals {np.abs(error) < 0.01}", flush=True) + print(f"Error for shot {Wave_obj.current_sources} is {error} and test has passed equals {np.abs(error) < 0.01}", flush=True) error_all = COMM_WORLD.allreduce(error, op=MPI.SUM) error_all /= 3 diff --git a/test_parallel/test_forward_multiple_serial_shots.py b/test_parallel/test_forward_multiple_serial_shots.py new file mode 100644 index 00000000..1746ab66 --- /dev/null +++ b/test_parallel/test_forward_multiple_serial_shots.py @@ -0,0 +1,118 @@ +from mpi4py.MPI import COMM_WORLD +from mpi4py import MPI +import numpy as np +import firedrake as fire +import spyro +import matplotlib.pyplot as plt + + +def error_calc(p_numerical, p_analytical, nt): + norm = np.linalg.norm(p_numerical, 2) / np.sqrt(nt) + error_time = np.linalg.norm(p_analytical - p_numerical, 2) / np.sqrt(nt) + div_error_time = error_time / norm + return div_error_time + + +def test_forward_3_shots(): + final_time = 1.0 + + dictionary = {} + dictionary["options"] = { + "cell_type": "Q", # simplexes such as triangles or tetrahedra (T) or quadrilaterals (Q) + "variant": "lumped", # lumped, equispaced or DG, default is lumped + "degree": 4, # p order + "dimension": 2, # dimension + } + dictionary["parallelism"] = { + "type": "spatial", # options: automatic (same number of cores for evey processor) or spatial + "shot_ids_per_propagation": [[0], [1]], + } + dictionary["mesh"] = { + "Lz": 2.0, # depth in km - always positive # Como ver isso sem ler a malha? + "Lx": 2.0, # width in km - always positive + "Ly": 0.0, # thickness in km - always positive + "mesh_file": None, + "mesh_type": "firedrake_mesh", + } + dictionary["acquisition"] = { + "source_type": "ricker", + "source_locations": spyro.create_transect((-0.55, 0.7), (-0.55, 1.3), 3), + "frequency": 5.0, + "delay": 0.2, + "delay_type": "time", + "receiver_locations": spyro.create_transect((-0.75, 0.7), (-0.75, 1.3), 200), + } + dictionary["time_axis"] = { + "initial_time": 0.0, # Initial time for event + "final_time": final_time, # Final time for event + "dt": 0.0005, # timestep size + "amplitude": 1, # the Ricker has an amplitude of 1. + "output_frequency": 100, # how frequently to output solution to pvds - Perguntar Daiane ''post_processing_frequnecy' + "gradient_sampling_frequency": 1, + } + dictionary["visualization"] = { + "forward_output": False, + "forward_output_filename": "results/forward_output.pvd", + "fwi_velocity_model_output": False, + "velocity_model_filename": None, + "gradient_output": False, + "gradient_filename": None, + } + + Wave_obj = spyro.AcousticWave(dictionary=dictionary) + Wave_obj.set_mesh(mesh_parameters={"dx": 0.1}) + + mesh_z = Wave_obj.mesh_z + cond = fire.conditional(mesh_z < -1.5, 3.5, 1.5) + Wave_obj.set_initial_velocity_model(conditional=cond, output=True) + + Wave_obj.forward_solve() + + comm = Wave_obj.comm + + if comm.comm.rank == 0: + analytical_p = spyro.utils.nodal_homogeneous_analytical( + Wave_obj, 0.2, 1.5, n_extra=100 + ) + else: + analytical_p = None + analytical_p = comm.comm.bcast(analytical_p, root=0) + + time_vector = np.linspace(0.0, 1.0, 2001) + cutoff = 830 + errors = [] + + for i in range(Wave_obj.number_of_sources): + plt.close() + plt.plot(time_vector[:cutoff], analytical_p[:cutoff], "--", label="analyt") + spyro.io.switch_serial_shot(Wave_obj, i) + rec_out = Wave_obj.forward_solution_receivers + if i == 0: + rec0 = rec_out[:, 0].flatten() + elif i == 1: + rec0 = rec_out[:, 99].flatten() + elif i == 2: + rec0 = rec_out[:, 199].flatten() + plt.plot(time_vector[:cutoff], rec0[:cutoff], label="numerical") + plt.title(f"Source {i}") + plt.legend() + plt.savefig(f"test{i}.png") + error_core = error_calc(rec0[:cutoff], analytical_p[:cutoff], cutoff) + error = COMM_WORLD.allreduce(error_core, op=MPI.SUM) + error /= comm.comm.size + errors.append(error) + print(f"Shot {i} produced error of {error}", flush=True) + + error_all = (errors[0] + errors[1] + errors[2]) / 3 + comm.comm.barrier() + + if comm.comm.rank == 0: + print(f"Combined error for all shots is {error_all} and test has passed equals {np.abs(error_all) < 0.01}", flush=True) + + test = np.abs(error_all) < 0.01 + + assert test + + +if __name__ == "__main__": + test_forward_3_shots() diff --git a/test_parallel/test_forward_supershot.py b/test_parallel/test_forward_supershot.py new file mode 100644 index 00000000..334325fa --- /dev/null +++ b/test_parallel/test_forward_supershot.py @@ -0,0 +1,112 @@ +from mpi4py.MPI import COMM_WORLD +from mpi4py import MPI +# import debugpy +# debugpy.listen(3000 + COMM_WORLD.rank) +# debugpy.wait_for_client() +import spyro +import numpy as np +import math + + +def error_calc(p_numerical, p_analytical, nt): + norm = np.linalg.norm(p_numerical, 2) / np.sqrt(nt) + error_time = np.linalg.norm(p_analytical - p_numerical, 2) / np.sqrt(nt) + div_error_time = error_time / norm + return div_error_time + + +def test_forward_supershot(): + dt = 0.0005 + + final_time = 1.0 + + dictionary = {} + dictionary["options"] = { + "cell_type": "Q", # simplexes such as triangles or tetrahedra (T) or quadrilaterals (Q) + "variant": "lumped", # lumped, equispaced or DG, default is lumped "method":"MLT", # (MLT/spectral_quadrilateral/DG_triangle/DG_quadrilateral) You can either specify a cell_type+variant or a method + "degree": 4, # p order + "dimension": 2, # dimension + } + + # Number of cores for the shot. For simplicity, we keep things serial. + # spyro however supports both spatial parallelism and "shot" parallelism. + dictionary["parallelism"] = { + "type": "custom", # options: automatic (same number of cores for evey processor) or spatial + "shot_ids_per_propagation": [[0, 1]], + } + + # Define the domain size without the PML. Here we'll assume a 1.00 x 1.00 km + # domain and reserve the remaining 250 m for the Perfectly Matched Layer (PML) to absorb + # outgoing waves on three sides (eg., -z, +-x sides) of the domain. + dictionary["mesh"] = { + "Lz": 2.0, # depth in km - always positive # Como ver isso sem ler a malha? + "Lx": 2.0, # width in km - always positive + "Ly": 0.0, # thickness in km - always positive + "mesh_file": None, + "mesh_type": "firedrake_mesh", + } + dictionary["acquisition"] = { + "source_type": "ricker", + "source_locations": spyro.create_transect((-0.55, 0.7), (-0.55, 1.3), 2), + "frequency": 5.0, + "delay": 0.2, + "delay_type": "time", + "receiver_locations": spyro.create_transect((-0.55, 0.5), (-0.55, 1.5), 200), + } + + # Simulate for 2.0 seconds. + dictionary["time_axis"] = { + "initial_time": 0.0, # Initial time for event + "final_time": final_time, # Final time for event + "dt": dt, # timestep size + "amplitude": 1, # the Ricker has an amplitude of 1. + "output_frequency": 100, # how frequently to output solution to pvds + "gradient_sampling_frequency": 100, # how frequently to save solution to RAM + } + + dictionary["visualization"] = { + "forward_output": True, + "forward_output_filename": "results/forward_output.pvd", + "fwi_velocity_model_output": False, + "velocity_model_filename": None, + "gradient_output": False, + "gradient_filename": None, + } + + Wave_obj = spyro.AcousticWave(dictionary=dictionary) + Wave_obj.set_mesh(mesh_parameters={"dx": 0.02, "periodic": True}) + + Wave_obj.set_initial_velocity_model(constant=1.5) + Wave_obj.forward_solve() + comm = Wave_obj.comm + + rec_out = Wave_obj.receivers_output + if comm.comm.rank == 0: + analytical_p = spyro.utils.nodal_homogeneous_analytical(Wave_obj, 0.2, 1.5, n_extra=100) + else: + analytical_p = None + + analytical_p = comm.comm.bcast(analytical_p, root=0) + + arr0 = rec_out[:, 0] + arr0 = arr0.flatten() + arr199 = rec_out[:, 199] + arr199 = arr199.flatten() + + error0 = error_calc(arr0[:430], analytical_p[:430], 430) + error199 = error_calc(arr199[:430], analytical_p[:430], 430) + error = error0 + error199 + error_all = COMM_WORLD.allreduce(error, op=MPI.SUM) + error_all /= 2 + comm.comm.barrier() + + if comm.comm.rank == 0: + print(f"Combined error for shots {Wave_obj.current_sources} is {error_all} and test has passed equals {np.abs(error_all) < 0.01}", flush=True) + + test = np.abs(error_all) < 0.01 + + assert test + + +if __name__ == "__main__": + test_forward_supershot() diff --git a/test_parallel/test_gradient_serialshots.py b/test_parallel/test_gradient_serialshots.py new file mode 100644 index 00000000..3a8775db --- /dev/null +++ b/test_parallel/test_gradient_serialshots.py @@ -0,0 +1,127 @@ +from mpi4py.MPI import COMM_WORLD +import numpy as np +import firedrake as fire +import random +import spyro +import warnings + + +warnings.filterwarnings("ignore") + + +final_time = 1.0 + +dictionary = {} +dictionary["options"] = { + "cell_type": "Q", # simplexes such as triangles or tetrahedra (T) or quadrilaterals (Q) + "variant": "lumped", # lumped, equispaced or DG, default is lumped + "degree": 4, # p order + "dimension": 2, # dimension +} + +dictionary["parallelism"] = { + "type": "spatial", # options: automatic (same number of cores for evey processor) or spatial + "shot_ids_per_propagation": [[0], [1]], +} + +dictionary["mesh"] = { + "Lz": 3.0, # depth in km - always positive # Como ver isso sem ler a malha? + "Lx": 3.0, # width in km - always positive + "Ly": 0.0, # thickness in km - always positive + "mesh_file": None, + "mesh_type": "firedrake_mesh", +} + +dictionary["acquisition"] = { + "source_type": "ricker", + "source_locations": [(-1.1, 1.3), (-1.1, 1.7)], + "frequency": 5.0, + "delay": 1.5, + "delay_type": "multiples_of_minimun", + "receiver_locations": spyro.create_transect((-1.8, 1.2), (-1.8, 1.8), 10), +} + +dictionary["time_axis"] = { + "initial_time": 0.0, # Initial time for event + "final_time": final_time, # Final time for event + "dt": 0.0005, # timestep size + "amplitude": 1, # the Ricker has an amplitude of 1. + "output_frequency": 100, # how frequently to output solution to pvds - Perguntar Daiane ''post_processing_frequnecy' + "gradient_sampling_frequency": 1, # how frequently to save solution to RAM - Perguntar Daiane 'gradient_sampling_frequency' +} + +dictionary["visualization"] = { + "forward_output": True, + "forward_output_filename": "results/forward_true.pvd", + "fwi_velocity_model_output": False, + "velocity_model_filename": None, + "gradient_output": False, + "gradient_filename": "results/Gradient.pvd", + "adjoint_output": False, + "adjoint_filename": None, + "debug_output": False, +} + + +def get_gradient(parallelism_type, points): + + dictionary["parallelism"]["type"] = parallelism_type + print(f"Calculating exact", flush=True) + Wave_obj_exact = spyro.AcousticWave(dictionary=dictionary) + Wave_obj_exact.set_mesh(mesh_parameters={"dx": 0.1}) + + cond = fire.conditional(Wave_obj_exact.mesh_z > -1.5, 1.5, 3.5) + Wave_obj_exact.set_initial_velocity_model( + conditional=cond, + ) + + Wave_obj_exact.forward_solve() + + print(f"Calculating guess", flush=True) + Wave_obj_guess = spyro.AcousticWave(dictionary=dictionary) + Wave_obj_guess.set_mesh(mesh_parameters={"dx": 0.1}) + Wave_obj_guess.set_initial_velocity_model(constant=2.0) + Wave_obj_guess.forward_solve() + + if parallelism_type == "automatic": + misfit = Wave_obj_exact.forward_solution_receivers - Wave_obj_guess.forward_solution_receivers + elif parallelism_type == "spatial": + misfit_list = [] + for source_id in range(len(dictionary["acquisition"]["source_locations"])): + spyro.io.switch_serial_shot(Wave_obj_exact, source_id) + spyro.io.switch_serial_shot(Wave_obj_guess, source_id) + misfit_list.append(Wave_obj_exact.forward_solution_receivers - Wave_obj_guess.forward_solution_receivers) + misfit= misfit_list + + gradient = Wave_obj_guess.gradient_solve(misfit=misfit) + Wave_obj_guess.comm.comm.barrier() + spyro.io.delete_tmp_files(Wave_obj_guess) + spyro.io.delete_tmp_files(Wave_obj_exact) + + gradient_point_values = [] + for point in points: + gradient_point_values.append(gradient.at(point)) + + return gradient_point_values + + +def test_gradient_serialshots(): + comm = COMM_WORLD + rank = comm.Get_rank() + if rank == 0: + points = [(random.uniform(-3, 0), random.uniform(0, 3)) for _ in range(20)] + else: + points = None + points = comm.bcast(points, root=0) + gradient_ensemble_parallelism = get_gradient("automatic", points) + gradient_serial_shot = get_gradient("spatial", points) + + # Check if the gradients are equal within a tolerance + tolerance = 1e-8 + test = all(np.isclose(a, b, atol=tolerance) for a, b in zip(gradient_ensemble_parallelism, gradient_serial_shot)) + + print(f"Gradient is equal: {test}", flush=True) + + +if __name__ == "__main__": + test_gradient_serialshots() diff --git a/test_parallel/test_parallel_io.py b/test_parallel/test_parallel_io.py new file mode 100644 index 00000000..6482e962 --- /dev/null +++ b/test_parallel/test_parallel_io.py @@ -0,0 +1,29 @@ +import spyro + + +def test_saving_and_loading_shot_record(): + from test.inputfiles.model import dictionary + + dictionary["parallelism"]["type"] = "custom" + dictionary["parallelism"]["shot_ids_per_propagation"] = [[0, 1]] + dictionary["time_axis"]["final_time"] = 0.5 + dictionary["acquisition"]["source_locations"] = [(-0.5, 0.4), (-0.5, 0.6)] + dictionary["acquisition"]["receiver_locations"] = spyro.create_transect((-0.55, 0.1), (-0.55, 0.9), 200) + + wave = spyro.AcousticWave(dictionary=dictionary) + wave.set_mesh(mesh_parameters={"dx": 0.02}) + wave.set_initial_velocity_model(constant=1.5) + wave.forward_solve() + spyro.io.save_shots(wave, file_name="test_shot_record") + shots1 = wave.forward_solution_receivers + + wave2 = spyro.AcousticWave(dictionary=dictionary) + wave2.set_mesh(mesh_parameters={"dx": 0.02}) + spyro.io.load_shots(wave2, file_name="test_shot_record") + shots2 = wave.forward_solution_receivers + + assert (shots1 == shots2).all() + + +if __name__ == "__main__": + test_saving_and_loading_shot_record() diff --git a/test_parallel/test_supershot_grad.py b/test_parallel/test_supershot_grad.py new file mode 100644 index 00000000..a2a8b07b --- /dev/null +++ b/test_parallel/test_supershot_grad.py @@ -0,0 +1,166 @@ +import numpy as np +import math +import matplotlib.pyplot as plt +from copy import deepcopy +from firedrake import File +import firedrake as fire +import spyro + + +def check_gradient(Wave_obj_guess, dJ, rec_out_exact, Jm, plot=False): + steps = [1e-3, 1e-4, 1e-5] # step length + + errors = [] + V_c = Wave_obj_guess.function_space + dm = fire.Function(V_c) + size, = np.shape(dm.dat.data[:]) + dm_data = np.random.rand(size) + dm.dat.data[:] = dm_data + # dm.assign(dJ) + + for step in steps: + + Wave_obj_guess.reset_pressure() + c_guess = fire.Constant(2.0) + step*dm + Wave_obj_guess.initial_velocity_model = c_guess + Wave_obj_guess.forward_solve() + misfit_plusdm = rec_out_exact - Wave_obj_guess.receivers_output + J_plusdm = spyro.utils.compute_functional(Wave_obj_guess, misfit_plusdm) + + grad_fd = (J_plusdm - Jm) / (step) + projnorm = fire.assemble(dJ * dm * fire.dx(scheme=Wave_obj_guess.quadrature_rule)) + + error = 100 * ((grad_fd - projnorm) / projnorm) + + errors.append(error) + + errors = np.array(errors) + + # Checking if error is first order in step + theory = [t for t in steps] + theory = [errors[0] * th / theory[0] for th in theory] + if plot: + plt.close() + plt.plot(steps, errors, label="Error") + plt.plot(steps, theory, "--", label="first order") + plt.legend() + plt.title(" Adjoint gradient versus finite difference gradient") + plt.xlabel("Step") + plt.ylabel("Error %") + plt.savefig("gradient_error_verification.png") + plt.close() + + # Checking if every error is less than 1 percent + + test1 = abs(errors[-1]) < 1 + print(f"Last gradient error less than 1 percent: {test1}") + + # Checking if error follows expected finite difference error convergence + test2 = math.isclose(np.log(abs(theory[-1])), np.log(abs(errors[-1])), rel_tol=1e-1) + + print(f"Gradient error behaved as expected: {test2}") + + assert all([test1, test2]) + + +final_time = 1.0 + +dictionary = {} +dictionary["options"] = { + "cell_type": "Q", # simplexes such as triangles or tetrahedra (T) or quadrilaterals (Q) + "variant": "lumped", # lumped, equispaced or DG, default is lumped + "degree": 4, # p order + "dimension": 2, # dimension +} + +dictionary["parallelism"] = { + "type": "custom", # options: automatic (same number of cores for evey processor) or spatial + "shot_ids_per_propagation": [[0, 1]], +} + +dictionary["mesh"] = { + "Lz": 3.0, # depth in km - always positive # Como ver isso sem ler a malha? + "Lx": 3.0, # width in km - always positive + "Ly": 0.0, # thickness in km - always positive + "mesh_file": None, + "mesh_type": "firedrake_mesh", +} + +dictionary["acquisition"] = { + "source_type": "ricker", + "source_locations": [(-1.1, 1.3), (-1.1, 1.7)], + "frequency": 5.0, + "delay": 1.5, + "delay_type": "multiples_of_minimun", + "receiver_locations": spyro.create_transect((-1.8, 1.2), (-1.8, 1.8), 10), +} + +dictionary["time_axis"] = { + "initial_time": 0.0, # Initial time for event + "final_time": final_time, # Final time for event + "dt": 0.0005, # timestep size + "amplitude": 1, # the Ricker has an amplitude of 1. + "output_frequency": 100, # how frequently to output solution to pvds - Perguntar Daiane ''post_processing_frequnecy' + "gradient_sampling_frequency": 1, # how frequently to save solution to RAM - Perguntar Daiane 'gradient_sampling_frequency' +} + +dictionary["visualization"] = { + "forward_output": False, + "forward_output_filename": "results/forward_output.pvd", + "fwi_velocity_model_output": False, + "velocity_model_filename": None, + "gradient_output": False, + "gradient_filename": "results/Gradient.pvd", + "adjoint_output": False, + "adjoint_filename": None, + "debug_output": False, +} + + +def get_forward_model(load_true=False): + if load_true is False: + Wave_obj_exact = spyro.AcousticWave(dictionary=dictionary) + Wave_obj_exact.set_mesh(mesh_parameters={"dx": 0.1}) + # Wave_obj_exact.set_initial_velocity_model(constant=3.0) + cond = fire.conditional(Wave_obj_exact.mesh_z > -1.5, 1.5, 3.5) + Wave_obj_exact.set_initial_velocity_model( + conditional=cond, + # output=True + ) + # spyro.plots.plot_model(Wave_obj_exact, abc_points=[(-1, 1), (-2, 1), (-2, 4), (-1, 2)]) + Wave_obj_exact.forward_solve() + # forward_solution_exact = Wave_obj_exact.forward_solution + rec_out_exact = Wave_obj_exact.receivers_output + # np.save("rec_out_exact", rec_out_exact) + + else: + rec_out_exact = np.load("rec_out_exact.npy") + + Wave_obj_guess = spyro.AcousticWave(dictionary=dictionary) + Wave_obj_guess.set_mesh(mesh_parameters={"dx": 0.1}) + Wave_obj_guess.set_initial_velocity_model(constant=2.0) + Wave_obj_guess.forward_solve() + rec_out_guess = Wave_obj_guess.receivers_output + + return rec_out_exact, rec_out_guess, Wave_obj_guess + + +def test_gradient_supershot(): + rec_out_exact, rec_out_guess, Wave_obj_guess = get_forward_model(load_true=False) + forward_solution = Wave_obj_guess.forward_solution + forward_solution_guess = deepcopy(forward_solution) + + misfit = rec_out_exact - rec_out_guess + + Jm = spyro.utils.compute_functional(Wave_obj_guess, misfit) + print(f"Cost functional : {Jm}") + + # compute the gradient of the control (to be verified) + dJ = Wave_obj_guess.gradient_solve(misfit=misfit, forward_solution=forward_solution_guess) + File("gradient.pvd").write(dJ) + + check_gradient(Wave_obj_guess, dJ, rec_out_exact, Jm, plot=True) + + +if __name__ == "__main__": + test_gradient_supershot()