Skip to content

Commit

Permalink
fixed fwi test
Browse files Browse the repository at this point in the history
  • Loading branch information
Olender committed Jul 5, 2024
1 parent 5c1a17f commit 05de317
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 8 deletions.
1 change: 1 addition & 0 deletions cleanup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ rm *.pvtu
rm *.pvd
rm *.npy
rm *.pdf
rm *.dat
rm results/*.vtu
rm results/*.pvd
rm results/*.pvtu
Expand Down
15 changes: 8 additions & 7 deletions spyro/io/basicio.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,16 @@ def wrapper(*args, **kwargs):
for propagation_id, shot_ids_in_propagation in enumerate(shot_ids_per_propagation_list):
if is_owner(comm, propagation_id):
grad = func(*args, **kwargs)
grad_total = fire.Function(args[0].function_space)
grad_total = fire.Function(args[0].function_space)

comm.comm.barrier()
grad_total = comm.allreduce(grad, grad_total)
grad_total /= comm.ensemble_comm.size
if comm.comm.size > 1:
grad_total /= comm.comm.size
comm.comm.barrier()
grad_total = comm.allreduce(grad, grad_total)
grad_total /= comm.ensemble_comm.size

return grad_total
if comm.comm.size > 1:
grad_total /= comm.comm.size

return grad_total
elif args[0].parallelism_type == "spatial" and args[0].number_of_sources > 1:
num = args[0].number_of_sources
starting_time = args[0].current_time
Expand Down
2 changes: 1 addition & 1 deletion spyro/solvers/inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def get_gradient(self, c=None, save=True, calculate_functional=True):
# self.gradient_out.write(dJ_total)
output = fire.File("gradient_" + str(self.current_iteration)+".pvd")
output.write(self.gradient)
print("DEBUG")

self.current_iteration += 1
comm.comm.barrier()

Expand Down

0 comments on commit 05de317

Please sign in to comment.