Skip to content

Commit

Permalink
Adding io test
Browse files Browse the repository at this point in the history
  • Loading branch information
Olender committed Jul 2, 2024
1 parent dc06177 commit bcdb07d
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 20 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/python-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ jobs:
mpiexec -n 6 pytest test_parallel/test_forward.py
mpiexec -n 6 pytest test_parallel/test_fwi.py
mpiexec -n 6 pytest test_parallel/test_forward_supershot.py
mpiexec -n 2 pytest test_parallel/test_parallel_io.py
- name: Covering parallel 3D forward test
continue-on-error: true
run: |
Expand All @@ -46,6 +47,11 @@ jobs:
run: |
source /home/olender/Firedrakes/newest3/firedrake/bin/activate
mpiexec -n 6 pytest --cov-report=xml --cov-append --cov=spyro test_parallel/test_forward_supershot.py
- name: Covering parallel io test
continue-on-error: true
run: |
source /home/olender/Firedrakes/newest3/firedrake/bin/activate
mpiexec -n 2 pytest --cov-report=xml --cov-append --cov=spyro test_parallel/test_parallel_io.py
# - name: Running serial tests for adjoint
# run: |
# source /home/olender/Firedrakes/main/firedrake/bin/activate
Expand Down
29 changes: 29 additions & 0 deletions test_parallel/test_parallel_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import spyro


def test_saving_and_loading_shot_record():
from test.inputfiles.model import dictionary

dictionary["parallelism"]["type"] = "custom"
dictionary["parallelism"]["shot_ids_per_propagation"] = [[0, 1]]
dictionary["time_axis"]["final_time"] = 0.5
dictionary["acquisition"]["source_locations"] = [(-0.5, 0.4), (-0.5, 0.6)]
dictionary["acquisition"]["receiver_locations"] = spyro.create_transect((-0.55, 0.1), (-0.55, 0.9), 200)

wave = spyro.AcousticWave(dictionary=dictionary)
wave.set_mesh(mesh_parameters={"dx": 0.02})
wave.set_initial_velocity_model(constant=1.5)
wave.forward_solve()
spyro.io.save_shots(wave, file_name="test_shot_record")
shots1 = wave.forward_solution_receivers

wave2 = spyro.AcousticWave(dictionary=dictionary)
wave2.set_mesh(mesh_parameters={"dx": 0.02})
spyro.io.load_shots(wave2, file_name="test_shot_record")
shots2 = wave.forward_solution_receivers

assert (shots1 == shots2).all()


if __name__ == "__main__":
test_saving_and_loading_shot_record()
33 changes: 13 additions & 20 deletions test_parallel_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import spyro


def test_saving_shot_record():
def test_saving_and_loading_shot_record():
from test.inputfiles.model import dictionary

dictionary["parallelism"]["type"] = "custom"
Expand All @@ -13,27 +13,20 @@ def test_saving_shot_record():
dictionary["acquisition"]["source_locations"] = [(-0.5, 0.4), (-0.5, 0.6)]
dictionary["acquisition"]["receiver_locations"] = spyro.create_transect((-0.55, 0.1), (-0.55, 0.9), 200)

Wave_obj = spyro.AcousticWave(dictionary=dictionary)
Wave_obj.set_mesh(mesh_parameters={"dx": 0.02})
Wave_obj.set_initial_velocity_model(constant=1.5)
Wave_obj.forward_solve()
spyro.io.save_shots(Wave_obj, file_name="test_shot_record")
wave = spyro.AcousticWave(dictionary=dictionary)
wave.set_mesh(mesh_parameters={"dx": 0.02})
wave.set_initial_velocity_model(constant=1.5)
wave.forward_solve()
spyro.io.save_shots(wave, file_name="test_shot_record")
shots1 = wave.forward_solution_receivers

wave2 = spyro.AcousticWave(dictionary=dictionary)
wave2.set_mesh(mesh_parameters={"dx": 0.02})
spyro.io.load_shots(wave2, file_name="test_shot_record")
shots2 = wave.forward_solution_receivers

def test_loading_shot_record():
from test.inputfiles.model import dictionary

dictionary["parallelism"]["type"] = "custom"
dictionary["parallelism"]["shot_ids_per_propagation"] = [[0, 1]]
dictionary["time_axis"]["final_time"] = 0.5
dictionary["acquisition"]["source_locations"] = [(-0.5, 0.4), (-0.5, 0.6)]
dictionary["acquisition"]["receiver_locations"] = spyro.create_transect((-0.55, 0.1), (-0.55, 0.9), 200)

Wave_obj = spyro.AcousticWave(dictionary=dictionary)
Wave_obj.set_mesh(mesh_parameters={"dx": 0.02})
spyro.io.load_shots(Wave_obj, file_name="test_shot_record")
assert (shots1 == shots2).all()


if __name__ == "__main__":
test_saving_shot_record()
test_loading_shot_record()
test_saving_and_loading_shot_record()

0 comments on commit bcdb07d

Please sign in to comment.