diff --git a/Makefile b/Makefile index a0d1756d..b02af0dd 100644 --- a/Makefile +++ b/Makefile @@ -25,13 +25,7 @@ clean: @rm -rf dist/ format: - isort -rc spyro/ test/*.py - black spyro/ test/*.py - -black: - black . + autopep8 --in-place --global-config setup.cfg --recursive . lint: - isort --check . spyro/ setup.py test/*.py - black --check spyro/ setup.py test/*.py flake8 setup.py spyro/ test/*.py diff --git a/demos/with_automatic_differentiation/run_forward_ad.py b/demos/with_automatic_differentiation/run_forward_ad.py index 47a789eb..a9ec3276 100644 --- a/demos/with_automatic_differentiation/run_forward_ad.py +++ b/demos/with_automatic_differentiation/run_forward_ad.py @@ -1,7 +1,7 @@ import firedrake as fire import spyro from demos.with_automatic_differentiation.utils import \ - model_settings, make_c_camembert + model_settings, make_c_camembert import os os.environ["OMP_NUM_THREADS"] = "1" @@ -16,7 +16,7 @@ element = fire.FiniteElement( model["opts"]["method"], mesh.ufl_cell(), degree=model["opts"]["degree"], variant=model["opts"]["quadrature"] - ) +) V = fire.FunctionSpace(mesh, element) @@ -43,4 +43,4 @@ sol = forward_solver.solution fire.VTKFile( "solution_" + str(source_number) + ".pvd", comm=my_ensemble.comm - ).write(sol) + ).write(sol) diff --git a/demos/with_automatic_differentiation/run_fwi_ad.py b/demos/with_automatic_differentiation/run_fwi_ad.py index 84fb2393..b38e89ec 100644 --- a/demos/with_automatic_differentiation/run_fwi_ad.py +++ b/demos/with_automatic_differentiation/run_fwi_ad.py @@ -65,7 +65,7 @@ def forward( sol = forward_solver.solution fire.VTKFile( "solution_" + str(source_number) + ".pvd", comm=my_ensemble.comm - ).write(sol) + ).write(sol) return receiver_data, J @@ -77,7 +77,7 @@ def forward( element = fire.FiniteElement( model["opts"]["method"], mesh.ufl_cell(), degree=model["opts"]["degree"], variant=model["opts"]["quadrature"] - ) +) V = fire.FunctionSpace(mesh, element) @@ -97,7 +97,7 @@ def forward( guess_rec, J = forward( c_guess, compute_functional=True, true_data_receivers=true_rec, annotate=True - ) +) # :class:`~.EnsembleReducedFunctional` is employed to recompute in # parallel the functional and its gradient associated with the multiple sources diff --git a/demos/with_automatic_differentiation/utils.py b/demos/with_automatic_differentiation/utils.py index 6d89fcf7..b95c1ca6 100644 --- a/demos/with_automatic_differentiation/utils.py +++ b/demos/with_automatic_differentiation/utils.py @@ -2,6 +2,7 @@ import firedrake as fire import spyro + def model_settings(): """Model settings for forward and Full Waveform Inversion (FWI) simulations. @@ -29,7 +30,7 @@ def model_settings(): # None - no shots parallelism. "type": "shots_parallelism", "num_spacial_cores": 1, # Number of cores to use in the spatial - # parallelism. + # parallelism. } # Define the domain size without the ABL. diff --git a/setup.cfg b/setup.cfg index d96cd6fe..3938897f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -3,5 +3,10 @@ ignore = E501,F403,F405,E226,E402,E721,E731,E741,W503,F999, N801,N802,N803,N806,N807,N811,N813,N814,N815,N816 exclude = .git,__pycache__ + [coverage:run] omit=*/site-packages/*,*/test/*,*/.eggs/*,/home/alexandre/firedrake/* + +[pep8] +ignore = E501,E226,E731,E741,W503 +exclude = .git,__pycache__ \ No newline at end of file diff --git a/spyro/domains/space.py b/spyro/domains/space.py index c305b2c2..9b6f618e 100644 --- a/spyro/domains/space.py +++ b/spyro/domains/space.py @@ -40,6 +40,6 @@ def FE_method(mesh, method, degree, dim=1): if dim > 1: element = VectorElement(element, dim=dim) - + function_space = FunctionSpace(mesh, element) return function_space diff --git a/spyro/examples/camembert_elastic.py b/spyro/examples/camembert_elastic.py index 082aec6a..83302a68 100644 --- a/spyro/examples/camembert_elastic.py +++ b/spyro/examples/camembert_elastic.py @@ -6,33 +6,33 @@ import numpy as np import spyro -L = 500 # [m] +L = 500 # [m] rho = 7850 # [kg/m3] lambda_in = 6.86e9 # [Pa] -lambda_out = 9.88e9 # [Pa] +lambda_out = 9.88e9 # [Pa] mu_in = 3.86e9 # [Pa] mu_out = 5.86e9 # [Pa] smag = 1e6 -freq = 2 # Central frequency of Ricker wavelet [Hz] -hf = 90 # [m] -hs = 100 # [m] +freq = 2 # Central frequency of Ricker wavelet [Hz] +hf = 90 # [m] +hs = 100 # [m] source_locations = spyro.create_transect((-hf, 0.2*L), (-hf, 0.8*L), 3) receiver_locations = spyro.create_transect((-hs, 0), (-hs, L), 40) source_locations = [[-hf, 0.5*L]] -time_step = 2e-4 # [s] -final_time = 1.5 # [s] +time_step = 2e-4 # [s] +final_time = 1.5 # [s] out_freq = int(0.01/time_step) n = 20 mesh = fire.RectangleMesh(n, n, 0, L, originX=-L, diagonal='crossed') z, x = fire.SpatialCoordinate(mesh) -zc = 250 # [m] -xc = 250 # [m] -ri = 50 # [m] +zc = 250 # [m] +xc = 250 # [m] +ri = 50 # [m] camembert = lambda v_inside, v_outside: fire.conditional( (z - zc) ** 2 + (x - xc) ** 2 < ri**2, v_inside, v_outside) @@ -60,8 +60,8 @@ "delay": 0, "delay_type": "time", "amplitude": np.array([0, smag]), - #"amplitude": smag * np.eye(2), - #"amplitude": smag * np.array([[0, 1], [-1, 0]]), + # "amplitude": smag * np.eye(2), + # "amplitude": smag * np.array([[0, 1], [-1, 0]]), "receiver_locations": receiver_locations, } diff --git a/spyro/examples/cut_marmousi.py b/spyro/examples/cut_marmousi.py index 989c50f3..05cc345c 100644 --- a/spyro/examples/cut_marmousi.py +++ b/spyro/examples/cut_marmousi.py @@ -114,6 +114,7 @@ class Cut_marmousi_acoustic(Example_model_acoustic): Dictionary with the parameters of the model that are different from the default model. The default is None. """ + def __init__( self, dictionary=None, diff --git a/spyro/examples/elastic_cube_3D.py b/spyro/examples/elastic_cube_3D.py index 911ba075..9102ca48 100644 --- a/spyro/examples/elastic_cube_3D.py +++ b/spyro/examples/elastic_cube_3D.py @@ -9,11 +9,11 @@ N = 5 # Number of elements in each direction h = L/N # Element size [m] -c_p = 5000 # P-wave velocity [m/s] -c_s = 2500 # S-wave velocity [m/s] -rho = 1000 # Density [kg/m3] +c_p = 5000 # P-wave velocity [m/s] +c_s = 2500 # S-wave velocity [m/s] +rho = 1000 # Density [kg/m3] -smag = 1e9 # Source magnitude +smag = 1e9 # Source magnitude freq = 1 # Source frequency [Hz] final_time = 2 diff --git a/spyro/examples/marmousi.py b/spyro/examples/marmousi.py index 56df9b4d..ea2db651 100644 --- a/spyro/examples/marmousi.py +++ b/spyro/examples/marmousi.py @@ -122,6 +122,7 @@ class Marmousi_acoustic(Example_model_acoustic): Dictionary with the parameters of the model that are different from the default model. The default is None. """ + def __init__( self, dictionary=None, diff --git a/spyro/examples/rectangle.py b/spyro/examples/rectangle.py index 5536876d..88737585 100644 --- a/spyro/examples/rectangle.py +++ b/spyro/examples/rectangle.py @@ -121,6 +121,7 @@ class Rectangle_acoustic(Example_model_acoustic): If True, the mesh will be periodic in all directions. The default is False. """ + def __init__( self, dictionary=None, diff --git a/spyro/io/dictionaryio.py b/spyro/io/dictionaryio.py index 9c956693..7cb48e08 100644 --- a/spyro/io/dictionaryio.py +++ b/spyro/io/dictionaryio.py @@ -526,7 +526,6 @@ def get_mesh_type(self): def _derive_mesh_type(self): dictionary = self.mesh_dictionary - user_mesh_in_dictionary = False if "user_mesh" not in dictionary: dictionary["user_mesh"] = None diff --git a/spyro/io/field_logger.py b/spyro/io/field_logger.py index a6a3c862..194984b7 100644 --- a/spyro/io/field_logger.py +++ b/spyro/io/field_logger.py @@ -4,15 +4,17 @@ from .basicio import parallel_print + class Field: def __init__(self, name, file, callback): self.name = name self.file = file self.callback = callback - + def write(self, t): self.file.write(self.callback(), time=t, name=self.name) + class FieldLogger: def __init__(self, comm, vis_dict): self.comm = comm @@ -21,14 +23,14 @@ def __init__(self, comm, vis_dict): self.__source_id = None self.__enabled_fields = [] self.__wave_data = [] - + def add_field(self, key, name, callback): self.__wave_data.append((key, name, callback)) - + def start_logging(self, source_id): - if self.__source_id != None: + if self.__source_id is not None: warnings.warn("Started a new record without stopping the previous one") - + self.__source_id = source_id self.__enabled_fields = [] for key, name, callback in self.__wave_data: @@ -42,10 +44,10 @@ def start_logging(self, source_id): file = VTKFile(filename, comm=self.comm.comm) self.__enabled_fields.append(Field(name, file, callback)) - + def stop_logging(self): self.__source_id = None - + def log(self, t): for field in self.__enabled_fields: - field.write(t) \ No newline at end of file + field.write(t) diff --git a/spyro/io/model_parameters.py b/spyro/io/model_parameters.py index 44cce24f..df368466 100644 --- a/spyro/io/model_parameters.py +++ b/spyro/io/model_parameters.py @@ -570,9 +570,9 @@ def _sanitize_optimization_and_velocity(self): if "velocity_conditional" not in dictionary["synthetic_data"]: self.velocity_model_type = None warnings.warn( - "No velocity model set initially. If using " \ - "user defined conditional or expression, please " \ - "input it in the Wave object." + "No velocity model set initially. If using " + "user defined conditional or expression, please " + "input it in the Wave object." ) if "velocity_conditional" in dictionary["synthetic_data"]: @@ -599,7 +599,7 @@ def _sanitize_optimization_and_velocity_for_fwi(self): else: default_optimization_parameters = { "General": {"Secant": {"Type": "Limited-Memory BFGS", - "Maximum Storage": 10}}, + "Maximum Storage": 10}}, "Step": { "Type": "Augmented Lagrangian", "Augmented Lagrangian": { diff --git a/spyro/meshing/__init__.py b/spyro/meshing/__init__.py index 2e8ebf57..1e7a04ef 100644 --- a/spyro/meshing/__init__.py +++ b/spyro/meshing/__init__.py @@ -1,6 +1,6 @@ -from .meshing_functions import RectangleMesh -from .meshing_functions import PeriodicRectangleMesh, BoxMesh -from .meshing_functions import AutomaticMesh +from .meshing_functions import RectangleMesh # noqa: F401 +from .meshing_functions import PeriodicRectangleMesh, BoxMesh # noqa: F401 +from .meshing_functions import AutomaticMesh # noqa: F401 all = [ "RectangleMesh", diff --git a/spyro/meshing/meshing_functions.py b/spyro/meshing/meshing_functions.py index 4ca7514b..26dff53e 100644 --- a/spyro/meshing/meshing_functions.py +++ b/spyro/meshing/meshing_functions.py @@ -1,4 +1,3 @@ -import os import firedrake as fire import SeismicMesh import meshio @@ -78,77 +77,77 @@ class AutomaticMesh: """ def __init__( - self, comm=None, mesh_parameters=None - ): - """ - Initialize the MeshingFunctions class. - - Parameters - ---------- - comm : MPI communicator, optional - MPI communicator. The default is None. - mesh_parameters : dict, optional - Dictionary containing the mesh parameters. The default is None. - - Raises - ------ - ValueError - If `abc_pad_length` is negative. - - Notes - ----- - The `mesh_parameters` dictionary should contain the following keys: - - 'dimension': int, optional. Dimension of the mesh. The default is 2. - - 'length_z': float, optional. Length of the mesh in the z-direction. - - 'length_x': float, optional. Length of the mesh in the x-direction. - - 'length_y': float, optional. Length of the mesh in the y-direction. - - 'cell_type': str, optional. Type of the mesh cells. - - 'mesh_type': str, optional. Type of the mesh. - - For mesh with absorbing layer only: - - 'abc_pad_length': float, optional. Length of the absorbing boundary condition padding. - - For Firedrake mesh only: - - 'dx': float, optional. Mesh element size. - - 'periodic': bool, optional. Whether the mesh is periodic. - - 'edge_length': float, optional. Length of the mesh edges. - - For SeismicMesh only: - - 'cells_per_wavelength': float, optional. Number of cells per wavelength. - - 'source_frequency': float, optional. Frequency of the source. - - 'minimum_velocity': float, optional. Minimum velocity. - - 'velocity_model_file': str, optional. File containing the velocity model. - - 'edge_length': float, optional. Length of the mesh edges. - """ - self.dimension = mesh_parameters["dimension"] - self.length_z = mesh_parameters["length_z"] - self.length_x = mesh_parameters["length_x"] - self.length_y = mesh_parameters["length_y"] - self.cell_type = mesh_parameters["cell_type"] - self.comm = comm - if mesh_parameters["abc_pad_length"] is None: - self.abc_pad = 0.0 - elif mesh_parameters["abc_pad_length"] >= 0.0: - self.abc_pad = mesh_parameters["abc_pad_length"] - else: - raise ValueError("abc_pad must be positive") - self.mesh_type = mesh_parameters["mesh_type"] - - # Firedrake mesh only parameters - self.dx = mesh_parameters["dx"] - self.quadrilateral = False - self.periodic = mesh_parameters["periodic"] - if self.dx is None: - self.dx = mesh_parameters["edge_length"] - - # SeismicMesh only parameters - self.cpw = mesh_parameters["cells_per_wavelength"] - self.source_frequency = mesh_parameters["source_frequency"] - self.minimum_velocity = mesh_parameters["minimum_velocity"] - self.lbda = None - self.velocity_model = mesh_parameters["velocity_model_file"] - self.edge_length = mesh_parameters["edge_length"] - self.output_file_name = "automatic_mesh.msh" + self, comm=None, mesh_parameters=None + ): + """ + Initialize the MeshingFunctions class. + + Parameters + ---------- + comm : MPI communicator, optional + MPI communicator. The default is None. + mesh_parameters : dict, optional + Dictionary containing the mesh parameters. The default is None. + + Raises + ------ + ValueError + If `abc_pad_length` is negative. + + Notes + ----- + The `mesh_parameters` dictionary should contain the following keys: + - 'dimension': int, optional. Dimension of the mesh. The default is 2. + - 'length_z': float, optional. Length of the mesh in the z-direction. + - 'length_x': float, optional. Length of the mesh in the x-direction. + - 'length_y': float, optional. Length of the mesh in the y-direction. + - 'cell_type': str, optional. Type of the mesh cells. + - 'mesh_type': str, optional. Type of the mesh. + + For mesh with absorbing layer only: + - 'abc_pad_length': float, optional. Length of the absorbing boundary condition padding. + + For Firedrake mesh only: + - 'dx': float, optional. Mesh element size. + - 'periodic': bool, optional. Whether the mesh is periodic. + - 'edge_length': float, optional. Length of the mesh edges. + + For SeismicMesh only: + - 'cells_per_wavelength': float, optional. Number of cells per wavelength. + - 'source_frequency': float, optional. Frequency of the source. + - 'minimum_velocity': float, optional. Minimum velocity. + - 'velocity_model_file': str, optional. File containing the velocity model. + - 'edge_length': float, optional. Length of the mesh edges. + """ + self.dimension = mesh_parameters["dimension"] + self.length_z = mesh_parameters["length_z"] + self.length_x = mesh_parameters["length_x"] + self.length_y = mesh_parameters["length_y"] + self.cell_type = mesh_parameters["cell_type"] + self.comm = comm + if mesh_parameters["abc_pad_length"] is None: + self.abc_pad = 0.0 + elif mesh_parameters["abc_pad_length"] >= 0.0: + self.abc_pad = mesh_parameters["abc_pad_length"] + else: + raise ValueError("abc_pad must be positive") + self.mesh_type = mesh_parameters["mesh_type"] + + # Firedrake mesh only parameters + self.dx = mesh_parameters["dx"] + self.quadrilateral = False + self.periodic = mesh_parameters["periodic"] + if self.dx is None: + self.dx = mesh_parameters["edge_length"] + + # SeismicMesh only parameters + self.cpw = mesh_parameters["cells_per_wavelength"] + self.source_frequency = mesh_parameters["source_frequency"] + self.minimum_velocity = mesh_parameters["minimum_velocity"] + self.lbda = None + self.velocity_model = mesh_parameters["velocity_model_file"] + self.edge_length = mesh_parameters["edge_length"] + self.output_file_name = "automatic_mesh.msh" def set_mesh_size(self, length_z=None, length_x=None, length_y=None): """ @@ -281,8 +280,8 @@ def create_firedrake_2D_mesh(self): Creates a 2D mesh based on Firedrake meshing utilities. """ if self.abc_pad: - nx = int( (self.length_x + 2*self.abc_pad) / self.dx) - nz = int( (self.length_z + self.abc_pad)/ self.dx) + nx = int((self.length_x + 2*self.abc_pad) / self.dx) + nz = int((self.length_z + self.abc_pad) / self.dx) else: nx = int(self.length_x / self.dx) nz = int(self.length_z / self.dx) diff --git a/spyro/plots/plots.py b/spyro/plots/plots.py index fbc33ac1..04a2b4cb 100644 --- a/spyro/plots/plots.py +++ b/spyro/plots/plots.py @@ -1,7 +1,5 @@ # from scipy.io import savemat import matplotlib.pyplot as plt -import matplotlib.patches as patches -from matplotlib.ticker import MultipleLocator from PIL import Image import numpy as np import firedrake @@ -59,7 +57,7 @@ def plot_shots( dt = Wave_object.dt tf = Wave_object.final_time - if out_index == None: + if out_index is None: arr = Wave_object.receivers_output else: arr = Wave_object.receivers_output[:, :, out_index] diff --git a/spyro/pml/damping.py b/spyro/pml/damping.py index 34519169..3ef21e38 100644 --- a/spyro/pml/damping.py +++ b/spyro/pml/damping.py @@ -1,5 +1,4 @@ import math -import warnings from firedrake import * # noqa: F403 @@ -33,7 +32,6 @@ def functions(Wave_obj): x = Wave_obj.mesh_x x1 = 0.0 x2 = Wave_obj.length_x - z1 = 0.0 z2 = -Wave_obj.length_z bar_sigma = ((3.0 * cmax) / (2.0 * pad_length)) * math.log10(1.0 / R) diff --git a/spyro/receivers/__init__.py b/spyro/receivers/__init__.py index 8b137891..e69de29b 100644 --- a/spyro/receivers/__init__.py +++ b/spyro/receivers/__init__.py @@ -1 +0,0 @@ - diff --git a/spyro/receivers/changing_coordinates.py b/spyro/receivers/changing_coordinates.py index 9589204c..a7d38eaa 100644 --- a/spyro/receivers/changing_coordinates.py +++ b/spyro/receivers/changing_coordinates.py @@ -763,4 +763,4 @@ def change_to_reference_hexa(p, cell_vertices): # pny = px * a21 + py * a22 + pz * a23 + a24 # pnz = px * a31 + py * a32 + pz * a33 + a34 - return (pnx, pny, pnz) + # return (pnx, pny, pnz) diff --git a/spyro/receivers/dirac_delta_projector.py b/spyro/receivers/dirac_delta_projector.py index 9eb3a6c6..d132870f 100644 --- a/spyro/receivers/dirac_delta_projector.py +++ b/spyro/receivers/dirac_delta_projector.py @@ -55,6 +55,7 @@ class Delta_projector: is_local: list List of cell IDs local to the processor """ + def __init__(self, wave_object): """ Initializes the class @@ -166,7 +167,7 @@ def new_at(self, udat, receiver_id): def __func_build_cell_tabulations(self, order): if order != 0 and order != 1: raise NotImplementedError - + element = self.choose_element() if order == 0: @@ -199,7 +200,7 @@ def __func_build_cell_tabulations(self, order): cell_tabulations[receiver_id, :] = tab return cell_tabulations - + def __reference_element(self, id): if self.dimension == 2 and self.quadrilateral is False: n_v = 3 @@ -399,7 +400,7 @@ def __point_locator_3D(self): cellVertices[receiver_id][vertex_number] = (z, x, y) return cellId_maps, cellVertices, cellNodeMaps - + def choose_element(self): if not self.quadrilateral: element = choosing_element(self.space, self.degree) @@ -417,6 +418,7 @@ def choose_element(self): raise NotImplementedError return element + def choosing_geometry(cell_geometry): """ Chooses UFC reference element geometry based on desired function space diff --git a/spyro/solvers/acoustic_wave.py b/spyro/solvers/acoustic_wave.py index 5a7c3509..7f52397a 100644 --- a/spyro/solvers/acoustic_wave.py +++ b/spyro/solvers/acoustic_wave.py @@ -1,9 +1,12 @@ import firedrake as fire import warnings +import os +from SeismicMesh import write_velocity_model from .wave import Wave from ..io.basicio import ensemble_gradient +from ..io import interpolate from .acoustic_solver_construction_no_pml import ( construct_solver_or_matrix_no_pml, ) @@ -16,6 +19,7 @@ from ..domains.space import FE_method from ..utils.typing import override + class AcousticWave(Wave): def save_current_velocity_model(self, file_name=None): if self.c is None: @@ -85,7 +89,7 @@ def reset_pressure(self): try: self.u_nm1.assign(0.0) self.u_n.assign(0.0) - except: + except Exception: warnings.warn("No pressure to reset") @override @@ -115,7 +119,7 @@ def _initialize_model_parameters(self): fire.File("initial_velocity_model.pvd").write( self.initial_velocity_model, name="velocity" ) - + self.c = self.initial_velocity_model @override @@ -159,7 +163,7 @@ def _get_next_vstate(self): return self.X_np1 else: return self.u_np1 - + @override def get_receivers_output(self): if self.abc_boundary_layer_type == "PML": @@ -167,18 +171,18 @@ def get_receivers_output(self): else: data_with_halos = self.u_n.dat.data_ro_with_halos[:] return self.receivers.interpolate(data_with_halos) - + @override def get_function(self): if self.abc_boundary_layer_type == "PML": return self.X_n.sub(0) else: return self.u_n - + @override def get_function_name(self): return "Pressure" - + @override def _create_function_space(self): return FE_method(self.mesh, self.method, self.degree) @@ -188,4 +192,4 @@ def rhs_no_pml(self): if self.abc_boundary_layer_type == "PML": return self.B.sub(0) else: - return self.B \ No newline at end of file + return self.B diff --git a/spyro/solvers/backward_time_integration.py b/spyro/solvers/backward_time_integration.py index b3727fb6..f4fe9f81 100644 --- a/spyro/solvers/backward_time_integration.py +++ b/spyro/solvers/backward_time_integration.py @@ -65,7 +65,7 @@ def backward_wave_propagator_no_pml(Wave_obj, dt=None): dt = Wave_obj.dt t = Wave_obj.current_time if t != final_time: - print(f"Current time of {t}, different than final_time of {final_time}. Setting final_time to current time in backwards propagation.", flush= True) + print(f"Current time of {t}, different than final_time of {final_time}. Setting final_time to current time in backwards propagation.", flush=True) nt = int(t / dt) + 1 # number of timesteps u_nm1 = Wave_obj.u_nm1 @@ -194,7 +194,7 @@ def mixed_space_backward_wave_propagator(Wave_obj, dt=None): dt = Wave_obj.dt t = Wave_obj.current_time if t != final_time: - print(f"Current time of {t}, different than final_time of {final_time}. Setting final_time to current time in backwards propagation.", flush= True) + print(f"Current time of {t}, different than final_time of {final_time}. Setting final_time to current time in backwards propagation.", flush=True) nt = int(t / dt) + 1 # number of timesteps X_nm1 = Wave_obj.X_nm1 @@ -216,7 +216,7 @@ def mixed_space_backward_wave_propagator(Wave_obj, dt=None): uadj = fire.Function(Wave_obj.function_space) # auxiliarly function for the gradient compt. # ffG = -2 * (Wave_obj.c)**(-3) * fire.dot(dufordt2, uadj) * m_v * fire.dx(scheme=Wave_obj.quadrature_rule) - ffG = 2.0 * Wave_obj.c * fire.dot(fire.grad(uadj), fire.grad(ufor)) * m_v * fire.dx(scheme=Wave_obj.quadrature_rule) + ffG = 2.0 * Wave_obj.c * fire.dot(fire.grad(uadj), fire.grad(ufor)) * m_v * fire.dx(scheme=Wave_obj.quadrature_rule) lhsG = mgrad rhsG = ffG diff --git a/spyro/solvers/elastic_wave/elastic_wave.py b/spyro/solvers/elastic_wave/elastic_wave.py index d201afe1..14c13bdb 100644 --- a/spyro/solvers/elastic_wave/elastic_wave.py +++ b/spyro/solvers/elastic_wave/elastic_wave.py @@ -4,12 +4,14 @@ from ..wave import Wave from ...utils.typing import override + class ElasticWave(Wave, metaclass=ABCMeta): '''Base class for elastic wave propagators''' + def __init__(self, dictionary, comm=None): super().__init__(dictionary, comm=comm) - self.time = Constant(0) # Time variable - + self.time = Constant(0) # Time variable + @override def _initialize_model_parameters(self): d = self.input_dictionary.get("synthetic_data", False) @@ -22,15 +24,15 @@ def _initialize_model_parameters(self): raise Exception(f"Invalid synthetic data type: {d['type']}") else: raise Exception("Input dictionary must contain ['synthetic_data']['type']") - + @abstractmethod def initialize_model_parameters_from_object(self, synthetic_data_dict): pass - + @abstractmethod def initialize_model_parameters_from_file(self, synthetic_data_dict): pass @override def update_source_expression(self, t): - self.time.assign(t) \ No newline at end of file + self.time.assign(t) diff --git a/spyro/solvers/elastic_wave/forms.py b/spyro/solvers/elastic_wave/forms.py index e14e8ed5..598078ae 100644 --- a/spyro/solvers/elastic_wave/forms.py +++ b/spyro/solvers/elastic_wave/forms.py @@ -1,10 +1,9 @@ -import numpy as np - from firedrake import (assemble, Cofunction, Constant, div, dot, dx, grad, inner, lhs, LinearSolver, rhs, TestFunction, TrialFunction) from .local_abc import clayton_engquist_A1 + def isotropic_elastic_without_pml(wave): V = wave.function_space quad_rule = wave.quadrature_rule @@ -14,7 +13,7 @@ def isotropic_elastic_without_pml(wave): u_nm1 = wave.u_nm1 u_n = wave.u_n - + dt = Constant(wave.dt) rho = wave.rho lmbda = wave.lmbda @@ -24,7 +23,7 @@ def isotropic_elastic_without_pml(wave): eps = lambda v: 0.5*(grad(v) + grad(v).T) F_k = lmbda*div(u_n)*div(v)*dx(scheme=quad_rule) \ - + 2*mu*inner(eps(u_n), eps(v))*dx(scheme=quad_rule) + + 2*mu*inner(eps(u_n), eps(v))*dx(scheme=quad_rule) F_s = 0 b = wave.body_forces @@ -50,5 +49,6 @@ def isotropic_elastic_without_pml(wave): wave.rhs = rhs(F) wave.B = Cofunction(V.dual()) + def isotropic_elastic_with_pml(): - raise NotImplementedError \ No newline at end of file + raise NotImplementedError diff --git a/spyro/solvers/elastic_wave/isotropic_wave.py b/spyro/solvers/elastic_wave/isotropic_wave.py index 6415db32..565db458 100644 --- a/spyro/solvers/elastic_wave/isotropic_wave.py +++ b/spyro/solvers/elastic_wave/isotropic_wave.py @@ -9,19 +9,21 @@ from ...domains.space import FE_method from ...utils.typing import override + class IsotropicWave(ElasticWave): '''Isotropic elastic wave propagator''' + def __init__(self, dictionary, comm=None): super().__init__(dictionary, comm=comm) self.rho = None # Density - self.lmbda = None # First Lame parameter + self.lmbda = None # First Lame parameter self.mu = None # Second Lame parameter self.c_s = None # Secondary wave velocity self.u_n = None # Current displacement field - self.u_nm1 = None # Displacement field in previous iteration - self.u_np1 = None # Displacement field in next iteration + self.u_nm1 = None # Displacement field in previous iteration + self.u_np1 = None # Displacement field in next iteration # Volumetric sourcers (defined through UFL) self.body_forces = None @@ -40,7 +42,7 @@ def __init__(self, dictionary, comm=None): self.C_h = None self.field_logger.add_field("s-wave", "S-wave", lambda: self.update_s_wave()) - + @override def initialize_model_parameters_from_object(self, synthetic_data_dict: dict): def constant_wrapper(value): @@ -48,28 +50,28 @@ def constant_wrapper(value): return Constant(value) else: return value - + def get_value(key, default=None): return constant_wrapper(synthetic_data_dict.get(key, default)) - + self.rho = get_value("density") self.lmbda = get_value("lambda", default=get_value("lame_first")) self.mu = get_value("mu", get_value("lame_second")) self.c = get_value("p_wave_velocity") self.c_s = get_value("s_wave_velocity") - + # Check if {rho, lambda, mu} is set and {c, c_s} are not option_1 = bool(self.rho) and \ - bool(self.lmbda) and \ - bool(self.mu) and \ - not bool(self.c) and \ - not bool(self.c_s) + bool(self.lmbda) and \ + bool(self.mu) and \ + not bool(self.c) and \ + not bool(self.c_s) # Check if {rho, c, c_s} is set and {lambda, mu} are not option_2 = bool(self.rho) and \ - bool(self.c) and \ - bool(self.c_s) and \ - not bool(self.lmbda) and \ - not bool(self.mu) + bool(self.c) and \ + bool(self.c_s) and \ + not bool(self.lmbda) and \ + not bool(self.mu) if option_1: self.c = ((self.lmbda + 2*self.mu)/self.rho)**0.5 @@ -78,19 +80,19 @@ def get_value(key, default=None): self.mu = self.rho*self.c_s**2 self.lmbda = self.rho*self.c**2 - 2*self.mu else: - raise Exception(f"Inconsistent selection of isotropic elastic wave parameters:\n" \ - f" Density : {bool(self.rho)}\n"\ - f" Lame first : {bool(self.lmbda)}\n"\ - f" Lame second : {bool(self.mu)}\n"\ - f" P-wave velocity: {bool(self.c)}\n"\ - f" S-wave velocity: {bool(self.c_s)}\n"\ - "The valid options are {Density, Lame first, Lame second} "\ + raise Exception(f"Inconsistent selection of isotropic elastic wave parameters:\n" + f" Density : {bool(self.rho)}\n" + f" Lame first : {bool(self.lmbda)}\n" + f" Lame second : {bool(self.mu)}\n" + f" P-wave velocity: {bool(self.c)}\n" + f" S-wave velocity: {bool(self.c_s)}\n" + "The valid options are {Density, Lame first, Lame second} " "or (exclusive) {Density, P-wave velocity, S-wave velocity}") - + @override def initialize_model_parameters_from_file(self, synthetic_data_dict): raise NotImplementedError - + @override def _create_function_space(self): return FE_method(self.mesh, self.method, self.degree, @@ -119,7 +121,7 @@ def _set_next_vstate(self, vstate): @override def _get_next_vstate(self): return self.u_np1 - + @override def get_receivers_output(self): if self.abc_boundary_layer_type == "PML": @@ -146,23 +148,23 @@ def matrix_building(self): name=self.get_function_name()) self.u_np1 = Function(self.function_space, name=self.get_function_name()) - + self.parse_initial_conditions() self.parse_boundary_conditions() self.parse_volumetric_forces() - + if self.abc_boundary_layer_type is None: isotropic_elastic_without_pml(self) elif self.abc_boundary_layer_type == "PML": isotropic_elastic_with_pml(self) - + @override def rhs_no_pml(self): if self.abc_boundary_layer_type == "PML": raise NotImplementedError else: return self.B - + def parse_initial_conditions(self): time_dict = self.input_dictionary["time_axis"] initial_condition = time_dict.get("initial_condition", None) @@ -170,7 +172,7 @@ def parse_initial_conditions(self): x_vec = self.get_spatial_coordinates() self.u_n.interpolate(initial_condition(x_vec, 0 - self.dt)) self.u_nm1.interpolate(initial_condition(x_vec, 0 - 2*self.dt)) - + def parse_boundary_conditions(self): bc_list = self.input_dictionary.get("boundary_conditions", []) for tag, id, value in bc_list: @@ -185,31 +187,31 @@ def parse_boundary_conditions(self): else: raise Exception(f"Unsupported boundary condition with tag: {tag}") self.bcs.append(DirichletBC(subspace, value, id)) - + def parse_volumetric_forces(self): acquisition_dict = self.input_dictionary["acquisition"] body_forces_data = acquisition_dict.get("body_forces", None) if body_forces_data is not None: x_vec = self.get_spatial_coordinates() self.body_forces = body_forces_data(x_vec, self.time) - + def update_p_wave(self): - if self.p_wave == None: + if self.p_wave is None: self.D_h = FunctionSpace(self.mesh, "DG", 0) self.p_wave = Function(self.D_h) - + self.p_wave.assign(project(div(self.get_function()), self.D_h)) - + return self.p_wave def update_s_wave(self): - if self.s_wave == None: + if self.s_wave is None: if self.dimension == 2: self.C_h = FunctionSpace(self.mesh, "DG", 0) else: self.C_h = VectorFunctionSpace(self.mesh, "DG", 0) self.s_wave = Function(self.C_h) - + self.s_wave.assign(project(curl(self.get_function()), self.C_h)) - - return self.s_wave \ No newline at end of file + + return self.s_wave diff --git a/spyro/solvers/elastic_wave/local_abc.py b/spyro/solvers/elastic_wave/local_abc.py index 55a6380c..a81744b4 100644 --- a/spyro/solvers/elastic_wave/local_abc.py +++ b/spyro/solvers/elastic_wave/local_abc.py @@ -1,11 +1,12 @@ from firedrake import (Constant, ds, TestFunction) + def clayton_engquist_A1(wave): ''' - Returns the linear form associated with the traction loads + Returns the linear form associated with the traction loads when combined with the Clayton-Engquist A1 relations. ''' - F_t = 0 # linear form + F_t = 0 # linear form V = wave.function_space v = TestFunction(V) @@ -20,7 +21,9 @@ def clayton_engquist_A1(wave): qr_s = wave.surface_quadrature_rule # Index of each coordinate - iz = 0; ix = 1; iy = 2 + iz = 0 + ix = 1 + iy = 2 # Partial derivatives uz_dt = (u_n[iz] - u_nm1[iz])/dt @@ -46,7 +49,7 @@ def clayton_engquist_A1(wave): if wave.dimension == 3: sig_yz = rho*c_s*uy_dt + rho*(c_s**2)*uz_dy F_t += -sig_yz*v[iy]*ds(1, scheme=qr_s) - + # Plane z = 0 sig_zz = -rho*c_p*uz_dt + rho*(c_p**2 - 2*c_s**2)*ux_dx if wave.dimension == 3: @@ -90,4 +93,4 @@ def clayton_engquist_A1(wave): sig_yy = -rho*c_p*uy_dt + rho*(c_p**2 - 2*c_s**2)*(uz_dz + ux_dx) F_t += (sig_zy*v[iz] + sig_xy*v[ix] + sig_yy*v[iy])*ds(6, scheme=qr_s) - return F_t \ No newline at end of file + return F_t diff --git a/spyro/solvers/forward_ad.py b/spyro/solvers/forward_ad.py index db21bedb..7e54e325 100644 --- a/spyro/solvers/forward_ad.py +++ b/spyro/solvers/forward_ad.py @@ -1,9 +1,7 @@ import firedrake as fire import firedrake.adjoint as fire_ad -from ..domains import quadrature from .time_integration_ad import central_difference_acoustic from firedrake.__future__ import interpolate -import finat # Note this turns off non-fatal warnings fire.set_log_level(fire.ERROR) @@ -33,7 +31,7 @@ def __init__(self, model, mesh, function_space): def execute_acoustic( self, c, source_number, wavelet, compute_functional=False, true_data_receivers=None - ): + ): """Time-stepping acoustic forward solver. The time integration is done using a central difference scheme. @@ -73,7 +71,7 @@ def execute_acoustic( source_mesh = fire.VertexOnlyMesh( self.mesh, [self.model["acquisition"]["source_pos"][source_number]] - ) + ) # Source function space. V_s = fire.FunctionSpace(source_mesh, "DG", 0) d_s = fire.Function(V_s) @@ -125,13 +123,13 @@ def _solver_parameters(self): params = {"ksp_type": "preonly", "pc_type": "jacobi"} elif ( self.model["opts"]["method"] == "CG" - and self.mesh.ufl_cell() != quadrilateral - and self.mesh.ufl_cell() != hexahedron + and self.mesh.ufl_cell() != quadrilateral # noqa: F821 + and self.mesh.ufl_cell() != hexahedron # noqa: F821 ): params = {"ksp_type": "cg", "pc_type": "jacobi"} elif self.model["opts"]["method"] == "CG" and ( - self.mesh.ufl_cell() == quadrilateral - or self.mesh.ufl_cell() == hexahedron + self.mesh.ufl_cell() == quadrilateral # noqa: F821 + or self.mesh.ufl_cell() == hexahedron # noqa: F821 ): params = {"ksp_type": "preonly", "pc_type": "jacobi"} else: diff --git a/spyro/solvers/inversion.py b/spyro/solvers/inversion.py index 98b56021..c3f89f7b 100644 --- a/spyro/solvers/inversion.py +++ b/spyro/solvers/inversion.py @@ -1,7 +1,7 @@ import firedrake as fire import warnings from scipy.optimize import minimize as scipy_minimize -from mpi4py import MPI +from mpi4py import MPI # noqa: F401 import numpy as np from .acoustic_wave import AcousticWave @@ -507,7 +507,7 @@ def set_gradient_mask(self, boundaries=None): Sets the gradient mask for zeroing gradient values outside defined boundaries. Args: - boundaries (list, optional): List of boundary values for the mask. If not provided, + boundaries (list, optional): List of boundary values for the mask. If not provided, the method expects the abc_active to be True and uses PML locations for boundary values. @@ -516,7 +516,7 @@ def set_gradient_mask(self, boundaries=None): ValueError: If mask options do not make sense. Warnings: - UserWarning: If abc_active is True and boundaries is not None, the boundaries will + UserWarning: If abc_active is True and boundaries is not None, the boundaries will override the PML boundaries for the mask. """ @@ -537,23 +537,23 @@ def set_gradient_mask(self, boundaries=None): self.mask_obj = mask_obj def _apply_gradient_mask(self): - """ - Applies a gradient mask to the gradient if it exists. + """ + Applies a gradient mask to the gradient if it exists. - If a gradient mask is available, this method applies the mask to the gradient - using the `apply_mask` method of the `mask_obj`. If no gradient mask is available, - this method does nothing. + If a gradient mask is available, this method applies the mask to the gradient + using the `apply_mask` method of the `mask_obj`. If no gradient mask is available, + this method does nothing. - Parameters: - None + Parameters: + None - Returns: - None - """ - if self.has_gradient_mask: - self.gradient = self.mask_obj.apply_mask(self.gradient) - else: - pass + Returns: + None + """ + if self.has_gradient_mask: + self.gradient = self.mask_obj.apply_mask(self.gradient) + else: + pass class SyntheticRealAcousticWave(AcousticWave): @@ -574,6 +574,7 @@ class SyntheticRealAcousticWave(AcousticWave): forward_solve(): Solves the forward problem. """ + def __init__(self, dictionary=None, comm=None): super().__init__(dictionary=dictionary, comm=comm) diff --git a/spyro/solvers/mms_acoustic.py b/spyro/solvers/mms_acoustic.py index c085ba60..6bdbb404 100644 --- a/spyro/solvers/mms_acoustic.py +++ b/spyro/solvers/mms_acoustic.py @@ -2,6 +2,7 @@ from .acoustic_wave import AcousticWave from ..utils.typing import override + class AcousticWaveMMS(AcousticWave): """Class for solving the acoustic wave equation in 2D or 3D using the finite element method. This class inherits from the AcousticWave class @@ -69,7 +70,7 @@ def analytical_solution(self, t): # self.analytical.assign(analytical) return self.analytical - + @override def update_source_expression(self, t): self.q_t.assign(2*t) diff --git a/spyro/solvers/time_integration_central_difference.py b/spyro/solvers/time_integration_central_difference.py index a50fb68a..0b73095d 100644 --- a/spyro/solvers/time_integration_central_difference.py +++ b/spyro/solvers/time_integration_central_difference.py @@ -23,12 +23,12 @@ def central_difference(wave, source_id=0): rhs_forcing = fire.Cofunction(wave.function_space.dual()) wave.field_logger.start_logging(source_id) - + wave.comm.comm.barrier() t = wave.current_time nt = int(wave.final_time / wave.dt) + 1 # number of timesteps - + usol = [ fire.Function(wave.function_space, name=wave.get_function_name()) for t in range(nt) @@ -46,7 +46,7 @@ def central_difference(wave, source_id=0): f = wave.sources.apply_source(rhs_forcing, step) B0 = wave.rhs_no_pml() B0 += f - + wave.solver.solve(wave.next_vstate, wave.B) wave.prev_vstate = wave.vstate @@ -67,7 +67,7 @@ def central_difference(wave, source_id=0): helpers.display_progress(wave.comm, t) t = step * float(wave.dt) - + wave.current_time = t helpers.display_progress(wave.comm, t) diff --git a/spyro/solvers/wave.py b/spyro/solvers/wave.py index f4cc3afd..62d73f43 100644 --- a/spyro/solvers/wave.py +++ b/spyro/solvers/wave.py @@ -1,13 +1,11 @@ -import os from abc import abstractmethod, ABCMeta import warnings import firedrake as fire from firedrake import sin, cos, pi, tanh, sqrt # noqa: F401 -from SeismicMesh import write_velocity_model from .time_integration_central_difference import central_difference as time_integrator from ..domains.quadrature import quadrature_rules -from ..io import Model_parameters, interpolate +from ..io import Model_parameters from ..io.basicio import ensemble_propagator from ..io.field_logger import FieldLogger from .. import utils @@ -117,42 +115,42 @@ def matrix_building(self): pass def set_mesh( - self, - user_mesh=None, - mesh_parameters=None, - ): - """ - Set the mesh for the solver. - - Args: - user_mesh (optional): User-defined mesh. Defaults to None. - mesh_parameters (optional): Parameters for generating a mesh. Defaults to None. - """ - super().set_mesh( - user_mesh=user_mesh, - mesh_parameters=mesh_parameters, - ) + self, + user_mesh=None, + mesh_parameters=None, + ): + """ + Set the mesh for the solver. - self.mesh = self.get_mesh() - self._build_function_space() - self._map_sources_and_receivers() + Args: + user_mesh (optional): User-defined mesh. Defaults to None. + mesh_parameters (optional): Parameters for generating a mesh. Defaults to None. + """ + super().set_mesh( + user_mesh=user_mesh, + mesh_parameters=mesh_parameters, + ) + + self.mesh = self.get_mesh() + self._build_function_space() + self._map_sources_and_receivers() def set_solver_parameters(self, parameters=None): - """ - Set the solver parameters. - - Args: - parameters (dict): A dictionary containing the solver parameters. - - Returns: - None - """ - if parameters is not None: - self.solver_parameters = parameters - elif parameters is None: - self.solver_parameters = get_default_parameters_for_method( - self.method - ) + """ + Set the solver parameters. + + Args: + parameters (dict): A dictionary containing the solver parameters. + + Returns: + None + """ + if parameters is not None: + self.solver_parameters = parameters + elif parameters is None: + self.solver_parameters = get_default_parameters_for_method( + self.method + ) def get_spatial_coordinates(self): if self.dimension == 2: @@ -227,8 +225,8 @@ def set_initial_velocity_model( self.initial_velocity_model = vp else: raise ValueError( - "Please specify either a conditional, expression, firedrake " \ - "function or new file name (segy or hdf5)." + "Please specify either a conditional, expression, firedrake " + "function or new file name (segy or hdf5)." ) if output: fire.File("initial_velocity_model.pvd").write( @@ -239,7 +237,7 @@ def _map_sources_and_receivers(self): if self.source_type == "ricker": self.sources = Sources(self) self.receivers = Receivers(self) - + @abstractmethod def _initialize_model_parameters(self): pass @@ -256,7 +254,7 @@ def _build_function_space(self): self.stiffness_quadrature_rule = k_rule self.surface_quadrature_rule = s_rule - # TO REVIEW: why are the mesh coordinates assigned here? I believe they + # TO REVIEW: why are the mesh coordinates assigned here? I believe they # should be copied when the mesh is assigned if self.dimension == 2: z, x = fire.SpatialCoordinate(self.mesh) @@ -317,7 +315,7 @@ def set_last_solve_as_real_shot_record(self): if self.current_time == 0.0: raise ValueError("No previous solve to set as real shot record.") self.real_shot_record = self.forward_solution_receivers - + @abstractmethod def _set_vstate(self, vstate): pass @@ -349,7 +347,7 @@ def _get_next_vstate(self): fset=lambda self, value: self._set_prev_vstate(value)) next_vstate = property(fget=lambda self: self._get_next_vstate(), fset=lambda self, value: self._set_next_vstate(value)) - + @abstractmethod def get_receivers_output(self): pass @@ -362,12 +360,12 @@ def get_function(self): @abstractmethod def get_function_name(self): - '''Returns the string representing the function of the wave object + '''Returns the string representing the function of the wave object (e.g., "pressure" or "displacement")''' pass def update_source_expression(self, t): - '''Update the source expression during wave propagation. This method must be + '''Update the source expression during wave propagation. This method must be implemented only by subclasses that make use of the source term''' pass @@ -401,10 +399,10 @@ def wave_propagator(self, dt=None, final_time=None, source_num=0): usol, usol_recv = time_integrator(self, source_num) return usol, usol_recv - + def get_dt(self): return self._dt - + def set_dt(self, dt): self._dt = dt if self.sources is not None: diff --git a/spyro/tools/cells_per_wavelength_calculator.py b/spyro/tools/cells_per_wavelength_calculator.py index af6e7895..2b41d526 100644 --- a/spyro/tools/cells_per_wavelength_calculator.py +++ b/spyro/tools/cells_per_wavelength_calculator.py @@ -82,6 +82,7 @@ class Meshing_parameter_calculator: build_current_object(cpw, degree=None): Builds the current acoustic wave solver object. """ + def __init__(self, parameters_dictionary): """ Initializes the Meshing_parameter_calculator class with a dictionary of parameters. @@ -347,7 +348,7 @@ def find_minimum(self, starting_cpw=None, TOL=None, accuracy=None, savetxt=False # Running forward model Wave_obj = self.build_current_object(cpw) - Wave_obj._initialize_model_parameters() # TO REVIEW: call to protected method + Wave_obj._initialize_model_parameters() # TO REVIEW: call to protected method # Setting up time-step if self.timestep_calculation != "float": diff --git a/spyro/tools/input_models.py b/spyro/tools/input_models.py index 12b85ce3..cf4feb7a 100644 --- a/spyro/tools/input_models.py +++ b/spyro/tools/input_models.py @@ -51,7 +51,7 @@ def build_on_top_of_base_dictionary(variables): "dimension": variables["dimension"], "automatic_adjoint": False, } - model_dictionary["parallelism"] = {"type": "automatic",} + model_dictionary["parallelism"] = {"type": "automatic", } model_dictionary["mesh"] = { "Lz": variables["Lz"], "Lx": variables["Lx"], @@ -173,7 +173,7 @@ def create_initial_model_for_meshing_parameter_2D_heterogeneous(Meshing_calc_obj ---------- Meshing_calc_obj : spyro.Meshing_parameter_calculator The meshing calculation object. - + Returns ------- model_dictionary : dict @@ -186,7 +186,7 @@ def create_initial_model_for_meshing_parameter_2D_heterogeneous(Meshing_calc_obj method = Meshing_calc_obj.FEM_method_to_evaluate degree = Meshing_calc_obj.desired_degree - reduced = Meshing_calc_obj.reduced_obj_for_testing + reduced = Meshing_calc_obj.reduced_obj_for_testing # noqa: F841 # Domain calculations lbda = c_value / frequency @@ -218,7 +218,7 @@ def create_initial_model_for_meshing_parameter_2D_heterogeneous(Meshing_calc_obj ) # Time axis calculations - tmin = 1.0 / frequency + tmin = 1.0 / frequency # noqa: F841 final_time = 7.5 variables = { diff --git a/spyro/utils/typing.py b/spyro/utils/typing.py index b393ccb6..6d91d4af 100644 --- a/spyro/utils/typing.py +++ b/spyro/utils/typing.py @@ -3,4 +3,4 @@ def override(func): This decorator should be replaced by typing.override when Python version is updated to 3.12 ''' - return func \ No newline at end of file + return func diff --git a/test/test_MMS.py b/test/test_MMS.py index a374d923..3f0b8357 100644 --- a/test/test_MMS.py +++ b/test/test_MMS.py @@ -1,6 +1,5 @@ import math from copy import deepcopy -import pytest from firedrake import * import spyro @@ -50,7 +49,7 @@ def test_isotropic_wave_2D(): b1 = lambda x, t: -(2*x[0]**2 + 6*x[1]**2 - 16*x[0]*x[1] + 10*x[0] - 14*x[1] + 4)*t b2 = lambda x, t: -(-12*x[0]**2 - 4*x[1]**2 + 8*x[0]*x[1] - 16*x[0] + 8*x[1] - 2)*t b = lambda x, t: as_vector([b1(x, t), b2(x, t)]) - + dt = 1e-3 fo = int(0.1/dt) @@ -86,8 +85,9 @@ def test_isotropic_wave_2D(): assert math.isclose(e1, 0.0, abs_tol=1e-7) assert math.isclose(e2, 0.0, abs_tol=1e-7) + if __name__ == "__main__": test_method_triangles_lumped() test_method_quads_lumped() - print("END") \ No newline at end of file + print("END") diff --git a/test/test_field_logger.py b/test/test_field_logger.py index 0ad32bfd..6037ab29 100644 --- a/test/test_field_logger.py +++ b/test/test_field_logger.py @@ -8,6 +8,7 @@ comm = fire.Ensemble(MPI.COMM_WORLD, 1) + @pytest.fixture def logger(): mesh = fire.UnitIntervalMesh(2) @@ -25,6 +26,7 @@ def logger(): logger.add_field("c", "3rd", lambda: u) return logger + def test_writing(logger): logger.start_logging(0) logger.log(0) @@ -33,15 +35,17 @@ def test_writing(logger): assert os.path.isfile("bsn0.pvd") assert not os.path.isfile("csn0.pvd") + def test_warning(logger): logger.start_logging(0) with pytest.warns(UserWarning): logger.start_logging(1) + def test_no_warning(logger): logger.start_logging(0) logger.stop_logging() # Assert that no warnings are emitted with warnings.catch_warnings(): warnings.simplefilter("error") - logger.start_logging(1) \ No newline at end of file + logger.start_logging(1) diff --git a/test/test_forward_examples.py b/test/test_forward_examples.py index 9a50d6bb..77e89062 100644 --- a/test/test_forward_examples.py +++ b/test/test_forward_examples.py @@ -49,14 +49,17 @@ def test_rectangle_forward(): assert all([test1, test2, test3]) + def test_camembert_elastic(): from spyro.examples.camembert_elastic import wave wave.forward_solve() + def test_elastic_cube_3D(): from spyro.examples.elastic_cube_3D import wave wave.forward_solve() + if __name__ == "__main__": test_camembert_forward() test_rectangle_forward() diff --git a/test/test_gradient_2d.py b/test/test_gradient_2d.py index 3b900894..003aa1cc 100644 --- a/test/test_gradient_2d.py +++ b/test/test_gradient_2d.py @@ -52,7 +52,7 @@ def check_gradient(Wave_obj_guess, dJ, rec_out_exact, Jm, plot=False): # Checking if every error is less than 1 percent - test1 =abs(errors[-1]) < 1 + test1 = abs(errors[-1]) < 1 print(f"Last gradient error less than 1 percent: {test1}") # Checking if error follows expected finite difference error convergence diff --git a/test/test_gradient_2d_pml.py b/test/test_gradient_2d_pml.py index 8af7e981..21416195 100644 --- a/test/test_gradient_2d_pml.py +++ b/test/test_gradient_2d_pml.py @@ -12,7 +12,7 @@ def __init__(self, Wave_obj=None): pass # Gatting necessary data from wave object - pad = Wave_obj.abc_pad_length + pad = Wave_obj.abc_pad_length # noqa: F841 z = Wave_obj.mesh_z x = Wave_obj.mesh_x V = Wave_obj.function_space @@ -78,7 +78,7 @@ def check_gradient(Wave_obj_guess, dJ, rec_out_exact, Jm, plot=False): # Checking if every error is less than 5 percent - test1 = (abs(errors[-1]) < 5 ) + test1 = (abs(errors[-1]) < 5) print(f"Gradient error less than 5 percent: {test1}") print(f"Error of {errors}") @@ -165,7 +165,7 @@ def get_forward_model(dictionary=None): conditional=cond, dg_velocity_model=False, ) - spyro.plots.plot_model(Wave_obj_exact, filename="pml_grad_test_model.png",abc_points=[(-0, 0), (-1, 0), (-1, 1), (-0, 1)]) + spyro.plots.plot_model(Wave_obj_exact, filename="pml_grad_test_model.png", abc_points=[(-0, 0), (-1, 0), (-1, 1), (-0, 1)]) Wave_obj_exact.forward_solve() rec_out_exact = Wave_obj_exact.receivers_output diff --git a/test/test_gradient_ad.py b/test/test_gradient_ad.py index 55b8b286..89698755 100644 --- a/test/test_gradient_ad.py +++ b/test/test_gradient_ad.py @@ -112,7 +112,7 @@ def test_taylor(): element = fire.FiniteElement( model["opts"]["method"], mesh.ufl_cell(), degree=model["opts"]["degree"], variant=model["opts"]["quadrature"] - ) + ) V = fire.FunctionSpace(mesh, element) fwd_solver = spyro.solvers.forward_ad.ForwardSolver(model, mesh, V) diff --git a/test/test_isotropic_wave.py b/test/test_isotropic_wave.py index 5899c049..fb48997a 100644 --- a/test/test_isotropic_wave.py +++ b/test/test_isotropic_wave.py @@ -26,14 +26,16 @@ }, } + def test_initialize_model_parameters_from_object_missing_parameters(): synthetic_dict = { "type": "object", } wave = IsotropicWave(dummy_dict) - with pytest.raises(Exception) as e: + with pytest.raises(Exception) as e: # noqa: F841 wave.initialize_model_parameters_from_object(synthetic_dict) + def test_initialize_model_parameters_from_object_first_option(): synthetic_dict = { "type": "object", @@ -44,6 +46,7 @@ def test_initialize_model_parameters_from_object_first_option(): wave = IsotropicWave(dummy_dict) wave.initialize_model_parameters_from_object(synthetic_dict) + def test_initialize_model_parameters_from_object_second_option(): synthetic_dict = { "type": "object", @@ -54,6 +57,7 @@ def test_initialize_model_parameters_from_object_second_option(): wave = IsotropicWave(dummy_dict) wave.initialize_model_parameters_from_object(synthetic_dict) + def test_initialize_model_parameters_from_object_redundant(): synthetic_dict = { "type": "object", @@ -64,9 +68,10 @@ def test_initialize_model_parameters_from_object_redundant(): "s_wave_velocity": 3, } wave = IsotropicWave(dummy_dict) - with pytest.raises(Exception) as e: + with pytest.raises(Exception) as e: # noqa: F841 wave.initialize_model_parameters_from_object(synthetic_dict) + def test_parse_boundary_conditions(): d = dummy_dict.copy() d["mesh"] = { @@ -77,7 +82,7 @@ def test_parse_boundary_conditions(): "mesh_type": "firedrake_mesh", } d["boundary_conditions"] = [ - ("u", 1, fire.Constant((1, 1, 1))), # x == 0: 1 (z in spyro) + ("u", 1, fire.Constant((1, 1, 1))), # x == 0: 1 (z in spyro) ("uz", 2, fire.Constant(2)), # x == Lx: 2 (z in spyro) ("ux", 3, fire.Constant(3)), # y == 0: 3 (x in spyro) ("uy", 4, fire.Constant(4)), # y == Ly: 4 (x in spyro) @@ -88,11 +93,12 @@ def test_parse_boundary_conditions(): u = fire.Function(wave.function_space) for bc in wave.bcs: bc.apply(u) - - assert np.allclose([1, 1, 1], u.at( 0.0, 0.5, 0.5)) - assert np.allclose([2, 0, 0], u.at(-1.0, 0.5, 0.5)) - assert np.allclose([0, 3, 0], u.at(-0.5, 0.0, 0.5)) - assert np.allclose([0, 0, 4], u.at(-0.5, 1.0, 0.5)) + + assert np.allclose([1, 1, 1], u.at(0.0, 0.5, 0.5)) + assert np.allclose([2, 0, 0], u.at(-1.0, 0.5, 0.5)) + assert np.allclose([0, 3, 0], u.at(-0.5, 0.0, 0.5)) + assert np.allclose([0, 0, 4], u.at(-0.5, 1.0, 0.5)) + def test_parse_boundary_conditions_exception(): d = dummy_dict.copy() @@ -108,13 +114,14 @@ def test_parse_boundary_conditions_exception(): ] wave = IsotropicWave(d) wave.set_mesh(mesh_parameters={"dx": 0.2, "periodic": True}) - with pytest.raises(Exception) as e: + with pytest.raises(Exception) as e: # noqa: F841 wave.parse_boundary_conditions() + def test_initialize_model_parameters_from_file_notimplemented(): synthetic_dict = { "type": "file", } wave = IsotropicWave(dummy_dict) - with pytest.raises(NotImplementedError) as e: - wave.initialize_model_parameters_from_file(synthetic_dict) \ No newline at end of file + with pytest.raises(NotImplementedError) as e: # noqa: F841 + wave.initialize_model_parameters_from_file(synthetic_dict) diff --git a/test/test_mask.py b/test/test_mask.py index 2b3c2eb1..e7ce519a 100644 --- a/test/test_mask.py +++ b/test/test_mask.py @@ -1,4 +1,3 @@ -import spyro from spyro.utils import Mask from spyro.utils import Gradient_mask_for_pml from spyro.examples.rectangle import Rectangle_acoustic @@ -36,10 +35,10 @@ def test_mask(): } Wave_obj = Rectangle_acoustic(dictionary=dictionary) boundaries = { - "z_min":-0.9, - "z_max":-0.1, - "x_min":0.2, - "x_max":0.8, + "z_min": -0.9, + "z_max": -0.1, + "x_min": 0.2, + "x_max": 0.8, } # Points we are going to check diff --git a/test/test_model_parameters.py b/test/test_model_parameters.py index 28715733..a2e44d65 100644 --- a/test/test_model_parameters.py +++ b/test/test_model_parameters.py @@ -349,47 +349,47 @@ def test_dictionary_conversion(): assert same -def test_degree_exception_2d(): +def test_degree_exception_2d(): # TODO: improve ex_dictionary = deepcopy(dictionary) with pytest.raises(Exception): ex_dictionary["options"]["dimension"] = 2 ex_dictionary["options"]["degree"] = 6 - model = Model_parameters(dictionary=ex_dictionary) + model = Model_parameters(dictionary=ex_dictionary) # noqa: F841 -def test_degree_exception_3d(): +def test_degree_exception_3d(): # TODO: improve ex_dictionary = deepcopy(dictionary) with pytest.raises(Exception): ex_dictionary["options"]["dimension"] = 3 ex_dictionary["options"]["degree"] = 5 - model = Model_parameters(dictionary=ex_dictionary) + model = Model_parameters(dictionary=ex_dictionary) # noqa: F841 -def test_time_exception(): +def test_time_exception(): # TODO: improve ex_dictionary = deepcopy(dictionary) with pytest.raises(Exception): ex_dictionary["time_axis"]["final_time"] = -0.5 - model = Model_parameters(dictionary=ex_dictionary) + model = Model_parameters(dictionary=ex_dictionary) # noqa: F841 -def test_source_exception(): +def test_source_exception(): # TODO: improve ex_dictionary = deepcopy(dictionary) with pytest.raises(Exception): ex_dictionary["acquistion"]["source_locations"] = [ (-0.1, 0.5), (1.0, 0.5), ] - model = Model_parameters(dictionary=ex_dictionary) + model = Model_parameters(dictionary=ex_dictionary) # noqa: F841 -def test_receiver_exception(): +def test_receiver_exception(): # TODO: improve ex_dictionary = deepcopy(dictionary) with pytest.raises(Exception): ex_dictionary["acquistion"]["receiver_locations"] = [ (-0.1, 0.5), (1.0, 0.5), ] - model = Model_parameters(dictionary=ex_dictionary) + model = Model_parameters(dictionary=ex_dictionary) # noqa: F841 if __name__ == "__main__": diff --git a/test/test_pml_2d.py b/test/test_pml_2d.py index 7ed511ef..b23632ab 100644 --- a/test/test_pml_2d.py +++ b/test/test_pml_2d.py @@ -1,5 +1,4 @@ import spyro -import matplotlib.pyplot as plt import numpy as np import time as timer import firedrake as fire diff --git a/test/test_sources.py b/test/test_sources.py index 236f201f..d48da68a 100644 --- a/test/test_sources.py +++ b/test/test_sources.py @@ -13,7 +13,7 @@ def test_ricker_varies_in_time(): and if the applied ricker function behaves correctly """ - ### initial ricker tests + # initial ricker tests modelRicker = deepcopy(oldmodel) frequency = 2 amplitude = 3