From 081789dccc75ca780b0df952b59329bcde6319d6 Mon Sep 17 00:00:00 2001 From: Daiane Iglesia Dolci <63597005+Ig-dolci@users.noreply.github.com> Date: Fri, 29 Nov 2024 10:07:28 +0000 Subject: [PATCH] Enable vtk write cofunction (#3892) --- firedrake/output/vtk_output.py | 17 ++++++++++++----- tests/firedrake/output/test_pvd_output.py | 22 ++++++++++++++++------ 2 files changed, 28 insertions(+), 11 deletions(-) diff --git a/firedrake/output/vtk_output.py b/firedrake/output/vtk_output.py index 2dc7ab0235..23072c5019 100644 --- a/firedrake/output/vtk_output.py +++ b/firedrake/output/vtk_output.py @@ -452,7 +452,7 @@ def __init__(self, filename, project_output=False, comm=None, mode="w", @no_annotations def _prepare_output(self, function, max_elem): from firedrake import FunctionSpace, VectorFunctionSpace, \ - TensorFunctionSpace, Function + TensorFunctionSpace, Function, Cofunction name = function.name() # Need to project/interpolate? @@ -477,8 +477,15 @@ def _prepare_output(self, function, max_elem): shape=shape) else: raise ValueError("Unsupported shape %s" % (shape, )) - output = Function(V) + if isinstance(function, Function): + output = Function(V) + else: + assert isinstance(function, Cofunction) + output = Function(V.dual()) + if self.project: + if isinstance(function, Cofunction): + raise ValueError("Can not project Cofunctions") output.project(function) else: output.interpolate(function) @@ -486,7 +493,7 @@ def _prepare_output(self, function, max_elem): return OFunction(array=get_array(output), name=name, function=output) def _write_vtu(self, *functions): - from firedrake.function import Function + from firedrake import Function, Cofunction # Check if the user has requested to write out a plain mesh if len(functions) == 1 and isinstance(functions[0], ufl.Mesh): @@ -496,8 +503,8 @@ def _write_vtu(self, *functions): functions = [Function(V)] for f in functions: - if not isinstance(f, Function): - raise ValueError("Can only output Functions or a single mesh, not %r" % type(f)) + if not isinstance(f, (Function, Cofunction)): + raise ValueError(f"Can only output Functions, Cofunctions or a single mesh, not {type(f).__name__}") meshes = tuple(extract_unique_domain(f) for f in functions) if not all(m == meshes[0] for m in meshes): raise ValueError("All functions must be on same mesh") diff --git a/tests/firedrake/output/test_pvd_output.py b/tests/firedrake/output/test_pvd_output.py index 6f17bcbe48..69c7243b7f 100644 --- a/tests/firedrake/output/test_pvd_output.py +++ b/tests/firedrake/output/test_pvd_output.py @@ -79,11 +79,16 @@ def test_bad_file_name(tmpdir): VTKFile(str(tmpdir.join("foo.vtu"))) -def test_different_functions(mesh, pvd): +@pytest.mark.parametrize("space", + ["primal", "dual"]) +def test_different_functions(mesh, pvd, space): V = FunctionSpace(mesh, "DG", 0) - - f = Function(V, name="foo") - g = Function(V, name="bar") + if space == "primal": + f = Function(V, name="foo") + g = Function(V, name="bar") + else: + f = Cofunction(V.dual(), name="foo") + g = Cofunction(V.dual(), name="bar") pvd.write(f) @@ -136,9 +141,14 @@ def test_not_function(mesh, pvd): pvd.write(grad(f)) -def test_append(mesh, tmpdir): +@pytest.mark.parametrize("space", + ["primal", "dual"]) +def test_append(mesh, tmpdir, space): V = FunctionSpace(mesh, "DG", 0) - g = Function(V) + if space == "primal": + g = Function(V) + else: + g = Cofunction(V.dual()) outfile = VTKFile(str(tmpdir.join("restart_test.pvd"))) outfile.write(g)