Skip to content

Commit

Permalink
[MLIR][OpenMP] LLVM IR translation of host_eval
Browse files Browse the repository at this point in the history
This patch adds support for processing the `host_eval` clause of `omp.target`
to populate default and runtime kernel launch attributes. Specifically, these
related to the `num_teams`, `thread_limit` and `num_threads` clauses attached
to operations nested inside of `omp.target`. As a result, the `thread_limit`
clause of `omp.target` is also supported.

The implementation of `initTargetDefaultAttrs()` is intended to reflect clang's
own processing of multiple constructs and clauses in order to define a default
number of teams and threads to be used as kernel attributes and to populate
global variables in the target device module.

One side effect of this change is that it is no longer possible to translate to
LLVM IR target device MLIR modules unless they have a supported target triple.
This is because the local `getGridValue()` function in the `OpenMPIRBuilder`
only works for certain architectures, and it is called whenever the maximum
number of threads has not been explicitly defined. This limitation also matches
clang.

Support for evaluating the collapsed loop trip count of target SPMD kernels
remains unsupported.
  • Loading branch information
skatrak committed Nov 27, 2024
1 parent c7ca41f commit 58bd5ff
Show file tree
Hide file tree
Showing 18 changed files with 361 additions and 60 deletions.
2 changes: 1 addition & 1 deletion flang/test/Integration/OpenMP/target-filtering.f90
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
!===----------------------------------------------------------------------===!

!RUN: %flang_fc1 -emit-llvm -fopenmp %s -o - | FileCheck %s --check-prefixes HOST,ALL
!RUN: %flang_fc1 -emit-llvm -fopenmp -fopenmp-is-target-device %s -o - | FileCheck %s --check-prefixes DEVICE,ALL
!RUN: %flang_fc1 -triple amdgcn-amd-amdhsa -emit-llvm -fopenmp -fopenmp-is-target-device %s -o - | FileCheck %s --check-prefixes DEVICE,ALL

!HOST: define {{.*}}@{{.*}}before{{.*}}(
!DEVICE-NOT: define {{.*}}@before{{.*}}(
Expand Down
6 changes: 3 additions & 3 deletions flang/test/Lower/OpenMP/function-filtering-2.f90
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
! RUN: %flang_fc1 -fopenmp -fopenmp-version=52 -flang-experimental-hlfir -emit-llvm %s -o - | FileCheck --check-prefixes=LLVM,LLVM-HOST %s
! RUN: %flang_fc1 -fopenmp -fopenmp-version=52 -emit-hlfir %s -o - | FileCheck --check-prefix=MLIR %s
! RUN: %flang_fc1 -fopenmp -fopenmp-version=52 -fopenmp-is-target-device -flang-experimental-hlfir -emit-llvm %s -o - | FileCheck --check-prefixes=LLVM,LLVM-DEVICE %s
! RUN: %flang_fc1 -fopenmp -fopenmp-version=52 -fopenmp-is-target-device -emit-hlfir %s -o - | FileCheck --check-prefix=MLIR %s
! RUN: %flang_fc1 -triple amdgcn-amd-amdhsa -fopenmp -fopenmp-version=52 -fopenmp-is-target-device -flang-experimental-hlfir -emit-llvm %s -o - | FileCheck --check-prefixes=LLVM,LLVM-DEVICE %s
! RUN: %flang_fc1 -triple amdgcn-amd-amdhsa -fopenmp -fopenmp-version=52 -fopenmp-is-target-device -emit-hlfir %s -o - | FileCheck --check-prefix=MLIR %s
! RUN: bbc -fopenmp -fopenmp-version=52 -emit-hlfir %s -o - | FileCheck --check-prefixes=MLIR-HOST,MLIR-ALL %s
! RUN: bbc -fopenmp -fopenmp-version=52 -fopenmp-is-target-device -emit-hlfir %s -o - | FileCheck --check-prefixes=MLIR-DEVICE,MLIR-ALL %s
! RUN: bbc -target amdgcn-amd-amdhsa -fopenmp -fopenmp-version=52 -fopenmp-is-target-device -emit-hlfir %s -o - | FileCheck --check-prefixes=MLIR-DEVICE,MLIR-ALL %s

! MLIR: func.func @{{.*}}implicit_invocation() attributes {omp.declare_target = #omp.declaretarget<device_type = (nohost), capture_clause = (to)>}
! MLIR: return
Expand Down
6 changes: 3 additions & 3 deletions flang/test/Lower/OpenMP/function-filtering-3.f90
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
! RUN: %flang_fc1 -fopenmp -flang-experimental-hlfir -emit-llvm %s -o - | FileCheck --check-prefixes=LLVM-HOST,LLVM-ALL %s
! RUN: %flang_fc1 -fopenmp -emit-hlfir %s -o - | FileCheck --check-prefixes=MLIR-HOST,MLIR-ALL %s
! RUN: %flang_fc1 -fopenmp -fopenmp-is-target-device -flang-experimental-hlfir -emit-llvm %s -o - | FileCheck --check-prefixes=LLVM-DEVICE,LLVM-ALL %s
! RUN: %flang_fc1 -fopenmp -fopenmp-is-target-device -emit-hlfir %s -o - | FileCheck --check-prefixes=MLIR-DEVICE,MLIR-ALL %s
! RUN: %flang_fc1 -triple amdgcn-amd-amdhsa -fopenmp -fopenmp-is-target-device -flang-experimental-hlfir -emit-llvm %s -o - | FileCheck --check-prefixes=LLVM-DEVICE,LLVM-ALL %s
! RUN: %flang_fc1 -triple amdgcn-amd-amdhsa -fopenmp -fopenmp-is-target-device -emit-hlfir %s -o - | FileCheck --check-prefixes=MLIR-DEVICE,MLIR-ALL %s
! RUN: bbc -fopenmp -emit-hlfir %s -o - | FileCheck --check-prefixes=MLIR-HOST,MLIR-ALL %s
! RUN: bbc -fopenmp -fopenmp-is-target-device -emit-hlfir %s -o - | FileCheck --check-prefixes=MLIR-DEVICE,MLIR-ALL %s
! RUN: bbc -target amdgcn-amd-amdhsa -fopenmp -fopenmp-is-target-device -emit-hlfir %s -o - | FileCheck --check-prefixes=MLIR-DEVICE,MLIR-ALL %s

! Check that the correct LLVM IR functions are kept for the host and device
! after running the whole set of translation and transformation passes from
Expand Down
6 changes: 3 additions & 3 deletions flang/test/Lower/OpenMP/function-filtering.f90
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
! RUN: %flang_fc1 -fopenmp -fopenmp-version=52 -flang-experimental-hlfir -emit-llvm %s -o - | FileCheck --check-prefixes=LLVM-HOST,LLVM-ALL %s
! RUN: %flang_fc1 -fopenmp -fopenmp-version=52 -emit-hlfir %s -o - | FileCheck --check-prefixes=MLIR-HOST,MLIR-ALL %s
! RUN: %flang_fc1 -fopenmp -fopenmp-version=52 -fopenmp-is-target-device -flang-experimental-hlfir -emit-llvm %s -o - | FileCheck --check-prefixes=LLVM-DEVICE,LLVM-ALL %s
! RUN: %flang_fc1 -fopenmp -fopenmp-version=52 -fopenmp-is-target-device -emit-hlfir %s -o - | FileCheck --check-prefixes=MLIR-DEVICE,MLIR-ALL %s
! RUN: %flang_fc1 -triple amdgcn-amd-amdhsa -fopenmp -fopenmp-version=52 -fopenmp-is-target-device -flang-experimental-hlfir -emit-llvm %s -o - | FileCheck --check-prefixes=LLVM-DEVICE,LLVM-ALL %s
! RUN: %flang_fc1 -triple amdgcn-amd-amdhsa -fopenmp -fopenmp-version=52 -fopenmp-is-target-device -emit-hlfir %s -o - | FileCheck --check-prefixes=MLIR-DEVICE,MLIR-ALL %s
! RUN: bbc -fopenmp -fopenmp-version=52 -emit-hlfir %s -o - | FileCheck --check-prefixes=MLIR-HOST,MLIR-ALL %s
! RUN: bbc -fopenmp -fopenmp-version=52 -fopenmp-is-target-device -emit-hlfir %s -o - | FileCheck --check-prefixes=MLIR-DEVICE,MLIR-ALL %s
! RUN: bbc -target amdgcn-amd-amdhsa -fopenmp -fopenmp-version=52 -fopenmp-is-target-device -emit-hlfir %s -o - | FileCheck --check-prefixes=MLIR-DEVICE,MLIR-ALL %s

! Check that the correct LLVM IR functions are kept for the host and device
! after running the whole set of translation and transformation passes from
Expand Down
262 changes: 244 additions & 18 deletions mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
if (op.getHint())
op.emitWarning("hint clause discarded");
};
auto checkHostEval = [&todo](auto op, LogicalResult &result) {
if (!op.getHostEvalVars().empty())
result = todo("host_eval");
};
auto checkIf = [&todo](auto op, LogicalResult &result) {
if (op.getIfExpr())
result = todo("if");
Expand Down Expand Up @@ -228,10 +224,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
op.getReductionSyms())
result = todo("reduction");
};
auto checkThreadLimit = [&todo](auto op, LogicalResult &result) {
if (op.getThreadLimit())
result = todo("thread_limit");
};
auto checkTaskReduction = [&todo](auto op, LogicalResult &result) {
if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
op.getTaskReductionSyms())
Expand Down Expand Up @@ -295,7 +287,16 @@ static LogicalResult checkImplementationStatus(Operation &op) {
checkAllocate(op, result);
checkDevice(op, result);
checkHasDeviceAddr(op, result);
checkHostEval(op, result);

// Host evaluated clauses are supported, except for target SPMD loop
// bounds.
for (BlockArgument arg :
cast<omp::BlockArgOpenMPOpInterface>(*op).getHostEvalBlockArgs())
for (Operation *user : arg.getUsers())
if (isa<omp::LoopNestOp>(user))
result = op.emitError("not yet implemented: host evaluation of "
"loop bounds in omp.target operation");

checkIf(op, result);
checkInReduction(op, result);
checkIsDevicePtr(op, result);
Expand All @@ -316,7 +317,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
"structures in omp.target operation");
}
}
checkThreadLimit(op, result);
})
.Default([](Operation &) {
// Assume all clauses for an operation can be translated unless they are
Expand Down Expand Up @@ -3800,6 +3800,215 @@ createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
return builder.saveIP();
}

/// Follow uses of `host_eval`-defined block arguments of the given `omp.target`
/// operation and populate output variables with their corresponding host value
/// (i.e. operand evaluated outside of the target region), based on their uses
/// inside of the target region.
///
/// Loop bounds and steps are only optionally populated, if output vectors are
/// provided.
static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
Value &numTeamsLower, Value &numTeamsUpper,
Value &threadLimit) {
auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
for (auto item : llvm::zip_equal(targetOp.getHostEvalVars(),
blockArgIface.getHostEvalBlockArgs())) {
Value hostEvalVar = std::get<0>(item), blockArg = std::get<1>(item);

for (Operation *user : blockArg.getUsers()) {
llvm::TypeSwitch<Operation *>(user)
.Case([&](omp::TeamsOp teamsOp) {
if (teamsOp.getNumTeamsLower() == blockArg)
numTeamsLower = hostEvalVar;
else if (teamsOp.getNumTeamsUpper() == blockArg)
numTeamsUpper = hostEvalVar;
else if (teamsOp.getThreadLimit() == blockArg)
threadLimit = hostEvalVar;
else
llvm_unreachable("unsupported host_eval use");
})
.Case([&](omp::ParallelOp parallelOp) {
if (parallelOp.getNumThreads() == blockArg)
numThreads = hostEvalVar;
else
llvm_unreachable("unsupported host_eval use");
})
.Case([&](omp::LoopNestOp loopOp) {
// TODO: Extract bounds and step values.
})
.Default([](Operation *) {
llvm_unreachable("unsupported host_eval use");
});
}
}
}

/// If \p op is of the given type parameter, return it casted to that type.
/// Otherwise, if its immediate parent operation (or some other higher-level
/// parent, if \p immediateParent is false) is of that type, return that parent
/// casted to the given type.
///
/// If \p op is \c null or neither it or its parent(s) are of the specified
/// type, return a \c null operation.
template <typename OpTy>
static OpTy castOrGetParentOfType(Operation *op, bool immediateParent = false) {
if (!op)
return OpTy();

if (OpTy casted = dyn_cast<OpTy>(op))
return casted;

if (immediateParent)
return dyn_cast_if_present<OpTy>(op->getParentOp());

return op->getParentOfType<OpTy>();
}

/// Populate default `MinTeams`, `MaxTeams` and `MaxThreads` to their default
/// values as stated by the corresponding clauses, if constant.
///
/// These default values must be set before the creation of the outlined LLVM
/// function for the target region, so that they can be used to initialize the
/// corresponding global `ConfigurationEnvironmentTy` structure.
static void
initTargetDefaultAttrs(omp::TargetOp targetOp,
llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs,
bool isTargetDevice) {
// TODO: Handle constant 'if' clauses.
Operation *capturedOp = targetOp.getInnermostCapturedOmpOp();

Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
if (!isTargetDevice) {
extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
threadLimit);
} else {
// In the target device, values for these clauses are not passed as
// host_eval, but instead evaluated prior to entry to the region. This
// ensures values are mapped and available inside of the target region.
if (auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
numTeamsLower = teamsOp.getNumTeamsLower();
numTeamsUpper = teamsOp.getNumTeamsUpper();
threadLimit = teamsOp.getThreadLimit();
}

if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp))
numThreads = parallelOp.getNumThreads();
}

auto extractConstInteger = [](Value value) -> std::optional<int64_t> {
if (auto constOp =
dyn_cast_if_present<LLVM::ConstantOp>(value.getDefiningOp()))
if (auto constAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
return constAttr.getInt();

return std::nullopt;
};

// Handle clauses impacting the number of teams.

int32_t minTeamsVal = 1, maxTeamsVal = -1;
if (castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
// TODO: Use `hostNumTeamsLower` to initialize `minTeamsVal`. For now, match
// clang and set min and max to the same value.
if (numTeamsUpper) {
if (auto val = extractConstInteger(numTeamsUpper))
minTeamsVal = maxTeamsVal = *val;
} else {
minTeamsVal = maxTeamsVal = 0;
}
} else if (castOrGetParentOfType<omp::ParallelOp>(capturedOp,
/*immediateParent=*/true) ||
castOrGetParentOfType<omp::SimdOp>(capturedOp,
/*immediateParent=*/true)) {
minTeamsVal = maxTeamsVal = 1;
} else {
minTeamsVal = maxTeamsVal = -1;
}

// Handle clauses impacting the number of threads.

auto setMaxValueFromClause = [&extractConstInteger](Value clauseValue,
int32_t &result) {
if (!clauseValue)
return;

if (auto val = extractConstInteger(clauseValue))
result = *val;

// Found an applicable clause, so it's not undefined. Mark as unknown
// because it's not constant.
if (result < 0)
result = 0;
};

// Extract 'thread_limit' clause from 'target' and 'teams' directives.
int32_t targetThreadLimitVal = -1, teamsThreadLimitVal = -1;
setMaxValueFromClause(targetOp.getThreadLimit(), targetThreadLimitVal);
setMaxValueFromClause(threadLimit, teamsThreadLimitVal);

// Extract 'max_threads' clause from 'parallel' or set to 1 if it's SIMD.
int32_t maxThreadsVal = -1;
if (castOrGetParentOfType<omp::ParallelOp>(capturedOp))
setMaxValueFromClause(numThreads, maxThreadsVal);
else if (castOrGetParentOfType<omp::SimdOp>(capturedOp,
/*immediateParent=*/true))
maxThreadsVal = 1;

// For max values, < 0 means unset, == 0 means set but unknown. Select the
// minimum value between 'max_threads' and 'thread_limit' clauses that were
// set.
int32_t combinedMaxThreadsVal = targetThreadLimitVal;
if (combinedMaxThreadsVal < 0 ||
(teamsThreadLimitVal >= 0 && teamsThreadLimitVal < combinedMaxThreadsVal))
combinedMaxThreadsVal = teamsThreadLimitVal;

if (combinedMaxThreadsVal < 0 ||
(maxThreadsVal >= 0 && maxThreadsVal < combinedMaxThreadsVal))
combinedMaxThreadsVal = maxThreadsVal;

// Update kernel bounds structure for the `OpenMPIRBuilder` to use.
attrs.MinTeams = minTeamsVal;
attrs.MaxTeams.front() = maxTeamsVal;
attrs.MinThreads = 1;
attrs.MaxThreads.front() = combinedMaxThreadsVal;
}

/// Gather LLVM runtime values for all clauses evaluated in the host that are
/// passed to the kernel invocation.
///
/// This function must be called only when compiling for the host. Also, it will
/// only provide correct results if it's called after the body of \c targetOp
/// has been fully generated.
static void
initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation,
omp::TargetOp targetOp,
llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
teamsThreadLimit);

// TODO: Handle constant 'if' clauses.
if (Value targetThreadLimit = targetOp.getThreadLimit())
attrs.TargetThreadLimit.front() =
moduleTranslation.lookupValue(targetThreadLimit);

if (numTeamsLower)
attrs.MinTeams = moduleTranslation.lookupValue(numTeamsLower);

if (numTeamsUpper)
attrs.MaxTeams.front() = moduleTranslation.lookupValue(numTeamsUpper);

if (teamsThreadLimit)
attrs.TeamsThreadLimit.front() =
moduleTranslation.lookupValue(teamsThreadLimit);

if (numThreads)
attrs.MaxThreads = moduleTranslation.lookupValue(numThreads);

// TODO: Populate attrs.LoopTripCount if it is target SPMD.
}

static LogicalResult
convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
Expand All @@ -3809,12 +4018,13 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,

llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
bool isTargetDevice = ompBuilder->Config.isTargetDevice();

auto parentFn = opInst.getParentOfType<LLVM::LLVMFuncOp>();
auto blockIface = cast<omp::BlockArgOpenMPOpInterface>(opInst);
auto &targetRegion = targetOp.getRegion();
DataLayout dl = DataLayout(opInst.getParentOfType<ModuleOp>());
SmallVector<Value> mapVars = targetOp.getMapVars();
ArrayRef<BlockArgument> mapBlockArgs =
cast<omp::BlockArgOpenMPOpInterface>(opInst).getMapBlockArgs();
ArrayRef<BlockArgument> mapBlockArgs = blockIface.getMapBlockArgs();
llvm::Function *llvmOutlinedFn = nullptr;

// TODO: It can also be false if a compile-time constant `false` IF clause is
Expand Down Expand Up @@ -3857,7 +4067,7 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
OperandRange privateVars = targetOp.getPrivateVars();
std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms();
MutableArrayRef<BlockArgument> privateBlockArgs =
cast<omp::BlockArgOpenMPOpInterface>(opInst).getPrivateBlockArgs();
blockIface.getPrivateBlockArgs();

for (auto [privVar, privatizerNameAttr, privBlockArg] :
llvm::zip_equal(privateVars, *privateSyms, privateBlockArgs)) {
Expand Down Expand Up @@ -3936,13 +4146,29 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
allocaIP, codeGenIP);
};

// TODO: Populate default and runtime attributes based on the construct and
// clauses.
llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs = {
/*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
llvm::SmallVector<llvm::Value *, 4> kernelInput;
llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs;
initTargetDefaultAttrs(targetOp, defaultAttrs, isTargetDevice);

// Collect host-evaluated values needed to properly launch the kernel from the
// host.
llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs;
if (!isTargetDevice)
initTargetRuntimeAttrs(builder, moduleTranslation, targetOp, runtimeAttrs);

// Pass host-evaluated values as parameters to the kernel / host fallback,
// except if they are constants. In any case, map the MLIR block argument to
// the corresponding LLVM values.
SmallVector<Value> hostEvalVars = targetOp.getHostEvalVars();
ArrayRef<BlockArgument> hostEvalBlockArgs = blockIface.getHostEvalBlockArgs();
for (auto [arg, var] : llvm::zip_equal(hostEvalBlockArgs, hostEvalVars)) {
llvm::Value *value = moduleTranslation.lookupValue(var);
moduleTranslation.mapValue(arg, value);

if (!llvm::isa<llvm::Constant>(value))
kernelInput.push_back(value);
}

llvm::SmallVector<llvm::Value *, 4> kernelInput;
for (size_t i = 0; i < mapVars.size(); ++i) {
// declare target arguments are not passed to kernels as arguments
// TODO: We currently do not handle cases where a member is explicitly
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s

module attributes {omp.is_target_device = true} {
module attributes {llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_target_device = true} {
llvm.func @_QQmain() attributes {fir.bindc_name = "main"} {
%0 = llvm.mlir.addressof @_QFEi : !llvm.ptr
%1 = llvm.mlir.addressof @_QFEsp : !llvm.ptr
Expand All @@ -23,7 +23,7 @@ module attributes {omp.is_target_device = true} {
}
}

// CHECK: define {{.*}} void @__omp_offloading_{{.*}}_{{.*}}__QQmain_l{{.*}}(ptr %[[DYN_PTR:.*]], ptr %[[ARG_BYREF:.*]], ptr %[[ARG_BYCOPY:.*]]) {
// CHECK: define {{.*}} void @__omp_offloading_{{.*}}_{{.*}}__QQmain_l{{.*}}(ptr %[[DYN_PTR:.*]], ptr %[[ARG_BYREF:.*]], ptr %[[ARG_BYCOPY:.*]]) #{{[0-9]+}} {

// CHECK: entry:
// CHECK: %[[ALLOCA_BYREF:.*]] = alloca ptr, align 8
Expand Down
Loading

0 comments on commit 58bd5ff

Please sign in to comment.