Skip to content

Commit

Permalink
starting serial shots addition
Browse files Browse the repository at this point in the history
  • Loading branch information
Olender committed Jul 2, 2024
1 parent 31d9243 commit 9dab73e
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 14 deletions.
26 changes: 19 additions & 7 deletions spyro/io/basicio.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,14 @@ def ensemble_propagator(func):
"""Decorator for forward to distribute shots for ensemble parallelism"""

def wrapper(*args, **kwargs):
if args[0].parallelism_type != "serial":
if args[0].parallelism_type != "spatial" or args[0].number_of_sources == 1:
shot_ids_per_propagation_list = args[0].shot_ids_per_propagation
_comm = args[0].comm
for propagation_id, shot_ids_in_propagation in enumerate(shot_ids_per_propagation_list):
if is_owner(_comm, propagation_id):
u, u_r = func(*args, **dict(kwargs, source_nums=shot_ids_in_propagation))
return u, u_r
elif args[0].parallelism_type == "serial":
elif args[0].parallelism_type == "spatial" and args[0].number_of_sources > 1:
num = args[0].number_of_sources
starting_time = args[0].current_time
for snum in range(num):
Expand All @@ -78,12 +78,24 @@ def ensemble_gradient(func):
"""Decorator for gradient to distribute shots for ensemble parallelism"""

def wrapper(*args, **kwargs):
num = args[0].number_of_sources
_comm = args[0].comm
for snum in range(num):
if is_owner(_comm, snum):
if args[0].parallelism_type != "spatial" or args[0].number_of_sources == 1:
shot_ids_per_propagation_list = args[0].shot_ids_per_propagation
_comm = args[0].comm
for propagation_id, shot_ids_in_propagation in enumerate(shot_ids_per_propagation_list):
if is_owner(_comm, propagation_id):
grad = func(*args, **kwargs)
return grad
elif args[0].parallelism_type == "spatial" and args[0].number_of_sources > 1:
num = args[0].number_of_sources
starting_time = args[0].current_time
for snum in range(num):
args[0].reset_pressure()
args[0].current_time = starting_time
grad = func(*args, **kwargs)
return grad
# arrays_list = [obj.dat.data[:] for obj in u]
# stacked_arrays = np.stack(arrays_list, axis=0)
# np.save(f'tmp_shot{snum}.npy', stacked_arrays)
# np.save(f"tmp_rec{snum}.npy", u_r)

return wrapper

Expand Down
5 changes: 3 additions & 2 deletions spyro/io/model_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,10 +547,11 @@ def _sanitize_comm(self, comm):

if self.parallelism_type == "custom":
self.shot_ids_per_propagation = dictionary["parallelism"]["shot_ids_per_propagation"]
else:
shot_ids_per_propagation = []
elif self.parallelism_type == "automatic":
available_cores = COMM_WORLD.size
self.shot_ids_per_propagation = [[i] for i in range(0, available_cores)]
elif self.parallelism_type == "spatial":
self.shot_ids_per_propagation = [[i] for i in range(0, self.number_of_sources)]

if comm is None:
self.comm = utils.mpi_init(self)
Expand Down
1 change: 0 additions & 1 deletion spyro/solvers/acoustic_wave.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,4 +151,3 @@ def reset_pressure(self):
self.X_nm1.assign(0.0)
except:
warnings.warn("No mixed space pressure to reset")

8 changes: 4 additions & 4 deletions temp_forward_shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
def run_forward(dt):
# dt = float(sys.argv[1])

final_time = 1.0
final_time = 0.8

dictionary = {}
dictionary["options"] = {
Expand All @@ -23,8 +23,8 @@ def run_forward(dt):
# Number of cores for the shot. For simplicity, we keep things serial.
# spyro however supports both spatial parallelism and "shot" parallelism.
dictionary["parallelism"] = {
"type": "custom", # options: automatic (same number of cores for evey processor) or spatial
"shot_ids_per_propagation": [[0, 1, 2]],
"type": "spatial", # options: automatic (same number of cores for evey processor) or spatial
"shot_ids_per_propagation": [[0], [1]],
}

# Define the domain size without the PML. Here we'll assume a 1.00 x 1.00 km
Expand All @@ -39,7 +39,7 @@ def run_forward(dt):
}
dictionary["acquisition"] = {
"source_type": "ricker",
"source_locations": spyro.create_transect((-0.55, 0.7), (-0.55, 1.3), 3),
"source_locations": spyro.create_transect((-0.55, 0.7), (-0.55, 1.3), 2),
"frequency": 5.0,
"delay": 0.2,
"delay_type": "time",
Expand Down
105 changes: 105 additions & 0 deletions test_forward_multiple_serial_shots.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from mpi4py.MPI import COMM_WORLD
from mpi4py import MPI
import numpy as np
import firedrake as fire
import spyro


def error_calc(p_numerical, p_analytical, nt):
norm = np.linalg.norm(p_numerical, 2) / np.sqrt(nt)
error_time = np.linalg.norm(p_analytical - p_numerical, 2) / np.sqrt(nt)
div_error_time = error_time / norm
return div_error_time


def test_forward_3_shots():
final_time = 1.0

dictionary = {}
dictionary["options"] = {
"cell_type": "Q", # simplexes such as triangles or tetrahedra (T) or quadrilaterals (Q)
"variant": "lumped", # lumped, equispaced or DG, default is lumped
"degree": 4, # p order
"dimension": 2, # dimension
}
dictionary["parallelism"] = {
"type": "spatial", # options: automatic (same number of cores for evey processor) or spatial
"shot_ids_per_propagation": [[0], [1], [2]],
}
dictionary["mesh"] = {
"Lz": 2.0, # depth in km - always positive # Como ver isso sem ler a malha?
"Lx": 2.0, # width in km - always positive
"Ly": 0.0, # thickness in km - always positive
"mesh_file": None,
"mesh_type": "firedrake_mesh",
}
dictionary["acquisition"] = {
"source_type": "ricker",
"source_locations": spyro.create_transect((-0.55, 0.7), (-0.55, 1.3), 2),
"frequency": 5.0,
"delay": 0.2,
"delay_type": "time",
"receiver_locations": spyro.create_transect((-0.75, 0.7), (-0.75, 1.3), 200),
}
dictionary["time_axis"] = {
"initial_time": 0.0, # Initial time for event
"final_time": final_time, # Final time for event
"dt": 0.0005, # timestep size
"amplitude": 1, # the Ricker has an amplitude of 1.
"output_frequency": 100, # how frequently to output solution to pvds - Perguntar Daiane ''post_processing_frequnecy'
"gradient_sampling_frequency": 1,
}
dictionary["visualization"] = {
"forward_output": False,
"forward_output_filename": "results/forward_output.pvd",
"fwi_velocity_model_output": False,
"velocity_model_filename": None,
"gradient_output": False,
"gradient_filename": None,
}

Wave_obj = spyro.AcousticWave(dictionary=dictionary)
Wave_obj.set_mesh(mesh_parameters={"dx": 0.1})

mesh_z = Wave_obj.mesh_z
cond = fire.conditional(mesh_z < -1.5, 3.5, 1.5)
Wave_obj.set_initial_velocity_model(conditional=cond, output=True)

Wave_obj.forward_solve()

comm = Wave_obj.comm

arr = Wave_obj.receivers_output

if comm.comm.rank == 0:
analytical_p = spyro.utils.nodal_homogeneous_analytical(
Wave_obj, 0.2, 1.5, n_extra=100
)
else:
analytical_p = None
analytical_p = comm.comm.bcast(analytical_p, root=0)

# Checking if error before reflection matches
if comm.ensemble_comm.rank == 0:
rec_id = 0
elif comm.ensemble_comm.rank == 1:
rec_id = 150
elif comm.ensemble_comm.rank == 2:
rec_id = 300

arr0 = arr[:, rec_id]
arr0 = arr0.flatten()

error = error_calc(arr0[:430], analytical_p[:430], 430)
if comm.comm.rank == 0:
print(f"Error for shot {Wave_obj.current_sources} is {error} and test has passed equals {np.abs(error) < 0.01}", flush=True)
error_all = COMM_WORLD.allreduce(error, op=MPI.SUM)
error_all /= 3

test = np.abs(error_all) < 0.01

assert test


if __name__ == "__main__":
test_forward_3_shots()

0 comments on commit 9dab73e

Please sign in to comment.