From c72d87baf3e284ddc46b77bf34d4a7fa4c2c8357 Mon Sep 17 00:00:00 2001 From: olender Date: Tue, 7 Feb 2023 10:38:06 -0300 Subject: [PATCH] added parallel support to custom file names --- spyro/io/io.py | 25 ++++++++++++++++++++----- spyro/plots/plots.py | 2 +- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/spyro/io/io.py b/spyro/io/io.py index 5ba47cd8..218ffd24 100644 --- a/spyro/io/io.py +++ b/spyro/io/io.py @@ -30,8 +30,13 @@ def wrapper(*args, **kwargs): ) ) else: - func(*args, **dict(kwargs, file_name=custom_file_name)) - + func( + *args, + **dict( + kwargs, + file_name=custom_file_name+"shot_record_" + str(snum + 1) + ".dat" + ) + ) return wrapper @@ -54,7 +59,13 @@ def wrapper(*args, **kwargs): ) ) else: - values = func(*args, **dict(kwargs, file_name=custom_file_name)) + values = func( + *args, + **dict( + kwargs, + file_name=custom_file_name+"shot_record_" + str(snum + 1) + ".dat" + ) + ) return values return wrapper @@ -67,9 +78,13 @@ def wrapper(*args, **kwargs): acq = args[0].get("acquisition") num = len(acq["source_pos"]) _comm = args[1] + custom_file_name = kwargs.get("file_name") for snum in range(num): if is_owner(_comm, snum) and _comm.comm.rank == 0: - func(*args, **dict(kwargs, file_name=str(snum + 1))) + if custom_file_name is None: + func(*args, **dict(kwargs, file_name="shot_number_" + str(snum + 1))) + else: + func(*args, **dict(kwargs, file_name=custom_file_name + str(snum + 1))) return wrapper @@ -215,7 +230,7 @@ def save_shots(model, comm, array, file_name=None): Parameters ---------- - filename: str, optional by default shot_number_#.dat + file_name: str, optional by default shot_record_#.dat The filename to save the data as a `pickle` array: `numpy.ndarray` The data to save a pickle (e.g., a shot) diff --git a/spyro/plots/plots.py b/spyro/plots/plots.py index dade1732..7215fc8b 100644 --- a/spyro/plots/plots.py +++ b/spyro/plots/plots.py @@ -73,7 +73,7 @@ def plot_shots( plt.xlim(start_index, end_index) plt.ylim(tf, 0) plt.subplots_adjust(left=0.18, right=0.95, bottom=0.14, top=0.95) - plt.savefig("shot_number_" + file_name + "." + file_format, format=file_format) + plt.savefig(file_name + "." + file_format, format=file_format) # plt.axis("image") if show: plt.show()