Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Flang][OpenMP] Lowering of host-evaluated clauses #116219

Open
wants to merge 1 commit into
base: users/skatrak/host-eval-05-mlir-llvmir-generic
Choose a base branch
from

Conversation

skatrak
Copy link
Contributor

@skatrak skatrak commented Nov 14, 2024

This patch adds support for lowering OpenMP clauses and expressions attached to constructs nested inside of a target region that need to be evaluated in the host device. This is done through the use of the OpenMP_HostEvalClause omp.target set of operands and entry block arguments.

When lowering clauses for a target construct, a more involved processHostEvalClauses() function is called, which looks at the current and potentially other nested constructs in order to find and lower clauses that need to be processed outside of the omp.target operation under construction. This populates an instance of a global structure with the resulting MLIR values.

The resulting list of host-evaluated values is used to initialize the host_eval operands when constructing the omp.target operation, and then replaced with the corresponding block arguments after creating that operation's region.

Afterwards, while lowering nested operations, those that might potentially be evaluated in the host (e.g. num_teams, thread_limit, num_threads and collapse) check first whether there is an active global host-evaluated information structure and whether it holds values referring to these clauses. If that is the case, the stored values (referring to omp.target entry block arguments at that stage) are used instead of lowering these clauses again.

@llvmbot
Copy link
Member

llvmbot commented Nov 14, 2024

@llvm/pr-subscribers-flang-openmp
@llvm/pr-subscribers-mlir-openmp
@llvm/pr-subscribers-flang-fir-hlfir

@llvm/pr-subscribers-mlir

Author: Sergio Afonso (skatrak)

Changes

This patch adds support for lowering OpenMP clauses and expressions attached to constructs nested inside of a target region that need to be evaluated in the host device. This is done through the use of the OpenMP_HostEvalClause omp.target set of operands and entry block arguments.

When lowering clauses for a target construct, a more involved processHostEvalClauses() function is called, which looks at the current and potentially other nested constructs in order to find and lower clauses that need to be processed outside of the omp.target operation under construction. This populates an instance of a global structure with the resulting MLIR values.

The resulting list of host-evaluated values is used to initialize the host_eval operands when constructing the omp.target operation, and then replaced with the corresponding block arguments after creating that operation's region.

Afterwards, while lowering nested operations, those that might potentially be evaluated in the host (e.g. num_teams, thread_limit, num_threads and collapse) check first whether there is an active global host-evaluated information structure and whether it holds values referring to these clauses. If that is the case, the stored values (referring to omp.target entry block arguments at that stage) are used instead of lowering these clauses again.


Patch is 34.90 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/116219.diff

4 Files Affected:

  • (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+429-29)
  • (added) flang/test/Lower/OpenMP/host-eval.f90 (+138)
  • (added) flang/test/Lower/OpenMP/target-spmd.f90 (+191)
  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h (+6)
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 91f99ba4b0ca55..a206af77a2f51f 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -45,6 +45,19 @@ using namespace Fortran::lower::omp;
 // Code generation helper functions
 //===----------------------------------------------------------------------===//
 
+static void genOMPDispatch(lower::AbstractConverter &converter,
+                           lower::SymMap &symTable,
+                           semantics::SemanticsContext &semaCtx,
+                           lower::pft::Evaluation &eval, mlir::Location loc,
+                           const ConstructQueue &queue,
+                           ConstructQueue::const_iterator item);
+
+static void processHostEvalClauses(lower::AbstractConverter &converter,
+                                   semantics::SemanticsContext &semaCtx,
+                                   lower::StatementContext &stmtCtx,
+                                   lower::pft::Evaluation &eval,
+                                   mlir::Location loc);
+
 namespace {
 /// Structure holding the information needed to create and bind entry block
 /// arguments associated to a single clause.
@@ -63,6 +76,7 @@ struct EntryBlockArgsEntry {
 /// Structure holding the information needed to create and bind entry block
 /// arguments associated to all clauses that can define them.
 struct EntryBlockArgs {
+  llvm::ArrayRef<mlir::Value> hostEvalVars;
   EntryBlockArgsEntry inReduction;
   EntryBlockArgsEntry map;
   EntryBlockArgsEntry priv;
@@ -85,18 +99,146 @@ struct EntryBlockArgs {
 
   auto getVars() const {
     return llvm::concat<const mlir::Value>(
-        inReduction.vars, map.vars, priv.vars, reduction.vars,
+        hostEvalVars, inReduction.vars, map.vars, priv.vars, reduction.vars,
         taskReduction.vars, useDeviceAddr.vars, useDevicePtr.vars);
   }
 };
+
+/// Structure holding information that is needed to pass host-evaluated
+/// information to later lowering stages.
+class HostEvalInfo {
+public:
+  // Allow this function access to private members in order to initialize them.
+  friend void ::processHostEvalClauses(lower::AbstractConverter &,
+                                       semantics::SemanticsContext &,
+                                       lower::StatementContext &,
+                                       lower::pft::Evaluation &,
+                                       mlir::Location);
+
+  /// Fill \c vars with values stored in \c ops.
+  ///
+  /// The order in which values are stored matches the one expected by \see
+  /// bindOperands().
+  void collectValues(llvm::SmallVectorImpl<mlir::Value> &vars) const {
+    vars.append(ops.loopLowerBounds);
+    vars.append(ops.loopUpperBounds);
+    vars.append(ops.loopSteps);
+
+    if (ops.numTeamsLower)
+      vars.push_back(ops.numTeamsLower);
+
+    if (ops.numTeamsUpper)
+      vars.push_back(ops.numTeamsUpper);
+
+    if (ops.numThreads)
+      vars.push_back(ops.numThreads);
+
+    if (ops.threadLimit)
+      vars.push_back(ops.threadLimit);
+  }
+
+  /// Update \c ops, replacing all values with the corresponding block argument
+  /// in \c args.
+  ///
+  /// The order in which values are stored in \c args is the same as the one
+  /// used by \see collectValues().
+  void bindOperands(llvm::ArrayRef<mlir::BlockArgument> args) {
+    assert(args.size() ==
+               ops.loopLowerBounds.size() + ops.loopUpperBounds.size() +
+                   ops.loopSteps.size() + (ops.numTeamsLower ? 1 : 0) +
+                   (ops.numTeamsUpper ? 1 : 0) + (ops.numThreads ? 1 : 0) +
+                   (ops.threadLimit ? 1 : 0) &&
+           "invalid block argument list");
+    int argIndex = 0;
+    for (size_t i = 0; i < ops.loopLowerBounds.size(); ++i)
+      ops.loopLowerBounds[i] = args[argIndex++];
+
+    for (size_t i = 0; i < ops.loopUpperBounds.size(); ++i)
+      ops.loopUpperBounds[i] = args[argIndex++];
+
+    for (size_t i = 0; i < ops.loopSteps.size(); ++i)
+      ops.loopSteps[i] = args[argIndex++];
+
+    if (ops.numTeamsLower)
+      ops.numTeamsLower = args[argIndex++];
+
+    if (ops.numTeamsUpper)
+      ops.numTeamsUpper = args[argIndex++];
+
+    if (ops.numThreads)
+      ops.numThreads = args[argIndex++];
+
+    if (ops.threadLimit)
+      ops.threadLimit = args[argIndex++];
+  }
+
+  /// Update \p clauseOps and \p ivOut with the corresponding host-evaluated
+  /// values and Fortran symbols, respectively, if they have already been
+  /// initialized but not yet applied.
+  ///
+  /// \returns whether an update was performed. If not, these clauses were not
+  ///          evaluated in the host device.
+  bool apply(mlir::omp::LoopNestOperands &clauseOps,
+             llvm::SmallVectorImpl<const semantics::Symbol *> &ivOut) {
+    if (iv.empty() || loopNestApplied) {
+      loopNestApplied = true;
+      return false;
+    }
+
+    loopNestApplied = true;
+    clauseOps.loopLowerBounds = ops.loopLowerBounds;
+    clauseOps.loopUpperBounds = ops.loopUpperBounds;
+    clauseOps.loopSteps = ops.loopSteps;
+    ivOut.append(iv);
+    return true;
+  }
+
+  /// Update \p clauseOps with the corresponding host-evaluated values if they
+  /// have already been initialized but not yet applied.
+  ///
+  /// \returns whether an update was performed. If not, these clauses were not
+  ///          evaluated in the host device.
+  bool apply(mlir::omp::ParallelOperands &clauseOps) {
+    if (!ops.numThreads || parallelApplied) {
+      parallelApplied = true;
+      return false;
+    }
+
+    parallelApplied = true;
+    clauseOps.numThreads = ops.numThreads;
+    return true;
+  }
+
+  /// Update \p clauseOps with the corresponding host-evaluated values if they
+  /// have already been initialized.
+  ///
+  /// \returns whether an update was performed. If not, these clauses were not
+  ///          evaluated in the host device.
+  bool apply(mlir::omp::TeamsOperands &clauseOps) {
+    if (!ops.numTeamsLower && !ops.numTeamsUpper && !ops.threadLimit)
+      return false;
+
+    clauseOps.numTeamsLower = ops.numTeamsLower;
+    clauseOps.numTeamsUpper = ops.numTeamsUpper;
+    clauseOps.threadLimit = ops.threadLimit;
+    return true;
+  }
+
+private:
+  mlir::omp::HostEvaluatedOperands ops;
+  llvm::SmallVector<const semantics::Symbol *> iv;
+  bool loopNestApplied = false, parallelApplied = false;
+};
 } // namespace
 
-static void genOMPDispatch(lower::AbstractConverter &converter,
-                           lower::SymMap &symTable,
-                           semantics::SemanticsContext &semaCtx,
-                           lower::pft::Evaluation &eval, mlir::Location loc,
-                           const ConstructQueue &queue,
-                           ConstructQueue::const_iterator item);
+/// Stack of \see HostEvalInfo to represent the current nest of \c omp.target
+/// operations being created.
+///
+/// The current implementation prevents nested 'target' regions from breaking
+/// the handling of the outer region by keeping a stack of information
+/// structures, but it will probably still require some further work to support
+/// reverse offloading.
+static llvm::SmallVector<HostEvalInfo, 0> hostEvalInfo;
 
 /// Bind symbols to their corresponding entry block arguments.
 ///
@@ -219,6 +361,8 @@ static void bindEntryBlockArgs(lower::AbstractConverter &converter,
   };
 
   // Process in clause name alphabetical order to match block arguments order.
+  // Do not bind host_eval variables because they cannot be used inside of the
+  // corresponding region, except for very specific cases handled separately.
   bindPrivateLike(args.inReduction.syms, args.inReduction.vars,
                   op.getInReductionBlockArgs());
   bindMapLike(args.map.syms, op.getMapBlockArgs());
@@ -256,6 +400,246 @@ extractMappedBaseValues(llvm::ArrayRef<mlir::Value> vars,
   });
 }
 
+/// Get the directive enumeration value corresponding to the given OpenMP
+/// construct PFT node.
+llvm::omp::Directive
+extractOmpDirective(const parser::OpenMPConstruct &ompConstruct) {
+  return common::visit(
+      common::visitors{
+          [](const parser::OpenMPAllocatorsConstruct &c) {
+            return llvm::omp::OMPD_allocators;
+          },
+          [](const parser::OpenMPAtomicConstruct &c) {
+            return llvm::omp::OMPD_atomic;
+          },
+          [](const parser::OpenMPBlockConstruct &c) {
+            return std::get<parser::OmpBlockDirective>(
+                       std::get<parser::OmpBeginBlockDirective>(c.t).t)
+                .v;
+          },
+          [](const parser::OpenMPCriticalConstruct &c) {
+            return llvm::omp::OMPD_critical;
+          },
+          [](const parser::OpenMPDeclarativeAllocate &c) {
+            return llvm::omp::OMPD_allocate;
+          },
+          [](const parser::OpenMPExecutableAllocate &c) {
+            return llvm::omp::OMPD_allocate;
+          },
+          [](const parser::OpenMPLoopConstruct &c) {
+            return std::get<parser::OmpLoopDirective>(
+                       std::get<parser::OmpBeginLoopDirective>(c.t).t)
+                .v;
+          },
+          [](const parser::OpenMPSectionConstruct &c) {
+            return llvm::omp::OMPD_section;
+          },
+          [](const parser::OpenMPSectionsConstruct &c) {
+            return std::get<parser::OmpSectionsDirective>(
+                       std::get<parser::OmpBeginSectionsDirective>(c.t).t)
+                .v;
+          },
+          [](const parser::OpenMPStandaloneConstruct &c) {
+            return common::visit(
+                common::visitors{
+                    [](const parser::OpenMPSimpleStandaloneConstruct &c) {
+                      return std::get<parser::OmpSimpleStandaloneDirective>(c.t)
+                          .v;
+                    },
+                    [](const parser::OpenMPFlushConstruct &c) {
+                      return llvm::omp::OMPD_flush;
+                    },
+                    [](const parser::OpenMPCancelConstruct &c) {
+                      return llvm::omp::OMPD_cancel;
+                    },
+                    [](const parser::OpenMPCancellationPointConstruct &c) {
+                      return llvm::omp::OMPD_cancellation_point;
+                    },
+                    [](const parser::OpenMPDepobjConstruct &c) {
+                      return llvm::omp::OMPD_depobj;
+                    }},
+                c.u);
+          }},
+      ompConstruct.u);
+}
+
+/// Populate the global \see hostEvalInfo after processing clauses for the given
+/// \p eval OpenMP target construct, or nested constructs, if these must be
+/// evaluated outside of the target region per the spec.
+///
+/// In particular, this will ensure that in 'target teams' and equivalent nested
+/// constructs, the \c thread_limit and \c num_teams clauses will be evaluated
+/// in the host. Additionally, loop bounds, steps and the \c num_threads clause
+/// will also be evaluated in the host if a target SPMD construct is detected
+/// (i.e. 'target teams distribute parallel do [simd]' or equivalent nesting).
+///
+/// The result, stored as a global, is intended to be used to populate the \c
+/// host_eval operands of the associated \c omp.target operation, and also to be
+/// checked and used by later lowering steps to populate the corresponding
+/// operands of the \c omp.teams, \c omp.parallel or \c omp.loop_nest
+/// operations.
+static void processHostEvalClauses(lower::AbstractConverter &converter,
+                                   semantics::SemanticsContext &semaCtx,
+                                   lower::StatementContext &stmtCtx,
+                                   lower::pft::Evaluation &eval,
+                                   mlir::Location loc) {
+  // Obtain the list of clauses of the given OpenMP block or loop construct
+  // evaluation. Other evaluations passed to this lambda keep `clauses`
+  // unchanged.
+  auto extractClauses = [&semaCtx](lower::pft::Evaluation &eval,
+                                   List<Clause> &clauses) {
+    const auto *ompEval = eval.getIf<parser::OpenMPConstruct>();
+    if (!ompEval)
+      return;
+
+    const parser::OmpClauseList *beginClauseList = nullptr;
+    const parser::OmpClauseList *endClauseList = nullptr;
+    common::visit(
+        common::visitors{
+            [&](const parser::OpenMPBlockConstruct &ompConstruct) {
+              const auto &beginDirective =
+                  std::get<parser::OmpBeginBlockDirective>(ompConstruct.t);
+              beginClauseList =
+                  &std::get<parser::OmpClauseList>(beginDirective.t);
+              endClauseList = &std::get<parser::OmpClauseList>(
+                  std::get<parser::OmpEndBlockDirective>(ompConstruct.t).t);
+            },
+            [&](const parser::OpenMPLoopConstruct &ompConstruct) {
+              const auto &beginDirective =
+                  std::get<parser::OmpBeginLoopDirective>(ompConstruct.t);
+              beginClauseList =
+                  &std::get<parser::OmpClauseList>(beginDirective.t);
+
+              if (auto &endDirective =
+                      std::get<std::optional<parser::OmpEndLoopDirective>>(
+                          ompConstruct.t))
+                endClauseList =
+                    &std::get<parser::OmpClauseList>(endDirective->t);
+            },
+            [&](const auto &) {}},
+        ompEval->u);
+
+    assert(beginClauseList && "expected begin directive");
+    clauses.append(makeClauses(*beginClauseList, semaCtx));
+
+    if (endClauseList)
+      clauses.append(makeClauses(*endClauseList, semaCtx));
+  };
+
+  // Return the directive that is immediately nested inside of the given
+  // `parent` evaluation, if it is its only non-end-statement nested evaluation
+  // and it represents an OpenMP construct.
+  auto extractOnlyOmpNestedDir = [](lower::pft::Evaluation &parent)
+      -> std::optional<llvm::omp::Directive> {
+    if (!parent.hasNestedEvaluations())
+      return std::nullopt;
+
+    llvm::omp::Directive dir;
+    auto &nested = parent.getFirstNestedEvaluation();
+    if (const auto *ompEval = nested.getIf<parser::OpenMPConstruct>())
+      dir = extractOmpDirective(*ompEval);
+    else
+      return std::nullopt;
+
+    for (auto &sibling : parent.getNestedEvaluations())
+      if (&sibling != &nested && !sibling.isEndStmt())
+        return std::nullopt;
+
+    return dir;
+  };
+
+  // Process the given evaluation assuming it's part of a 'target' construct or
+  // captured by one, and store results in the global `hostEvalInfo`.
+  std::function<void(lower::pft::Evaluation &, const List<Clause> &)>
+      processEval;
+  processEval = [&](lower::pft::Evaluation &eval, const List<Clause> &clauses) {
+    using namespace llvm::omp;
+    ClauseProcessor cp(converter, semaCtx, clauses);
+
+    // Call `processEval` recursively with the immediately nested evaluation and
+    // its corresponding clauses if there is a single nested evaluation
+    // representing an OpenMP directive that passes the given test.
+    auto processSingleNestedIf = [&](llvm::function_ref<bool(Directive)> test) {
+      std::optional<Directive> nestedDir = extractOnlyOmpNestedDir(eval);
+      if (!nestedDir || !test(*nestedDir))
+        return;
+
+      lower::pft::Evaluation &nestedEval = eval.getFirstNestedEvaluation();
+      List<lower::omp::Clause> nestedClauses;
+      extractClauses(nestedEval, nestedClauses);
+      processEval(nestedEval, nestedClauses);
+    };
+
+    const auto *ompEval = eval.getIf<parser::OpenMPConstruct>();
+    if (!ompEval)
+      return;
+
+    HostEvalInfo &hostInfo = hostEvalInfo.back();
+
+    switch (extractOmpDirective(*ompEval)) {
+    // Cases where 'teams' and target SPMD clauses might be present.
+    case OMPD_teams_distribute_parallel_do:
+    case OMPD_teams_distribute_parallel_do_simd:
+      cp.processThreadLimit(stmtCtx, hostInfo.ops);
+      [[fallthrough]];
+    case OMPD_target_teams_distribute_parallel_do:
+    case OMPD_target_teams_distribute_parallel_do_simd:
+      cp.processNumTeams(stmtCtx, hostInfo.ops);
+      [[fallthrough]];
+    case OMPD_distribute_parallel_do:
+    case OMPD_distribute_parallel_do_simd:
+      cp.processCollapse(loc, eval, hostInfo.ops, hostInfo.iv);
+      cp.processNumThreads(stmtCtx, hostInfo.ops);
+      break;
+
+    // Cases where 'teams' clauses might be present, and target SPMD is
+    // possible by looking at nested evaluations.
+    case OMPD_teams:
+      cp.processThreadLimit(stmtCtx, hostInfo.ops);
+      [[fallthrough]];
+    case OMPD_target_teams:
+      cp.processNumTeams(stmtCtx, hostInfo.ops);
+      processSingleNestedIf([](Directive nestedDir) {
+        return nestedDir == OMPD_distribute_parallel_do ||
+               nestedDir == OMPD_distribute_parallel_do_simd;
+      });
+      break;
+
+    // Cases where only 'teams' host-evaluated clauses might be present.
+    case OMPD_teams_distribute:
+    case OMPD_teams_distribute_simd:
+      cp.processThreadLimit(stmtCtx, hostInfo.ops);
+      [[fallthrough]];
+    case OMPD_target_teams_distribute:
+    case OMPD_target_teams_distribute_simd:
+      cp.processNumTeams(stmtCtx, hostInfo.ops);
+      break;
+
+    // Standalone 'target' case.
+    case OMPD_target: {
+      processSingleNestedIf(
+          [](Directive nestedDir) { return topTeamsSet.test(nestedDir); });
+      break;
+    }
+    default:
+      break;
+    }
+  };
+
+  assert(!hostEvalInfo.empty() && "expected HOST_EVAL info structure");
+
+  const auto *ompEval = eval.getIf<parser::OpenMPConstruct>();
+  assert(ompEval &&
+         llvm::omp::allTargetSet.test(extractOmpDirective(*ompEval)) &&
+         "expected TARGET construct evaluation");
+
+  // Use the whole list of clauses passed to the construct here, rather than the
+  // ones only applied to omp.target.
+  List<lower::omp::Clause> clauses;
+  extractClauses(eval, clauses);
+  processEval(eval, clauses);
+}
+
 static lower::pft::Evaluation *
 getCollapsedLoopEval(lower::pft::Evaluation &eval, int collapseValue) {
   // Return the Evaluation of the innermost collapsed loop, or the current one
@@ -638,11 +1022,11 @@ static mlir::Block *genEntryBlock(lower::AbstractConverter &converter,
 
   llvm::SmallVector<mlir::Type> types;
   llvm::SmallVector<mlir::Location> locs;
-  unsigned numVars = args.inReduction.vars.size() + args.map.vars.size() +
-                     args.priv.vars.size() + args.reduction.vars.size() +
-                     args.taskReduction.vars.size() +
-                     args.useDeviceAddr.vars.size() +
-                     args.useDevicePtr.vars.size();
+  unsigned numVars =
+      args.hostEvalVars.size() + args.inReduction.vars.size() +
+      args.map.vars.size() + args.priv.vars.size() +
+      args.reduction.vars.size() + args.taskReduction.vars.size() +
+      args.useDeviceAddr.vars.size() + args.useDevicePtr.vars.size();
   types.reserve(numVars);
   locs.reserve(numVars);
 
@@ -655,6 +1039,7 @@ static mlir::Block *genEntryBlock(lower::AbstractConverter &converter,
 
   // Populate block arguments in clause name alphabetical order to match
   // expected order by the BlockArgOpenMPOpInterface.
+  extractTypeLoc(args.hostEvalVars);
   extractTypeLoc(args.inReduction.vars);
   extractTypeLoc(args.map.vars);
   extractTypeLoc(args.priv.vars);
@@ -991,12 +1376,15 @@ static void genBodyOfTargetOp(
     mlir::omp::TargetOp &targetOp, const EntryBlockArgs &args,
     const mlir::Location &currentLocation, const ConstructQueue &queue,
     ConstructQueue::const_iterator item, DataSharingProcessor &dsp) {
+  assert(!hostEvalInfo.empty() && "expected HOST_EVAL info structure");
+
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
   auto argIface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(*targetOp);
 
   mlir::Region &region = targetOp.getRegion();
   mlir::Block *entryBlock = genEntryBlock(converter, args, region);
   bindEntryBlockArgs(converter, targetOp, args);
+  hostEvalInfo.back().bindOperands(argIface.getHostEvalBlockArgs());
 
   // Check if cloning the bounds introduced any dependency on the outer region.
   // If so, then either clone them as well if they are MemoryEffectFree, or else
@@ -1172,7 +1560,10 @@ genLoopNestClauses(lower::AbstractConverter &converter,
                    mlir::Location loc, ...
[truncated]

@skatrak skatrak force-pushed the users/skatrak/host-eval-05-mlir-llvmir-generic branch from 8ff0d3b to b5571be Compare November 15, 2024 16:27
@skatrak skatrak force-pushed the users/skatrak/host-eval-06-flang branch from 6020805 to 28770c7 Compare November 15, 2024 16:48
@skatrak skatrak force-pushed the users/skatrak/host-eval-05-mlir-llvmir-generic branch from b5571be to 58bd5ff Compare November 27, 2024 12:31
This patch adds support for lowering OpenMP clauses and expressions attached to
constructs nested inside of a target region that need to be evaluated in the
host device. This is done through the use of the `OpenMP_HostEvalClause`
`omp.target` set of operands and entry block arguments.

When lowering clauses for a target construct, a more involved
`processHostEvalClauses()` function is called, which looks at the current and
potentially other nested constructs in order to find and lower clauses that
need to be processed outside of the `omp.target` operation under construction.
This populates an instance of a global structure with the resulting MLIR
values.

The resulting list of host-evaluated values is used to initialize the
`host_eval` operands when constructing the `omp.target` operation, and then
replaced with the corresponding block arguments after creating that operation's
region.

Afterwards, while lowering nested operations, those that might potentially be
evaluated in the host (e.g. `num_teams`, `thread_limit`, `num_threads` and
`collapse`) check first whether there is an active global host-evaluated
information structure and whether it holds values referring to these clauses.
If that is the case, the stored values (referring to `omp.target` entry block
arguments at that stage) are used instead of lowering clauses again.
@skatrak skatrak force-pushed the users/skatrak/host-eval-06-flang branch from 28770c7 to 5817462 Compare November 27, 2024 12:31
Copy link
Contributor

@mjklemm mjklemm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, albeit not being a true expert here. But the general idea looks fine to me.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang:openmp flang Flang issues not falling into any other category mlir:openmp mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants