Skip to content

Commit

Permalink
debugging paralellism
Browse files Browse the repository at this point in the history
  • Loading branch information
Olender committed Aug 6, 2024
1 parent 068e54b commit 43b6656
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 8 deletions.
6 changes: 4 additions & 2 deletions spyro/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
parallel_print,
saving_source_and_receiver_location_in_csv,
switch_serial_shot,
ensemble_save_or_load,
ensemble_save,
ensemble_load,
delete_tmp_files,
ensemble_shot_record,
ensemble_functional
Expand Down Expand Up @@ -44,7 +45,8 @@
"boundary_layer_io",
"saving_source_and_receiver_location_in_csv",
"switch_serial_shot",
"ensemble_save_or_load",
"ensemble_save",
"ensemble_load",
"delete_tmp_files",
"ensemble_shot_record",
"ensemble_functional",
Expand Down
26 changes: 23 additions & 3 deletions spyro/io/basicio.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def wrapper(*args, **kwargs):
return wrapper


def ensemble_save_or_load(func):
def ensemble_save(func):
"""Decorator for read and write shots for ensemble parallelism"""

def wrapper(*args, **kwargs):
Expand All @@ -54,6 +54,26 @@ def wrapper(*args, **kwargs):
return wrapper


def ensemble_load(func):
"""Decorator for read and write shots for ensemble parallelism"""

def wrapper(*args, **kwargs):
_comm = args[0].comm

if args[0].parallelism_type != "spatial" or args[0].number_of_sources == 1:
shot_ids_per_propagation_list = args[0].shot_ids_per_propagation
for propagation_id, shot_ids_in_propagation in enumerate(shot_ids_per_propagation_list):
if is_owner(_comm, propagation_id):
func(*args, **dict(kwargs, shot_ids=shot_ids_in_propagation))
elif args[0].parallelism_type == "spatial" and args[0].number_of_sources > 1:
for snum in range(args[0].number_of_sources):
switch_serial_shot(args[0], snum)
if _comm.comm.rank == 0:
func(*args, **dict(kwargs, shot_ids=[snum]))

return wrapper


def ensemble_propagator(func):
"""Decorator for forward to distribute shots for ensemble parallelism"""

Expand Down Expand Up @@ -275,7 +295,7 @@ def create_segy(function, V, grid_spacing, filename):
f.trace[tr] = velocity[:, tr]


@ensemble_save_or_load
@ensemble_save
def save_shots(Wave_obj, file_name="shots/shot_record_", shot_ids=0):
"""Save a the shot record from last forward solve to a `pickle`.
Expand Down Expand Up @@ -305,7 +325,7 @@ def rebuild_empty_forward_solution(wave, time_steps):
wave.forward_solution.append(fire.Function(wave.function_space))


@ensemble_save_or_load
@ensemble_load
def load_shots(Wave_obj, file_name="shots/shot_record_", shot_ids=0):
"""Load a `pickle` to a `numpy.ndarray`.
Expand Down
4 changes: 2 additions & 2 deletions spyro/plots/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
import numpy as np
import firedrake
import copy
from ..io import ensemble_save_or_load
from ..io import ensemble_save

__all__ = ["plot_shots"]


@ensemble_save_or_load
@ensemble_save
def plot_shots(
Wave_object,
show=False,
Expand Down
2 changes: 1 addition & 1 deletion spyro/solvers/inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ def run_fwi(self, **kwargs):

vmin = parameters["vmin"]
vmax = parameters["vmax"]
vp_0 = self.initial_velocity_model.vector().gather()
vp_0 = self.initial_velocity_model.vector()
bounds = [(vmin, vmax) for _ in range(len(vp_0))]
options = parameters["scipy_options"]

Expand Down

0 comments on commit 43b6656

Please sign in to comment.