Skip to content

Commit

Permalink
fixing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Olender committed Jul 26, 2024
1 parent 497d642 commit dd2bbc6
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 9 deletions.
8 changes: 7 additions & 1 deletion spyro/examples/rectangle.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from spyro.examples.example_model import Example_model_acoustic
from spyro.examples.example_model import Example_model_acoustic_FWI
import firedrake as fire
import copy

rectangle_optimization_parameters = {
"General": {
Expand Down Expand Up @@ -104,6 +105,11 @@
"gradient_filename": None,
}

rectangle_dictionary_fwi = copy.deepcopy(rectangle_dictionary)
rectangle_dictionary_fwi["inversion"] = {
"perform_fwi": True, # switch to true to make a FWI
}


class Rectangle_acoustic(Example_model_acoustic):
"""
Expand Down Expand Up @@ -204,7 +210,7 @@ class Rectangle_acoustic_FWI(Example_model_acoustic_FWI):
def __init__(
self,
dictionary=None,
example_dictionary=rectangle_dictionary,
example_dictionary=rectangle_dictionary_fwi,
comm=None,
periodic=False,
):
Expand Down
3 changes: 1 addition & 2 deletions spyro/io/basicio.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def wrapper(*args, **kwargs):
for snum in range(args[0].number_of_sources):
switch_serial_shot(args[0], snum)
output_list.append(func(*args, **kwargs))

return output_list

return wrapper
Expand Down Expand Up @@ -305,7 +305,6 @@ def rebuild_empty_forward_solution(wave, time_steps):
wave.forward_solution.append(fire.Function(wave.function_space))



@ensemble_save_or_load
def load_shots(Wave_obj, file_name="shots/shot_record_", shot_ids=0):
"""Load a `pickle` to a `numpy.ndarray`.
Expand Down
8 changes: 6 additions & 2 deletions spyro/tools/velocity_smoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import segyio
import numpy as np
import matplotlib.pyplot as plt
from SeismicMesh import write_velocity_model


def smooth_velocity_field_file(input_filename, output_filename, sigma, show=False):
def smooth_velocity_field_file(input_filename, output_filename, sigma, show=False, write_hdf5=True):
"""Smooths a velocity field using a Gaussian filter.
Parameters
Expand Down Expand Up @@ -40,7 +41,7 @@ def smooth_velocity_field_file(input_filename, output_filename, sigma, show=Fals

for i in range(ni):
for j in range(nj):
if vp[i, j] < 1.51 and i < 400:
if vp[i, j] < 1.51:
vp_smooth[i, j] = vp[i, j]

spec = segyio.spec()
Expand Down Expand Up @@ -72,4 +73,7 @@ def smooth_velocity_field_file(input_filename, output_filename, sigma, show=Fals
ax.axis("equal")
plt.show()

if write_hdf5:
write_velocity_model(output_filename, ofname=output_filename[:-5])

return None
3 changes: 1 addition & 2 deletions test/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ def test_read_and_write_segy():

vp.interpolate(c)

xi, yi, zi = spyro.io.write_function_to_grid(vp, V, 10.0 / 1000.0)
spyro.io.create_segy(zi, segy_file)
spyro.io.create_segy(vp, V, 10.0/1000.0, segy_file)
write_velocity_model(segy_file, vp_name)

model = {}
Expand Down
2 changes: 1 addition & 1 deletion test/test_serialshot_fwi.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def test_fwi(load_real_shot=False, use_rol=False):
FWI_obj.set_real_velocity_model(conditional=cond, output=True, dg_velocity_model=False)
FWI_obj.generate_real_shot_record(
plot_model=True,
filename="True_experiment.png",
model_filename="True_experiment.png",
abc_points=[(-0.5, 0.5), (-1.5, 0.5), (-1.5, 1.5), (-0.5, 1.5)]
)
np.save("real_shot_record", FWI_obj.real_shot_record)
Expand Down
34 changes: 33 additions & 1 deletion test_polygon_fwi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# debugpy.wait_for_client()
import spyro
import numpy as np
import matplotlib.pyplot as plt


def test_real_shot_record_generation_parallel():
Expand Down Expand Up @@ -56,5 +57,36 @@ def test_real_shot_record_generation_parallel():
assert test


def test_velocity_smoother_in_fwi():
dictionary = {}
dictionary["absorving_boundary_conditions"] = {
"pad_length": 2.0, # True or false
}
dictionary["mesh"] = {
"h": 0.01, # mesh size in km
}
dictionary["polygon_options"] = {
"water_layer_is_present": True,
"upper_layer": 2.0,
"middle_layer": 2.5,
"lower_layer": 3.0,
"polygon_layer_perturbation": 0.3,
}
dictionary["acquisition"] = {
"source_locations": spyro.create_transect((-0.1, 0.1), (-0.1, 0.9), 1),
}
fwi = spyro.examples.Polygon_acoustic_FWI(dictionary=dictionary, periodic=True)
spyro.io.create_segy(
fwi.initial_velocity_model,
fwi.function_space,
10.0/1000.0,
"velocity_models/true_case1.segy",
)
spyro.tools.velocity_smoother.smooth_velocity_field_file("velocity_models/true_case1.segy", "velocity_models/case1_sigma10.segy", 10, show=True, write_hdf5=True)
plt.savefig("velocity_models/case1_sigma10.png")
plt.close()


if __name__ == "__main__":
test_real_shot_record_generation_parallel()
# test_real_shot_record_generation_parallel()
test_velocity_smoother_in_fwi()

0 comments on commit dd2bbc6

Please sign in to comment.