Skip to content

Commit

Permalink
basic saving and loading
Browse files Browse the repository at this point in the history
  • Loading branch information
Olender committed Jul 4, 2024
1 parent 9dab73e commit 7cf61b8
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 4 deletions.
3 changes: 3 additions & 0 deletions cleanup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@ rm *.png
rm *.vtu
rm *.pvtu
rm *.pvd
rm *.npy
rm *.pdf
rm results/*.vtu
rm results/*.pvd
rm results/*.pvtu
rm shots/*.dat
rm -rf results/shot*
rm -rf results/gradient
rm -rf results/adjoint_shot
Expand Down
21 changes: 17 additions & 4 deletions spyro/io/basicio.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,27 @@ def wrapper(*args, **kwargs):
args[0].reset_pressure()
args[0].current_time = starting_time
u, u_r = func(*args, **dict(kwargs, source_nums=[snum]))
arrays_list = [obj.dat.data[:] for obj in u]
stacked_arrays = np.stack(arrays_list, axis=0)
np.save(f'tmp_shot{snum}.npy', stacked_arrays)
np.save(f"tmp_rec{snum}.npy", u_r)
save_serial_data(args[0], snum)
return u, u_r

return wrapper


def save_serial_data(wave, propagation_id):
arrays_list = [obj.dat.data[:] for obj in wave.forward_solution]
stacked_arrays = np.stack(arrays_list, axis=0)
np.save(f'tmp_shot{propagation_id}.npy', stacked_arrays)
np.save(f"tmp_rec{propagation_id}.npy", wave.forward_solution_receivers)


def switch_serial_shot(wave, propagation_id):
stacked_shot_arrays = np.load(f'tmp_shot{propagation_id}.npy')
for array_i, array in enumerate(stacked_shot_arrays):
wave.forward_solution[array_i].dat.data[:] = array
wave.forward_solution_receivers = np.load(f"tmp_rec{propagation_id}.npy")



def ensemble_gradient(func):
"""Decorator for gradient to distribute shots for ensemble parallelism"""

Expand Down

0 comments on commit 7cf61b8

Please sign in to comment.