-
Notifications
You must be signed in to change notification settings - Fork 12.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Flang][OpenMP] Lowering of host-evaluated clauses #116219
base: users/skatrak/host-eval-05-mlir-llvmir-generic
Are you sure you want to change the base?
[Flang][OpenMP] Lowering of host-evaluated clauses #116219
Conversation
@llvm/pr-subscribers-flang-openmp @llvm/pr-subscribers-mlir Author: Sergio Afonso (skatrak) ChangesThis patch adds support for lowering OpenMP clauses and expressions attached to constructs nested inside of a target region that need to be evaluated in the host device. This is done through the use of the When lowering clauses for a target construct, a more involved The resulting list of host-evaluated values is used to initialize the Afterwards, while lowering nested operations, those that might potentially be evaluated in the host (e.g. Patch is 34.90 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/116219.diff 4 Files Affected:
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 91f99ba4b0ca55..a206af77a2f51f 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -45,6 +45,19 @@ using namespace Fortran::lower::omp;
// Code generation helper functions
//===----------------------------------------------------------------------===//
+static void genOMPDispatch(lower::AbstractConverter &converter,
+ lower::SymMap &symTable,
+ semantics::SemanticsContext &semaCtx,
+ lower::pft::Evaluation &eval, mlir::Location loc,
+ const ConstructQueue &queue,
+ ConstructQueue::const_iterator item);
+
+static void processHostEvalClauses(lower::AbstractConverter &converter,
+ semantics::SemanticsContext &semaCtx,
+ lower::StatementContext &stmtCtx,
+ lower::pft::Evaluation &eval,
+ mlir::Location loc);
+
namespace {
/// Structure holding the information needed to create and bind entry block
/// arguments associated to a single clause.
@@ -63,6 +76,7 @@ struct EntryBlockArgsEntry {
/// Structure holding the information needed to create and bind entry block
/// arguments associated to all clauses that can define them.
struct EntryBlockArgs {
+ llvm::ArrayRef<mlir::Value> hostEvalVars;
EntryBlockArgsEntry inReduction;
EntryBlockArgsEntry map;
EntryBlockArgsEntry priv;
@@ -85,18 +99,146 @@ struct EntryBlockArgs {
auto getVars() const {
return llvm::concat<const mlir::Value>(
- inReduction.vars, map.vars, priv.vars, reduction.vars,
+ hostEvalVars, inReduction.vars, map.vars, priv.vars, reduction.vars,
taskReduction.vars, useDeviceAddr.vars, useDevicePtr.vars);
}
};
+
+/// Structure holding information that is needed to pass host-evaluated
+/// information to later lowering stages.
+class HostEvalInfo {
+public:
+ // Allow this function access to private members in order to initialize them.
+ friend void ::processHostEvalClauses(lower::AbstractConverter &,
+ semantics::SemanticsContext &,
+ lower::StatementContext &,
+ lower::pft::Evaluation &,
+ mlir::Location);
+
+ /// Fill \c vars with values stored in \c ops.
+ ///
+ /// The order in which values are stored matches the one expected by \see
+ /// bindOperands().
+ void collectValues(llvm::SmallVectorImpl<mlir::Value> &vars) const {
+ vars.append(ops.loopLowerBounds);
+ vars.append(ops.loopUpperBounds);
+ vars.append(ops.loopSteps);
+
+ if (ops.numTeamsLower)
+ vars.push_back(ops.numTeamsLower);
+
+ if (ops.numTeamsUpper)
+ vars.push_back(ops.numTeamsUpper);
+
+ if (ops.numThreads)
+ vars.push_back(ops.numThreads);
+
+ if (ops.threadLimit)
+ vars.push_back(ops.threadLimit);
+ }
+
+ /// Update \c ops, replacing all values with the corresponding block argument
+ /// in \c args.
+ ///
+ /// The order in which values are stored in \c args is the same as the one
+ /// used by \see collectValues().
+ void bindOperands(llvm::ArrayRef<mlir::BlockArgument> args) {
+ assert(args.size() ==
+ ops.loopLowerBounds.size() + ops.loopUpperBounds.size() +
+ ops.loopSteps.size() + (ops.numTeamsLower ? 1 : 0) +
+ (ops.numTeamsUpper ? 1 : 0) + (ops.numThreads ? 1 : 0) +
+ (ops.threadLimit ? 1 : 0) &&
+ "invalid block argument list");
+ int argIndex = 0;
+ for (size_t i = 0; i < ops.loopLowerBounds.size(); ++i)
+ ops.loopLowerBounds[i] = args[argIndex++];
+
+ for (size_t i = 0; i < ops.loopUpperBounds.size(); ++i)
+ ops.loopUpperBounds[i] = args[argIndex++];
+
+ for (size_t i = 0; i < ops.loopSteps.size(); ++i)
+ ops.loopSteps[i] = args[argIndex++];
+
+ if (ops.numTeamsLower)
+ ops.numTeamsLower = args[argIndex++];
+
+ if (ops.numTeamsUpper)
+ ops.numTeamsUpper = args[argIndex++];
+
+ if (ops.numThreads)
+ ops.numThreads = args[argIndex++];
+
+ if (ops.threadLimit)
+ ops.threadLimit = args[argIndex++];
+ }
+
+ /// Update \p clauseOps and \p ivOut with the corresponding host-evaluated
+ /// values and Fortran symbols, respectively, if they have already been
+ /// initialized but not yet applied.
+ ///
+ /// \returns whether an update was performed. If not, these clauses were not
+ /// evaluated in the host device.
+ bool apply(mlir::omp::LoopNestOperands &clauseOps,
+ llvm::SmallVectorImpl<const semantics::Symbol *> &ivOut) {
+ if (iv.empty() || loopNestApplied) {
+ loopNestApplied = true;
+ return false;
+ }
+
+ loopNestApplied = true;
+ clauseOps.loopLowerBounds = ops.loopLowerBounds;
+ clauseOps.loopUpperBounds = ops.loopUpperBounds;
+ clauseOps.loopSteps = ops.loopSteps;
+ ivOut.append(iv);
+ return true;
+ }
+
+ /// Update \p clauseOps with the corresponding host-evaluated values if they
+ /// have already been initialized but not yet applied.
+ ///
+ /// \returns whether an update was performed. If not, these clauses were not
+ /// evaluated in the host device.
+ bool apply(mlir::omp::ParallelOperands &clauseOps) {
+ if (!ops.numThreads || parallelApplied) {
+ parallelApplied = true;
+ return false;
+ }
+
+ parallelApplied = true;
+ clauseOps.numThreads = ops.numThreads;
+ return true;
+ }
+
+ /// Update \p clauseOps with the corresponding host-evaluated values if they
+ /// have already been initialized.
+ ///
+ /// \returns whether an update was performed. If not, these clauses were not
+ /// evaluated in the host device.
+ bool apply(mlir::omp::TeamsOperands &clauseOps) {
+ if (!ops.numTeamsLower && !ops.numTeamsUpper && !ops.threadLimit)
+ return false;
+
+ clauseOps.numTeamsLower = ops.numTeamsLower;
+ clauseOps.numTeamsUpper = ops.numTeamsUpper;
+ clauseOps.threadLimit = ops.threadLimit;
+ return true;
+ }
+
+private:
+ mlir::omp::HostEvaluatedOperands ops;
+ llvm::SmallVector<const semantics::Symbol *> iv;
+ bool loopNestApplied = false, parallelApplied = false;
+};
} // namespace
-static void genOMPDispatch(lower::AbstractConverter &converter,
- lower::SymMap &symTable,
- semantics::SemanticsContext &semaCtx,
- lower::pft::Evaluation &eval, mlir::Location loc,
- const ConstructQueue &queue,
- ConstructQueue::const_iterator item);
+/// Stack of \see HostEvalInfo to represent the current nest of \c omp.target
+/// operations being created.
+///
+/// The current implementation prevents nested 'target' regions from breaking
+/// the handling of the outer region by keeping a stack of information
+/// structures, but it will probably still require some further work to support
+/// reverse offloading.
+static llvm::SmallVector<HostEvalInfo, 0> hostEvalInfo;
/// Bind symbols to their corresponding entry block arguments.
///
@@ -219,6 +361,8 @@ static void bindEntryBlockArgs(lower::AbstractConverter &converter,
};
// Process in clause name alphabetical order to match block arguments order.
+ // Do not bind host_eval variables because they cannot be used inside of the
+ // corresponding region, except for very specific cases handled separately.
bindPrivateLike(args.inReduction.syms, args.inReduction.vars,
op.getInReductionBlockArgs());
bindMapLike(args.map.syms, op.getMapBlockArgs());
@@ -256,6 +400,246 @@ extractMappedBaseValues(llvm::ArrayRef<mlir::Value> vars,
});
}
+/// Get the directive enumeration value corresponding to the given OpenMP
+/// construct PFT node.
+llvm::omp::Directive
+extractOmpDirective(const parser::OpenMPConstruct &ompConstruct) {
+ return common::visit(
+ common::visitors{
+ [](const parser::OpenMPAllocatorsConstruct &c) {
+ return llvm::omp::OMPD_allocators;
+ },
+ [](const parser::OpenMPAtomicConstruct &c) {
+ return llvm::omp::OMPD_atomic;
+ },
+ [](const parser::OpenMPBlockConstruct &c) {
+ return std::get<parser::OmpBlockDirective>(
+ std::get<parser::OmpBeginBlockDirective>(c.t).t)
+ .v;
+ },
+ [](const parser::OpenMPCriticalConstruct &c) {
+ return llvm::omp::OMPD_critical;
+ },
+ [](const parser::OpenMPDeclarativeAllocate &c) {
+ return llvm::omp::OMPD_allocate;
+ },
+ [](const parser::OpenMPExecutableAllocate &c) {
+ return llvm::omp::OMPD_allocate;
+ },
+ [](const parser::OpenMPLoopConstruct &c) {
+ return std::get<parser::OmpLoopDirective>(
+ std::get<parser::OmpBeginLoopDirective>(c.t).t)
+ .v;
+ },
+ [](const parser::OpenMPSectionConstruct &c) {
+ return llvm::omp::OMPD_section;
+ },
+ [](const parser::OpenMPSectionsConstruct &c) {
+ return std::get<parser::OmpSectionsDirective>(
+ std::get<parser::OmpBeginSectionsDirective>(c.t).t)
+ .v;
+ },
+ [](const parser::OpenMPStandaloneConstruct &c) {
+ return common::visit(
+ common::visitors{
+ [](const parser::OpenMPSimpleStandaloneConstruct &c) {
+ return std::get<parser::OmpSimpleStandaloneDirective>(c.t)
+ .v;
+ },
+ [](const parser::OpenMPFlushConstruct &c) {
+ return llvm::omp::OMPD_flush;
+ },
+ [](const parser::OpenMPCancelConstruct &c) {
+ return llvm::omp::OMPD_cancel;
+ },
+ [](const parser::OpenMPCancellationPointConstruct &c) {
+ return llvm::omp::OMPD_cancellation_point;
+ },
+ [](const parser::OpenMPDepobjConstruct &c) {
+ return llvm::omp::OMPD_depobj;
+ }},
+ c.u);
+ }},
+ ompConstruct.u);
+}
+
+/// Populate the global \see hostEvalInfo after processing clauses for the given
+/// \p eval OpenMP target construct, or nested constructs, if these must be
+/// evaluated outside of the target region per the spec.
+///
+/// In particular, this will ensure that in 'target teams' and equivalent nested
+/// constructs, the \c thread_limit and \c num_teams clauses will be evaluated
+/// in the host. Additionally, loop bounds, steps and the \c num_threads clause
+/// will also be evaluated in the host if a target SPMD construct is detected
+/// (i.e. 'target teams distribute parallel do [simd]' or equivalent nesting).
+///
+/// The result, stored as a global, is intended to be used to populate the \c
+/// host_eval operands of the associated \c omp.target operation, and also to be
+/// checked and used by later lowering steps to populate the corresponding
+/// operands of the \c omp.teams, \c omp.parallel or \c omp.loop_nest
+/// operations.
+static void processHostEvalClauses(lower::AbstractConverter &converter,
+ semantics::SemanticsContext &semaCtx,
+ lower::StatementContext &stmtCtx,
+ lower::pft::Evaluation &eval,
+ mlir::Location loc) {
+ // Obtain the list of clauses of the given OpenMP block or loop construct
+ // evaluation. Other evaluations passed to this lambda keep `clauses`
+ // unchanged.
+ auto extractClauses = [&semaCtx](lower::pft::Evaluation &eval,
+ List<Clause> &clauses) {
+ const auto *ompEval = eval.getIf<parser::OpenMPConstruct>();
+ if (!ompEval)
+ return;
+
+ const parser::OmpClauseList *beginClauseList = nullptr;
+ const parser::OmpClauseList *endClauseList = nullptr;
+ common::visit(
+ common::visitors{
+ [&](const parser::OpenMPBlockConstruct &ompConstruct) {
+ const auto &beginDirective =
+ std::get<parser::OmpBeginBlockDirective>(ompConstruct.t);
+ beginClauseList =
+ &std::get<parser::OmpClauseList>(beginDirective.t);
+ endClauseList = &std::get<parser::OmpClauseList>(
+ std::get<parser::OmpEndBlockDirective>(ompConstruct.t).t);
+ },
+ [&](const parser::OpenMPLoopConstruct &ompConstruct) {
+ const auto &beginDirective =
+ std::get<parser::OmpBeginLoopDirective>(ompConstruct.t);
+ beginClauseList =
+ &std::get<parser::OmpClauseList>(beginDirective.t);
+
+ if (auto &endDirective =
+ std::get<std::optional<parser::OmpEndLoopDirective>>(
+ ompConstruct.t))
+ endClauseList =
+ &std::get<parser::OmpClauseList>(endDirective->t);
+ },
+ [&](const auto &) {}},
+ ompEval->u);
+
+ assert(beginClauseList && "expected begin directive");
+ clauses.append(makeClauses(*beginClauseList, semaCtx));
+
+ if (endClauseList)
+ clauses.append(makeClauses(*endClauseList, semaCtx));
+ };
+
+ // Return the directive that is immediately nested inside of the given
+ // `parent` evaluation, if it is its only non-end-statement nested evaluation
+ // and it represents an OpenMP construct.
+ auto extractOnlyOmpNestedDir = [](lower::pft::Evaluation &parent)
+ -> std::optional<llvm::omp::Directive> {
+ if (!parent.hasNestedEvaluations())
+ return std::nullopt;
+
+ llvm::omp::Directive dir;
+ auto &nested = parent.getFirstNestedEvaluation();
+ if (const auto *ompEval = nested.getIf<parser::OpenMPConstruct>())
+ dir = extractOmpDirective(*ompEval);
+ else
+ return std::nullopt;
+
+ for (auto &sibling : parent.getNestedEvaluations())
+ if (&sibling != &nested && !sibling.isEndStmt())
+ return std::nullopt;
+
+ return dir;
+ };
+
+ // Process the given evaluation assuming it's part of a 'target' construct or
+ // captured by one, and store results in the global `hostEvalInfo`.
+ std::function<void(lower::pft::Evaluation &, const List<Clause> &)>
+ processEval;
+ processEval = [&](lower::pft::Evaluation &eval, const List<Clause> &clauses) {
+ using namespace llvm::omp;
+ ClauseProcessor cp(converter, semaCtx, clauses);
+
+ // Call `processEval` recursively with the immediately nested evaluation and
+ // its corresponding clauses if there is a single nested evaluation
+ // representing an OpenMP directive that passes the given test.
+ auto processSingleNestedIf = [&](llvm::function_ref<bool(Directive)> test) {
+ std::optional<Directive> nestedDir = extractOnlyOmpNestedDir(eval);
+ if (!nestedDir || !test(*nestedDir))
+ return;
+
+ lower::pft::Evaluation &nestedEval = eval.getFirstNestedEvaluation();
+ List<lower::omp::Clause> nestedClauses;
+ extractClauses(nestedEval, nestedClauses);
+ processEval(nestedEval, nestedClauses);
+ };
+
+ const auto *ompEval = eval.getIf<parser::OpenMPConstruct>();
+ if (!ompEval)
+ return;
+
+ HostEvalInfo &hostInfo = hostEvalInfo.back();
+
+ switch (extractOmpDirective(*ompEval)) {
+ // Cases where 'teams' and target SPMD clauses might be present.
+ case OMPD_teams_distribute_parallel_do:
+ case OMPD_teams_distribute_parallel_do_simd:
+ cp.processThreadLimit(stmtCtx, hostInfo.ops);
+ [[fallthrough]];
+ case OMPD_target_teams_distribute_parallel_do:
+ case OMPD_target_teams_distribute_parallel_do_simd:
+ cp.processNumTeams(stmtCtx, hostInfo.ops);
+ [[fallthrough]];
+ case OMPD_distribute_parallel_do:
+ case OMPD_distribute_parallel_do_simd:
+ cp.processCollapse(loc, eval, hostInfo.ops, hostInfo.iv);
+ cp.processNumThreads(stmtCtx, hostInfo.ops);
+ break;
+
+ // Cases where 'teams' clauses might be present, and target SPMD is
+ // possible by looking at nested evaluations.
+ case OMPD_teams:
+ cp.processThreadLimit(stmtCtx, hostInfo.ops);
+ [[fallthrough]];
+ case OMPD_target_teams:
+ cp.processNumTeams(stmtCtx, hostInfo.ops);
+ processSingleNestedIf([](Directive nestedDir) {
+ return nestedDir == OMPD_distribute_parallel_do ||
+ nestedDir == OMPD_distribute_parallel_do_simd;
+ });
+ break;
+
+ // Cases where only 'teams' host-evaluated clauses might be present.
+ case OMPD_teams_distribute:
+ case OMPD_teams_distribute_simd:
+ cp.processThreadLimit(stmtCtx, hostInfo.ops);
+ [[fallthrough]];
+ case OMPD_target_teams_distribute:
+ case OMPD_target_teams_distribute_simd:
+ cp.processNumTeams(stmtCtx, hostInfo.ops);
+ break;
+
+ // Standalone 'target' case.
+ case OMPD_target: {
+ processSingleNestedIf(
+ [](Directive nestedDir) { return topTeamsSet.test(nestedDir); });
+ break;
+ }
+ default:
+ break;
+ }
+ };
+
+ assert(!hostEvalInfo.empty() && "expected HOST_EVAL info structure");
+
+ const auto *ompEval = eval.getIf<parser::OpenMPConstruct>();
+ assert(ompEval &&
+ llvm::omp::allTargetSet.test(extractOmpDirective(*ompEval)) &&
+ "expected TARGET construct evaluation");
+
+ // Use the whole list of clauses passed to the construct here, rather than the
+ // ones only applied to omp.target.
+ List<lower::omp::Clause> clauses;
+ extractClauses(eval, clauses);
+ processEval(eval, clauses);
+}
+
static lower::pft::Evaluation *
getCollapsedLoopEval(lower::pft::Evaluation &eval, int collapseValue) {
// Return the Evaluation of the innermost collapsed loop, or the current one
@@ -638,11 +1022,11 @@ static mlir::Block *genEntryBlock(lower::AbstractConverter &converter,
llvm::SmallVector<mlir::Type> types;
llvm::SmallVector<mlir::Location> locs;
- unsigned numVars = args.inReduction.vars.size() + args.map.vars.size() +
- args.priv.vars.size() + args.reduction.vars.size() +
- args.taskReduction.vars.size() +
- args.useDeviceAddr.vars.size() +
- args.useDevicePtr.vars.size();
+ unsigned numVars =
+ args.hostEvalVars.size() + args.inReduction.vars.size() +
+ args.map.vars.size() + args.priv.vars.size() +
+ args.reduction.vars.size() + args.taskReduction.vars.size() +
+ args.useDeviceAddr.vars.size() + args.useDevicePtr.vars.size();
types.reserve(numVars);
locs.reserve(numVars);
@@ -655,6 +1039,7 @@ static mlir::Block *genEntryBlock(lower::AbstractConverter &converter,
// Populate block arguments in clause name alphabetical order to match
// expected order by the BlockArgOpenMPOpInterface.
+ extractTypeLoc(args.hostEvalVars);
extractTypeLoc(args.inReduction.vars);
extractTypeLoc(args.map.vars);
extractTypeLoc(args.priv.vars);
@@ -991,12 +1376,15 @@ static void genBodyOfTargetOp(
mlir::omp::TargetOp &targetOp, const EntryBlockArgs &args,
const mlir::Location ¤tLocation, const ConstructQueue &queue,
ConstructQueue::const_iterator item, DataSharingProcessor &dsp) {
+ assert(!hostEvalInfo.empty() && "expected HOST_EVAL info structure");
+
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
auto argIface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(*targetOp);
mlir::Region ®ion = targetOp.getRegion();
mlir::Block *entryBlock = genEntryBlock(converter, args, region);
bindEntryBlockArgs(converter, targetOp, args);
+ hostEvalInfo.back().bindOperands(argIface.getHostEvalBlockArgs());
// Check if cloning the bounds introduced any dependency on the outer region.
// If so, then either clone them as well if they are MemoryEffectFree, or else
@@ -1172,7 +1560,10 @@ genLoopNestClauses(lower::AbstractConverter &converter,
mlir::Location loc, ...
[truncated]
|
8ff0d3b
to
b5571be
Compare
6020805
to
28770c7
Compare
b5571be
to
58bd5ff
Compare
This patch adds support for lowering OpenMP clauses and expressions attached to constructs nested inside of a target region that need to be evaluated in the host device. This is done through the use of the `OpenMP_HostEvalClause` `omp.target` set of operands and entry block arguments. When lowering clauses for a target construct, a more involved `processHostEvalClauses()` function is called, which looks at the current and potentially other nested constructs in order to find and lower clauses that need to be processed outside of the `omp.target` operation under construction. This populates an instance of a global structure with the resulting MLIR values. The resulting list of host-evaluated values is used to initialize the `host_eval` operands when constructing the `omp.target` operation, and then replaced with the corresponding block arguments after creating that operation's region. Afterwards, while lowering nested operations, those that might potentially be evaluated in the host (e.g. `num_teams`, `thread_limit`, `num_threads` and `collapse`) check first whether there is an active global host-evaluated information structure and whether it holds values referring to these clauses. If that is the case, the stored values (referring to `omp.target` entry block arguments at that stage) are used instead of lowering clauses again.
28770c7
to
5817462
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, albeit not being a true expert here. But the general idea looks fine to me.
This patch adds support for lowering OpenMP clauses and expressions attached to constructs nested inside of a target region that need to be evaluated in the host device. This is done through the use of the
OpenMP_HostEvalClause
omp.target
set of operands and entry block arguments.When lowering clauses for a target construct, a more involved
processHostEvalClauses()
function is called, which looks at the current and potentially other nested constructs in order to find and lower clauses that need to be processed outside of theomp.target
operation under construction. This populates an instance of a global structure with the resulting MLIR values.The resulting list of host-evaluated values is used to initialize the
host_eval
operands when constructing theomp.target
operation, and then replaced with the corresponding block arguments after creating that operation's region.Afterwards, while lowering nested operations, those that might potentially be evaluated in the host (e.g.
num_teams
,thread_limit
,num_threads
andcollapse
) check first whether there is an active global host-evaluated information structure and whether it holds values referring to these clauses. If that is the case, the stored values (referring toomp.target
entry block arguments at that stage) are used instead of lowering these clauses again.