Skip to content

Commit

Permalink
Enable vtk write cofunction (#3892)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ig-dolci authored Nov 29, 2024
1 parent 86c4c44 commit 081789d
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 11 deletions.
17 changes: 12 additions & 5 deletions firedrake/output/vtk_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ def __init__(self, filename, project_output=False, comm=None, mode="w",
@no_annotations
def _prepare_output(self, function, max_elem):
from firedrake import FunctionSpace, VectorFunctionSpace, \
TensorFunctionSpace, Function
TensorFunctionSpace, Function, Cofunction

name = function.name()
# Need to project/interpolate?
Expand All @@ -477,16 +477,23 @@ def _prepare_output(self, function, max_elem):
shape=shape)
else:
raise ValueError("Unsupported shape %s" % (shape, ))
output = Function(V)
if isinstance(function, Function):
output = Function(V)
else:
assert isinstance(function, Cofunction)
output = Function(V.dual())

if self.project:
if isinstance(function, Cofunction):
raise ValueError("Can not project Cofunctions")
output.project(function)
else:
output.interpolate(function)

return OFunction(array=get_array(output), name=name, function=output)

def _write_vtu(self, *functions):
from firedrake.function import Function
from firedrake import Function, Cofunction

# Check if the user has requested to write out a plain mesh
if len(functions) == 1 and isinstance(functions[0], ufl.Mesh):
Expand All @@ -496,8 +503,8 @@ def _write_vtu(self, *functions):
functions = [Function(V)]

for f in functions:
if not isinstance(f, Function):
raise ValueError("Can only output Functions or a single mesh, not %r" % type(f))
if not isinstance(f, (Function, Cofunction)):
raise ValueError(f"Can only output Functions, Cofunctions or a single mesh, not {type(f).__name__}")
meshes = tuple(extract_unique_domain(f) for f in functions)
if not all(m == meshes[0] for m in meshes):
raise ValueError("All functions must be on same mesh")
Expand Down
22 changes: 16 additions & 6 deletions tests/firedrake/output/test_pvd_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,16 @@ def test_bad_file_name(tmpdir):
VTKFile(str(tmpdir.join("foo.vtu")))


def test_different_functions(mesh, pvd):
@pytest.mark.parametrize("space",
["primal", "dual"])
def test_different_functions(mesh, pvd, space):
V = FunctionSpace(mesh, "DG", 0)

f = Function(V, name="foo")
g = Function(V, name="bar")
if space == "primal":
f = Function(V, name="foo")
g = Function(V, name="bar")
else:
f = Cofunction(V.dual(), name="foo")
g = Cofunction(V.dual(), name="bar")

pvd.write(f)

Expand Down Expand Up @@ -136,9 +141,14 @@ def test_not_function(mesh, pvd):
pvd.write(grad(f))


def test_append(mesh, tmpdir):
@pytest.mark.parametrize("space",
["primal", "dual"])
def test_append(mesh, tmpdir, space):
V = FunctionSpace(mesh, "DG", 0)
g = Function(V)
if space == "primal":
g = Function(V)
else:
g = Cofunction(V.dual())

outfile = VTKFile(str(tmpdir.join("restart_test.pvd")))
outfile.write(g)
Expand Down

0 comments on commit 081789d

Please sign in to comment.