diff --git a/spyro/io/__init__.py b/spyro/io/__init__.py index 37d58b45..a6f7f016 100644 --- a/spyro/io/__init__.py +++ b/spyro/io/__init__.py @@ -13,7 +13,8 @@ parallel_print, saving_source_and_receiver_location_in_csv, switch_serial_shot, - ensemble_save_or_load, + ensemble_save, + ensemble_load, delete_tmp_files, ensemble_shot_record, ensemble_functional @@ -44,7 +45,8 @@ "boundary_layer_io", "saving_source_and_receiver_location_in_csv", "switch_serial_shot", - "ensemble_save_or_load", + "ensemble_save", + "ensemble_load", "delete_tmp_files", "ensemble_shot_record", "ensemble_functional", diff --git a/spyro/io/basicio.py b/spyro/io/basicio.py index c55ab8b6..dd7a7467 100644 --- a/spyro/io/basicio.py +++ b/spyro/io/basicio.py @@ -34,7 +34,7 @@ def wrapper(*args, **kwargs): return wrapper -def ensemble_save_or_load(func): +def ensemble_save(func): """Decorator for read and write shots for ensemble parallelism""" def wrapper(*args, **kwargs): @@ -54,6 +54,26 @@ def wrapper(*args, **kwargs): return wrapper +def ensemble_load(func): + """Decorator for read and write shots for ensemble parallelism""" + + def wrapper(*args, **kwargs): + _comm = args[0].comm + + if args[0].parallelism_type != "spatial" or args[0].number_of_sources == 1: + shot_ids_per_propagation_list = args[0].shot_ids_per_propagation + for propagation_id, shot_ids_in_propagation in enumerate(shot_ids_per_propagation_list): + if is_owner(_comm, propagation_id): + func(*args, **dict(kwargs, shot_ids=shot_ids_in_propagation)) + elif args[0].parallelism_type == "spatial" and args[0].number_of_sources > 1: + for snum in range(args[0].number_of_sources): + switch_serial_shot(args[0], snum) + if _comm.comm.rank == 0: + func(*args, **dict(kwargs, shot_ids=[snum])) + + return wrapper + + def ensemble_propagator(func): """Decorator for forward to distribute shots for ensemble parallelism""" @@ -275,7 +295,7 @@ def create_segy(function, V, grid_spacing, filename): f.trace[tr] = velocity[:, tr] -@ensemble_save_or_load +@ensemble_save def save_shots(Wave_obj, file_name="shots/shot_record_", shot_ids=0): """Save a the shot record from last forward solve to a `pickle`. @@ -305,7 +325,7 @@ def rebuild_empty_forward_solution(wave, time_steps): wave.forward_solution.append(fire.Function(wave.function_space)) -@ensemble_save_or_load +@ensemble_load def load_shots(Wave_obj, file_name="shots/shot_record_", shot_ids=0): """Load a `pickle` to a `numpy.ndarray`. diff --git a/spyro/plots/plots.py b/spyro/plots/plots.py index b0d6c3a6..edda37c5 100644 --- a/spyro/plots/plots.py +++ b/spyro/plots/plots.py @@ -6,12 +6,12 @@ import numpy as np import firedrake import copy -from ..io import ensemble_save_or_load +from ..io import ensemble_save __all__ = ["plot_shots"] -@ensemble_save_or_load +@ensemble_save def plot_shots( Wave_object, show=False, diff --git a/spyro/solvers/inversion.py b/spyro/solvers/inversion.py index 75ad64ec..97c3353b 100644 --- a/spyro/solvers/inversion.py +++ b/spyro/solvers/inversion.py @@ -437,7 +437,7 @@ def run_fwi(self, **kwargs): vmin = parameters["vmin"] vmax = parameters["vmax"] - vp_0 = self.initial_velocity_model.vector().gather() + vp_0 = self.initial_velocity_model.vector() bounds = [(vmin, vmax) for _ in range(len(vp_0))] options = parameters["scipy_options"]