Skip to content

Commit

Permalink
finished supershot forward and added tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Olender committed Jul 1, 2024
1 parent 8c3ffca commit 76125b7
Show file tree
Hide file tree
Showing 6 changed files with 234 additions and 22 deletions.
21 changes: 7 additions & 14 deletions spyro/io/basicio.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,20 +79,13 @@ def ensemble_propagator(func):
"""Decorator for forward to distribute shots for ensemble parallelism"""

def wrapper(*args, **kwargs):
if args[0].parallelism_type == "automatic":
num = args[0].number_of_sources
_comm = args[0].comm
for snum in range(num):
if is_owner(_comm, snum):
u, u_r = func(*args, **dict(kwargs, source_nums=[snum]))
return u, u_r
elif args[0].parallelism_type == "custom":
shot_ids_per_propagation_list = args[0].shot_ids_per_propagation
_comm = args[0].comm
for shot_ids_in_propagation in shot_ids_per_propagation_list:
if is_owner(_comm, shot_ids_in_propagation):
u, u_r = func(*args, **dict(kwargs, source_nums=shot_ids_in_propagation))
return u, u_r
# if args[0].parallelism_type == "custom":
shot_ids_per_propagation_list = args[0].shot_ids_per_propagation
_comm = args[0].comm
for propagation_id, shot_ids_in_propagation in enumerate(shot_ids_per_propagation_list):
if is_owner(_comm, propagation_id):
u, u_r = func(*args, **dict(kwargs, source_nums=shot_ids_in_propagation))
return u, u_r

return wrapper

Expand Down
4 changes: 1 addition & 3 deletions spyro/io/model_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,9 +549,7 @@ def _sanitize_comm(self, comm):
else:
shot_ids_per_propagation = []
available_cores = COMM_WORLD.size
num_cores_per_propagation = available_cores / self.number_of_sources
for shot in range(self.number_of_sources):

self.shot_ids_per_propagation = [[i] for i in range(0, available_cores)]

if comm is None:
self.comm = utils.mpi_init(self)
Expand Down
3 changes: 2 additions & 1 deletion spyro/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ def mpi_init(model):
elif model.parallelism_type == "custom":
shot_ids_per_propagation = model.shot_ids_per_propagation
num_max_shots_per_core = max(len(sublist) for sublist in shot_ids_per_propagation)
num_cores_per_propagation = len(shot_ids_per_propagation)
num_propagations = len(shot_ids_per_propagation)
num_cores_per_propagation = available_cores / num_propagations

comm_ens = Ensemble(COMM_WORLD, num_cores_per_propagation) # noqa: F405
return comm_ens
Expand Down
8 changes: 4 additions & 4 deletions temp_forward_shot.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from mpi4py.MPI import COMM_WORLD
import debugpy
debugpy.listen(3000 + COMM_WORLD.rank)
debugpy.wait_for_client()
# from mpi4py.MPI import COMM_WORLD
# import debugpy
# debugpy.listen(3000 + COMM_WORLD.rank)
# debugpy.wait_for_client()
import spyro
import numpy as np
import math
Expand Down
110 changes: 110 additions & 0 deletions test_forward_supershot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from mpi4py.MPI import COMM_WORLD
from mpi4py import MPI
# import debugpy
# debugpy.listen(3000 + COMM_WORLD.rank)
# debugpy.wait_for_client()
import spyro
import numpy as np
import math


def error_calc(p_numerical, p_analytical, nt):
norm = np.linalg.norm(p_numerical, 2) / np.sqrt(nt)
error_time = np.linalg.norm(p_analytical - p_numerical, 2) / np.sqrt(nt)
div_error_time = error_time / norm
return div_error_time


def test_forward_supershot():
dt = 0.0005

final_time = 1.0

dictionary = {}
dictionary["options"] = {
"cell_type": "Q", # simplexes such as triangles or tetrahedra (T) or quadrilaterals (Q)
"variant": "lumped", # lumped, equispaced or DG, default is lumped "method":"MLT", # (MLT/spectral_quadrilateral/DG_triangle/DG_quadrilateral) You can either specify a cell_type+variant or a method
"degree": 4, # p order
"dimension": 2, # dimension
}

# Number of cores for the shot. For simplicity, we keep things serial.
# spyro however supports both spatial parallelism and "shot" parallelism.
dictionary["parallelism"] = {
"type": "custom", # options: automatic (same number of cores for evey processor) or spatial
"shot_ids_per_propagation": [[0, 1, 2]],
}

# Define the domain size without the PML. Here we'll assume a 1.00 x 1.00 km
# domain and reserve the remaining 250 m for the Perfectly Matched Layer (PML) to absorb
# outgoing waves on three sides (eg., -z, +-x sides) of the domain.
dictionary["mesh"] = {
"Lz": 2.0, # depth in km - always positive # Como ver isso sem ler a malha?
"Lx": 2.0, # width in km - always positive
"Ly": 0.0, # thickness in km - always positive
"mesh_file": None,
"mesh_type": "firedrake_mesh",
}
dictionary["acquisition"] = {
"source_type": "ricker",
"source_locations": spyro.create_transect((-0.55, 0.7), (-0.55, 1.3), 2),
"frequency": 5.0,
"delay": 0.2,
"delay_type": "time",
"receiver_locations": spyro.create_transect((-0.55, 0.5), (-0.55, 1.5), 200),
}

# Simulate for 2.0 seconds.
dictionary["time_axis"] = {
"initial_time": 0.0, # Initial time for event
"final_time": final_time, # Final time for event
"dt": dt, # timestep size
"amplitude": 1, # the Ricker has an amplitude of 1.
"output_frequency": 100, # how frequently to output solution to pvds
"gradient_sampling_frequency": 100, # how frequently to save solution to RAM
}

dictionary["visualization"] = {
"forward_output": True,
"forward_output_filename": "results/forward_output.pvd",
"fwi_velocity_model_output": False,
"velocity_model_filename": None,
"gradient_output": False,
"gradient_filename": None,
}

Wave_obj = spyro.AcousticWave(dictionary=dictionary)
Wave_obj.set_mesh(mesh_parameters={"dx": 0.02, "periodic": True})

Wave_obj.set_initial_velocity_model(constant=1.5)
Wave_obj.forward_solve()
comm = Wave_obj.comm

rec_out = Wave_obj.receivers_output
if comm.comm.rank == 0:
analytical_p = spyro.utils.nodal_homogeneous_analytical(Wave_obj, 0.2, 1.5, n_extra=100)
else:
analytical_p = None

analytical_p = comm.comm.bcast(analytical_p, root=0)

arr0 = rec_out[:, 0]
arr0 = arr0.flatten()
arr199 = rec_out[:, 199]
arr199 = arr199.flatten()

error0 = error_calc(arr0[:430], analytical_p[:430], 430)
error199 = error_calc(arr199[:430], analytical_p[:430], 430)
error = error0 + error199
error_all = COMM_WORLD.allreduce(error, op=MPI.SUM)
error_all /= 2
comm.comm.barrier()

if comm.comm.rank == 0:
print(f"Combined error for shots {Wave_obj.current_sources} is {error_all} and test has passed equals {np.abs(error_all) < 0.01}", flush=True)

return rec_out


if __name__ == "__main__":
test_forward_supershot()
110 changes: 110 additions & 0 deletions test_parallel/test_forward_supershot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from mpi4py.MPI import COMM_WORLD
from mpi4py import MPI
# import debugpy
# debugpy.listen(3000 + COMM_WORLD.rank)
# debugpy.wait_for_client()
import spyro
import numpy as np
import math


def error_calc(p_numerical, p_analytical, nt):
norm = np.linalg.norm(p_numerical, 2) / np.sqrt(nt)
error_time = np.linalg.norm(p_analytical - p_numerical, 2) / np.sqrt(nt)
div_error_time = error_time / norm
return div_error_time


def test_forward_supershot():
dt = 0.0005

final_time = 1.0

dictionary = {}
dictionary["options"] = {
"cell_type": "Q", # simplexes such as triangles or tetrahedra (T) or quadrilaterals (Q)
"variant": "lumped", # lumped, equispaced or DG, default is lumped "method":"MLT", # (MLT/spectral_quadrilateral/DG_triangle/DG_quadrilateral) You can either specify a cell_type+variant or a method
"degree": 4, # p order
"dimension": 2, # dimension
}

# Number of cores for the shot. For simplicity, we keep things serial.
# spyro however supports both spatial parallelism and "shot" parallelism.
dictionary["parallelism"] = {
"type": "custom", # options: automatic (same number of cores for evey processor) or spatial
"shot_ids_per_propagation": [[0, 1, 2]],
}

# Define the domain size without the PML. Here we'll assume a 1.00 x 1.00 km
# domain and reserve the remaining 250 m for the Perfectly Matched Layer (PML) to absorb
# outgoing waves on three sides (eg., -z, +-x sides) of the domain.
dictionary["mesh"] = {
"Lz": 2.0, # depth in km - always positive # Como ver isso sem ler a malha?
"Lx": 2.0, # width in km - always positive
"Ly": 0.0, # thickness in km - always positive
"mesh_file": None,
"mesh_type": "firedrake_mesh",
}
dictionary["acquisition"] = {
"source_type": "ricker",
"source_locations": spyro.create_transect((-0.55, 0.7), (-0.55, 1.3), 2),
"frequency": 5.0,
"delay": 0.2,
"delay_type": "time",
"receiver_locations": spyro.create_transect((-0.55, 0.5), (-0.55, 1.5), 200),
}

# Simulate for 2.0 seconds.
dictionary["time_axis"] = {
"initial_time": 0.0, # Initial time for event
"final_time": final_time, # Final time for event
"dt": dt, # timestep size
"amplitude": 1, # the Ricker has an amplitude of 1.
"output_frequency": 100, # how frequently to output solution to pvds
"gradient_sampling_frequency": 100, # how frequently to save solution to RAM
}

dictionary["visualization"] = {
"forward_output": True,
"forward_output_filename": "results/forward_output.pvd",
"fwi_velocity_model_output": False,
"velocity_model_filename": None,
"gradient_output": False,
"gradient_filename": None,
}

Wave_obj = spyro.AcousticWave(dictionary=dictionary)
Wave_obj.set_mesh(mesh_parameters={"dx": 0.02, "periodic": True})

Wave_obj.set_initial_velocity_model(constant=1.5)
Wave_obj.forward_solve()
comm = Wave_obj.comm

rec_out = Wave_obj.receivers_output
if comm.comm.rank == 0:
analytical_p = spyro.utils.nodal_homogeneous_analytical(Wave_obj, 0.2, 1.5, n_extra=100)
else:
analytical_p = None

analytical_p = comm.comm.bcast(analytical_p, root=0)

arr0 = rec_out[:, 0]
arr0 = arr0.flatten()
arr199 = rec_out[:, 199]
arr199 = arr199.flatten()

error0 = error_calc(arr0[:430], analytical_p[:430], 430)
error199 = error_calc(arr199[:430], analytical_p[:430], 430)
error = error0 + error199
error_all = COMM_WORLD.allreduce(error, op=MPI.SUM)
error_all /= 2
comm.comm.barrier()

if comm.comm.rank == 0:
print(f"Combined error for shots {Wave_obj.current_sources} is {error_all} and test has passed equals {np.abs(error_all) < 0.01}", flush=True)

return rec_out


if __name__ == "__main__":
test_forward_supershot()

0 comments on commit 76125b7

Please sign in to comment.