diff --git a/spyro/io/basicio.py b/spyro/io/basicio.py index 6ca87df9..032f5b39 100644 --- a/spyro/io/basicio.py +++ b/spyro/io/basicio.py @@ -97,22 +97,35 @@ def ensemble_gradient(func): def wrapper(*args, **kwargs): if args[0].parallelism_type != "spatial" or args[0].number_of_sources == 1: shot_ids_per_propagation_list = args[0].shot_ids_per_propagation - _comm = args[0].comm + comm = args[0].comm for propagation_id, shot_ids_in_propagation in enumerate(shot_ids_per_propagation_list): - if is_owner(_comm, propagation_id): + if is_owner(comm, propagation_id): grad = func(*args, **kwargs) - return grad + grad_total = fire.Function(args[0].function_space) + + comm.comm.barrier() + grad_total = comm.allreduce(grad, grad_total) + grad_total /= comm.ensemble_comm.size + if comm.comm.size > 1: + grad_total /= comm.comm.size + + return grad_total elif args[0].parallelism_type == "spatial" and args[0].number_of_sources > 1: num = args[0].number_of_sources starting_time = args[0].current_time + grad_total = fire.Function(args[0].function_space) for snum in range(num): + switch_serial_shot(args[0], snum) args[0].reset_pressure() args[0].current_time = starting_time grad = func(*args, **kwargs) - # arrays_list = [obj.dat.data[:] for obj in u] - # stacked_arrays = np.stack(arrays_list, axis=0) - # np.save(f'tmp_shot{snum}.npy', stacked_arrays) - # np.save(f"tmp_rec{snum}.npy", u_r) + grad_total += grad + + grad_total /= num + if comm.comm.size > 1: + grad_total /= comm.comm.size + + return grad_total return wrapper @@ -220,7 +233,7 @@ def save_shots(Wave_obj, file_name="shots/shot_record_", shot_ids=0): @ensemble_save_or_load -def load_shots(Wave_obj, file_name=None): +def load_shots(Wave_obj, file_name=None, shot_ids=0): """Load a `pickle` to a `numpy.ndarray`. Parameters diff --git a/spyro/solvers/inversion.py b/spyro/solvers/inversion.py index 98b56021..8176eb83 100644 --- a/spyro/solvers/inversion.py +++ b/spyro/solvers/inversion.py @@ -382,19 +382,12 @@ def get_gradient(self, c=None, save=True, calculate_functional=True): if calculate_functional: self.get_functional(c=c) comm.comm.barrier() - dJ = self.gradient_solve(misfit=self.misfit, forward_solution=self.guess_forward_solution) - dJ_total = fire.Function(self.function_space) - comm.comm.barrier() - dJ_total = comm.allreduce(dJ, dJ_total) - dJ_total /= comm.ensemble_comm.size - if comm.comm.size > 1: - dJ_total /= comm.comm.size - self.gradient = dJ_total + self.gradient = self.gradient_solve(misfit=self.misfit, forward_solution=self.guess_forward_solution) self._apply_gradient_mask() if save and comm.comm.rank == 0: # self.gradient_out.write(dJ_total) output = fire.File("gradient_" + str(self.current_iteration)+".pvd") - output.write(dJ_total) + output.write(self.gradient) print("DEBUG") self.current_iteration += 1 comm.comm.barrier()