Skip to content

Commit

Permalink
moving grad parallelism section
Browse files Browse the repository at this point in the history
  • Loading branch information
Olender committed Jul 4, 2024
1 parent 76be77e commit 5c1a17f
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 17 deletions.
29 changes: 21 additions & 8 deletions spyro/io/basicio.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,22 +97,35 @@ def ensemble_gradient(func):
def wrapper(*args, **kwargs):
if args[0].parallelism_type != "spatial" or args[0].number_of_sources == 1:
shot_ids_per_propagation_list = args[0].shot_ids_per_propagation
_comm = args[0].comm
comm = args[0].comm
for propagation_id, shot_ids_in_propagation in enumerate(shot_ids_per_propagation_list):
if is_owner(_comm, propagation_id):
if is_owner(comm, propagation_id):
grad = func(*args, **kwargs)
return grad
grad_total = fire.Function(args[0].function_space)

comm.comm.barrier()
grad_total = comm.allreduce(grad, grad_total)
grad_total /= comm.ensemble_comm.size
if comm.comm.size > 1:
grad_total /= comm.comm.size

return grad_total
elif args[0].parallelism_type == "spatial" and args[0].number_of_sources > 1:
num = args[0].number_of_sources
starting_time = args[0].current_time
grad_total = fire.Function(args[0].function_space)
for snum in range(num):
switch_serial_shot(args[0], snum)
args[0].reset_pressure()
args[0].current_time = starting_time
grad = func(*args, **kwargs)
# arrays_list = [obj.dat.data[:] for obj in u]
# stacked_arrays = np.stack(arrays_list, axis=0)
# np.save(f'tmp_shot{snum}.npy', stacked_arrays)
# np.save(f"tmp_rec{snum}.npy", u_r)
grad_total += grad

grad_total /= num
if comm.comm.size > 1:
grad_total /= comm.comm.size

return grad_total

return wrapper

Expand Down Expand Up @@ -220,7 +233,7 @@ def save_shots(Wave_obj, file_name="shots/shot_record_", shot_ids=0):


@ensemble_save_or_load
def load_shots(Wave_obj, file_name=None):
def load_shots(Wave_obj, file_name=None, shot_ids=0):
"""Load a `pickle` to a `numpy.ndarray`.
Parameters
Expand Down
11 changes: 2 additions & 9 deletions spyro/solvers/inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,19 +382,12 @@ def get_gradient(self, c=None, save=True, calculate_functional=True):
if calculate_functional:
self.get_functional(c=c)
comm.comm.barrier()
dJ = self.gradient_solve(misfit=self.misfit, forward_solution=self.guess_forward_solution)
dJ_total = fire.Function(self.function_space)
comm.comm.barrier()
dJ_total = comm.allreduce(dJ, dJ_total)
dJ_total /= comm.ensemble_comm.size
if comm.comm.size > 1:
dJ_total /= comm.comm.size
self.gradient = dJ_total
self.gradient = self.gradient_solve(misfit=self.misfit, forward_solution=self.guess_forward_solution)
self._apply_gradient_mask()
if save and comm.comm.rank == 0:
# self.gradient_out.write(dJ_total)
output = fire.File("gradient_" + str(self.current_iteration)+".pvd")
output.write(dJ_total)
output.write(self.gradient)
print("DEBUG")
self.current_iteration += 1
comm.comm.barrier()
Expand Down

0 comments on commit 5c1a17f

Please sign in to comment.