Skip to content

Commit

Permalink
Add SmartField usage to some legacy algorithms (#1236)
Browse files Browse the repository at this point in the history
* Add SmartField usage to some legacy algorithms

* Format

* Fixes for change in API

* Style
  • Loading branch information
psakievich authored Apr 5, 2024
1 parent 30a8102 commit 9e33aeb
Show file tree
Hide file tree
Showing 9 changed files with 46 additions and 43 deletions.
2 changes: 2 additions & 0 deletions include/Algorithm.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class Realm;
class MasterElement;
class SupplementalAlgorithm;
class Kernel;
class FieldManager;

class Algorithm
{
Expand All @@ -43,6 +44,7 @@ class Algorithm

Realm& realm_;
stk::mesh::PartVector partVec_;
const FieldManager& fieldManager_;
std::vector<SupplementalAlgorithm*> supplementalAlg_;

std::vector<Kernel*> activeKernels_;
Expand Down
8 changes: 4 additions & 4 deletions include/FieldManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class FieldManager
const std::string& name,
const stk::mesh::PartVector& parts,
const void* init_val = nullptr,
stk::mesh::FieldState state = stk::mesh::FieldState::StateNone)
stk::mesh::FieldState state = stk::mesh::FieldState::StateNone) const
{
const int numStates = 0;
const int numComponents = 0;
Expand All @@ -70,7 +70,7 @@ class FieldManager
}

/// Check to see if the field has been registered.
bool field_exists(const std::string& name);
bool field_exists(const std::string& name) const;

unsigned size() const { return meta_.get_fields().size(); }
/// Register a Generic field.
Expand All @@ -83,7 +83,7 @@ class FieldManager
const int numStates,
const int numComponents,
const void* init_val = nullptr,
stk::mesh::FieldState state = stk::mesh::FieldState::StateNone)
stk::mesh::FieldState state = stk::mesh::FieldState::StateNone) const
{
register_field(name, parts, numStates, numComponents, init_val);
return get_field_ptr<GenericFieldType::value_type>(name, state);
Expand Down Expand Up @@ -124,7 +124,7 @@ class FieldManager
const stk::mesh::PartVector& parts,
const int numStates = 0,
const int numComponents = 0,
const void* init_val = nullptr);
const void* init_val = nullptr) const;

/// Given the named field that has already been registered on the CPU
/// return the GPU version of the same field.
Expand Down
14 changes: 14 additions & 0 deletions include/SmartField.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,13 @@ class SmartField<FieldType, tags::LEGACY, ACCESS>
return stk::mesh::field_data(stkField_, entity);
}

template <typename A = ACCESS>
inline typename std::enable_if_t<!std::is_same_v<A, READ>, T>*
operator()(const stk::mesh::Bucket& bucket) const
{
return stk::mesh::field_data(stkField_, bucket);
}

// --- Const Accessors
template <typename A = ACCESS>
inline const typename std::enable_if_t<std::is_same_v<A, READ>, T>*
Expand All @@ -125,6 +132,13 @@ class SmartField<FieldType, tags::LEGACY, ACCESS>
return stk::mesh::field_data(stkField_, entity);
}

template <typename A = ACCESS>
inline const typename std::enable_if_t<std::is_same_v<A, READ>, T>*
operator()(const stk::mesh::Bucket& bucket) const
{
return stk::mesh::field_data(stkField_, bucket);
}

~SmartField()
{
if (is_write_) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,11 @@ class ThermalConductivityFromPrandtlPropAlgorithm : public Algorithm
{
public:
ThermalConductivityFromPrandtlPropAlgorithm(
Realm& realm,
const stk::mesh::PartVector& part_vec,
ScalarFieldType* thermalCond,
ScalarFieldType* specificHeat,
ScalarFieldType* viscosity,
const double Pr);
Realm& realm, const stk::mesh::PartVector& part_vec, const double Pr);

virtual ~ThermalConductivityFromPrandtlPropAlgorithm() {}

virtual void execute();

ScalarFieldType* thermalCond_;
ScalarFieldType* specHeat_;
ScalarFieldType* viscosity_;

const double Pr_;
};

Expand Down
9 changes: 7 additions & 2 deletions src/Algorithm.C
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include <Algorithm.h>
#include <SupplementalAlgorithm.h>
#include <Realm.h>
#include <kernel/Kernel.h>

namespace sierra {
Expand All @@ -23,14 +24,18 @@ namespace nalu {
//-------- constructor -----------------------------------------------------
//--------------------------------------------------------------------------
Algorithm::Algorithm(Realm& realm, stk::mesh::Part* part)
: realm_(realm), partVec_(1, part)
: realm_(realm),
partVec_(1, part),
fieldManager_(*(realm.fieldManager_.get()))
{
// nothing to do
}

// alternative; provide full partVec
Algorithm::Algorithm(Realm& realm, const stk::mesh::PartVector& partVec)
: realm_(realm), partVec_(partVec)
: realm_(realm),
partVec_(partVec),
fieldManager_(*(realm.fieldManager_.get()))
{
// nothing to do
}
Expand Down
2 changes: 1 addition & 1 deletion src/EnthalpyEquationSystem.C
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ EnthalpyEquationSystem::register_nodal_fields(
"this constant value"
<< std::endl;
Algorithm* propAlg = new ThermalConductivityFromPrandtlPropAlgorithm(
realm_, part_vec, thermalCond_, specHeat_, visc_, providedPr);
realm_, part_vec, providedPr);
propertyAlg_.push_back(propAlg);
} else {
// no Pr provided, simply augment property map and expect lambda to be
Expand Down
4 changes: 2 additions & 2 deletions src/FieldManager.C
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ FieldManager::FieldManager(stk::mesh::MetaData& meta, const int numStates)
}

bool
FieldManager::field_exists(const std::string& name)
FieldManager::field_exists(const std::string& name) const
{
auto definition = FieldRegistry::query(numDimensions_, numStates_, name);

Expand All @@ -37,7 +37,7 @@ FieldManager::register_field(
const stk::mesh::PartVector& parts,
const int numStates,
const int numComponents,
const void* init_val)
const void* init_val) const
{
auto definition = FieldRegistry::query(numDimensions_, numStates_, name);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include <Algorithm.h>
#include <property_evaluator/ThermalConductivityFromPrandtlPropAlgorithm.h>
#include <FieldManager.h>
#include <FieldTypeDef.h>
#include <Realm.h>

Expand All @@ -31,19 +32,12 @@ namespace nalu {
//--------------------------------------------------------------------------
ThermalConductivityFromPrandtlPropAlgorithm::
ThermalConductivityFromPrandtlPropAlgorithm(
Realm& realm,
const stk::mesh::PartVector& part_vec,
ScalarFieldType* thermalCond,
ScalarFieldType* specHeat,
ScalarFieldType* viscosity,
const double Pr)
: Algorithm(realm, part_vec),
thermalCond_(thermalCond),
specHeat_(specHeat),
viscosity_(viscosity),
Pr_(Pr)
Realm& realm, const stk::mesh::PartVector& part_vec, const double Pr)
: Algorithm(realm, part_vec), Pr_(Pr)
{
// does nothing
fieldManager_.register_field<double>("thermal_conductivity", part_vec);
fieldManager_.register_field<double>("specific_heat", part_vec);
fieldManager_.register_field<double>("viscosity", part_vec);
}

//--------------------------------------------------------------------------
Expand All @@ -60,25 +54,23 @@ ThermalConductivityFromPrandtlPropAlgorithm::execute()
stk::mesh::BucketVector const& node_buckets =
realm_.get_buckets(stk::topology::NODE_RANK, selector);

thermalCond_->sync_to_host();
specHeat_->sync_to_host();
viscosity_->sync_to_host();
auto thermalCond =
fieldManager_.get_legacy_smart_field<double, tags::READ_WRITE>(
"thermal_conductivity");
const auto specHeat =
fieldManager_.get_legacy_smart_field<double, tags::READ>("specific_heat");
const auto viscosity =
fieldManager_.get_legacy_smart_field<double, tags::READ>("viscosity");

for (stk::mesh::BucketVector::const_iterator ib = node_buckets.begin();
ib != node_buckets.end(); ++ib) {
stk::mesh::Bucket& b = **ib;
const stk::mesh::Bucket::size_type length = b.size();

double* thermalCond = stk::mesh::field_data(*thermalCond_, b);
const double* specHeat = stk::mesh::field_data(*specHeat_, b);
const double* viscosity = stk::mesh::field_data(*viscosity_, b);

for (stk::mesh::Bucket::size_type k = 0; k < length; ++k) {
thermalCond[k] = specHeat[k] * viscosity[k] / Pr_;
thermalCond(b)[k] = specHeat(b)[k] * viscosity(b)[k] / Pr_;
}
}
thermalCond_->modify_on_host();
thermalCond_->sync_to_device();
}

} // namespace nalu
Expand Down

0 comments on commit 9e33aeb

Please sign in to comment.