Skip to content

Commit

Permalink
added supershot in forward problem
Browse files Browse the repository at this point in the history
  • Loading branch information
Olender committed Jun 27, 2024
1 parent 1be6302 commit 6f32eee
Show file tree
Hide file tree
Showing 10 changed files with 342 additions and 34 deletions.
20 changes: 14 additions & 6 deletions spyro/io/basicio.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,20 @@ def ensemble_propagator(func):
"""Decorator for forward to distribute shots for ensemble parallelism"""

def wrapper(*args, **kwargs):
num = args[0].number_of_sources
_comm = args[0].comm
for snum in range(num):
if is_owner(_comm, snum):
u, u_r = func(*args, **dict(kwargs, source_num=snum))
return u, u_r
if args[0].parallelism_type == "automatic":
num = args[0].number_of_sources
_comm = args[0].comm
for snum in range(num):
if is_owner(_comm, snum):
u, u_r = func(*args, **dict(kwargs, source_nums=[snum]))
return u, u_r
elif args[0].parallelism_type == "custom":
shots_per_core_list = args[0].shots_per_core
_comm = args[0].comm
for id_shots, shots_in_core in enumerate(shots_per_core_list):
if is_owner(_comm, id_shots):
u, u_r = func(*args, **dict(kwargs, source_num=shots_in_core))
return u, u_r

return wrapper

Expand Down
8 changes: 5 additions & 3 deletions spyro/io/model_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,9 +542,11 @@ def _sanitize_comm(self, comm):
else:
warnings.warn("No paralellism type listed. Assuming automatic")
self.parallelism_type = "automatic"

if self.source_type == "MMS":
self.parallelism_type = "spatial"

if self.parallelism_type == "custom":
self.shots_per_core = dictionary["parallelism"]["shots_per_core"]
else:
self.shots_per_core = 1

if comm is None:
self.comm = utils.mpi_init(self)
Expand Down
6 changes: 3 additions & 3 deletions spyro/solvers/acoustic_wave.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def matrix_building(self):
construct_solver_or_matrix_with_pml(self)

@ensemble_propagator
def wave_propagator(self, dt=None, final_time=None, source_num=0):
def wave_propagator(self, dt=None, final_time=None, source_nums=[0]):
"""Propagates the wave forward in time.
Currently uses central differences.
Expand All @@ -109,8 +109,8 @@ def wave_propagator(self, dt=None, final_time=None, source_num=0):
if dt is not None:
self.dt = dt

self.current_source = source_num
usol, usol_recv = time_integrator(self, source_id=source_num)
self.current_sources = source_nums
usol, usol_recv = time_integrator(self, source_ids=source_nums)

return usol, usol_recv

Expand Down
16 changes: 8 additions & 8 deletions spyro/solvers/time_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,24 @@
from .time_integration_central_difference import central_difference_MMS


def time_integrator(Wave_object, source_id=0):
def time_integrator(Wave_object, source_ids=[0]):
if Wave_object.source_type == "ricker":
return time_integrator_ricker(Wave_object, source_id=source_id)
return time_integrator_ricker(Wave_object, source_ids=source_ids)
elif Wave_object.source_type == "MMS":
return time_integrator_mms(Wave_object, source_id=source_id)
return time_integrator_mms(Wave_object, source_ids=source_ids)


def time_integrator_ricker(Wave_object, source_id=0):
def time_integrator_ricker(Wave_object, source_ids=[0]):
if Wave_object.time_integrator == "central_difference":
return central_difference(Wave_object, source_id=source_id)
return central_difference(Wave_object, source_ids=source_ids)
elif Wave_object.time_integrator == "mixed_space_central_difference":
return mixed_space_central_difference(Wave_object, source_id=source_id)
return mixed_space_central_difference(Wave_object, source_ids=source_ids)
else:
raise ValueError("The time integrator specified is not implemented yet")


def time_integrator_mms(Wave_object, source_id=0):
def time_integrator_mms(Wave_object, source_ids=[0]):
if Wave_object.time_integrator == "central_difference":
return central_difference_MMS(Wave_object, source_id=source_id)
return central_difference_MMS(Wave_object, source_ids=source_ids)
else:
raise ValueError("The time integrator specified is not implemented yet")
22 changes: 11 additions & 11 deletions spyro/solvers/time_integration_central_difference.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,30 @@
from .. import utils


def central_difference(Wave_object, source_id=0):
def central_difference(Wave_object, source_ids=[0]):
"""
Perform central difference time integration for wave propagation.
Parameters:
-----------
Wave_object: Spyro object
The Wave object containing the necessary data and parameters.
source_id: int (optional)
The ID of the source being propagated. Defaults to 0.
source_ids: list of ints (optional)
The ID of the sources being propagated. Defaults to [0].
Returns:
--------
tuple:
A tuple containing the forward solution and the receiver output.
"""
excitations = Wave_object.sources
excitations.current_source = source_id
excitations.current_sources = source_ids
receivers = Wave_object.receivers
comm = Wave_object.comm
temp_filename = Wave_object.forward_output_file

filename, file_extension = temp_filename.split(".")
output_filename = filename + "sn" + str(source_id) + "." + file_extension
output_filename = filename + "sn" + str(source_ids) + "." + file_extension
if Wave_object.forward_output:
parallel_print(f"Saving output in: {output_filename}", Wave_object.comm)

Expand Down Expand Up @@ -106,7 +106,7 @@ def central_difference(Wave_object, source_id=0):
return usol, usol_recv


def mixed_space_central_difference(Wave_object, source_id=0):
def mixed_space_central_difference(Wave_object, source_ids=[0]):
"""
Performs central difference time integration for wave propagation.
Solves for a mixed space formulation, for function X. For correctly
Expand All @@ -117,21 +117,21 @@ def mixed_space_central_difference(Wave_object, source_id=0):
-----------
Wave_object: Spyro object
The Wave object containing the necessary data and parameters.
source_id: int (optional)
The ID of the source being propagated. Defaults to 0.
source_ids: list of int (optional)
The ID of the source being propagated. Defaults to [0].
Returns:
--------
tuple:
A tuple containing the forward solution and the receiver output.
"""
excitations = Wave_object.sources
excitations.current_source = source_id
excitations.current_sources = source_ids
receivers = Wave_object.receivers
comm = Wave_object.comm
temp_filename = Wave_object.forward_output_file
filename, file_extension = temp_filename.split(".")
output_filename = filename + "sn" + str(source_id) + "." + file_extension
output_filename = filename + "sn" + str(source_ids) + "." + file_extension
if Wave_object.forward_output:
parallel_print(f"Saving output in: {output_filename}", Wave_object.comm)

Expand Down Expand Up @@ -208,7 +208,7 @@ def mixed_space_central_difference(Wave_object, source_id=0):
return usol, usol_recv


def central_difference_MMS(Wave_object, source_id=0):
def central_difference_MMS(Wave_object, source_ids=[0]):
"""Propagates the wave forward in time.
Currently uses central differences.
Expand Down
4 changes: 2 additions & 2 deletions spyro/sources/Sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(self, wave_object):
self.point_locations = wave_object.source_locations
self.number_of_points = wave_object.number_of_sources
self.is_local = [0] * self.number_of_points
self.current_source = None
self.current_sources = None

self.build_maps()

Expand All @@ -86,7 +86,7 @@ def apply_source(self, rhs_forcing, value):
The right hand side of the wave equation with the source applied
"""
for source_id in range(self.number_of_points):
if self.is_local[source_id] and source_id == self.current_source:
if self.is_local[source_id] and source_id in self.current_sources:
for i in range(len(self.cellNodeMaps[source_id])):
rhs_forcing.dat.data_with_halos[
int(self.cellNodeMaps[source_id][i])
Expand Down
4 changes: 3 additions & 1 deletion spyro/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,9 @@ def mpi_init(model):
elif model.parallelism_type == "spatial":
num_cores_per_shot = available_cores
elif model.parallelism_type == "custom":
raise ValueError("Custom parallelism not yet implemented")
shots_per_core = model.shots_per_core
num_max_shots_per_core = max(len(sublist) for sublist in shots_per_core)
num_cores_per_shot = len(shots_per_core)

comm_ens = Ensemble(COMM_WORLD, num_cores_per_shot) # noqa: F405
return comm_ens
Expand Down
81 changes: 81 additions & 0 deletions temp_forward_shot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# from mpi4py.MPI import COMM_WORLD
# import debugpy
# debugpy.listen(3000 + COMM_WORLD.rank)
# debugpy.wait_for_client()
import spyro
import numpy as np
import math


def run_forward(dt):
# dt = float(sys.argv[1])

final_time = 1.0

dictionary = {}
dictionary["options"] = {
"cell_type": "Q", # simplexes such as triangles or tetrahedra (T) or quadrilaterals (Q)
"variant": "lumped", # lumped, equispaced or DG, default is lumped "method":"MLT", # (MLT/spectral_quadrilateral/DG_triangle/DG_quadrilateral) You can either specify a cell_type+variant or a method
"degree": 4, # p order
"dimension": 2, # dimension
}

# Number of cores for the shot. For simplicity, we keep things serial.
# spyro however supports both spatial parallelism and "shot" parallelism.
dictionary["parallelism"] = {
"type": "custom", # options: automatic (same number of cores for evey processor) or spatial
"seperate_shots" : True,
"shots_per_core": [[0, 1]],
}

# Define the domain size without the PML. Here we'll assume a 1.00 x 1.00 km
# domain and reserve the remaining 250 m for the Perfectly Matched Layer (PML) to absorb
# outgoing waves on three sides (eg., -z, +-x sides) of the domain.
dictionary["mesh"] = {
"Lz": 2.0, # depth in km - always positive # Como ver isso sem ler a malha?
"Lx": 2.0, # width in km - always positive
"Ly": 0.0, # thickness in km - always positive
"mesh_file": None,
"mesh_type": "firedrake_mesh",
}
dictionary["acquisition"] = {
"source_type": "ricker",
"source_locations": spyro.create_transect((-0.55, 0.7), (-0.55, 1.3), 2),
"frequency": 5.0,
"delay": 0.2,
"delay_type": "time",
"receiver_locations": spyro.create_transect((-1.45, 0.7), (-1.45, 1.3), 200),
}

# Simulate for 2.0 seconds.
dictionary["time_axis"] = {
"initial_time": 0.0, # Initial time for event
"final_time": final_time, # Final time for event
"dt": dt, # timestep size
"amplitude": 1, # the Ricker has an amplitude of 1.
"output_frequency": 100, # how frequently to output solution to pvds
"gradient_sampling_frequency": 100, # how frequently to save solution to RAM
}

dictionary["visualization"] = {
"forward_output": True,
"forward_output_filename": "results/forward_output.pvd",
"fwi_velocity_model_output": False,
"velocity_model_filename": None,
"gradient_output": False,
"gradient_filename": None,
}

Wave_obj = spyro.AcousticWave(dictionary=dictionary)
Wave_obj.set_mesh(mesh_parameters={"dx": 0.02, "periodic": True})

Wave_obj.set_initial_velocity_model(constant=1.5)
Wave_obj.forward_solve()

rec_out = Wave_obj.receivers_output

return rec_out


if __name__ == "__main__":
run_forward(0.0005)
Loading

0 comments on commit 6f32eee

Please sign in to comment.