Skip to content

Commit

Permalink
Add tests to verify performance of SmartFields (#1217)
Browse files Browse the repository at this point in the history
* Add tests to verify performance of SmartFields

* Switch template function
  • Loading branch information
psakievich authored Sep 28, 2023
1 parent 4347e2d commit 90b42ca
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 43 deletions.
80 changes: 37 additions & 43 deletions include/SmartField.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,10 @@ class SmartField<FieldType, tags::LEGACY, ACCESS>

private:
static constexpr bool is_read_{
std::is_same<ACCESS, READ>::value ||
std::is_same<ACCESS, READ_WRITE>::value};
std::is_same_v<ACCESS, READ> || std::is_same_v<ACCESS, READ_WRITE>};

static constexpr bool is_write_{
std::is_same<ACCESS, WRITE_ALL>::value ||
std::is_same<ACCESS, READ_WRITE>::value};
std::is_same_v<ACCESS, WRITE_ALL> || std::is_same_v<ACCESS, READ_WRITE>};

FieldType& stkField_;

Expand All @@ -99,29 +97,29 @@ class SmartField<FieldType, tags::LEGACY, ACCESS>

// --- Default Accessors
template <typename A = ACCESS>
inline typename std::enable_if_t<!std::is_same<A, READ>::value, T>*
inline typename std::enable_if_t<!std::is_same_v<A, READ>, T>*
get(const stk::mesh::Entity& entity) const
{
return stk::mesh::field_data(stkField_, entity);
}

template <typename A = ACCESS>
inline typename std::enable_if_t<!std::is_same<A, READ>::value, T>*
inline typename std::enable_if_t<!std::is_same_v<A, READ>, T>*
operator()(const stk::mesh::Entity& entity) const
{
return stk::mesh::field_data(stkField_, entity);
}

// --- Const Accessors
template <typename A = ACCESS>
inline const typename std::enable_if_t<std::is_same<A, READ>::value, T>*
inline const typename std::enable_if_t<std::is_same_v<A, READ>, T>*
get(const stk::mesh::Entity& entity) const
{
return stk::mesh::field_data(stkField_, entity);
}

template <typename A = ACCESS>
inline const typename std::enable_if_t<std::is_same<A, READ>::value, T>*
inline const typename std::enable_if_t<std::is_same_v<A, READ>, T>*
operator()(const stk::mesh::Entity& entity) const
{
return stk::mesh::field_data(stkField_, entity);
Expand Down Expand Up @@ -149,12 +147,10 @@ class SmartField<FieldType, tags::DEVICE, ACCESS>
{
private:
static constexpr bool is_read_{
std::is_same<ACCESS, READ>::value ||
std::is_same<ACCESS, READ_WRITE>::value};
std::is_same_v<ACCESS, READ> || std::is_same_v<ACCESS, READ_WRITE>};

static constexpr bool is_write_{
std::is_same<ACCESS, WRITE_ALL>::value ||
std::is_same<ACCESS, READ_WRITE>::value};
std::is_same_v<ACCESS, WRITE_ALL> || std::is_same_v<ACCESS, READ_WRITE>};

FieldType stkField_;
const bool is_copy_constructed_{false};
Expand Down Expand Up @@ -190,81 +186,81 @@ class SmartField<FieldType, tags::DEVICE, ACCESS>
//************************************************************
// Device functions
//************************************************************
KOKKOS_INLINE_FUNCTION
KOKKOS_FORCEINLINE_FUNCTION
unsigned get_ordinal() const { return stkField_.get_ordinal(); }

// --- Default Accessors
template <typename Mesh, typename A = ACCESS>
KOKKOS_INLINE_FUNCTION std::enable_if_t<!std::is_same<A, READ>::value, T>&
KOKKOS_FORCEINLINE_FUNCTION std::enable_if_t<!std::is_same_v<A, READ>, T>&
get(const Mesh& ngpMesh, stk::mesh::Entity entity, int component) const
{
return stkField_.get(ngpMesh, entity, component);
}

template <typename A = ACCESS>
KOKKOS_INLINE_FUNCTION std::enable_if_t<!std::is_same<A, READ>::value, T>&
KOKKOS_FORCEINLINE_FUNCTION std::enable_if_t<!std::is_same_v<A, READ>, T>&
get(stk::mesh::FastMeshIndex& index, int component) const
{
return stkField_.get(index, component);
}

template <typename MeshIndex, typename A = ACCESS>
KOKKOS_INLINE_FUNCTION std::enable_if_t<!std::is_same<A, READ>::value, T>&
KOKKOS_FORCEINLINE_FUNCTION std::enable_if_t<!std::is_same_v<A, READ>, T>&
get(MeshIndex index, int component) const
{
return stkField_.get(index, component);
}

template <typename A = ACCESS>
KOKKOS_INLINE_FUNCTION std::enable_if_t<!std::is_same<A, READ>::value, T>&
KOKKOS_FORCEINLINE_FUNCTION std::enable_if_t<!std::is_same_v<A, READ>, T>&
operator()(const stk::mesh::FastMeshIndex& index, int component) const
{
return stkField_(index, component);
}

template <typename MeshIndex, typename A = ACCESS>
KOKKOS_INLINE_FUNCTION std::enable_if_t<!std::is_same<A, READ>::value, T>&
KOKKOS_FORCEINLINE_FUNCTION std::enable_if_t<!std::is_same_v<A, READ>, T>&
operator()(const MeshIndex index, int component) const
{
return stkField_(index, component);
}

// --- Const Accessors
template <typename Mesh, typename A = ACCESS>
KOKKOS_INLINE_FUNCTION const
std::enable_if_t<std::is_same<A, READ>::value, T>&
KOKKOS_FORCEINLINE_FUNCTION const
std::enable_if_t<std::is_same_v<A, READ>, T>&
get(const Mesh& ngpMesh, stk::mesh::Entity entity, int component) const
{
return stkField_.get(ngpMesh, entity, component);
}

template <typename A = ACCESS>
KOKKOS_INLINE_FUNCTION const
std::enable_if_t<std::is_same<A, READ>::value, T>&
KOKKOS_FORCEINLINE_FUNCTION const
std::enable_if_t<std::is_same_v<A, READ>, T>&
get(stk::mesh::FastMeshIndex& index, int component) const
{
return stkField_.get(index, component);
}

template <typename MeshIndex, typename A = ACCESS>
KOKKOS_INLINE_FUNCTION const
std::enable_if_t<std::is_same<A, READ>::value, T>&
KOKKOS_FORCEINLINE_FUNCTION const
std::enable_if_t<std::is_same_v<A, READ>, T>&
get(MeshIndex index, int component) const
{
return stkField_.get(index, component);
}

template <typename A = ACCESS>
KOKKOS_INLINE_FUNCTION const
std::enable_if_t<std::is_same<A, READ>::value, T>&
KOKKOS_FORCEINLINE_FUNCTION const
std::enable_if_t<std::is_same_v<A, READ>, T>&
operator()(const stk::mesh::FastMeshIndex& index, int component) const
{
return stkField_(index, component);
}

template <typename MeshIndex, typename A = ACCESS>
KOKKOS_INLINE_FUNCTION const
std::enable_if_t<std::is_same<A, READ>::value, T>&
KOKKOS_FORCEINLINE_FUNCTION const
std::enable_if_t<std::is_same_v<A, READ>, T>&
operator()(const MeshIndex index, int component) const
{
return stkField_(index, component);
Expand All @@ -281,12 +277,10 @@ class SmartField<FieldType, tags::HOST, ACCESS>
{
private:
static constexpr bool is_read_{
std::is_same<ACCESS, READ>::value ||
std::is_same<ACCESS, READ_WRITE>::value};
std::is_same_v<ACCESS, READ> || std::is_same_v<ACCESS, READ_WRITE>};

static constexpr bool is_write_{
std::is_same<ACCESS, WRITE_ALL>::value ||
std::is_same<ACCESS, READ_WRITE>::value};
std::is_same_v<ACCESS, WRITE_ALL> || std::is_same_v<ACCESS, READ_WRITE>};

FieldType stkField_;
const bool is_copy_constructed_{false};
Expand Down Expand Up @@ -320,77 +314,77 @@ class SmartField<FieldType, tags::HOST, ACCESS>
}

//************************************************************
// Host functions (Remove KOKKOS_INLINE_FUNCTION decorators)
// Host functions (Remove KOKKOS_FORCEINLINE_FUNCTION decorators)
//************************************************************
inline unsigned get_ordinal() const { return stkField_.get_ordinal(); }

// --- Default Accessors
template <typename Mesh, typename A = ACCESS>
inline std::enable_if_t<!std::is_same<A, READ>::value, T>&
inline std::enable_if_t<!std::is_same_v<A, READ>, T>&
get(const Mesh& ngpMesh, stk::mesh::Entity entity, int component) const
{
return stkField_.get(ngpMesh, entity, component);
}

template <typename A = ACCESS>
inline std::enable_if_t<!std::is_same<A, READ>::value, T>&
inline std::enable_if_t<!std::is_same_v<A, READ>, T>&
get(stk::mesh::FastMeshIndex& index, int component) const
{
return stkField_.get(index, component);
}

template <typename MeshIndex, typename A = ACCESS>
inline std::enable_if_t<!std::is_same<A, READ>::value, T>&
inline std::enable_if_t<!std::is_same_v<A, READ>, T>&
get(MeshIndex index, int component) const
{
return stkField_.get(index, component);
}

template <typename A = ACCESS>
inline std::enable_if_t<!std::is_same<A, READ>::value, T>&
inline std::enable_if_t<!std::is_same_v<A, READ>, T>&
operator()(const stk::mesh::FastMeshIndex& index, int component) const
{
return stkField_(index, component);
}

template <typename MeshIndex, typename A = ACCESS>
inline std::enable_if_t<!std::is_same<A, READ>::value, T>&
inline std::enable_if_t<!std::is_same_v<A, READ>, T>&
operator()(const MeshIndex index, int component) const
{
return stkField_(index, component);
}

// --- Const Accessors
template <typename Mesh, typename A = ACCESS>
inline const std::enable_if_t<std::is_same<A, READ>::value, T>&
inline const std::enable_if_t<std::is_same_v<A, READ>, T>&
get(const Mesh& ngpMesh, stk::mesh::Entity entity, int component) const
{
return stkField_.get(ngpMesh, entity, component);
}

template <typename A = ACCESS>
inline const std::enable_if_t<std::is_same<A, READ>::value, T>&
inline const std::enable_if_t<std::is_same_v<A, READ>, T>&
get(stk::mesh::FastMeshIndex& index, int component) const
{
return stkField_.get(index, component);
}

template <typename MeshIndex, typename A = ACCESS>
inline const std::enable_if_t<std::is_same<A, READ>::value, T>&
inline const std::enable_if_t<std::is_same_v<A, READ>, T>&
get(MeshIndex index, int component) const
{
return stkField_.get(index, component);
}

template <typename A = ACCESS>
inline const std::enable_if_t<std::is_same<A, READ>::value, T>&
inline const std::enable_if_t<std::is_same_v<A, READ>, T>&
operator()(const stk::mesh::FastMeshIndex& index, int component) const
{
return stkField_(index, component);
}

template <typename MeshIndex, typename A = ACCESS>
inline const std::enable_if_t<std::is_same<A, READ>::value, T>&
inline const std::enable_if_t<std::is_same_v<A, READ>, T>&
operator()(const MeshIndex index, int component) const
{
return stkField_(index, component);
Expand Down
41 changes: 41 additions & 0 deletions unit_tests/UnitTestSmartField.C
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,24 @@ lambda_loop_assign(
});
}

template <typename T>
void
lambda_loop_performance(
stk::mesh::BulkData& bulk,
stk::mesh::PartVector partVec,
T& ptr,
double val = 300.0)
{
stk::mesh::NgpMesh& ngpMesh = stk::mesh::get_updated_ngp_mesh(bulk);
stk::mesh::Selector sel = stk::mesh::selectUnion(partVec);
stk::mesh::for_each_entity_run(
ngpMesh, stk::topology::NODE_RANK, sel,
KOKKOS_LAMBDA(const stk::mesh::FastMeshIndex& entity) {
for (int i = 0; i < 1e6; ++i)
ptr(entity, 0) = val;
});
}

//*****************************************************************************
// Tests
//*****************************************************************************
Expand Down Expand Up @@ -112,6 +130,29 @@ TEST_F(TestSmartField, device_read_mod_no_sync_with_lambda)
EXPECT_EQ(initSyncsHost_ + 0, ngpField_->num_syncs_to_host());
}

TEST_F(TestSmartField, device_performance_smart_field)
{
ngpField_->modify_on_host();

ASSERT_TRUE(ngpField_->need_sync_to_device());

auto sPtr = MakeSmartField<DEVICE, READ_WRITE>()(*ngpField_);

double assignmentValue = 300.0;
lambda_loop_performance(*bulk, partVec, sPtr, assignmentValue);
}

TEST_F(TestSmartField, device_performance_ngp_field)
{
ngpField_->modify_on_host();

ASSERT_TRUE(ngpField_->need_sync_to_device());
ngpField_->sync_to_device();

double assignmentValue = 300.0;
lambda_loop_performance(*bulk, partVec, *ngpField_, assignmentValue);
}

TEST_F(TestSmartField, update_field_on_device_check_on_host)
{
ngpField_->modify_on_host();
Expand Down

0 comments on commit 90b42ca

Please sign in to comment.