Skip to content

Commit

Permalink
added fwi with serialshots
Browse files Browse the repository at this point in the history
  • Loading branch information
Olender committed Jul 17, 2024
1 parent d0c47d9 commit 33f53e4
Show file tree
Hide file tree
Showing 6 changed files with 359 additions and 11 deletions.
4 changes: 4 additions & 0 deletions spyro/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
switch_serial_shot,
ensemble_save_or_load,
delete_tmp_files,
ensemble_shot_record,
ensemble_functional
)
from .model_parameters import Model_parameters
from .backwards_compatibility_io import Dictionary_conversion
Expand Down Expand Up @@ -44,4 +46,6 @@
"switch_serial_shot",
"ensemble_save_or_load",
"delete_tmp_files",
"ensemble_shot_record",
"ensemble_functional",
]
48 changes: 47 additions & 1 deletion spyro/io/basicio.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import with_statement

import pickle

from mpi4py import MPI
import firedrake as fire
import h5py
import numpy as np
Expand All @@ -19,6 +19,21 @@ def delete_tmp_files(wave):
os.remove(file)


def ensemble_shot_record(func):
"""Decorator for read and write shots for ensemble parallelism"""

def wrapper(*args, **kwargs):
if args[0].parallelism_type == "spatial" and args[0].number_of_sources > 1:
output_list = []
for snum in range(args[0].number_of_sources):
switch_serial_shot(args[0], snum)
output_list.append(func(*args, **kwargs))

return output_list

return wrapper


def ensemble_save_or_load(func):
"""Decorator for read and write shots for ensemble parallelism"""

Expand Down Expand Up @@ -106,6 +121,37 @@ def switch_serial_shot(wave, propagation_id):
wave.receivers_output = wave.forward_solution_receivers


def ensemble_functional(func):
"""Decorator for gradient to distribute shots for ensemble parallelism"""

def wrapper(*args, **kwargs):
comm = args[0].comm
if args[0].parallelism_type != "spatial" or args[0].number_of_sources == 1:
J = func(*args, **kwargs)
J_total = np.zeros((1))
J_total[0] += J
J_total = fire.COMM_WORLD.allreduce(J_total, op=MPI.SUM)
J_total[0] /= comm.comm.size

elif args[0].parallelism_type == "spatial" and args[0].number_of_sources > 1:
num = args[0].number_of_sources
residual_list = args[1]
J_total = np.zeros((1))

for snum in range(args[0].number_of_sources):
switch_serial_shot(args[0], snum)
current_residual = residual_list[snum]
J = func(args[0], current_residual)
J_total += J
J_total[0] /= comm.comm.size

comm.comm.barrier()

return J_total[0]

return wrapper


def ensemble_gradient(func):
"""Decorator for gradient to distribute shots for ensemble parallelism"""

Expand Down
29 changes: 24 additions & 5 deletions spyro/solvers/inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from ..utils import compute_functional
from ..utils import Gradient_mask_for_pml, Mask
from ..plots import plot_model as spyro_plot_model
from ..io.basicio import ensemble_shot_record
from ..io.basicio import switch_serial_shot


try:
from ROL.firedrake_vector import FiredrakeVector as FireVector
Expand Down Expand Up @@ -181,10 +184,19 @@ def calculate_misfit(self, c=None):
self.forward_solve()
output = fire.File("control_" + str(self.current_iteration)+".pvd")
output.write(self.c)
self.guess_shot_record = self.forward_solution_receivers
self.guess_forward_solution = self.forward_solution

self.misfit = self.real_shot_record - self.guess_shot_record
if self.parallelism_type == "spatial" and self.number_of_sources > 1:
misfit_list = []
guess_shot_record_list = []
for snum in range (self.number_of_sources):
switch_serial_shot(self, snum)
guess_shot_record_list.append(self.forward_solution_receivers)
misfit_list.append(self.real_shot_record[snum] - self.forward_solution_receivers)
self.guess_shot_record = guess_shot_record_list
self.misfit = misfit_list
else:
self.guess_shot_record = self.forward_solution_receivers
self.guess_forward_solution = self.forward_solution
self.misfit = self.real_shot_record - self.guess_shot_record
return self.misfit

def generate_real_shot_record(self, plot_model=False, filename=None, abc_points=None):
Expand Down Expand Up @@ -572,4 +584,11 @@ def __init__(self, dictionary=None, comm=None):

def forward_solve(self):
super().forward_solve()
self.real_shot_record = self.receivers_output
if self.parallelism_type == "spatial" and self.number_of_sources > 1:
real_shot_record_list = []
for snum in range (self.number_of_sources):
switch_serial_shot(self, snum)
real_shot_record_list.append(self.receivers_output)
self.real_shot_record = real_shot_record_list
else:
self.real_shot_record = self.receivers_output
8 changes: 3 additions & 5 deletions spyro/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from mpi4py import MPI
from scipy.signal import butter, filtfilt
import warnings
from ..io import ensemble_functional


def butter_lowpass_filter(shot, cutoff, fs, order=2):
Expand Down Expand Up @@ -37,6 +38,7 @@ def butter_lowpass_filter(shot, cutoff, fs, order=2):
return filtered_shot


@ensemble_functional
def compute_functional(Wave_object, residual):
"""Compute the functional to be optimized.
Accepts the velocity optionally and uses
Expand All @@ -52,11 +54,7 @@ def compute_functional(Wave_object, residual):

J *= 0.5

J_total = np.zeros((1))
J_total[0] += J
J_total = COMM_WORLD.allreduce(J_total, op=MPI.SUM)
J_total[0] /= comm.comm.size
return J_total[0]
return J


def evaluate_misfit(model, guess, exact):
Expand Down
138 changes: 138 additions & 0 deletions temp_serialshot_fwi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import numpy as np
import firedrake as fire
import spyro
import pytest
import warnings


warnings.filterwarnings("ignore")


def is_rol_installed():
try:
import ROL
return True
except ImportError:
return False


final_time = 0.9

dictionary = {}
dictionary["options"] = {
"cell_type": "T", # simplexes such as triangles or tetrahedra (T) or quadrilaterals (Q)
"variant": "lumped", # lumped, equispaced or DG, default is lumped
"degree": 4, # p order
"dimension": 2, # dimension
}
dictionary["parallelism"] = {
"type": "spatial", # options: automatic (same number of cores for evey processor) or spatial
}
dictionary["mesh"] = {
"Lz": 2.0, # depth in km - always positive # Como ver isso sem ler a malha?
"Lx": 2.0, # width in km - always positive
"Ly": 0.0, # thickness in km - always positive
"mesh_file": None,
"mesh_type": "firedrake_mesh",
}
dictionary["acquisition"] = {
"source_type": "ricker",
"source_locations": spyro.create_transect((-0.55, 0.7), (-0.55, 1.3), 6),
"frequency": 5.0,
"delay": 0.2,
"delay_type": "time",
"receiver_locations": spyro.create_transect((-1.45, 0.7), (-1.45, 1.3), 200),
}
dictionary["time_axis"] = {
"initial_time": 0.0, # Initial time for event
"final_time": final_time, # Final time for event
"dt": 0.001, # timestep size
"amplitude": 1, # the Ricker has an amplitude of 1.
"output_frequency": 100, # how frequently to output solution to pvds - Perguntar Daiane ''post_processing_frequnecy'
"gradient_sampling_frequency": 1, # how frequently to save solution to RAM - Perguntar Daiane 'gradient_sampling_frequency'
}
dictionary["visualization"] = {
"forward_output": False,
"forward_output_filename": "results/forward_output.pvd",
"fwi_velocity_model_output": False,
"velocity_model_filename": None,
"gradient_output": False,
"gradient_filename": "results/Gradient.pvd",
"adjoint_output": False,
"adjoint_filename": None,
"debug_output": False,
}
dictionary["inversion"] = {
"perform_fwi": True, # switch to true to make a FWI
"initial_guess_model_file": None,
"shot_record_file": None,
}


def test_fwi(load_real_shot=False, use_rol=False):
"""
Run the Full Waveform Inversion (FWI) test.
Parameters
----------
load_real_shot (bool, optional): Whether to load a real shot record or not. Defaults to False.
"""

# Setting up to run synthetic real problem
FWI_obj = spyro.FullWaveformInversion(dictionary=dictionary)

FWI_obj.set_real_mesh(mesh_parameters={"dx": 0.1})
center_z = -1.0
center_x = 1.0
mesh_z = FWI_obj.mesh_z
mesh_x = FWI_obj.mesh_x
cond = fire.conditional((mesh_z-center_z)**2 + (mesh_x-center_x)**2 < .2**2, 3.0, 2.5)

FWI_obj.set_real_velocity_model(conditional=cond, output=True, dg_velocity_model=False)
FWI_obj.generate_real_shot_record(
plot_model=True,
filename="True_experiment.png",
abc_points=[(-0.5, 0.5), (-1.5, 0.5), (-1.5, 1.5), (-0.5, 1.5)]
)
np.save("real_shot_record", FWI_obj.real_shot_record)

# Setting up initial guess problem
FWI_obj.set_guess_mesh(mesh_parameters={"dx": 0.1})
FWI_obj.set_guess_velocity_model(constant=2.5)
mask_boundaries = {
"z_min": -1.3,
"z_max": -0.7,
"x_min": 0.7,
"x_max": 1.3,
}
FWI_obj.set_gradient_mask(boundaries=mask_boundaries)
if use_rol:
FWI_obj.run_fwi_rol(vmin=2.5, vmax=3.0, maxiter=2)
else:
FWI_obj.run_fwi(vmin=2.5, vmax=3.0, maxiter=5)

# simple mask test
grad_test = FWI_obj.gradient
test0 = np.isclose(grad_test.at((-0.1, 0.1)), 0.0)
print(f"PML looks masked: {test0}", flush=True)
test1 = np.abs(grad_test.at((-1.0, 1.0))) > 1e-5
print(f"Center looks unmasked: {test1}", flush=True)

# quick look at functional and if it reduced
test2 = FWI_obj.functional < 1e-3
print(f"Last functional small: {test2}", flush=True)
test3 = FWI_obj.functional_history[-1]/FWI_obj.functional_history[0] < 1e-2
print(f"Considerable functional reduction during test: {test3}", flush=True)

print("END", flush=True)
assert all([test0, test1, test2, test3])


@pytest.mark.skipif(not is_rol_installed(), reason="ROL is not installed")
def test_fwi_with_rol(load_real_shot=False, use_rol=True):
test_fwi(load_real_shot=load_real_shot, use_rol=use_rol)


if __name__ == "__main__":
test_fwi(load_real_shot=False)
test_fwi_with_rol()
Loading

0 comments on commit 33f53e4

Please sign in to comment.