diff --git a/spyro/io/__init__.py b/spyro/io/__init__.py index f829cd72..167e5b8b 100644 --- a/spyro/io/__init__.py +++ b/spyro/io/__init__.py @@ -6,15 +6,14 @@ load_shots, read_mesh, interpolate, - # ensemble_forward, # ensemble_forward_ad, # ensemble_forward_elastic_waves, ensemble_gradient, # ensemble_gradient_elastic_waves, - ensemble_plot, parallel_print, saving_source_and_receiver_location_in_csv, switch_serial_shot, + ensemble_save_or_load, ) from .model_parameters import Model_parameters from .backwards_compatibility_io import Dictionary_conversion @@ -29,7 +28,6 @@ "load_shots", "read_mesh", "interpolate", - # "ensemble_forward", # "ensemble_forward_ad", # "ensemble_forward_elastic_waves", "ensemble_gradient", @@ -43,4 +41,5 @@ "boundary_layer_io", "saving_source_and_receiver_location_in_csv", "switch_serial_shot", + "ensemble_save_or_load", ] diff --git a/spyro/io/basicio.py b/spyro/io/basicio.py index 1f629f57..6ca87df9 100644 --- a/spyro/io/basicio.py +++ b/spyro/io/basicio.py @@ -14,36 +14,18 @@ def ensemble_save_or_load(func): """Decorator for read and write shots for ensemble parallelism""" def wrapper(*args, **kwargs): - shot_ids_per_propagation_list = args[0].shot_ids_per_propagation _comm = args[0].comm - file_name = kwargs.get("file_name") - for propagation_id, shot_ids_in_propagation in enumerate(shot_ids_per_propagation_list): - if is_owner(_comm, propagation_id) and _comm.comm.rank == 0: - func( - *args, - **dict( - kwargs, - propagation_id=propagation_id, - file_name="shots/" - + file_name - + str(shot_ids_in_propagation) - + ".dat", - ) - ) - return wrapper - - -def ensemble_plot(func): - """Decorator for `plot_shots` to distribute shots for - ensemble parallelism""" - - def wrapper(*args, **kwargs): - num = args[0].number_of_sources - _comm = args[0].comm - for snum in range(num): - if is_owner(_comm, snum) and _comm.comm.rank == 0: - func(*args, **dict(kwargs, file_name=str(snum + 1))) + if args[0].parallelism_type != "spatial" or args[0].number_of_sources == 1: + shot_ids_per_propagation_list = args[0].shot_ids_per_propagation + for propagation_id, shot_ids_in_propagation in enumerate(shot_ids_per_propagation_list): + if is_owner(_comm, propagation_id) and _comm.comm.rank == 0: + func(*args, **dict(kwargs, shot_ids=shot_ids_in_propagation)) + elif args[0].parallelism_type == "spatial" and args[0].number_of_sources > 1: + for snum in range(args[0].number_of_sources): + switch_serial_shot(args[0], snum) + if _comm.comm.rank == 0: + func(*args, **dict(kwargs, shot_ids=[snum])) return wrapper @@ -67,13 +49,23 @@ def wrapper(*args, **kwargs): args[0].current_time = starting_time u, u_r = func(*args, **dict(kwargs, source_nums=[snum])) save_serial_data(args[0], snum) - + return u, u_r return wrapper def save_serial_data(wave, propagation_id): + """ + Save serial data to numpy files. + + Args: + wave (Wave): The wave object containing the forward solution. + propagation_id (int): The propagation ID. + + Returns: + None + """ arrays_list = [obj.dat.data[:] for obj in wave.forward_solution] stacked_arrays = np.stack(arrays_list, axis=0) spatialcomm = wave.comm.comm.rank @@ -82,6 +74,16 @@ def save_serial_data(wave, propagation_id): def switch_serial_shot(wave, propagation_id): + """ + Switches the current serial shot for a given wave to shot identified with propagation ID. + + Args: + wave (Wave): The wave object. + propagation_id (int): The propagation ID. + + Returns: + None + """ spatialcomm = wave.comm.comm.rank stacked_shot_arrays = np.load(f'tmp_shot{propagation_id}_comm{spatialcomm}.npy') for array_i, array in enumerate(stacked_shot_arrays): @@ -89,7 +91,6 @@ def switch_serial_shot(wave, propagation_id): wave.forward_solution_receivers = np.load(f"tmp_rec{propagation_id}_comm{spatialcomm}.npy") - def ensemble_gradient(func): """Decorator for gradient to distribute shots for ensemble parallelism""" @@ -195,7 +196,7 @@ def create_segy(velocity, filename): @ensemble_save_or_load -def save_shots(Wave_obj, propagation_id=0, file_name="shot_record_"): +def save_shots(Wave_obj, file_name="shots/shot_record_", shot_ids=0): """Save a the shot record from last forward solve to a `pickle`. Parameters @@ -212,13 +213,14 @@ def save_shots(Wave_obj, propagation_id=0, file_name="shot_record_"): None """ + file_name = file_name + str(shot_ids) + ".dat" with open(file_name, "wb") as f: - pickle.dump(Wave_obj.forward_solution_receivers[:, propagation_id], f) + pickle.dump(Wave_obj.forward_solution_receivers, f) return None @ensemble_save_or_load -def load_shots(Wave_obj, propagation_id=0, file_name=None): +def load_shots(Wave_obj, file_name=None): """Load a `pickle` to a `numpy.ndarray`. Parameters diff --git a/spyro/plots/plots.py b/spyro/plots/plots.py index 68742439..b0d6c3a6 100644 --- a/spyro/plots/plots.py +++ b/spyro/plots/plots.py @@ -6,16 +6,17 @@ import numpy as np import firedrake import copy -from ..io import ensemble_plot +from ..io import ensemble_save_or_load __all__ = ["plot_shots"] -@ensemble_plot +@ensemble_save_or_load def plot_shots( Wave_object, show=False, - file_name="1", + file_name="plot_of_shot", + shot_ids=[0], vmin=-1e-5, vmax=1e-5, contour_lines=700, @@ -52,7 +53,7 @@ def plot_shots( ------- None """ - + file_name = file_name + str(shot_ids) + "." + file_format num_recvs = Wave_object.number_of_receivers dt = Wave_object.dt @@ -79,7 +80,7 @@ def plot_shots( plt.xlim(start_index, end_index) plt.ylim(tf, 0) plt.subplots_adjust(left=0.18, right=0.95, bottom=0.14, top=0.95) - plt.savefig(file_name + "." + file_format, format=file_format) + plt.savefig(file_name, format=file_format) # plt.axis("image") if show: plt.show() diff --git a/test_forward_multiple_serial_shots.py b/test_forward_multiple_serial_shots.py index c9a7f528..c4cfff23 100644 --- a/test_forward_multiple_serial_shots.py +++ b/test_forward_multiple_serial_shots.py @@ -88,11 +88,11 @@ def test_forward_3_shots(): spyro.io.switch_serial_shot(Wave_obj, i) rec_out = Wave_obj.forward_solution_receivers if i == 0: - rec0 = rec_out[:, 0].flatten() + rec0 = rec_out[:, 0].flatten() elif i == 1: - rec0 = rec_out[:, 99].flatten() + rec0 = rec_out[:, 99].flatten() elif i == 2: - rec0 = rec_out[:, 199].flatten() + rec0 = rec_out[:, 199].flatten() plt.plot(time_vector[:cutoff], rec0[:cutoff], label="numerical") plt.title(f"Source {i}") plt.legend() @@ -102,7 +102,7 @@ def test_forward_3_shots(): error /= comm.comm.size errors.append(error) print(f"Shot {i} produced error of {error}", flush=True) - + error_all = (errors[0] + errors[1] + errors[2]) / 3 comm.comm.barrier() diff --git a/test_parallel/test_forward_multiple_serial_shots.py b/test_parallel/test_forward_multiple_serial_shots.py index c9a7f528..1746ab66 100644 --- a/test_parallel/test_forward_multiple_serial_shots.py +++ b/test_parallel/test_forward_multiple_serial_shots.py @@ -84,15 +84,15 @@ def test_forward_3_shots(): for i in range(Wave_obj.number_of_sources): plt.close() - plt.plot(time_vector[:cutoff], analytical_p[:cutoff], "--",label="analyt") + plt.plot(time_vector[:cutoff], analytical_p[:cutoff], "--", label="analyt") spyro.io.switch_serial_shot(Wave_obj, i) rec_out = Wave_obj.forward_solution_receivers if i == 0: - rec0 = rec_out[:, 0].flatten() + rec0 = rec_out[:, 0].flatten() elif i == 1: - rec0 = rec_out[:, 99].flatten() + rec0 = rec_out[:, 99].flatten() elif i == 2: - rec0 = rec_out[:, 199].flatten() + rec0 = rec_out[:, 199].flatten() plt.plot(time_vector[:cutoff], rec0[:cutoff], label="numerical") plt.title(f"Source {i}") plt.legend() @@ -102,7 +102,7 @@ def test_forward_3_shots(): error /= comm.comm.size errors.append(error) print(f"Shot {i} produced error of {error}", flush=True) - + error_all = (errors[0] + errors[1] + errors[2]) / 3 comm.comm.barrier()