From 05de317a87743a79165564ff2312c1b66c748bcf Mon Sep 17 00:00:00 2001 From: olender Date: Fri, 5 Jul 2024 18:45:07 -0300 Subject: [PATCH] fixed fwi test --- cleanup.sh | 1 + spyro/io/basicio.py | 15 ++++++++------- spyro/solvers/inversion.py | 2 +- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/cleanup.sh b/cleanup.sh index 815e366f..0ef3f65d 100755 --- a/cleanup.sh +++ b/cleanup.sh @@ -7,6 +7,7 @@ rm *.pvtu rm *.pvd rm *.npy rm *.pdf +rm *.dat rm results/*.vtu rm results/*.pvd rm results/*.pvtu diff --git a/spyro/io/basicio.py b/spyro/io/basicio.py index 032f5b39..52294c29 100644 --- a/spyro/io/basicio.py +++ b/spyro/io/basicio.py @@ -101,15 +101,16 @@ def wrapper(*args, **kwargs): for propagation_id, shot_ids_in_propagation in enumerate(shot_ids_per_propagation_list): if is_owner(comm, propagation_id): grad = func(*args, **kwargs) - grad_total = fire.Function(args[0].function_space) + grad_total = fire.Function(args[0].function_space) - comm.comm.barrier() - grad_total = comm.allreduce(grad, grad_total) - grad_total /= comm.ensemble_comm.size - if comm.comm.size > 1: - grad_total /= comm.comm.size + comm.comm.barrier() + grad_total = comm.allreduce(grad, grad_total) + grad_total /= comm.ensemble_comm.size - return grad_total + if comm.comm.size > 1: + grad_total /= comm.comm.size + + return grad_total elif args[0].parallelism_type == "spatial" and args[0].number_of_sources > 1: num = args[0].number_of_sources starting_time = args[0].current_time diff --git a/spyro/solvers/inversion.py b/spyro/solvers/inversion.py index 8176eb83..0ae7bcd2 100644 --- a/spyro/solvers/inversion.py +++ b/spyro/solvers/inversion.py @@ -388,7 +388,7 @@ def get_gradient(self, c=None, save=True, calculate_functional=True): # self.gradient_out.write(dJ_total) output = fire.File("gradient_" + str(self.current_iteration)+".pvd") output.write(self.gradient) - print("DEBUG") + self.current_iteration += 1 comm.comm.barrier()