Skip to content

Commit

Permalink
Vtk exporter fix to save more than one product (#1439)
Browse files Browse the repository at this point in the history
Co-authored-by: Sylwester Arabas <[email protected]>
  • Loading branch information
olastrz and slayoo authored Dec 8, 2024
1 parent 5641771 commit 2a4d2d5
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 1 deletion.
2 changes: 1 addition & 1 deletion PySDM/exporters/vtk_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def export_products(self, particulator):

if isinstance(v, np.ndarray):
if v.shape == particulator.mesh.grid:
payload[k] = v[:, :, np.newaxis]
payload[k] = v[:, :, np.newaxis].copy()
else:
if self.verbose:
print(
Expand Down
Empty file.
57 changes: 57 additions & 0 deletions tests/unit_tests/exporters/test_vtk_exporter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
""" checks for VTK exporter """

from collections import namedtuple

import numpy as np

from PySDM.exporters import VTKExporter


def test_vtk_exporter_copies_product_data(tmp_path):
"""note: since VTK files contain unencoded binary data, we cannot use XML parsers;
not to introduce a new dependency to PySDM, we read the binary data with NumPy"""
# arrange
productc_filename = tmp_path / "prod"
sut = VTKExporter(products_filename=productc_filename)

grid = (1, 1)
arr = np.zeros(shape=grid, dtype=float)

incr = 666

def plusplus(arr):
arr += incr
return arr

prod = namedtuple(typename="MockProductA", field_names=("get",))(
get=lambda: plusplus(arr)
)

particulator = namedtuple(
typename="MockParticulator", field_names=("products", "n_steps", "dt", "mesh")
)(
n_steps=1,
products={
"a": prod,
"b": prod,
},
dt=0,
mesh=namedtuple(typename="MockMesh", field_names=("dimension", "grid", "size"))(
dimension=2,
grid=grid,
size=(1, 1),
),
)

# act
sut.export_products(particulator)

# assert
offsets = (113, 129)
with open(str(productc_filename) + "_num0000000001.vts", mode="rb") as vtk:
binary_data = vtk.readlines()[14]
for i, off in enumerate(offsets):
assert (
np.frombuffer(binary_data[off : off + 8], dtype=np.float64)
== (i + 1) * incr
)

0 comments on commit 2a4d2d5

Please sign in to comment.