diff --git a/spyro/io/basicio.py b/spyro/io/basicio.py index 52294c29..9eccdc56 100644 --- a/spyro/io/basicio.py +++ b/spyro/io/basicio.py @@ -107,9 +107,6 @@ def wrapper(*args, **kwargs): grad_total = comm.allreduce(grad, grad_total) grad_total /= comm.ensemble_comm.size - if comm.comm.size > 1: - grad_total /= comm.comm.size - return grad_total elif args[0].parallelism_type == "spatial" and args[0].number_of_sources > 1: num = args[0].number_of_sources @@ -123,8 +120,6 @@ def wrapper(*args, **kwargs): grad_total += grad grad_total /= num - if comm.comm.size > 1: - grad_total /= comm.comm.size return grad_total @@ -253,6 +248,7 @@ def load_shots(Wave_obj, file_name=None, shot_ids=0): """ array = np.zeros(()) + file_name = file_name + str(shot_ids) + ".dat" with open(file_name, "rb") as f: array = np.asarray(pickle.load(f), dtype=float) diff --git a/test_parallel/test_supershot_grad.py b/test_parallel/test_supershot_grad.py index 9ce9e540..a2a8b07b 100644 --- a/test_parallel/test_supershot_grad.py +++ b/test_parallel/test_supershot_grad.py @@ -127,7 +127,7 @@ def get_forward_model(load_true=False): conditional=cond, # output=True ) - spyro.plots.plot_model(Wave_obj_exact, abc_points=[(-1, 1), (-2, 1), (-2, 4), (-1, 2)]) + # spyro.plots.plot_model(Wave_obj_exact, abc_points=[(-1, 1), (-2, 1), (-2, 4), (-1, 2)]) Wave_obj_exact.forward_solve() # forward_solution_exact = Wave_obj_exact.forward_solution rec_out_exact = Wave_obj_exact.receivers_output