Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OMPIRBuilder] Introduce struct to hold default kernel teams/threads #116050

Open
wants to merge 1 commit into
base: users/skatrak/host-eval-02-mlir
Choose a base branch
from

Conversation

skatrak
Copy link
Contributor

@skatrak skatrak commented Nov 13, 2024

This patch introduces the OpenMPIRBuilder::TargetKernelDefaultAttrs structure used to simplify passing default and constant values for number of teams and threads, and possibly other target kernel-related information in the future.

This is used to forward values passed to createTarget to createTargetInit, which previously used a default unrelated set of values.

This patch introduces the `OpenMPIRBuilder::TargetKernelDefaultAttrs` structure
used to simplify passing default and constant values for number of teams and
threads, and possibly other target kernel-related information in the future.

This is used to forward values passed to `createTarget` to `createTargetInit`,
which previously used a default unrelated set of values.
@llvmbot
Copy link
Member

llvmbot commented Nov 13, 2024

@llvm/pr-subscribers-clang-codegen
@llvm/pr-subscribers-mlir-openmp
@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-clang

Author: Sergio Afonso (skatrak)

Changes

This patch introduces the OpenMPIRBuilder::TargetKernelDefaultAttrs structure used to simplify passing default and constant values for number of teams and threads, and possibly other target kernel-related information in the future.

This is used to forward values passed to createTarget to createTargetInit, which previously used a default unrelated set of values.


Patch is 21.80 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/116050.diff

8 Files Affected:

  • (modified) clang/lib/CodeGen/CGOpenMPRuntime.cpp (+8-5)
  • (modified) clang/lib/CodeGen/CGOpenMPRuntime.h (+3-6)
  • (modified) clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp (+3-6)
  • (modified) llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h (+25-14)
  • (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+40-31)
  • (modified) llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp (+16-13)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+6-5)
  • (modified) mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir (+1-1)
diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
index d714af035d21a2..0f7a1166227476 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp
+++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
@@ -5880,10 +5880,13 @@ void CGOpenMPRuntime::emitUsesAllocatorsFini(CodeGenFunction &CGF,
 
 void CGOpenMPRuntime::computeMinAndMaxThreadsAndTeams(
     const OMPExecutableDirective &D, CodeGenFunction &CGF,
-    int32_t &MinThreadsVal, int32_t &MaxThreadsVal, int32_t &MinTeamsVal,
-    int32_t &MaxTeamsVal) {
+    llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &Attrs) {
+  assert(Attrs.MaxTeams.size() == 1 && Attrs.MaxThreads.size() == 1 &&
+         "invalid default attrs structure");
+  int32_t &MaxTeamsVal = Attrs.MaxTeams.front();
+  int32_t &MaxThreadsVal = Attrs.MaxThreads.front();
 
-  getNumTeamsExprForTargetDirective(CGF, D, MinTeamsVal, MaxTeamsVal);
+  getNumTeamsExprForTargetDirective(CGF, D, Attrs.MinTeams, MaxTeamsVal);
   getNumThreadsExprForTargetDirective(CGF, D, MaxThreadsVal,
                                       /*UpperBoundOnly=*/true);
 
@@ -5901,12 +5904,12 @@ void CGOpenMPRuntime::computeMinAndMaxThreadsAndTeams(
       else
         continue;
 
-      MinThreadsVal = std::max(MinThreadsVal, AttrMinThreadsVal);
+      Attrs.MinThreads = std::max(Attrs.MinThreads, AttrMinThreadsVal);
       if (AttrMaxThreadsVal > 0)
         MaxThreadsVal = MaxThreadsVal > 0
                             ? std::min(MaxThreadsVal, AttrMaxThreadsVal)
                             : AttrMaxThreadsVal;
-      MinTeamsVal = std::max(MinTeamsVal, AttrMinBlocksVal);
+      Attrs.MinTeams = std::max(Attrs.MinTeams, AttrMinBlocksVal);
       if (AttrMaxBlocksVal > 0)
         MaxTeamsVal = MaxTeamsVal > 0 ? std::min(MaxTeamsVal, AttrMaxBlocksVal)
                                       : AttrMaxBlocksVal;
diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.h b/clang/lib/CodeGen/CGOpenMPRuntime.h
index 5e7715743afb58..003395e7f17ded 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntime.h
+++ b/clang/lib/CodeGen/CGOpenMPRuntime.h
@@ -312,12 +312,9 @@ class CGOpenMPRuntime {
   llvm::OpenMPIRBuilder OMPBuilder;
 
   /// Helper to determine the min/max number of threads/teams for \p D.
-  void computeMinAndMaxThreadsAndTeams(const OMPExecutableDirective &D,
-                                       CodeGenFunction &CGF,
-                                       int32_t &MinThreadsVal,
-                                       int32_t &MaxThreadsVal,
-                                       int32_t &MinTeamsVal,
-                                       int32_t &MaxTeamsVal);
+  void computeMinAndMaxThreadsAndTeams(
+      const OMPExecutableDirective &D, CodeGenFunction &CGF,
+      llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &Attrs);
 
   /// Helper to emit outlined function for 'target' directive.
   /// \param D Directive to emit.
diff --git a/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp b/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
index 43dc0e62284602..96f8d6c5c08e56 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
+++ b/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
@@ -745,14 +745,11 @@ void CGOpenMPRuntimeGPU::emitNonSPMDKernel(const OMPExecutableDirective &D,
 void CGOpenMPRuntimeGPU::emitKernelInit(const OMPExecutableDirective &D,
                                         CodeGenFunction &CGF,
                                         EntryFunctionState &EST, bool IsSPMD) {
-  int32_t MinThreadsVal = 1, MaxThreadsVal = -1, MinTeamsVal = 1,
-          MaxTeamsVal = -1;
-  computeMinAndMaxThreadsAndTeams(D, CGF, MinThreadsVal, MaxThreadsVal,
-                                  MinTeamsVal, MaxTeamsVal);
+  llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs Attrs;
+  computeMinAndMaxThreadsAndTeams(D, CGF, Attrs);
 
   CGBuilderTy &Bld = CGF.Builder;
-  Bld.restoreIP(OMPBuilder.createTargetInit(
-      Bld, IsSPMD, MinThreadsVal, MaxThreadsVal, MinTeamsVal, MaxTeamsVal));
+  Bld.restoreIP(OMPBuilder.createTargetInit(Bld, IsSPMD, Attrs));
   if (!IsSPMD)
     emitGenericVarsProlog(CGF, EST.Loc);
 }
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 3afb9d84278e81..da450ef5adbc14 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -2223,6 +2223,20 @@ class OpenMPIRBuilder {
           MapNamesArray(MapNamesArray) {}
   };
 
+  /// Container to pass the default attributes with which a kernel must be
+  /// launched, used to set kernel attributes and populate associated static
+  /// structures.
+  ///
+  /// For max values, < 0 means unset, == 0 means set but unknown at compile
+  /// time. The number of max values will be 1 except for the case where
+  /// ompx_bare is set.
+  struct TargetKernelDefaultAttrs {
+    SmallVector<int32_t, 3> MaxTeams = {-1};
+    int32_t MinTeams = 1;
+    SmallVector<int32_t, 3> MaxThreads = {-1};
+    int32_t MinThreads = 1;
+  };
+
   /// Data structure that contains the needed information to construct the
   /// kernel args vector.
   struct TargetKernelArgs {
@@ -2726,15 +2740,11 @@ class OpenMPIRBuilder {
   ///
   /// \param Loc The insert and source location description.
   /// \param IsSPMD Flag to indicate if the kernel is an SPMD kernel or not.
-  /// \param MinThreads Minimal number of threads, or 0.
-  /// \param MaxThreads Maximal number of threads, or 0.
-  /// \param MinTeams Minimal number of teams, or 0.
-  /// \param MaxTeams Maximal number of teams, or 0.
-  InsertPointTy createTargetInit(const LocationDescription &Loc, bool IsSPMD,
-                                 int32_t MinThreadsVal = 0,
-                                 int32_t MaxThreadsVal = 0,
-                                 int32_t MinTeamsVal = 0,
-                                 int32_t MaxTeamsVal = 0);
+  /// \param Attrs Structure containing the default numbers of threads and teams
+  ///        to launch the kernel with.
+  InsertPointTy createTargetInit(
+      const LocationDescription &Loc, bool IsSPMD,
+      const llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &Attrs);
 
   /// Create a runtime call for kmpc_target_deinit
   ///
@@ -2898,8 +2908,8 @@ class OpenMPIRBuilder {
   /// \param CodeGenIP The insertion point where the call to the outlined
   /// function should be emitted.
   /// \param EntryInfo The entry information about the function.
-  /// \param NumTeams Number of teams specified in the num_teams clause.
-  /// \param NumThreads Number of teams specified in the thread_limit clause.
+  /// \param DefaultAttrs Structure containing the default numbers of threads
+  ///        and teams to launch the kernel with.
   /// \param Inputs The input values to the region that will be passed.
   /// as arguments to the outlined function.
   /// \param BodyGenCB Callback that will generate the region code.
@@ -2912,9 +2922,10 @@ class OpenMPIRBuilder {
       const LocationDescription &Loc, bool IsOffloadEntry,
       OpenMPIRBuilder::InsertPointTy AllocaIP,
       OpenMPIRBuilder::InsertPointTy CodeGenIP,
-      TargetRegionEntryInfo &EntryInfo, ArrayRef<int32_t> NumTeams,
-      ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Inputs,
-      GenMapInfoCallbackTy GenMapInfoCB, TargetBodyGenCallbackTy BodyGenCB,
+      TargetRegionEntryInfo &EntryInfo,
+      const TargetKernelDefaultAttrs &DefaultAttrs,
+      SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
+      TargetBodyGenCallbackTy BodyGenCB,
       TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
       SmallVector<DependData> Dependencies = {}, bool HasNowait = false);
 
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index d2e4dc1c85dfd2..302d363965c940 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -6109,10 +6109,12 @@ CallInst *OpenMPIRBuilder::createCachedThreadPrivate(
   return Builder.CreateCall(Fn, Args);
 }
 
-OpenMPIRBuilder::InsertPointTy
-OpenMPIRBuilder::createTargetInit(const LocationDescription &Loc, bool IsSPMD,
-                                  int32_t MinThreadsVal, int32_t MaxThreadsVal,
-                                  int32_t MinTeamsVal, int32_t MaxTeamsVal) {
+OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetInit(
+    const LocationDescription &Loc, bool IsSPMD,
+    const llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &Attrs) {
+  assert(!Attrs.MaxThreads.empty() && !Attrs.MaxTeams.empty() &&
+         "expected num_threads and num_teams to be specified");
+
   if (!updateToLocation(Loc))
     return Loc.IP;
 
@@ -6139,21 +6141,23 @@ OpenMPIRBuilder::createTargetInit(const LocationDescription &Loc, bool IsSPMD,
 
   // Manifest the launch configuration in the metadata matching the kernel
   // environment.
-  if (MinTeamsVal > 1 || MaxTeamsVal > 0)
-    writeTeamsForKernel(T, *Kernel, MinTeamsVal, MaxTeamsVal);
+  if (Attrs.MinTeams > 1 || Attrs.MaxTeams.front() > 0)
+    writeTeamsForKernel(T, *Kernel, Attrs.MinTeams, Attrs.MaxTeams.front());
 
-  // For max values, < 0 means unset, == 0 means set but unknown.
+  // If MaxThreads not set, select the maximum between the default workgroup
+  // size and the MinThreads value.
+  int32_t MaxThreadsVal = Attrs.MaxThreads.front();
   if (MaxThreadsVal < 0)
     MaxThreadsVal = std::max(
-        int32_t(getGridValue(T, Kernel).GV_Default_WG_Size), MinThreadsVal);
+        int32_t(getGridValue(T, Kernel).GV_Default_WG_Size), Attrs.MinThreads);
 
   if (MaxThreadsVal > 0)
-    writeThreadBoundsForKernel(T, *Kernel, MinThreadsVal, MaxThreadsVal);
+    writeThreadBoundsForKernel(T, *Kernel, Attrs.MinThreads, MaxThreadsVal);
 
-  Constant *MinThreads = ConstantInt::getSigned(Int32, MinThreadsVal);
+  Constant *MinThreads = ConstantInt::getSigned(Int32, Attrs.MinThreads);
   Constant *MaxThreads = ConstantInt::getSigned(Int32, MaxThreadsVal);
-  Constant *MinTeams = ConstantInt::getSigned(Int32, MinTeamsVal);
-  Constant *MaxTeams = ConstantInt::getSigned(Int32, MaxTeamsVal);
+  Constant *MinTeams = ConstantInt::getSigned(Int32, Attrs.MinTeams);
+  Constant *MaxTeams = ConstantInt::getSigned(Int32, Attrs.MaxTeams.front());
   Constant *ReductionDataSize = ConstantInt::getSigned(Int32, 0);
   Constant *ReductionBufferLength = ConstantInt::getSigned(Int32, 0);
 
@@ -6724,8 +6728,9 @@ FunctionCallee OpenMPIRBuilder::createDispatchDeinitFunction() {
 }
 
 static Expected<Function *> createOutlinedFunction(
-    OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, StringRef FuncName,
-    SmallVectorImpl<Value *> &Inputs,
+    OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
+    const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
+    StringRef FuncName, SmallVectorImpl<Value *> &Inputs,
     OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
     OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) {
   SmallVector<Type *> ParameterTypes;
@@ -6792,7 +6797,8 @@ static Expected<Function *> createOutlinedFunction(
 
   // Insert target init call in the device compilation pass.
   if (OMPBuilder.Config.isTargetDevice())
-    Builder.restoreIP(OMPBuilder.createTargetInit(Builder, /*IsSPMD*/ false));
+    Builder.restoreIP(
+        OMPBuilder.createTargetInit(Builder, /*IsSPMD=*/false, DefaultAttrs));
 
   BasicBlock *UserCodeEntryBB = Builder.GetInsertBlock();
 
@@ -6989,16 +6995,18 @@ static Function *emitTargetTaskProxyFunction(OpenMPIRBuilder &OMPBuilder,
 
 static Error emitTargetOutlinedFunction(
     OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsOffloadEntry,
-    TargetRegionEntryInfo &EntryInfo, Function *&OutlinedFn,
-    Constant *&OutlinedFnID, SmallVectorImpl<Value *> &Inputs,
+    TargetRegionEntryInfo &EntryInfo,
+    const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
+    Function *&OutlinedFn, Constant *&OutlinedFnID,
+    SmallVectorImpl<Value *> &Inputs,
     OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
     OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) {
 
   OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction =
-      [&OMPBuilder, &Builder, &Inputs, &CBFunc,
-       &ArgAccessorFuncCB](StringRef EntryFnName) {
-        return createOutlinedFunction(OMPBuilder, Builder, EntryFnName, Inputs,
-                                      CBFunc, ArgAccessorFuncCB);
+      [&](StringRef EntryFnName) {
+        return createOutlinedFunction(OMPBuilder, Builder, DefaultAttrs,
+                                      EntryFnName, Inputs, CBFunc,
+                                      ArgAccessorFuncCB);
       };
 
   return OMPBuilder.emitTargetRegionFunction(
@@ -7294,9 +7302,10 @@ void OpenMPIRBuilder::emitOffloadingArraysAndArgs(
 
 static void
 emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
-               OpenMPIRBuilder::InsertPointTy AllocaIP, Function *OutlinedFn,
-               Constant *OutlinedFnID, ArrayRef<int32_t> NumTeams,
-               ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Args,
+               OpenMPIRBuilder::InsertPointTy AllocaIP,
+               const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
+               Function *OutlinedFn, Constant *OutlinedFnID,
+               SmallVectorImpl<Value *> &Args,
                OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
                SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {},
                bool HasNoWait = false) {
@@ -7377,9 +7386,9 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
 
   SmallVector<Value *, 3> NumTeamsC;
   SmallVector<Value *, 3> NumThreadsC;
-  for (auto V : NumTeams)
+  for (auto V : DefaultAttrs.MaxTeams)
     NumTeamsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
-  for (auto V : NumThreads)
+  for (auto V : DefaultAttrs.MaxThreads)
     NumThreadsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
 
   unsigned NumTargetItems = Info.NumberOfPtrs;
@@ -7420,7 +7429,7 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
     const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
     InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
-    ArrayRef<int32_t> NumTeams, ArrayRef<int32_t> NumThreads,
+    const TargetKernelDefaultAttrs &DefaultAttrs,
     SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
     OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
     OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
@@ -7437,16 +7446,16 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
   // the target region itself is generated using the callbacks CBFunc
   // and ArgAccessorFuncCB
   if (Error Err = emitTargetOutlinedFunction(
-          *this, Builder, IsOffloadEntry, EntryInfo, OutlinedFn, OutlinedFnID,
-          Args, CBFunc, ArgAccessorFuncCB))
+          *this, Builder, IsOffloadEntry, EntryInfo, DefaultAttrs, OutlinedFn,
+          OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB))
     return Err;
 
   // If we are not on the target device, then we need to generate code
   // to make a remote call (offload) to the previously outlined function
   // that represents the target region. Do that now.
   if (!Config.isTargetDevice())
-    emitTargetCall(*this, Builder, AllocaIP, OutlinedFn, OutlinedFnID, NumTeams,
-                   NumThreads, Args, GenMapInfoCB, Dependencies, HasNowait);
+    emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, OutlinedFn,
+                   OutlinedFnID, Args, GenMapInfoCB, Dependencies, HasNowait);
   return Builder.saveIP();
 }
 
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index 630cd03c688012..b0688d6215e42d 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -6182,9 +6182,12 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
 
   TargetRegionEntryInfo EntryInfo("func", 42, 4711, 17);
   OpenMPIRBuilder::LocationDescription OmpLoc({Builder.saveIP(), DL});
-  OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
-      OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(), Builder.saveIP(),
-      EntryInfo, -1, 0, Inputs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
+  OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
+      /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+  OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
+      OMPBuilder.createTarget(OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(),
+                              Builder.saveIP(), EntryInfo, DefaultAttrs, Inputs,
+                              GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
   assert(AfterIP && "unexpected error");
   Builder.restoreIP(*AfterIP);
   OMPBuilder.finalize();
@@ -6292,11 +6295,11 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
   TargetRegionEntryInfo EntryInfo("parent", /*DeviceID=*/1, /*FileID=*/2,
                                   /*Line=*/3, /*Count=*/0);
 
-  OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
-      OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
-                              EntryInfo, /*NumTeams=*/-1,
-                              /*NumThreads=*/0, CapturedArgs, GenMapInfoCB,
-                              BodyGenCB, SimpleArgAccessorCB);
+  OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
+      /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+  OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
+      Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, EntryInfo, DefaultAttrs,
+      CapturedArgs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
   assert(AfterIP && "unexpected error");
   Builder.restoreIP(*AfterIP);
 
@@ -6443,11 +6446,11 @@ TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) {
   TargetRegionEntryInfo EntryInfo("parent", /*DeviceID=*/1, /*FileID=*/2,
                                   /*Line=*/3, /*Count=*/0);
 
-  OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
-      OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
-                              EntryInfo, /*NumTeams=*/-1,
-                              /*NumThreads=*/0, CapturedArgs, GenMapInfoCB,
-                              BodyGenCB, SimpleArgAccessorCB);
+  OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
+      /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+  OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
+      Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, EntryInfo, DefaultAttrs,
+      CapturedArgs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
   assert(AfterIP && "unexpected error");
   Builder.restoreIP(*AfterIP);
 
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index cbcbeea4ab9225..d3c3839accb7e7 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -3902,9 +3902,6 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
   if (!getTargetEntryUniqueInfo(entryInfo, targetOp, parentName))
     return failure();
 
-  int32_t defaultValTeams = -1;
-  int32_t defaultValThreads = 0;
-
   llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
       findAllocaInsertPoint(builder, moduleTranslation);
 
@@ -3939,6 +3936,10 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
                                         allocaIP, codeGenIP);
   };
 
+  // TODO: Populate default attributes based on the construct and clauses.
+  llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs = {
+      /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+
   llvm::SmallVector<llvm::Value *, 4> kernelInput;
   for (size_t i = 0; i < mapVars.size(); ++i) {
     // declare target arguments are not passed to kernels as arguments
@@ -3957,8 +3958,8 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
   llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
       moduleTranslation.getOpenMPBuilder()->createTarget(
           ompLoc, isOffloadEntry, alloca...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Nov 13, 2024

@llvm/pr-subscribers-flang-openmp

Author: Sergio Afonso (skatrak)

Changes

This patch introduces the OpenMPIRBuilder::TargetKernelDefaultAttrs structure used to simplify passing default and constant values for number of teams and threads, and possibly other target kernel-related information in the future.

This is used to forward values passed to createTarget to createTargetInit, which previously used a default unrelated set of values.


Patch is 21.80 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/116050.diff

8 Files Affected:

  • (modified) clang/lib/CodeGen/CGOpenMPRuntime.cpp (+8-5)
  • (modified) clang/lib/CodeGen/CGOpenMPRuntime.h (+3-6)
  • (modified) clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp (+3-6)
  • (modified) llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h (+25-14)
  • (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+40-31)
  • (modified) llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp (+16-13)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+6-5)
  • (modified) mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir (+1-1)
diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
index d714af035d21a2..0f7a1166227476 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp
+++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
@@ -5880,10 +5880,13 @@ void CGOpenMPRuntime::emitUsesAllocatorsFini(CodeGenFunction &CGF,
 
 void CGOpenMPRuntime::computeMinAndMaxThreadsAndTeams(
     const OMPExecutableDirective &D, CodeGenFunction &CGF,
-    int32_t &MinThreadsVal, int32_t &MaxThreadsVal, int32_t &MinTeamsVal,
-    int32_t &MaxTeamsVal) {
+    llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &Attrs) {
+  assert(Attrs.MaxTeams.size() == 1 && Attrs.MaxThreads.size() == 1 &&
+         "invalid default attrs structure");
+  int32_t &MaxTeamsVal = Attrs.MaxTeams.front();
+  int32_t &MaxThreadsVal = Attrs.MaxThreads.front();
 
-  getNumTeamsExprForTargetDirective(CGF, D, MinTeamsVal, MaxTeamsVal);
+  getNumTeamsExprForTargetDirective(CGF, D, Attrs.MinTeams, MaxTeamsVal);
   getNumThreadsExprForTargetDirective(CGF, D, MaxThreadsVal,
                                       /*UpperBoundOnly=*/true);
 
@@ -5901,12 +5904,12 @@ void CGOpenMPRuntime::computeMinAndMaxThreadsAndTeams(
       else
         continue;
 
-      MinThreadsVal = std::max(MinThreadsVal, AttrMinThreadsVal);
+      Attrs.MinThreads = std::max(Attrs.MinThreads, AttrMinThreadsVal);
       if (AttrMaxThreadsVal > 0)
         MaxThreadsVal = MaxThreadsVal > 0
                             ? std::min(MaxThreadsVal, AttrMaxThreadsVal)
                             : AttrMaxThreadsVal;
-      MinTeamsVal = std::max(MinTeamsVal, AttrMinBlocksVal);
+      Attrs.MinTeams = std::max(Attrs.MinTeams, AttrMinBlocksVal);
       if (AttrMaxBlocksVal > 0)
         MaxTeamsVal = MaxTeamsVal > 0 ? std::min(MaxTeamsVal, AttrMaxBlocksVal)
                                       : AttrMaxBlocksVal;
diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.h b/clang/lib/CodeGen/CGOpenMPRuntime.h
index 5e7715743afb58..003395e7f17ded 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntime.h
+++ b/clang/lib/CodeGen/CGOpenMPRuntime.h
@@ -312,12 +312,9 @@ class CGOpenMPRuntime {
   llvm::OpenMPIRBuilder OMPBuilder;
 
   /// Helper to determine the min/max number of threads/teams for \p D.
-  void computeMinAndMaxThreadsAndTeams(const OMPExecutableDirective &D,
-                                       CodeGenFunction &CGF,
-                                       int32_t &MinThreadsVal,
-                                       int32_t &MaxThreadsVal,
-                                       int32_t &MinTeamsVal,
-                                       int32_t &MaxTeamsVal);
+  void computeMinAndMaxThreadsAndTeams(
+      const OMPExecutableDirective &D, CodeGenFunction &CGF,
+      llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &Attrs);
 
   /// Helper to emit outlined function for 'target' directive.
   /// \param D Directive to emit.
diff --git a/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp b/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
index 43dc0e62284602..96f8d6c5c08e56 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
+++ b/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
@@ -745,14 +745,11 @@ void CGOpenMPRuntimeGPU::emitNonSPMDKernel(const OMPExecutableDirective &D,
 void CGOpenMPRuntimeGPU::emitKernelInit(const OMPExecutableDirective &D,
                                         CodeGenFunction &CGF,
                                         EntryFunctionState &EST, bool IsSPMD) {
-  int32_t MinThreadsVal = 1, MaxThreadsVal = -1, MinTeamsVal = 1,
-          MaxTeamsVal = -1;
-  computeMinAndMaxThreadsAndTeams(D, CGF, MinThreadsVal, MaxThreadsVal,
-                                  MinTeamsVal, MaxTeamsVal);
+  llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs Attrs;
+  computeMinAndMaxThreadsAndTeams(D, CGF, Attrs);
 
   CGBuilderTy &Bld = CGF.Builder;
-  Bld.restoreIP(OMPBuilder.createTargetInit(
-      Bld, IsSPMD, MinThreadsVal, MaxThreadsVal, MinTeamsVal, MaxTeamsVal));
+  Bld.restoreIP(OMPBuilder.createTargetInit(Bld, IsSPMD, Attrs));
   if (!IsSPMD)
     emitGenericVarsProlog(CGF, EST.Loc);
 }
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 3afb9d84278e81..da450ef5adbc14 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -2223,6 +2223,20 @@ class OpenMPIRBuilder {
           MapNamesArray(MapNamesArray) {}
   };
 
+  /// Container to pass the default attributes with which a kernel must be
+  /// launched, used to set kernel attributes and populate associated static
+  /// structures.
+  ///
+  /// For max values, < 0 means unset, == 0 means set but unknown at compile
+  /// time. The number of max values will be 1 except for the case where
+  /// ompx_bare is set.
+  struct TargetKernelDefaultAttrs {
+    SmallVector<int32_t, 3> MaxTeams = {-1};
+    int32_t MinTeams = 1;
+    SmallVector<int32_t, 3> MaxThreads = {-1};
+    int32_t MinThreads = 1;
+  };
+
   /// Data structure that contains the needed information to construct the
   /// kernel args vector.
   struct TargetKernelArgs {
@@ -2726,15 +2740,11 @@ class OpenMPIRBuilder {
   ///
   /// \param Loc The insert and source location description.
   /// \param IsSPMD Flag to indicate if the kernel is an SPMD kernel or not.
-  /// \param MinThreads Minimal number of threads, or 0.
-  /// \param MaxThreads Maximal number of threads, or 0.
-  /// \param MinTeams Minimal number of teams, or 0.
-  /// \param MaxTeams Maximal number of teams, or 0.
-  InsertPointTy createTargetInit(const LocationDescription &Loc, bool IsSPMD,
-                                 int32_t MinThreadsVal = 0,
-                                 int32_t MaxThreadsVal = 0,
-                                 int32_t MinTeamsVal = 0,
-                                 int32_t MaxTeamsVal = 0);
+  /// \param Attrs Structure containing the default numbers of threads and teams
+  ///        to launch the kernel with.
+  InsertPointTy createTargetInit(
+      const LocationDescription &Loc, bool IsSPMD,
+      const llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &Attrs);
 
   /// Create a runtime call for kmpc_target_deinit
   ///
@@ -2898,8 +2908,8 @@ class OpenMPIRBuilder {
   /// \param CodeGenIP The insertion point where the call to the outlined
   /// function should be emitted.
   /// \param EntryInfo The entry information about the function.
-  /// \param NumTeams Number of teams specified in the num_teams clause.
-  /// \param NumThreads Number of teams specified in the thread_limit clause.
+  /// \param DefaultAttrs Structure containing the default numbers of threads
+  ///        and teams to launch the kernel with.
   /// \param Inputs The input values to the region that will be passed.
   /// as arguments to the outlined function.
   /// \param BodyGenCB Callback that will generate the region code.
@@ -2912,9 +2922,10 @@ class OpenMPIRBuilder {
       const LocationDescription &Loc, bool IsOffloadEntry,
       OpenMPIRBuilder::InsertPointTy AllocaIP,
       OpenMPIRBuilder::InsertPointTy CodeGenIP,
-      TargetRegionEntryInfo &EntryInfo, ArrayRef<int32_t> NumTeams,
-      ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Inputs,
-      GenMapInfoCallbackTy GenMapInfoCB, TargetBodyGenCallbackTy BodyGenCB,
+      TargetRegionEntryInfo &EntryInfo,
+      const TargetKernelDefaultAttrs &DefaultAttrs,
+      SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
+      TargetBodyGenCallbackTy BodyGenCB,
       TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
       SmallVector<DependData> Dependencies = {}, bool HasNowait = false);
 
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index d2e4dc1c85dfd2..302d363965c940 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -6109,10 +6109,12 @@ CallInst *OpenMPIRBuilder::createCachedThreadPrivate(
   return Builder.CreateCall(Fn, Args);
 }
 
-OpenMPIRBuilder::InsertPointTy
-OpenMPIRBuilder::createTargetInit(const LocationDescription &Loc, bool IsSPMD,
-                                  int32_t MinThreadsVal, int32_t MaxThreadsVal,
-                                  int32_t MinTeamsVal, int32_t MaxTeamsVal) {
+OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetInit(
+    const LocationDescription &Loc, bool IsSPMD,
+    const llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &Attrs) {
+  assert(!Attrs.MaxThreads.empty() && !Attrs.MaxTeams.empty() &&
+         "expected num_threads and num_teams to be specified");
+
   if (!updateToLocation(Loc))
     return Loc.IP;
 
@@ -6139,21 +6141,23 @@ OpenMPIRBuilder::createTargetInit(const LocationDescription &Loc, bool IsSPMD,
 
   // Manifest the launch configuration in the metadata matching the kernel
   // environment.
-  if (MinTeamsVal > 1 || MaxTeamsVal > 0)
-    writeTeamsForKernel(T, *Kernel, MinTeamsVal, MaxTeamsVal);
+  if (Attrs.MinTeams > 1 || Attrs.MaxTeams.front() > 0)
+    writeTeamsForKernel(T, *Kernel, Attrs.MinTeams, Attrs.MaxTeams.front());
 
-  // For max values, < 0 means unset, == 0 means set but unknown.
+  // If MaxThreads not set, select the maximum between the default workgroup
+  // size and the MinThreads value.
+  int32_t MaxThreadsVal = Attrs.MaxThreads.front();
   if (MaxThreadsVal < 0)
     MaxThreadsVal = std::max(
-        int32_t(getGridValue(T, Kernel).GV_Default_WG_Size), MinThreadsVal);
+        int32_t(getGridValue(T, Kernel).GV_Default_WG_Size), Attrs.MinThreads);
 
   if (MaxThreadsVal > 0)
-    writeThreadBoundsForKernel(T, *Kernel, MinThreadsVal, MaxThreadsVal);
+    writeThreadBoundsForKernel(T, *Kernel, Attrs.MinThreads, MaxThreadsVal);
 
-  Constant *MinThreads = ConstantInt::getSigned(Int32, MinThreadsVal);
+  Constant *MinThreads = ConstantInt::getSigned(Int32, Attrs.MinThreads);
   Constant *MaxThreads = ConstantInt::getSigned(Int32, MaxThreadsVal);
-  Constant *MinTeams = ConstantInt::getSigned(Int32, MinTeamsVal);
-  Constant *MaxTeams = ConstantInt::getSigned(Int32, MaxTeamsVal);
+  Constant *MinTeams = ConstantInt::getSigned(Int32, Attrs.MinTeams);
+  Constant *MaxTeams = ConstantInt::getSigned(Int32, Attrs.MaxTeams.front());
   Constant *ReductionDataSize = ConstantInt::getSigned(Int32, 0);
   Constant *ReductionBufferLength = ConstantInt::getSigned(Int32, 0);
 
@@ -6724,8 +6728,9 @@ FunctionCallee OpenMPIRBuilder::createDispatchDeinitFunction() {
 }
 
 static Expected<Function *> createOutlinedFunction(
-    OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, StringRef FuncName,
-    SmallVectorImpl<Value *> &Inputs,
+    OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
+    const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
+    StringRef FuncName, SmallVectorImpl<Value *> &Inputs,
     OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
     OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) {
   SmallVector<Type *> ParameterTypes;
@@ -6792,7 +6797,8 @@ static Expected<Function *> createOutlinedFunction(
 
   // Insert target init call in the device compilation pass.
   if (OMPBuilder.Config.isTargetDevice())
-    Builder.restoreIP(OMPBuilder.createTargetInit(Builder, /*IsSPMD*/ false));
+    Builder.restoreIP(
+        OMPBuilder.createTargetInit(Builder, /*IsSPMD=*/false, DefaultAttrs));
 
   BasicBlock *UserCodeEntryBB = Builder.GetInsertBlock();
 
@@ -6989,16 +6995,18 @@ static Function *emitTargetTaskProxyFunction(OpenMPIRBuilder &OMPBuilder,
 
 static Error emitTargetOutlinedFunction(
     OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsOffloadEntry,
-    TargetRegionEntryInfo &EntryInfo, Function *&OutlinedFn,
-    Constant *&OutlinedFnID, SmallVectorImpl<Value *> &Inputs,
+    TargetRegionEntryInfo &EntryInfo,
+    const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
+    Function *&OutlinedFn, Constant *&OutlinedFnID,
+    SmallVectorImpl<Value *> &Inputs,
     OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
     OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) {
 
   OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction =
-      [&OMPBuilder, &Builder, &Inputs, &CBFunc,
-       &ArgAccessorFuncCB](StringRef EntryFnName) {
-        return createOutlinedFunction(OMPBuilder, Builder, EntryFnName, Inputs,
-                                      CBFunc, ArgAccessorFuncCB);
+      [&](StringRef EntryFnName) {
+        return createOutlinedFunction(OMPBuilder, Builder, DefaultAttrs,
+                                      EntryFnName, Inputs, CBFunc,
+                                      ArgAccessorFuncCB);
       };
 
   return OMPBuilder.emitTargetRegionFunction(
@@ -7294,9 +7302,10 @@ void OpenMPIRBuilder::emitOffloadingArraysAndArgs(
 
 static void
 emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
-               OpenMPIRBuilder::InsertPointTy AllocaIP, Function *OutlinedFn,
-               Constant *OutlinedFnID, ArrayRef<int32_t> NumTeams,
-               ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Args,
+               OpenMPIRBuilder::InsertPointTy AllocaIP,
+               const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
+               Function *OutlinedFn, Constant *OutlinedFnID,
+               SmallVectorImpl<Value *> &Args,
                OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
                SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {},
                bool HasNoWait = false) {
@@ -7377,9 +7386,9 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
 
   SmallVector<Value *, 3> NumTeamsC;
   SmallVector<Value *, 3> NumThreadsC;
-  for (auto V : NumTeams)
+  for (auto V : DefaultAttrs.MaxTeams)
     NumTeamsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
-  for (auto V : NumThreads)
+  for (auto V : DefaultAttrs.MaxThreads)
     NumThreadsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
 
   unsigned NumTargetItems = Info.NumberOfPtrs;
@@ -7420,7 +7429,7 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
     const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
     InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
-    ArrayRef<int32_t> NumTeams, ArrayRef<int32_t> NumThreads,
+    const TargetKernelDefaultAttrs &DefaultAttrs,
     SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
     OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
     OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
@@ -7437,16 +7446,16 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
   // the target region itself is generated using the callbacks CBFunc
   // and ArgAccessorFuncCB
   if (Error Err = emitTargetOutlinedFunction(
-          *this, Builder, IsOffloadEntry, EntryInfo, OutlinedFn, OutlinedFnID,
-          Args, CBFunc, ArgAccessorFuncCB))
+          *this, Builder, IsOffloadEntry, EntryInfo, DefaultAttrs, OutlinedFn,
+          OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB))
     return Err;
 
   // If we are not on the target device, then we need to generate code
   // to make a remote call (offload) to the previously outlined function
   // that represents the target region. Do that now.
   if (!Config.isTargetDevice())
-    emitTargetCall(*this, Builder, AllocaIP, OutlinedFn, OutlinedFnID, NumTeams,
-                   NumThreads, Args, GenMapInfoCB, Dependencies, HasNowait);
+    emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, OutlinedFn,
+                   OutlinedFnID, Args, GenMapInfoCB, Dependencies, HasNowait);
   return Builder.saveIP();
 }
 
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index 630cd03c688012..b0688d6215e42d 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -6182,9 +6182,12 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
 
   TargetRegionEntryInfo EntryInfo("func", 42, 4711, 17);
   OpenMPIRBuilder::LocationDescription OmpLoc({Builder.saveIP(), DL});
-  OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
-      OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(), Builder.saveIP(),
-      EntryInfo, -1, 0, Inputs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
+  OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
+      /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+  OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
+      OMPBuilder.createTarget(OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(),
+                              Builder.saveIP(), EntryInfo, DefaultAttrs, Inputs,
+                              GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
   assert(AfterIP && "unexpected error");
   Builder.restoreIP(*AfterIP);
   OMPBuilder.finalize();
@@ -6292,11 +6295,11 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
   TargetRegionEntryInfo EntryInfo("parent", /*DeviceID=*/1, /*FileID=*/2,
                                   /*Line=*/3, /*Count=*/0);
 
-  OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
-      OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
-                              EntryInfo, /*NumTeams=*/-1,
-                              /*NumThreads=*/0, CapturedArgs, GenMapInfoCB,
-                              BodyGenCB, SimpleArgAccessorCB);
+  OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
+      /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+  OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
+      Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, EntryInfo, DefaultAttrs,
+      CapturedArgs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
   assert(AfterIP && "unexpected error");
   Builder.restoreIP(*AfterIP);
 
@@ -6443,11 +6446,11 @@ TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) {
   TargetRegionEntryInfo EntryInfo("parent", /*DeviceID=*/1, /*FileID=*/2,
                                   /*Line=*/3, /*Count=*/0);
 
-  OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
-      OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
-                              EntryInfo, /*NumTeams=*/-1,
-                              /*NumThreads=*/0, CapturedArgs, GenMapInfoCB,
-                              BodyGenCB, SimpleArgAccessorCB);
+  OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
+      /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+  OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
+      Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, EntryInfo, DefaultAttrs,
+      CapturedArgs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
   assert(AfterIP && "unexpected error");
   Builder.restoreIP(*AfterIP);
 
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index cbcbeea4ab9225..d3c3839accb7e7 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -3902,9 +3902,6 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
   if (!getTargetEntryUniqueInfo(entryInfo, targetOp, parentName))
     return failure();
 
-  int32_t defaultValTeams = -1;
-  int32_t defaultValThreads = 0;
-
   llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
       findAllocaInsertPoint(builder, moduleTranslation);
 
@@ -3939,6 +3936,10 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
                                         allocaIP, codeGenIP);
   };
 
+  // TODO: Populate default attributes based on the construct and clauses.
+  llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs = {
+      /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+
   llvm::SmallVector<llvm::Value *, 4> kernelInput;
   for (size_t i = 0; i < mapVars.size(); ++i) {
     // declare target arguments are not passed to kernels as arguments
@@ -3957,8 +3958,8 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
   llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
       moduleTranslation.getOpenMPBuilder()->createTarget(
           ompLoc, isOffloadEntry, alloca...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Nov 13, 2024

@llvm/pr-subscribers-mlir

Author: Sergio Afonso (skatrak)

Changes

This patch introduces the OpenMPIRBuilder::TargetKernelDefaultAttrs structure used to simplify passing default and constant values for number of teams and threads, and possibly other target kernel-related information in the future.

This is used to forward values passed to createTarget to createTargetInit, which previously used a default unrelated set of values.


Patch is 21.80 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/116050.diff

8 Files Affected:

  • (modified) clang/lib/CodeGen/CGOpenMPRuntime.cpp (+8-5)
  • (modified) clang/lib/CodeGen/CGOpenMPRuntime.h (+3-6)
  • (modified) clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp (+3-6)
  • (modified) llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h (+25-14)
  • (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+40-31)
  • (modified) llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp (+16-13)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+6-5)
  • (modified) mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir (+1-1)
diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
index d714af035d21a2..0f7a1166227476 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp
+++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
@@ -5880,10 +5880,13 @@ void CGOpenMPRuntime::emitUsesAllocatorsFini(CodeGenFunction &CGF,
 
 void CGOpenMPRuntime::computeMinAndMaxThreadsAndTeams(
     const OMPExecutableDirective &D, CodeGenFunction &CGF,
-    int32_t &MinThreadsVal, int32_t &MaxThreadsVal, int32_t &MinTeamsVal,
-    int32_t &MaxTeamsVal) {
+    llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &Attrs) {
+  assert(Attrs.MaxTeams.size() == 1 && Attrs.MaxThreads.size() == 1 &&
+         "invalid default attrs structure");
+  int32_t &MaxTeamsVal = Attrs.MaxTeams.front();
+  int32_t &MaxThreadsVal = Attrs.MaxThreads.front();
 
-  getNumTeamsExprForTargetDirective(CGF, D, MinTeamsVal, MaxTeamsVal);
+  getNumTeamsExprForTargetDirective(CGF, D, Attrs.MinTeams, MaxTeamsVal);
   getNumThreadsExprForTargetDirective(CGF, D, MaxThreadsVal,
                                       /*UpperBoundOnly=*/true);
 
@@ -5901,12 +5904,12 @@ void CGOpenMPRuntime::computeMinAndMaxThreadsAndTeams(
       else
         continue;
 
-      MinThreadsVal = std::max(MinThreadsVal, AttrMinThreadsVal);
+      Attrs.MinThreads = std::max(Attrs.MinThreads, AttrMinThreadsVal);
       if (AttrMaxThreadsVal > 0)
         MaxThreadsVal = MaxThreadsVal > 0
                             ? std::min(MaxThreadsVal, AttrMaxThreadsVal)
                             : AttrMaxThreadsVal;
-      MinTeamsVal = std::max(MinTeamsVal, AttrMinBlocksVal);
+      Attrs.MinTeams = std::max(Attrs.MinTeams, AttrMinBlocksVal);
       if (AttrMaxBlocksVal > 0)
         MaxTeamsVal = MaxTeamsVal > 0 ? std::min(MaxTeamsVal, AttrMaxBlocksVal)
                                       : AttrMaxBlocksVal;
diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.h b/clang/lib/CodeGen/CGOpenMPRuntime.h
index 5e7715743afb58..003395e7f17ded 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntime.h
+++ b/clang/lib/CodeGen/CGOpenMPRuntime.h
@@ -312,12 +312,9 @@ class CGOpenMPRuntime {
   llvm::OpenMPIRBuilder OMPBuilder;
 
   /// Helper to determine the min/max number of threads/teams for \p D.
-  void computeMinAndMaxThreadsAndTeams(const OMPExecutableDirective &D,
-                                       CodeGenFunction &CGF,
-                                       int32_t &MinThreadsVal,
-                                       int32_t &MaxThreadsVal,
-                                       int32_t &MinTeamsVal,
-                                       int32_t &MaxTeamsVal);
+  void computeMinAndMaxThreadsAndTeams(
+      const OMPExecutableDirective &D, CodeGenFunction &CGF,
+      llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &Attrs);
 
   /// Helper to emit outlined function for 'target' directive.
   /// \param D Directive to emit.
diff --git a/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp b/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
index 43dc0e62284602..96f8d6c5c08e56 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
+++ b/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
@@ -745,14 +745,11 @@ void CGOpenMPRuntimeGPU::emitNonSPMDKernel(const OMPExecutableDirective &D,
 void CGOpenMPRuntimeGPU::emitKernelInit(const OMPExecutableDirective &D,
                                         CodeGenFunction &CGF,
                                         EntryFunctionState &EST, bool IsSPMD) {
-  int32_t MinThreadsVal = 1, MaxThreadsVal = -1, MinTeamsVal = 1,
-          MaxTeamsVal = -1;
-  computeMinAndMaxThreadsAndTeams(D, CGF, MinThreadsVal, MaxThreadsVal,
-                                  MinTeamsVal, MaxTeamsVal);
+  llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs Attrs;
+  computeMinAndMaxThreadsAndTeams(D, CGF, Attrs);
 
   CGBuilderTy &Bld = CGF.Builder;
-  Bld.restoreIP(OMPBuilder.createTargetInit(
-      Bld, IsSPMD, MinThreadsVal, MaxThreadsVal, MinTeamsVal, MaxTeamsVal));
+  Bld.restoreIP(OMPBuilder.createTargetInit(Bld, IsSPMD, Attrs));
   if (!IsSPMD)
     emitGenericVarsProlog(CGF, EST.Loc);
 }
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 3afb9d84278e81..da450ef5adbc14 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -2223,6 +2223,20 @@ class OpenMPIRBuilder {
           MapNamesArray(MapNamesArray) {}
   };
 
+  /// Container to pass the default attributes with which a kernel must be
+  /// launched, used to set kernel attributes and populate associated static
+  /// structures.
+  ///
+  /// For max values, < 0 means unset, == 0 means set but unknown at compile
+  /// time. The number of max values will be 1 except for the case where
+  /// ompx_bare is set.
+  struct TargetKernelDefaultAttrs {
+    SmallVector<int32_t, 3> MaxTeams = {-1};
+    int32_t MinTeams = 1;
+    SmallVector<int32_t, 3> MaxThreads = {-1};
+    int32_t MinThreads = 1;
+  };
+
   /// Data structure that contains the needed information to construct the
   /// kernel args vector.
   struct TargetKernelArgs {
@@ -2726,15 +2740,11 @@ class OpenMPIRBuilder {
   ///
   /// \param Loc The insert and source location description.
   /// \param IsSPMD Flag to indicate if the kernel is an SPMD kernel or not.
-  /// \param MinThreads Minimal number of threads, or 0.
-  /// \param MaxThreads Maximal number of threads, or 0.
-  /// \param MinTeams Minimal number of teams, or 0.
-  /// \param MaxTeams Maximal number of teams, or 0.
-  InsertPointTy createTargetInit(const LocationDescription &Loc, bool IsSPMD,
-                                 int32_t MinThreadsVal = 0,
-                                 int32_t MaxThreadsVal = 0,
-                                 int32_t MinTeamsVal = 0,
-                                 int32_t MaxTeamsVal = 0);
+  /// \param Attrs Structure containing the default numbers of threads and teams
+  ///        to launch the kernel with.
+  InsertPointTy createTargetInit(
+      const LocationDescription &Loc, bool IsSPMD,
+      const llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &Attrs);
 
   /// Create a runtime call for kmpc_target_deinit
   ///
@@ -2898,8 +2908,8 @@ class OpenMPIRBuilder {
   /// \param CodeGenIP The insertion point where the call to the outlined
   /// function should be emitted.
   /// \param EntryInfo The entry information about the function.
-  /// \param NumTeams Number of teams specified in the num_teams clause.
-  /// \param NumThreads Number of teams specified in the thread_limit clause.
+  /// \param DefaultAttrs Structure containing the default numbers of threads
+  ///        and teams to launch the kernel with.
   /// \param Inputs The input values to the region that will be passed.
   /// as arguments to the outlined function.
   /// \param BodyGenCB Callback that will generate the region code.
@@ -2912,9 +2922,10 @@ class OpenMPIRBuilder {
       const LocationDescription &Loc, bool IsOffloadEntry,
       OpenMPIRBuilder::InsertPointTy AllocaIP,
       OpenMPIRBuilder::InsertPointTy CodeGenIP,
-      TargetRegionEntryInfo &EntryInfo, ArrayRef<int32_t> NumTeams,
-      ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Inputs,
-      GenMapInfoCallbackTy GenMapInfoCB, TargetBodyGenCallbackTy BodyGenCB,
+      TargetRegionEntryInfo &EntryInfo,
+      const TargetKernelDefaultAttrs &DefaultAttrs,
+      SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
+      TargetBodyGenCallbackTy BodyGenCB,
       TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
       SmallVector<DependData> Dependencies = {}, bool HasNowait = false);
 
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index d2e4dc1c85dfd2..302d363965c940 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -6109,10 +6109,12 @@ CallInst *OpenMPIRBuilder::createCachedThreadPrivate(
   return Builder.CreateCall(Fn, Args);
 }
 
-OpenMPIRBuilder::InsertPointTy
-OpenMPIRBuilder::createTargetInit(const LocationDescription &Loc, bool IsSPMD,
-                                  int32_t MinThreadsVal, int32_t MaxThreadsVal,
-                                  int32_t MinTeamsVal, int32_t MaxTeamsVal) {
+OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetInit(
+    const LocationDescription &Loc, bool IsSPMD,
+    const llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &Attrs) {
+  assert(!Attrs.MaxThreads.empty() && !Attrs.MaxTeams.empty() &&
+         "expected num_threads and num_teams to be specified");
+
   if (!updateToLocation(Loc))
     return Loc.IP;
 
@@ -6139,21 +6141,23 @@ OpenMPIRBuilder::createTargetInit(const LocationDescription &Loc, bool IsSPMD,
 
   // Manifest the launch configuration in the metadata matching the kernel
   // environment.
-  if (MinTeamsVal > 1 || MaxTeamsVal > 0)
-    writeTeamsForKernel(T, *Kernel, MinTeamsVal, MaxTeamsVal);
+  if (Attrs.MinTeams > 1 || Attrs.MaxTeams.front() > 0)
+    writeTeamsForKernel(T, *Kernel, Attrs.MinTeams, Attrs.MaxTeams.front());
 
-  // For max values, < 0 means unset, == 0 means set but unknown.
+  // If MaxThreads not set, select the maximum between the default workgroup
+  // size and the MinThreads value.
+  int32_t MaxThreadsVal = Attrs.MaxThreads.front();
   if (MaxThreadsVal < 0)
     MaxThreadsVal = std::max(
-        int32_t(getGridValue(T, Kernel).GV_Default_WG_Size), MinThreadsVal);
+        int32_t(getGridValue(T, Kernel).GV_Default_WG_Size), Attrs.MinThreads);
 
   if (MaxThreadsVal > 0)
-    writeThreadBoundsForKernel(T, *Kernel, MinThreadsVal, MaxThreadsVal);
+    writeThreadBoundsForKernel(T, *Kernel, Attrs.MinThreads, MaxThreadsVal);
 
-  Constant *MinThreads = ConstantInt::getSigned(Int32, MinThreadsVal);
+  Constant *MinThreads = ConstantInt::getSigned(Int32, Attrs.MinThreads);
   Constant *MaxThreads = ConstantInt::getSigned(Int32, MaxThreadsVal);
-  Constant *MinTeams = ConstantInt::getSigned(Int32, MinTeamsVal);
-  Constant *MaxTeams = ConstantInt::getSigned(Int32, MaxTeamsVal);
+  Constant *MinTeams = ConstantInt::getSigned(Int32, Attrs.MinTeams);
+  Constant *MaxTeams = ConstantInt::getSigned(Int32, Attrs.MaxTeams.front());
   Constant *ReductionDataSize = ConstantInt::getSigned(Int32, 0);
   Constant *ReductionBufferLength = ConstantInt::getSigned(Int32, 0);
 
@@ -6724,8 +6728,9 @@ FunctionCallee OpenMPIRBuilder::createDispatchDeinitFunction() {
 }
 
 static Expected<Function *> createOutlinedFunction(
-    OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, StringRef FuncName,
-    SmallVectorImpl<Value *> &Inputs,
+    OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
+    const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
+    StringRef FuncName, SmallVectorImpl<Value *> &Inputs,
     OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
     OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) {
   SmallVector<Type *> ParameterTypes;
@@ -6792,7 +6797,8 @@ static Expected<Function *> createOutlinedFunction(
 
   // Insert target init call in the device compilation pass.
   if (OMPBuilder.Config.isTargetDevice())
-    Builder.restoreIP(OMPBuilder.createTargetInit(Builder, /*IsSPMD*/ false));
+    Builder.restoreIP(
+        OMPBuilder.createTargetInit(Builder, /*IsSPMD=*/false, DefaultAttrs));
 
   BasicBlock *UserCodeEntryBB = Builder.GetInsertBlock();
 
@@ -6989,16 +6995,18 @@ static Function *emitTargetTaskProxyFunction(OpenMPIRBuilder &OMPBuilder,
 
 static Error emitTargetOutlinedFunction(
     OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsOffloadEntry,
-    TargetRegionEntryInfo &EntryInfo, Function *&OutlinedFn,
-    Constant *&OutlinedFnID, SmallVectorImpl<Value *> &Inputs,
+    TargetRegionEntryInfo &EntryInfo,
+    const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
+    Function *&OutlinedFn, Constant *&OutlinedFnID,
+    SmallVectorImpl<Value *> &Inputs,
     OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
     OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) {
 
   OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction =
-      [&OMPBuilder, &Builder, &Inputs, &CBFunc,
-       &ArgAccessorFuncCB](StringRef EntryFnName) {
-        return createOutlinedFunction(OMPBuilder, Builder, EntryFnName, Inputs,
-                                      CBFunc, ArgAccessorFuncCB);
+      [&](StringRef EntryFnName) {
+        return createOutlinedFunction(OMPBuilder, Builder, DefaultAttrs,
+                                      EntryFnName, Inputs, CBFunc,
+                                      ArgAccessorFuncCB);
       };
 
   return OMPBuilder.emitTargetRegionFunction(
@@ -7294,9 +7302,10 @@ void OpenMPIRBuilder::emitOffloadingArraysAndArgs(
 
 static void
 emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
-               OpenMPIRBuilder::InsertPointTy AllocaIP, Function *OutlinedFn,
-               Constant *OutlinedFnID, ArrayRef<int32_t> NumTeams,
-               ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Args,
+               OpenMPIRBuilder::InsertPointTy AllocaIP,
+               const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
+               Function *OutlinedFn, Constant *OutlinedFnID,
+               SmallVectorImpl<Value *> &Args,
                OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
                SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {},
                bool HasNoWait = false) {
@@ -7377,9 +7386,9 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
 
   SmallVector<Value *, 3> NumTeamsC;
   SmallVector<Value *, 3> NumThreadsC;
-  for (auto V : NumTeams)
+  for (auto V : DefaultAttrs.MaxTeams)
     NumTeamsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
-  for (auto V : NumThreads)
+  for (auto V : DefaultAttrs.MaxThreads)
     NumThreadsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
 
   unsigned NumTargetItems = Info.NumberOfPtrs;
@@ -7420,7 +7429,7 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
     const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
     InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
-    ArrayRef<int32_t> NumTeams, ArrayRef<int32_t> NumThreads,
+    const TargetKernelDefaultAttrs &DefaultAttrs,
     SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
     OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
     OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
@@ -7437,16 +7446,16 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
   // the target region itself is generated using the callbacks CBFunc
   // and ArgAccessorFuncCB
   if (Error Err = emitTargetOutlinedFunction(
-          *this, Builder, IsOffloadEntry, EntryInfo, OutlinedFn, OutlinedFnID,
-          Args, CBFunc, ArgAccessorFuncCB))
+          *this, Builder, IsOffloadEntry, EntryInfo, DefaultAttrs, OutlinedFn,
+          OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB))
     return Err;
 
   // If we are not on the target device, then we need to generate code
   // to make a remote call (offload) to the previously outlined function
   // that represents the target region. Do that now.
   if (!Config.isTargetDevice())
-    emitTargetCall(*this, Builder, AllocaIP, OutlinedFn, OutlinedFnID, NumTeams,
-                   NumThreads, Args, GenMapInfoCB, Dependencies, HasNowait);
+    emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, OutlinedFn,
+                   OutlinedFnID, Args, GenMapInfoCB, Dependencies, HasNowait);
   return Builder.saveIP();
 }
 
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index 630cd03c688012..b0688d6215e42d 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -6182,9 +6182,12 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
 
   TargetRegionEntryInfo EntryInfo("func", 42, 4711, 17);
   OpenMPIRBuilder::LocationDescription OmpLoc({Builder.saveIP(), DL});
-  OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
-      OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(), Builder.saveIP(),
-      EntryInfo, -1, 0, Inputs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
+  OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
+      /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+  OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
+      OMPBuilder.createTarget(OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(),
+                              Builder.saveIP(), EntryInfo, DefaultAttrs, Inputs,
+                              GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
   assert(AfterIP && "unexpected error");
   Builder.restoreIP(*AfterIP);
   OMPBuilder.finalize();
@@ -6292,11 +6295,11 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
   TargetRegionEntryInfo EntryInfo("parent", /*DeviceID=*/1, /*FileID=*/2,
                                   /*Line=*/3, /*Count=*/0);
 
-  OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
-      OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
-                              EntryInfo, /*NumTeams=*/-1,
-                              /*NumThreads=*/0, CapturedArgs, GenMapInfoCB,
-                              BodyGenCB, SimpleArgAccessorCB);
+  OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
+      /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+  OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
+      Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, EntryInfo, DefaultAttrs,
+      CapturedArgs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
   assert(AfterIP && "unexpected error");
   Builder.restoreIP(*AfterIP);
 
@@ -6443,11 +6446,11 @@ TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) {
   TargetRegionEntryInfo EntryInfo("parent", /*DeviceID=*/1, /*FileID=*/2,
                                   /*Line=*/3, /*Count=*/0);
 
-  OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
-      OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
-                              EntryInfo, /*NumTeams=*/-1,
-                              /*NumThreads=*/0, CapturedArgs, GenMapInfoCB,
-                              BodyGenCB, SimpleArgAccessorCB);
+  OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
+      /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+  OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
+      Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, EntryInfo, DefaultAttrs,
+      CapturedArgs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
   assert(AfterIP && "unexpected error");
   Builder.restoreIP(*AfterIP);
 
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index cbcbeea4ab9225..d3c3839accb7e7 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -3902,9 +3902,6 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
   if (!getTargetEntryUniqueInfo(entryInfo, targetOp, parentName))
     return failure();
 
-  int32_t defaultValTeams = -1;
-  int32_t defaultValThreads = 0;
-
   llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
       findAllocaInsertPoint(builder, moduleTranslation);
 
@@ -3939,6 +3936,10 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
                                         allocaIP, codeGenIP);
   };
 
+  // TODO: Populate default attributes based on the construct and clauses.
+  llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs = {
+      /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+
   llvm::SmallVector<llvm::Value *, 4> kernelInput;
   for (size_t i = 0; i < mapVars.size(); ++i) {
     // declare target arguments are not passed to kernels as arguments
@@ -3957,8 +3958,8 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
   llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
       moduleTranslation.getOpenMPBuilder()->createTarget(
           ompLoc, isOffloadEntry, alloca...
[truncated]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang:codegen clang:openmp OpenMP related changes to Clang clang Clang issues not falling into any other category flang:openmp mlir:llvm mlir:openmp mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants