Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/issue_0107_compatibility' into i…
Browse files Browse the repository at this point in the history
…ssue_0112-add-habc-from-ruben-salas
  • Loading branch information
Olender committed Jun 28, 2024
2 parents c92cbc4 + 790ee6a commit 9a8998b
Show file tree
Hide file tree
Showing 10 changed files with 46 additions and 295 deletions.
2 changes: 1 addition & 1 deletion spyro/examples/rectangle.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
rectangle_dictionary = {}
rectangle_dictionary["options"] = {
# simplexes such as triangles or tetrahedra (T) or quadrilaterals (Q)
"cell_type": "Q",
"cell_type": "T",
"variant": "lumped", # lumped, equispaced or DG, default is lumped
"degree": 4, # p order
"dimension": 2, # dimension
Expand Down
7 changes: 6 additions & 1 deletion spyro/solvers/acoustic_solver_construction_no_pml.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ def construct_solver_or_matrix_no_pml(Wave_object):

u_nm1 = fire.Function(V, name="pressure t-dt")
u_n = fire.Function(V, name="pressure")
u_np1 = fire.Function(V, name="pressure t+dt")
Wave_object.u_nm1 = u_nm1
Wave_object.u_n = u_n
Wave_object.u_np1 = u_np1

Wave_object.current_time = 0.0
dt = Wave_object.dt
Expand All @@ -35,7 +37,7 @@ def construct_solver_or_matrix_no_pml(Wave_object):
)
a = dot(grad(u_n), grad(v)) * dx(scheme=quad_rule) # explicit

B = fire.Function(V)
B = fire.Cofunction(V.dual())

form = m1 + a
lhs = fire.lhs(form)
Expand All @@ -46,6 +48,9 @@ def construct_solver_or_matrix_no_pml(Wave_object):
Wave_object.solver = fire.LinearSolver(
A, solver_parameters=Wave_object.solver_parameters
)
# lin_var = fire.LinearVariationalProblem(lhs, rhs + B, u_np1)
# solver_parameters = {"mat_type": "matfree", "ksp_type": "preonly", "pc_type": "jacobi"}
# Wave_object.solver = fire.LinearVariationalSolver(lin_var,solver_parameters=solver_parameters)

Wave_object.rhs = rhs
Wave_object.B = B
2 changes: 1 addition & 1 deletion spyro/solvers/backward_time_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def backward_wave_propagator_no_pml(Wave_obj, dt=None):
u_n = Wave_obj.u_n
u_np1 = fire.Function(Wave_obj.function_space)

rhs_forcing = fire.Function(Wave_obj.function_space)
rhs_forcing = fire.Cofunction(Wave_obj.function_space.dual())

B = Wave_obj.B
rhs = Wave_obj.rhs
Expand Down
3 changes: 1 addition & 2 deletions spyro/solvers/mms_acoustic.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ def mms_source_in_space(self):
if self.dimension == 2:
# xy = fire.project(sin(pi*x)*sin(pi*y), V)
# self.q_xy.assign(xy)
xy = fire.project((-(x**2) - x - y**2 + y), V)
self.q_xy.assign(xy)
self.q_xy.interpolate(-(x**2) - x - y**2 + y)
elif self.dimension == 3:
z = self.mesh_y
# xyz = fire.project(sin(pi*x)*sin(pi*y)*sin(pi*z), V)
Expand Down
11 changes: 7 additions & 4 deletions spyro/solvers/time_integration_central_difference.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def central_difference(Wave_object, source_id=0):
u_n = Wave_object.u_n
u_np1 = fire.Function(Wave_object.function_space)

rhs_forcing = fire.Function(Wave_object.function_space)
rhs_forcing = fire.Cofunction(Wave_object.function_space.dual())
usol = [
fire.Function(Wave_object.function_space, name="pressure")
for t in range(nt)
Expand Down Expand Up @@ -328,9 +328,12 @@ def central_difference_MMS(Wave_object, source_id=0):

u_nm1 = Wave_object.u_nm1
u_n = Wave_object.u_n
u_np1 = Wave_object.u_np1
u_nm1.assign(Wave_object.analytical_solution(t - 2 * dt))
u_n.assign(Wave_object.analytical_solution(t - dt))
u_np1 = fire.Function(Wave_object.function_space, name="pressure t +dt")
# u_np1 = fire.Function(Wave_object.function_space, name="pressure t +dt")
# u_nm1.dat.data[:] = np.load("old_u_nm1.npy")
# u_n.dat.data[:] = np.load("old_u_n.npy")
u = fire.TrialFunction(Wave_object.function_space)
v = fire.TestFunction(Wave_object.function_space)

Expand Down Expand Up @@ -364,9 +367,9 @@ def central_difference_MMS(Wave_object, source_id=0):

B = fire.assemble(rhs, tensor=B)

Wave_object.solver.solve(X, B)
Wave_object.solver.solve(u_np1, B)

u_np1.assign(X)
# u_np1.assign(X)

usol_recv.append(
Wave_object.receivers.interpolate(u_np1.dat.data_ro_with_halos[:])
Expand Down
51 changes: 24 additions & 27 deletions test/test_MMS.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,10 @@
import spyro
import time

from .model import dictionary as model

from model import dictionary as model
model["acquisition"]["source_type"] = "MMS"


@pytest.fixture(params=["triangle", "square"])
def mesh_type(request):
if mesh_type == "triangle":
model["cell_type"] = "triangles"
elif mesh_type == "square":
model["cell_type"] = "quadrilaterals"
return request.param


@pytest.fixture(params=["lumped", "equispaced"])
def method_type(request):
if method_type == "lumped":
model["variant"] = "lumped"
elif method_type == "equispaced":
model["variant"] = "equispaced"
return request.param


def run_solve(model):
testmodel = deepcopy(model)

Expand All @@ -42,12 +23,28 @@ def run_solve(model):
return errornorm(u_num, u_an)


def test_method(mesh_type, method_type):
def run_method(mesh_type, method_type):
model["options"]["cell_type"] = mesh_type
model["options"]["variant"] = method_type
print(f"For {mesh_type} and {method_type}")
error = run_solve(model)
print(error)
print(mesh_type)
print(method_type)
print(version.__version__)
time.sleep(10)
test = math.isclose(error, 0.0, abs_tol=1e-7)
print(f"Error is {error}")
print(f"Test: {test}")

assert test


def test_method_triangles_lumped():
run_method("triangles", "lumped")


def test_method_quads_lumped():
run_method("quadrilaterals", "lumped")


if __name__ == "__main__":
test_method_triangles_lumped()
test_method_quads_lumped()

assert math.isclose(error, 0.0, abs_tol=1e-7)
print("END")
7 changes: 4 additions & 3 deletions test/test_cpw_calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_cpw_calc():
"velocity_model_file_name": None,
# FEM to evaluate such as `KMV` or `spectral`
# (GLL nodes on quads and hexas)
"FEM_method_to_evaluate": "spectral_quadrilateral",
"FEM_method_to_evaluate": "mass_lumped_triangle",
"dimension": 2, # Domain dimension. Either 2 or 3.
# Either near or line. Near defines a receiver grid near to the source,
"receiver_setup": "near",
Expand All @@ -32,7 +32,7 @@ def test_cpw_calc():
# grid point density to use in the reference case (float)
"C_reference": None,
"desired_degree": 4, # degree we are calculating G for. (int)
"C_initial": 2.4, # Initial G for line search (float)
"C_initial": 2.2, # Initial G for line search (float)
"accepted_error_threshold": 0.05,
"C_accuracy": 0.1,
}
Expand Down Expand Up @@ -60,7 +60,8 @@ def test_cpw_calc():

# Check if cpw is within error TOL, starting search at min
min = Cpw_calc.find_minimum()
test3 = np.isclose(2.5, min)
print(f"Minimum of {min}")
test3 = np.isclose(2.3, min)

print("END")
assert all([test1, test2, test3])
Expand Down
Loading

0 comments on commit 9a8998b

Please sign in to comment.