Skip to content

Commit

Permalink
fixing minor bug
Browse files Browse the repository at this point in the history
  • Loading branch information
Olender committed Jul 8, 2024
1 parent 05de317 commit 0fdcf3c
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 6 deletions.
6 changes: 1 addition & 5 deletions spyro/io/basicio.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,6 @@ def wrapper(*args, **kwargs):
grad_total = comm.allreduce(grad, grad_total)
grad_total /= comm.ensemble_comm.size

if comm.comm.size > 1:
grad_total /= comm.comm.size

return grad_total
elif args[0].parallelism_type == "spatial" and args[0].number_of_sources > 1:
num = args[0].number_of_sources
Expand All @@ -123,8 +120,6 @@ def wrapper(*args, **kwargs):
grad_total += grad

grad_total /= num
if comm.comm.size > 1:
grad_total /= comm.comm.size

return grad_total

Expand Down Expand Up @@ -253,6 +248,7 @@ def load_shots(Wave_obj, file_name=None, shot_ids=0):
"""
array = np.zeros(())
file_name = file_name + str(shot_ids) + ".dat"

with open(file_name, "rb") as f:
array = np.asarray(pickle.load(f), dtype=float)
Expand Down
2 changes: 1 addition & 1 deletion test_parallel/test_supershot_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def get_forward_model(load_true=False):
conditional=cond,
# output=True
)
spyro.plots.plot_model(Wave_obj_exact, abc_points=[(-1, 1), (-2, 1), (-2, 4), (-1, 2)])
# spyro.plots.plot_model(Wave_obj_exact, abc_points=[(-1, 1), (-2, 1), (-2, 4), (-1, 2)])
Wave_obj_exact.forward_solve()
# forward_solution_exact = Wave_obj_exact.forward_solution
rec_out_exact = Wave_obj_exact.receivers_output
Expand Down

0 comments on commit 0fdcf3c

Please sign in to comment.