Skip to content

Commit

Permalink
Foward solver working with the source built with vom.
Browse files Browse the repository at this point in the history
  • Loading branch information
Ig-dolci committed Dec 20, 2023
1 parent ca69e5a commit ef19dd0
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 89 deletions.
14 changes: 3 additions & 11 deletions demos/with_automatic_differentiation/forward_circle_vp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from firedrake import *
from scipy.optimize import *
import spyro
import matplotlib.pyplot as plt
import numpy as np
Expand All @@ -17,7 +16,7 @@
}

model["parallelism"] = {
"type": "automatic", # options: automatic (same number of cores for evey processor) or spatial
"type": "spatial", # options: automatic (same number of cores for evey processor) or spatial
}

# Define the domain size without the ABL.
Expand All @@ -42,7 +41,7 @@

model["acquisition"] = {
"source_type": "Ricker",
"source_pos": [(0.7, 0.7)],
"source_pos": spyro.create_transect((0.2, 0.2), (0.2, 0.8), 1),
"frequency": 10.0,
"delay": 1.0,
"receiver_locations": spyro.create_transect((0.9, 0.2), (0.9, 0.8), 10),
Expand Down Expand Up @@ -71,17 +70,10 @@
z, x = SpatialCoordinate(mesh)

vp_exact = Function(V).interpolate(1.0 + 0.0 * x)
source_position = model["acquisition"]["source_pos"]
wavelet = spyro.full_ricker_wavelet(
dt=model["timeaxis"]["dt"],
tf=model["timeaxis"]["tf"],
freq=model["acquisition"]["frequency"],
)

source = spyro.Sources(model, mesh, V, comm)
f = source.apply_source_based_in_vom(max(wavelet))
outfile = File("output.pvd")
outfile.write(f)


# spyro.solvers.forward_AD(model, model, mesh, comm, vp_exact, excitations, wavelet, receivers, source_num=0)
spyro.solvers.forward_AD(model, mesh, comm, vp_exact, wavelet, debug=True)
141 changes: 65 additions & 76 deletions spyro/solvers/forward_AD.py
Original file line number Diff line number Diff line change
@@ -1,71 +1,49 @@
from firedrake import *

# from .. import utils
from ..domains import quadrature, space

# from ..pml import damping
# from ..io import ensemble_forward
from . import helpers
from ..sources.Sources import Sources

# Note this turns off non-fatal warnings
set_log_level(ERROR)


# @ensemble_forward
def forward(
model,
mesh,
comm,
c,
excitations,
wavelet,
receivers,
source_num=0,
output=False,
**kwargs
):
"""Secord-order in time fully-explicit scheme
with implementation of a Perfectly Matched Layer (PML) using
CG FEM with or without higher order mass lumping (KMV type elements).
def forward(model, mesh, comm, c, wavelet, source_num=0, fwi=False, **kwargs):
"""Secord-order in time fully-explicit scheme.
Parameters
----------
model: Python `dictionary`
Contains model options and parameters
mesh: Firedrake.mesh object
model: dict
Contains model options and parameters.
mesh: firedrake.mesh
The 2D/3D triangular mesh
comm: Firedrake.ensemble_communicator
comm: firedrake.ensemble_communicator
The MPI communicator for parallelism
c: Firedrake.Function
c: firedrake.Function
The velocity model interpolated onto the mesh.
excitations: A list Firedrake.Functions
wavelet: array-like
Time series data that's injected at the source location.
receivers: A :class:`spyro.Receivers` object.
Contains the receiver locations and sparse interpolation methods.
source_num: `int`, optional
The source number you wish to simulate
output: `boolean`, optional
Whether or not to write results to pvd files.
fwi: `bool`, optional
Whether this forward simulation is for FWI or not.
Returns
-------
usol: list of Firedrake.Functions
The full field solution at `fspool` timesteps
usol_recv: array-like
The solution interpolated to the receivers at all timesteps
usol_recv: list
The receiver data.
J: float
The functional for FWI. Only returned if `fwi=True`.
"""

method = model["opts"]["method"]
degree = model["opts"]["degree"]
dim = model["opts"]["dimension"]
dt = model["timeaxis"]["dt"]
tf = model["timeaxis"]["tf"]
nspool = model["timeaxis"]["nspool"]
receiver_points = model["acquisition"]["receiver_locations"]
nt = int(tf / dt) # number of timesteps
excitations.current_source = source_num
params = set_params(method)
params = set_params(method, mesh)
element = space.FE_method(mesh, method, degree)

V = FunctionSpace(mesh, element)
Expand All @@ -84,10 +62,6 @@ def forward(
u_n = Function(V)
u_np1 = Function(V)

if output:
outfile = helpers.create_output_file("forward.pvd", comm, source_num)

t = 0.0
m = 1 / (c * c)
m1 = m * ((u - 2.0 * u_n + u_nm1) / Constant(dt**2)) * v * dx(scheme=qr_x)
a = dot(grad(u_n), grad(v)) * dx(scheme=qr_x) # explicit
Expand All @@ -97,8 +71,7 @@ def forward(
if model["BCs"]["outer_bc"] == "non-reflective":
nf = c * ((u_n - u_nm1) / dt) * v * ds(scheme=qr_s)

h = CellSize(mesh)
FF = m1 + a + nf - (1 / (h / degree * h / degree)) * f * v * dx(scheme=qr_x)
FF = m1 + a + nf - f * v * dx(scheme=qr_x)
X = Function(V)

lhs_ = lhs(FF)
Expand All @@ -107,58 +80,74 @@ def forward(
problem = LinearVariationalProblem(lhs_, rhs_, X)
solver = LinearVariationalSolver(problem, solver_parameters=params)

# This part of the code is the base for receivers and sources.
# definition
usol_recv = []

P = FunctionSpace(receivers, "DG", 0)
interpolator = Interpolator(u_np1, P)
J0 = 0.0

# Source object.
source = Sources(model, mesh, V, comm)
# Receiver mesh.
vom = VertexOnlyMesh(mesh, receiver_points)
# P0DG is the only function space you can make on a vertex-only mesh.
P0DG = FunctionSpace(vom, "DG", 0)
interpolator = Interpolator(u_np1, P0DG)
if fwi:
# Get the true receiver data.
# In FWI, we need to calculate the objective function,
# which requires the true receiver data.
true_receivers = kwargs.get("true_receiver")
# cost function
J = 0.0
for step in range(nt):

excitations.apply_source(f, wavelet[step])

f.assign(source.apply_source_based_in_vom(wavelet[step], source_num))
solver.solve()
u_np1.assign(X)

rec = Function(P)
interpolator.interpolate(output=rec)

fwi = kwargs.get("fwi")
p_true_rec = kwargs.get("true_rec")

usol_recv.append(rec.dat.data)

# receiver function
receivers = Function(P0DG)
interpolator.interpolate(output=receivers)
usol_recv.append(receivers)
if fwi:
J0 += calc_objective_func(rec, p_true_rec[step], step, dt, P)

J += compute_functional(receivers, true_receivers[step])
if step % nspool == 0:
assert (
norm(u_n) < 1
), "Numerical instability. Try reducing dt or building the mesh differently"
if output:
outfile.write(u_n, time=t, name="Pressure")
if t > 0:
helpers.display_progress(comm, t)

if float(step*dt) > 0:
helpers.display_progress(comm, float(step*dt))
u_nm1.assign(u_n)
u_n.assign(u_np1)

t = step * float(dt)
debug = kwargs.get("debug")
if debug:
# Save the solution for debugging.
outfile = File("output.pvd")
outfile.write(u_n)

if fwi:
return usol_recv, J0
return usol_recv, J
else:
return usol_recv


def calc_objective_func(p_rec, p_true_rec, IT, dt, P):
true_rec = Function(P)
true_rec.dat.data[:] = p_true_rec
J = 0.5 * assemble(inner(true_rec - p_rec, true_rec - p_rec) * dx)
def compute_functional(guess_receivers, true_receivers):
"""Compute the functional for FWI.
Parameters
----------
guess_receivers : firedrake.Function
The receivers from the forward simulation.
true_receivers : firedrake.Function
Supposed to be the receivers data from the true model.
Returns
-------
J : float
The functional.
"""
misfit = guess_receivers - true_receivers
J = 0.5 * assemble(inner(misfit, misfit) * dx)
return J


def set_params(method):
def set_params(method, mesh):
if method == "KMV":
params = {"ksp_type": "preonly", "pc_type": "jacobi"}
elif (
Expand Down
4 changes: 2 additions & 2 deletions spyro/sources/Sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def apply_source(self, rhs_forcing, value):

return rhs_forcing

def apply_source_based_in_vom(self, wavelet):
def apply_source_based_in_vom(self, wavelet, source_number):
"""Applie source using VertexOnlyMesh (VOM).
Parameters
Expand All @@ -124,7 +124,7 @@ def apply_source_based_in_vom(self, wavelet):
The forcing function that models the wavelet source in the wave
equation.
"""
vom = VertexOnlyMesh(self.mesh, self.receiver_locations,
vom = VertexOnlyMesh(self.mesh, [self.receiver_locations[source_number]],
redundant=False)
f_vom = FunctionSpace(vom, "DG", 0)
f_vom_input_ordering = FunctionSpace(vom.input_ordering, "DG", 0)
Expand Down

0 comments on commit ef19dd0

Please sign in to comment.