Skip to content

Commit

Permalink
Add more ravels
Browse files Browse the repository at this point in the history
  • Loading branch information
connorjward committed Nov 28, 2024
1 parent 8b067f1 commit 1e12f53
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 19 deletions.
14 changes: 7 additions & 7 deletions firedrake/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -1966,7 +1966,7 @@ def _renumber_entities(self, reorder):
if reorder:
swarm = self.topology_dm
parent = self._parent_mesh.topology_dm
swarm_parent_cell_nums = swarm.getField("DMSwarm_cellid")
swarm_parent_cell_nums = swarm.getField("DMSwarm_cellid").ravel()
parent_renum = self._parent_mesh._dm_renumbering.getIndices()
pStart, _ = parent.getChart()
parent_renum_inv = np.empty_like(parent_renum)
Expand Down Expand Up @@ -2063,7 +2063,7 @@ def cell_parent_cell_list(self):
"""Return a list of parent mesh cells numbers in vertex only
mesh cell order.
"""
cell_parent_cell_list = np.copy(self.topology_dm.getField("parentcellnum"))
cell_parent_cell_list = np.copy(self.topology_dm.getField("parentcellnum").ravel())
self.topology_dm.restoreField("parentcellnum")
return cell_parent_cell_list[self.cell_closure[:, -1]]

Expand All @@ -2082,7 +2082,7 @@ def cell_parent_base_cell_list(self):
"""
if not isinstance(self._parent_mesh, ExtrudedMeshTopology):
raise AttributeError("Parent mesh is not extruded")
cell_parent_base_cell_list = np.copy(self.topology_dm.getField("parentcellbasenum"))
cell_parent_base_cell_list = np.copy(self.topology_dm.getField("parentcellbasenum").ravel())
self.topology_dm.restoreField("parentcellbasenum")
return cell_parent_base_cell_list[self.cell_closure[:, -1]]

Expand All @@ -2103,7 +2103,7 @@ def cell_parent_extrusion_height_list(self):
"""
if not isinstance(self._parent_mesh, ExtrudedMeshTopology):
raise AttributeError("Parent mesh is not extruded.")
cell_parent_extrusion_height_list = np.copy(self.topology_dm.getField("parentcellextrusionheight"))
cell_parent_extrusion_height_list = np.copy(self.topology_dm.getField("parentcellextrusionheight").ravel())
self.topology_dm.restoreField("parentcellextrusionheight")
return cell_parent_extrusion_height_list[self.cell_closure[:, -1]]

Expand All @@ -2123,7 +2123,7 @@ def mark_entities(self, tf, label_value, label_name=None):
@utils.cached_property # TODO: Recalculate if mesh moves
def cell_global_index(self):
"""Return a list of unique cell IDs in vertex only mesh cell order."""
cell_global_index = np.copy(self.topology_dm.getField("globalindex"))
cell_global_index = np.copy(self.topology_dm.getField("globalindex").ravel())
self.topology_dm.restoreField("globalindex")
return cell_global_index

Expand Down Expand Up @@ -3895,8 +3895,8 @@ def _dmswarm_create(
swarm.restoreField("DMSwarm_cellid")

if extruded:
field_base_parent_cell_nums = swarm.getField("parentcellbasenum")
field_extrusion_heights = swarm.getField("parentcellextrusionheight")
field_base_parent_cell_nums = swarm.getField("parentcellbasenum").ravel()
field_extrusion_heights = swarm.getField("parentcellextrusionheight").ravel()
field_base_parent_cell_nums[...] = base_parent_cell_nums
field_extrusion_heights[...] = extrusion_heights
swarm.restoreField("parentcellbasenum")
Expand Down
18 changes: 9 additions & 9 deletions tests/firedrake/vertexonly/test_swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def test_pic_swarm_in_mesh(parentmesh, redundant, exclude_halos):
exclude_halos = True

# Get point coords on current MPI rank
localpointcoords = np.copy(swarm.getField("DMSwarmPIC_coor"))
localpointcoords = np.copy(swarm.getField("DMSwarmPIC_coor").ravel())
swarm.restoreField("DMSwarmPIC_coor")
if len(inputpointcoords.shape) > 1:
localpointcoords = np.reshape(localpointcoords, (-1, inputpointcoords.shape[1]))
Expand All @@ -218,11 +218,11 @@ def test_pic_swarm_in_mesh(parentmesh, redundant, exclude_halos):
nptslocal = len(localpointcoords)
nptsglobal = MPI.COMM_WORLD.allreduce(nptslocal, op=MPI.SUM)
# Get parent PETSc cell indices on current MPI rank
localparentcellindices = np.copy(swarm.getField("DMSwarm_cellid"))
localparentcellindices = np.copy(swarm.getField("DMSwarm_cellid").ravel())
swarm.restoreField("DMSwarm_cellid")

# also get the global coordinate numbering
globalindices = np.copy(swarm.getField("globalindex"))
globalindices = np.copy(swarm.getField("globalindex").ravel())
swarm.restoreField("globalindex")

# Tests
Expand All @@ -233,7 +233,7 @@ def test_pic_swarm_in_mesh(parentmesh, redundant, exclude_halos):

# get custom fields on swarm - will fail if didn't get created
for name, size, dtype in other_fields:
f = swarm.getField(name)
f = swarm.getField(name).ravel()
assert len(f) == size*nptslocal
assert f.dtype == dtype
swarm.restoreField(name)
Expand Down Expand Up @@ -330,7 +330,7 @@ def test_pic_swarm_in_mesh(parentmesh, redundant, exclude_halos):
# Check that the rank numbering is correct. Since we know all points are at
# the midpoints of cells, there should be no disagreement about cell
# ownership and the voting algorithm should have no effect.
owned_ranks = np.copy(swarm.getField("DMSwarm_rank"))
owned_ranks = np.copy(swarm.getField("DMSwarm_rank").ravel())
swarm.restoreField("DMSwarm_rank")
if exclude_halos:
assert np.array_equal(owned_ranks, inputlocalpointcoordranks)
Expand All @@ -339,7 +339,7 @@ def test_pic_swarm_in_mesh(parentmesh, redundant, exclude_halos):
assert np.all(np.isin(inputlocalpointcoordranks, owned_ranks))

# check that the input rank is correct
input_ranks = np.copy(swarm.getField("inputrank"))
input_ranks = np.copy(swarm.getField("inputrank").ravel())
swarm.restoreField("inputrank")
if exclude_halos:
assert np.all(input_ranks == input_rank)
Expand All @@ -351,7 +351,7 @@ def test_pic_swarm_in_mesh(parentmesh, redundant, exclude_halos):
assert np.all(input_ranks < parentmesh.comm.size)

# check that the input index is correct
input_indices = np.copy(swarm.getField("inputindex"))
input_indices = np.copy(swarm.getField("inputindex").ravel())
swarm.restoreField("inputindex")
if exclude_halos:
assert np.array_equal(input_indices, input_local_coord_indices)
Expand All @@ -365,7 +365,7 @@ def test_pic_swarm_in_mesh(parentmesh, redundant, exclude_halos):

# check we have unique parent cell numbers, which we should since we have
# points at cell midpoints
parentcellnums = np.copy(swarm.getField("parentcellnum"))
parentcellnums = np.copy(swarm.getField("parentcellnum").ravel())
swarm.restoreField("parentcellnum")
assert len(np.unique(parentcellnums)) == len(parentcellnums)

Expand All @@ -378,7 +378,7 @@ def test_pic_swarm_in_mesh(parentmesh, redundant, exclude_halos):
):
swarm.setPointCoordinates(localpointcoords, redundant=False,
mode=PETSc.InsertMode.INSERT_VALUES)
petsclocalparentcellindices = np.copy(swarm.getField("DMSwarm_cellid"))
petsclocalparentcellindices = np.copy(swarm.getField("DMSwarm_cellid").ravel())
swarm.restoreField("DMSwarm_cellid")
if exclude_halos:
assert np.all(petsclocalparentcellindices == localparentcellindices)
Expand Down
4 changes: 2 additions & 2 deletions tests/firedrake/vertexonly/test_vertex_only_fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def functionspace_tests(vm):
h.dat.data_wo_with_halos[:] = -1
h.interpolate(g)
# Exclude points which we know are missing - these should all be equal to -1
input_ordering_parent_cell_nums = vm.input_ordering.topology_dm.getField("parentcellnum")
input_ordering_parent_cell_nums = vm.input_ordering.topology_dm.getField("parentcellnum").ravel()
vm.input_ordering.topology_dm.restoreField("parentcellnum")
idxs_to_include = input_ordering_parent_cell_nums != -1
assert np.allclose(h.dat.data_ro_with_halos[idxs_to_include], np.prod(vm.input_ordering.coordinates.dat.data_ro_with_halos[idxs_to_include].reshape(-1, vm.input_ordering.geometric_dimension()), axis=1))
Expand Down Expand Up @@ -221,7 +221,7 @@ def vectorfunctionspace_tests(vm):
h.dat.data_wo_with_halos[:] = -1
h.interpolate(g)
# Exclude points which we know are missing - these should all be equal to -1
input_ordering_parent_cell_nums = vm.input_ordering.topology_dm.getField("parentcellnum")
input_ordering_parent_cell_nums = vm.input_ordering.topology_dm.getField("parentcellnum").ravel()
vm.input_ordering.topology_dm.restoreField("parentcellnum")
idxs_to_include = input_ordering_parent_cell_nums != -1
assert np.allclose(h.dat.data_ro[idxs_to_include], 2*vm.input_ordering.coordinates.dat.data_ro_with_halos[idxs_to_include])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def verify_vertexonly_mesh(m, vm, inputvertexcoords, name):
# Correct parent cell numbers
stored_vertex_coords = np.copy(vm.topology_dm.getField("DMSwarmPIC_coor")).reshape((vm.num_cells(), gdim))
vm.topology_dm.restoreField("DMSwarmPIC_coor")
stored_parent_cell_nums = np.copy(vm.topology_dm.getField("parentcellnum"))
stored_parent_cell_nums = np.copy(vm.topology_dm.getField("parentcellnum").ravel())
vm.topology_dm.restoreField("parentcellnum")
assert len(stored_vertex_coords) == len(stored_parent_cell_nums)
if MPI.COMM_WORLD.size == 1:
Expand Down

0 comments on commit 1e12f53

Please sign in to comment.