Skip to content

Commit

Permalink
fixing and documenting parallel io
Browse files Browse the repository at this point in the history
  • Loading branch information
Olender committed Jul 4, 2024
1 parent ff13956 commit 76be77e
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 50 deletions.
5 changes: 2 additions & 3 deletions spyro/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,14 @@
load_shots,
read_mesh,
interpolate,
# ensemble_forward,
# ensemble_forward_ad,
# ensemble_forward_elastic_waves,
ensemble_gradient,
# ensemble_gradient_elastic_waves,
ensemble_plot,
parallel_print,
saving_source_and_receiver_location_in_csv,
switch_serial_shot,
ensemble_save_or_load,
)
from .model_parameters import Model_parameters
from .backwards_compatibility_io import Dictionary_conversion
Expand All @@ -29,7 +28,6 @@
"load_shots",
"read_mesh",
"interpolate",
# "ensemble_forward",
# "ensemble_forward_ad",
# "ensemble_forward_elastic_waves",
"ensemble_gradient",
Expand All @@ -43,4 +41,5 @@
"boundary_layer_io",
"saving_source_and_receiver_location_in_csv",
"switch_serial_shot",
"ensemble_save_or_load",
]
68 changes: 35 additions & 33 deletions spyro/io/basicio.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,36 +14,18 @@ def ensemble_save_or_load(func):
"""Decorator for read and write shots for ensemble parallelism"""

def wrapper(*args, **kwargs):
shot_ids_per_propagation_list = args[0].shot_ids_per_propagation
_comm = args[0].comm
file_name = kwargs.get("file_name")
for propagation_id, shot_ids_in_propagation in enumerate(shot_ids_per_propagation_list):
if is_owner(_comm, propagation_id) and _comm.comm.rank == 0:
func(
*args,
**dict(
kwargs,
propagation_id=propagation_id,
file_name="shots/"
+ file_name
+ str(shot_ids_in_propagation)
+ ".dat",
)
)

return wrapper


def ensemble_plot(func):
"""Decorator for `plot_shots` to distribute shots for
ensemble parallelism"""

def wrapper(*args, **kwargs):
num = args[0].number_of_sources
_comm = args[0].comm
for snum in range(num):
if is_owner(_comm, snum) and _comm.comm.rank == 0:
func(*args, **dict(kwargs, file_name=str(snum + 1)))
if args[0].parallelism_type != "spatial" or args[0].number_of_sources == 1:
shot_ids_per_propagation_list = args[0].shot_ids_per_propagation
for propagation_id, shot_ids_in_propagation in enumerate(shot_ids_per_propagation_list):
if is_owner(_comm, propagation_id) and _comm.comm.rank == 0:
func(*args, **dict(kwargs, shot_ids=shot_ids_in_propagation))
elif args[0].parallelism_type == "spatial" and args[0].number_of_sources > 1:
for snum in range(args[0].number_of_sources):
switch_serial_shot(args[0], snum)
if _comm.comm.rank == 0:
func(*args, **dict(kwargs, shot_ids=[snum]))

return wrapper

Expand All @@ -67,13 +49,23 @@ def wrapper(*args, **kwargs):
args[0].current_time = starting_time
u, u_r = func(*args, **dict(kwargs, source_nums=[snum]))
save_serial_data(args[0], snum)

return u, u_r

return wrapper


def save_serial_data(wave, propagation_id):
"""
Save serial data to numpy files.
Args:
wave (Wave): The wave object containing the forward solution.
propagation_id (int): The propagation ID.
Returns:
None
"""
arrays_list = [obj.dat.data[:] for obj in wave.forward_solution]
stacked_arrays = np.stack(arrays_list, axis=0)
spatialcomm = wave.comm.comm.rank
Expand All @@ -82,14 +74,23 @@ def save_serial_data(wave, propagation_id):


def switch_serial_shot(wave, propagation_id):
"""
Switches the current serial shot for a given wave to shot identified with propagation ID.
Args:
wave (Wave): The wave object.
propagation_id (int): The propagation ID.
Returns:
None
"""
spatialcomm = wave.comm.comm.rank
stacked_shot_arrays = np.load(f'tmp_shot{propagation_id}_comm{spatialcomm}.npy')
for array_i, array in enumerate(stacked_shot_arrays):
wave.forward_solution[array_i].dat.data[:] = array
wave.forward_solution_receivers = np.load(f"tmp_rec{propagation_id}_comm{spatialcomm}.npy")



def ensemble_gradient(func):
"""Decorator for gradient to distribute shots for ensemble parallelism"""

Expand Down Expand Up @@ -195,7 +196,7 @@ def create_segy(velocity, filename):


@ensemble_save_or_load
def save_shots(Wave_obj, propagation_id=0, file_name="shot_record_"):
def save_shots(Wave_obj, file_name="shots/shot_record_", shot_ids=0):
"""Save a the shot record from last forward solve to a `pickle`.
Parameters
Expand All @@ -212,13 +213,14 @@ def save_shots(Wave_obj, propagation_id=0, file_name="shot_record_"):
None
"""
file_name = file_name + str(shot_ids) + ".dat"
with open(file_name, "wb") as f:
pickle.dump(Wave_obj.forward_solution_receivers[:, propagation_id], f)
pickle.dump(Wave_obj.forward_solution_receivers, f)
return None


@ensemble_save_or_load
def load_shots(Wave_obj, propagation_id=0, file_name=None):
def load_shots(Wave_obj, file_name=None):
"""Load a `pickle` to a `numpy.ndarray`.
Parameters
Expand Down
11 changes: 6 additions & 5 deletions spyro/plots/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,17 @@
import numpy as np
import firedrake
import copy
from ..io import ensemble_plot
from ..io import ensemble_save_or_load

__all__ = ["plot_shots"]


@ensemble_plot
@ensemble_save_or_load
def plot_shots(
Wave_object,
show=False,
file_name="1",
file_name="plot_of_shot",
shot_ids=[0],
vmin=-1e-5,
vmax=1e-5,
contour_lines=700,
Expand Down Expand Up @@ -52,7 +53,7 @@ def plot_shots(
-------
None
"""

file_name = file_name + str(shot_ids) + "." + file_format
num_recvs = Wave_object.number_of_receivers

dt = Wave_object.dt
Expand All @@ -79,7 +80,7 @@ def plot_shots(
plt.xlim(start_index, end_index)
plt.ylim(tf, 0)
plt.subplots_adjust(left=0.18, right=0.95, bottom=0.14, top=0.95)
plt.savefig(file_name + "." + file_format, format=file_format)
plt.savefig(file_name, format=file_format)
# plt.axis("image")
if show:
plt.show()
Expand Down
8 changes: 4 additions & 4 deletions test_forward_multiple_serial_shots.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,11 @@ def test_forward_3_shots():
spyro.io.switch_serial_shot(Wave_obj, i)
rec_out = Wave_obj.forward_solution_receivers
if i == 0:
rec0 = rec_out[:, 0].flatten()
rec0 = rec_out[:, 0].flatten()
elif i == 1:
rec0 = rec_out[:, 99].flatten()
rec0 = rec_out[:, 99].flatten()
elif i == 2:
rec0 = rec_out[:, 199].flatten()
rec0 = rec_out[:, 199].flatten()
plt.plot(time_vector[:cutoff], rec0[:cutoff], label="numerical")
plt.title(f"Source {i}")
plt.legend()
Expand All @@ -102,7 +102,7 @@ def test_forward_3_shots():
error /= comm.comm.size
errors.append(error)
print(f"Shot {i} produced error of {error}", flush=True)

error_all = (errors[0] + errors[1] + errors[2]) / 3
comm.comm.barrier()

Expand Down
10 changes: 5 additions & 5 deletions test_parallel/test_forward_multiple_serial_shots.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,15 @@ def test_forward_3_shots():

for i in range(Wave_obj.number_of_sources):
plt.close()
plt.plot(time_vector[:cutoff], analytical_p[:cutoff], "--",label="analyt")
plt.plot(time_vector[:cutoff], analytical_p[:cutoff], "--", label="analyt")
spyro.io.switch_serial_shot(Wave_obj, i)
rec_out = Wave_obj.forward_solution_receivers
if i == 0:
rec0 = rec_out[:, 0].flatten()
rec0 = rec_out[:, 0].flatten()
elif i == 1:
rec0 = rec_out[:, 99].flatten()
rec0 = rec_out[:, 99].flatten()
elif i == 2:
rec0 = rec_out[:, 199].flatten()
rec0 = rec_out[:, 199].flatten()
plt.plot(time_vector[:cutoff], rec0[:cutoff], label="numerical")
plt.title(f"Source {i}")
plt.legend()
Expand All @@ -102,7 +102,7 @@ def test_forward_3_shots():
error /= comm.comm.size
errors.append(error)
print(f"Shot {i} produced error of {error}", flush=True)

error_all = (errors[0] + errors[1] + errors[2]) / 3
comm.comm.barrier()

Expand Down

0 comments on commit 76be77e

Please sign in to comment.