Skip to content

Commit

Permalink
added spatially parallelizable shots in serial
Browse files Browse the repository at this point in the history
  • Loading branch information
Olender committed Jul 4, 2024
1 parent 7cf61b8 commit ff13956
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 23 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/python-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ jobs:
mpiexec -n 6 pytest test_parallel/test_forward_supershot.py
mpiexec -n 2 pytest test_parallel/test_parallel_io.py
mpiexec -n 3 pytest test_parallel/test_supershot_grad.py
mpiexec -n 2 pytest test_parallel/test_forward_multiple_serial_shots.py
- name: Covering parallel 3D forward test
continue-on-error: true
run: |
Expand Down Expand Up @@ -58,6 +59,11 @@ jobs:
run: |
source /home/olender/Firedrakes/newest3/firedrake/bin/activate
mpiexec -n 3 pytest --cov-report=xml --cov-append --cov=spyro test_parallel/test_supershot_grad.py
- name: Covering spatially parallelized shots in serial
continue-on-error: true
run: |
source /home/olender/Firedrakes/newest3/firedrake/bin/activate
mpiexec -n 2 pytest --cov-report=xml --cov-append --cov=spyro test_parallel/test_forward_multiple_serial_shots.py
# - name: Running serial tests for adjoint
# run: |
# source /home/olender/Firedrakes/main/firedrake/bin/activate
Expand Down
2 changes: 2 additions & 0 deletions spyro/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
ensemble_plot,
parallel_print,
saving_source_and_receiver_location_in_csv,
switch_serial_shot,
)
from .model_parameters import Model_parameters
from .backwards_compatibility_io import Dictionary_conversion
Expand Down Expand Up @@ -41,4 +42,5 @@
"dictionaryio",
"boundary_layer_io",
"saving_source_and_receiver_location_in_csv",
"switch_serial_shot",
]
13 changes: 8 additions & 5 deletions spyro/io/basicio.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,23 +67,26 @@ def wrapper(*args, **kwargs):
args[0].current_time = starting_time
u, u_r = func(*args, **dict(kwargs, source_nums=[snum]))
save_serial_data(args[0], snum)
return u, u_r

return u, u_r

return wrapper


def save_serial_data(wave, propagation_id):
arrays_list = [obj.dat.data[:] for obj in wave.forward_solution]
stacked_arrays = np.stack(arrays_list, axis=0)
np.save(f'tmp_shot{propagation_id}.npy', stacked_arrays)
np.save(f"tmp_rec{propagation_id}.npy", wave.forward_solution_receivers)
spatialcomm = wave.comm.comm.rank
np.save(f'tmp_shot{propagation_id}_comm{spatialcomm}.npy', stacked_arrays)
np.save(f"tmp_rec{propagation_id}_comm{spatialcomm}.npy", wave.forward_solution_receivers)


def switch_serial_shot(wave, propagation_id):
stacked_shot_arrays = np.load(f'tmp_shot{propagation_id}.npy')
spatialcomm = wave.comm.comm.rank
stacked_shot_arrays = np.load(f'tmp_shot{propagation_id}_comm{spatialcomm}.npy')
for array_i, array in enumerate(stacked_shot_arrays):
wave.forward_solution[array_i].dat.data[:] = array
wave.forward_solution_receivers = np.load(f"tmp_rec{propagation_id}.npy")
wave.forward_solution_receivers = np.load(f"tmp_rec{propagation_id}_comm{spatialcomm}.npy")



Expand Down
49 changes: 31 additions & 18 deletions test_forward_multiple_serial_shots.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import firedrake as fire
import spyro
import matplotlib.pyplot as plt


def error_calc(p_numerical, p_analytical, nt):
Expand All @@ -24,7 +25,7 @@ def test_forward_3_shots():
}
dictionary["parallelism"] = {
"type": "spatial", # options: automatic (same number of cores for evey processor) or spatial
"shot_ids_per_propagation": [[0], [1], [2]],
"shot_ids_per_propagation": [[0], [1]],
}
dictionary["mesh"] = {
"Lz": 2.0, # depth in km - always positive # Como ver isso sem ler a malha?
Expand All @@ -35,7 +36,7 @@ def test_forward_3_shots():
}
dictionary["acquisition"] = {
"source_type": "ricker",
"source_locations": spyro.create_transect((-0.55, 0.7), (-0.55, 1.3), 2),
"source_locations": spyro.create_transect((-0.55, 0.7), (-0.55, 1.3), 3),
"frequency": 5.0,
"delay": 0.2,
"delay_type": "time",
Expand Down Expand Up @@ -69,8 +70,6 @@ def test_forward_3_shots():

comm = Wave_obj.comm

arr = Wave_obj.receivers_output

if comm.comm.rank == 0:
analytical_p = spyro.utils.nodal_homogeneous_analytical(
Wave_obj, 0.2, 1.5, n_extra=100
Expand All @@ -79,22 +78,36 @@ def test_forward_3_shots():
analytical_p = None
analytical_p = comm.comm.bcast(analytical_p, root=0)

# Checking if error before reflection matches
if comm.ensemble_comm.rank == 0:
rec_id = 0
elif comm.ensemble_comm.rank == 1:
rec_id = 150
elif comm.ensemble_comm.rank == 2:
rec_id = 300

arr0 = arr[:, rec_id]
arr0 = arr0.flatten()
time_vector = np.linspace(0.0, 1.0, 2001)
cutoff = 830
errors = []

for i in range(Wave_obj.number_of_sources):
plt.close()
plt.plot(time_vector[:cutoff], analytical_p[:cutoff], "--",label="analyt")
spyro.io.switch_serial_shot(Wave_obj, i)
rec_out = Wave_obj.forward_solution_receivers
if i == 0:
rec0 = rec_out[:, 0].flatten()
elif i == 1:
rec0 = rec_out[:, 99].flatten()
elif i == 2:
rec0 = rec_out[:, 199].flatten()
plt.plot(time_vector[:cutoff], rec0[:cutoff], label="numerical")
plt.title(f"Source {i}")
plt.legend()
plt.savefig(f"test{i}.png")
error_core = error_calc(rec0[:cutoff], analytical_p[:cutoff], cutoff)
error = COMM_WORLD.allreduce(error_core, op=MPI.SUM)
error /= comm.comm.size
errors.append(error)
print(f"Shot {i} produced error of {error}", flush=True)

error_all = (errors[0] + errors[1] + errors[2]) / 3
comm.comm.barrier()

error = error_calc(arr0[:430], analytical_p[:430], 430)
if comm.comm.rank == 0:
print(f"Error for shot {Wave_obj.current_sources} is {error} and test has passed equals {np.abs(error) < 0.01}", flush=True)
error_all = COMM_WORLD.allreduce(error, op=MPI.SUM)
error_all /= 3
print(f"Combined error for all shots is {error_all} and test has passed equals {np.abs(error_all) < 0.01}", flush=True)

test = np.abs(error_all) < 0.01

Expand Down
118 changes: 118 additions & 0 deletions test_parallel/test_forward_multiple_serial_shots.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from mpi4py.MPI import COMM_WORLD
from mpi4py import MPI
import numpy as np
import firedrake as fire
import spyro
import matplotlib.pyplot as plt


def error_calc(p_numerical, p_analytical, nt):
norm = np.linalg.norm(p_numerical, 2) / np.sqrt(nt)
error_time = np.linalg.norm(p_analytical - p_numerical, 2) / np.sqrt(nt)
div_error_time = error_time / norm
return div_error_time


def test_forward_3_shots():
final_time = 1.0

dictionary = {}
dictionary["options"] = {
"cell_type": "Q", # simplexes such as triangles or tetrahedra (T) or quadrilaterals (Q)
"variant": "lumped", # lumped, equispaced or DG, default is lumped
"degree": 4, # p order
"dimension": 2, # dimension
}
dictionary["parallelism"] = {
"type": "spatial", # options: automatic (same number of cores for evey processor) or spatial
"shot_ids_per_propagation": [[0], [1]],
}
dictionary["mesh"] = {
"Lz": 2.0, # depth in km - always positive # Como ver isso sem ler a malha?
"Lx": 2.0, # width in km - always positive
"Ly": 0.0, # thickness in km - always positive
"mesh_file": None,
"mesh_type": "firedrake_mesh",
}
dictionary["acquisition"] = {
"source_type": "ricker",
"source_locations": spyro.create_transect((-0.55, 0.7), (-0.55, 1.3), 3),
"frequency": 5.0,
"delay": 0.2,
"delay_type": "time",
"receiver_locations": spyro.create_transect((-0.75, 0.7), (-0.75, 1.3), 200),
}
dictionary["time_axis"] = {
"initial_time": 0.0, # Initial time for event
"final_time": final_time, # Final time for event
"dt": 0.0005, # timestep size
"amplitude": 1, # the Ricker has an amplitude of 1.
"output_frequency": 100, # how frequently to output solution to pvds - Perguntar Daiane ''post_processing_frequnecy'
"gradient_sampling_frequency": 1,
}
dictionary["visualization"] = {
"forward_output": False,
"forward_output_filename": "results/forward_output.pvd",
"fwi_velocity_model_output": False,
"velocity_model_filename": None,
"gradient_output": False,
"gradient_filename": None,
}

Wave_obj = spyro.AcousticWave(dictionary=dictionary)
Wave_obj.set_mesh(mesh_parameters={"dx": 0.1})

mesh_z = Wave_obj.mesh_z
cond = fire.conditional(mesh_z < -1.5, 3.5, 1.5)
Wave_obj.set_initial_velocity_model(conditional=cond, output=True)

Wave_obj.forward_solve()

comm = Wave_obj.comm

if comm.comm.rank == 0:
analytical_p = spyro.utils.nodal_homogeneous_analytical(
Wave_obj, 0.2, 1.5, n_extra=100
)
else:
analytical_p = None
analytical_p = comm.comm.bcast(analytical_p, root=0)

time_vector = np.linspace(0.0, 1.0, 2001)
cutoff = 830
errors = []

for i in range(Wave_obj.number_of_sources):
plt.close()
plt.plot(time_vector[:cutoff], analytical_p[:cutoff], "--",label="analyt")
spyro.io.switch_serial_shot(Wave_obj, i)
rec_out = Wave_obj.forward_solution_receivers
if i == 0:
rec0 = rec_out[:, 0].flatten()
elif i == 1:
rec0 = rec_out[:, 99].flatten()
elif i == 2:
rec0 = rec_out[:, 199].flatten()
plt.plot(time_vector[:cutoff], rec0[:cutoff], label="numerical")
plt.title(f"Source {i}")
plt.legend()
plt.savefig(f"test{i}.png")
error_core = error_calc(rec0[:cutoff], analytical_p[:cutoff], cutoff)
error = COMM_WORLD.allreduce(error_core, op=MPI.SUM)
error /= comm.comm.size
errors.append(error)
print(f"Shot {i} produced error of {error}", flush=True)

error_all = (errors[0] + errors[1] + errors[2]) / 3
comm.comm.barrier()

if comm.comm.rank == 0:
print(f"Combined error for all shots is {error_all} and test has passed equals {np.abs(error_all) < 0.01}", flush=True)

test = np.abs(error_all) < 0.01

assert test


if __name__ == "__main__":
test_forward_3_shots()

0 comments on commit ff13956

Please sign in to comment.