Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Proof-of-Concept][flang][OpenMP] Implicitely map allocatable record fields #117867

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

ergawy
Copy link
Member

@ergawy ergawy commented Nov 27, 2024

This is a proof-of-concept to implicitely map allocatable record fields. This PR contains the following changes:

  1. Re-purposes some of the utils used in Lower/OpenMP.cpp so that these utils work on the mlir::Value level rather than the semantics::Symbol level. This takes one step towards to enabling MLIR passes to more easily do some lowering themselves (e.g. creating omp.map.bounds ops for implicitely caputured data like this PR does).
  2. Adds support for implicitely capturing and mapping allocatable fields in record types.

If the taken approach by this PR is acceptable, I will convert it to a proper PR and add more extensive testing.

…fields

This is a proof-of-concept to implicitely map allocatable record fields.
This PR contains the following changes:
1. Re-purposes some of the utils used in `Lower/OpenMP.cpp` so that
   these utils work on the `mlir::Value` level rather than the
   `semantics::Symbol` level. This takes one step towards to enabling
   MLIR passes to more easily do some lowering themselves (e.g. creating
   `omp.map.bounds` ops for implicitely caputured data like this PR
   does).
2. Adds support for implicitely capturing and mapping allocatable fields
   in record types.

If the taken approach by this PR is acceptable, I will convert it to a
proper PR and add more extensive testing.
@ergawy ergawy changed the title [Proof-of-Concept][flang][OpenMP] Implicitely map allocatable record … [Proof-of-Concept][flang][OpenMP] Implicitely map allocatable record fields Nov 27, 2024
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir offload labels Nov 27, 2024
@llvmbot
Copy link
Member

llvmbot commented Nov 27, 2024

@llvm/pr-subscribers-flang-fir-hlfir

Author: Kareem Ergawy (ergawy)

Changes

This is a proof-of-concept to implicitely map allocatable record fields. This PR contains the following changes:

  1. Re-purposes some of the utils used in Lower/OpenMP.cpp so that these utils work on the mlir::Value level rather than the semantics::Symbol level. This takes one step towards to enabling MLIR passes to more easily do some lowering themselves (e.g. creating omp.map.bounds ops for implicitely caputured data like this PR does).
  2. Adds support for implicitely capturing and mapping allocatable fields in record types.

If the taken approach by this PR is acceptable, I will convert it to a proper PR and add more extensive testing.


Full diff: https://github.com/llvm/llvm-project/pull/117867.diff

5 Files Affected:

  • (modified) flang/lib/Lower/DirectivesCommon.h (+33-17)
  • (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+7-14)
  • (modified) flang/lib/Optimizer/OpenMP/CMakeLists.txt (+2)
  • (modified) flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp (+154)
  • (added) offload/test/offloading/fortran/implicit-record-field-mapping.f90 (+63)
diff --git a/flang/lib/Lower/DirectivesCommon.h b/flang/lib/Lower/DirectivesCommon.h
index 88514b16743278..6e2c6ee4b1bcdb 100644
--- a/flang/lib/Lower/DirectivesCommon.h
+++ b/flang/lib/Lower/DirectivesCommon.h
@@ -609,11 +609,10 @@ void createEmptyRegionBlocks(
   }
 }
 
-inline AddrAndBoundsInfo
-getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
-                       fir::FirOpBuilder &builder,
-                       Fortran::lower::SymbolRef sym, mlir::Location loc) {
-  mlir::Value symAddr = converter.getSymbolAddress(sym);
+inline AddrAndBoundsInfo getDataOperandBaseAddr(fir::FirOpBuilder &builder,
+                                                mlir::Value symAddr,
+                                                bool isOptional,
+                                                mlir::Location loc) {
   mlir::Value rawInput = symAddr;
   if (auto declareOp =
           mlir::dyn_cast_or_null<hlfir::DeclareOp>(symAddr.getDefiningOp())) {
@@ -621,20 +620,11 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
     rawInput = declareOp.getResults()[1];
   }
 
-  // TODO: Might need revisiting to handle for non-shared clauses
-  if (!symAddr) {
-    if (const auto *details =
-            sym->detailsIf<Fortran::semantics::HostAssocDetails>()) {
-      symAddr = converter.getSymbolAddress(details->symbol());
-      rawInput = symAddr;
-    }
-  }
-
   if (!symAddr)
     llvm::report_fatal_error("could not retrieve symbol address");
 
   mlir::Value isPresent;
-  if (Fortran::semantics::IsOptional(sym))
+  if (isOptional)
     isPresent =
         builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), rawInput);
 
@@ -648,8 +638,7 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
     // all address/dimension retrievals. For Fortran optional though, leave
     // the load generation for later so it can be done in the appropriate
     // if branches.
-    if (mlir::isa<fir::ReferenceType>(symAddr.getType()) &&
-        !Fortran::semantics::IsOptional(sym)) {
+    if (mlir::isa<fir::ReferenceType>(symAddr.getType()) && !isOptional) {
       mlir::Value addr = builder.create<fir::LoadOp>(loc, symAddr);
       return AddrAndBoundsInfo(addr, rawInput, isPresent, boxTy);
     }
@@ -659,6 +648,14 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
   return AddrAndBoundsInfo(symAddr, rawInput, isPresent);
 }
 
+inline AddrAndBoundsInfo
+getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
+                       fir::FirOpBuilder &builder,
+                       Fortran::lower::SymbolRef sym, mlir::Location loc) {
+  return getDataOperandBaseAddr(builder, converter.getSymbolAddress(sym),
+                                Fortran::semantics::IsOptional(sym), loc);
+}
+
 template <typename BoundsOp, typename BoundsType>
 llvm::SmallVector<mlir::Value>
 gatherBoundsOrBoundValues(fir::FirOpBuilder &builder, mlir::Location loc,
@@ -1224,6 +1221,25 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
 
   return info;
 }
+
+template <typename BoundsOp, typename BoundsType>
+llvm::SmallVector<mlir::Value>
+genImplicitBoundsOps(fir::FirOpBuilder &builder, lower::AddrAndBoundsInfo &info,
+                     fir::ExtendedValue dataExv, bool dataExvIsAssumedSize,
+                     mlir::Location loc) {
+  llvm::SmallVector<mlir::Value> bounds;
+
+  mlir::Value baseOp = info.rawInput;
+  if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(baseOp.getType())))
+    bounds = lower::genBoundsOpsFromBox<BoundsOp, BoundsType>(builder, loc,
+                                                              dataExv, info);
+  if (mlir::isa<fir::SequenceType>(fir::unwrapRefType(baseOp.getType()))) {
+    bounds = lower::genBaseBoundsOps<BoundsOp, BoundsType>(
+        builder, loc, dataExv, dataExvIsAssumedSize);
+  }
+
+  return bounds;
+}
 } // namespace lower
 } // namespace Fortran
 
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index ea2d1eb66bfea5..ddf15ce7f3c3dd 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1817,32 +1817,25 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
       if (const auto *details =
               sym.template detailsIf<semantics::HostAssocDetails>())
         converter.copySymbolBinding(details->symbol(), sym);
-      llvm::SmallVector<mlir::Value> bounds;
       std::stringstream name;
       fir::ExtendedValue dataExv = converter.getSymbolExtendedValue(sym);
       name << sym.name().ToString();
 
       lower::AddrAndBoundsInfo info = getDataOperandBaseAddr(
           converter, firOpBuilder, sym, converter.getCurrentLocation());
-      mlir::Value baseOp = info.rawInput;
-      if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(baseOp.getType())))
-        bounds = lower::genBoundsOpsFromBox<mlir::omp::MapBoundsOp,
-                                            mlir::omp::MapBoundsType>(
-            firOpBuilder, converter.getCurrentLocation(), dataExv, info);
-      if (mlir::isa<fir::SequenceType>(fir::unwrapRefType(baseOp.getType()))) {
-        bool dataExvIsAssumedSize =
-            semantics::IsAssumedSizeArray(sym.GetUltimate());
-        bounds = lower::genBaseBoundsOps<mlir::omp::MapBoundsOp,
-                                         mlir::omp::MapBoundsType>(
-            firOpBuilder, converter.getCurrentLocation(), dataExv,
-            dataExvIsAssumedSize);
-      }
+      llvm::SmallVector<mlir::Value> bounds =
+          lower::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
+                                      mlir::omp::MapBoundsType>(
+              firOpBuilder, info, dataExv,
+              semantics::IsAssumedSizeArray(sym.GetUltimate()),
+              converter.getCurrentLocation());
 
       llvm::omp::OpenMPOffloadMappingFlags mapFlag =
           llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
       mlir::omp::VariableCaptureKind captureKind =
           mlir::omp::VariableCaptureKind::ByRef;
 
+      mlir::Value baseOp = info.rawInput;
       mlir::Type eleType = baseOp.getType();
       if (auto refType = mlir::dyn_cast<fir::ReferenceType>(baseOp.getType()))
         eleType = refType.getElementType();
diff --git a/flang/lib/Optimizer/OpenMP/CMakeLists.txt b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
index b1e0dbf6e707e5..0b78773a252cd2 100644
--- a/flang/lib/Optimizer/OpenMP/CMakeLists.txt
+++ b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
@@ -11,6 +11,7 @@ add_flang_library(FlangOpenMPTransforms
   FIRDialect
   HLFIROpsIncGen
   FlangOpenMPPassesIncGen
+  ${dialect_libs}
 
   LINK_LIBS
   FIRAnalysis
@@ -25,4 +26,5 @@ add_flang_library(FlangOpenMPTransforms
   HLFIRDialect
   MLIRIR
   MLIRPass
+  ${dialect_libs}
 )
diff --git a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
index 4575c90e34acdd..717d0dcd3fdfba 100644
--- a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
+++ b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
@@ -25,9 +25,12 @@
 //===----------------------------------------------------------------------===//
 
 #include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Builder/HLFIRTools.h"
 #include "flang/Optimizer/Dialect/FIRType.h"
 #include "flang/Optimizer/Dialect/Support/KindMapping.h"
+#include "flang/Optimizer/HLFIR/HLFIROps.h"
 #include "flang/Optimizer/OpenMP/Passes.h"
+#include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
 #include "mlir/IR/BuiltinDialect.h"
@@ -43,6 +46,10 @@
 #include <iterator>
 #include <numeric>
 
+// Move to `flang/include/flang/Common/OpenMP-utils.h` once
+// https://github.com/llvm/llvm-project/pull/115443 is merge.
+#include "../../Lower/DirectivesCommon.h"
+
 namespace flangomp {
 #define GEN_PASS_DEF_MAPINFOFINALIZATIONPASS
 #include "flang/Optimizer/OpenMP/Passes.h.inc"
@@ -485,6 +492,153 @@ class MapInfoFinalizationPass
       // clear all local allocations we made for any boxes in any prior
       // iterations from previous function scopes.
       localBoxAllocas.clear();
+      func->walk([&](mlir::omp::MapInfoOp op) {
+        mlir::Type underlyingType =
+            fir::unwrapRefType(op.getVarPtr().getType());
+
+        if (!fir::isRecordWithAllocatableMember(underlyingType))
+          return mlir::WalkResult::advance();
+
+        mlir::omp::TargetOp target =
+            mlir::dyn_cast_if_present<mlir::omp::TargetOp>(
+                getFirstTargetUser(op));
+
+        if (!target)
+          return mlir::WalkResult::advance();
+
+        auto mapClauseOwner =
+            llvm::dyn_cast<mlir::omp::MapClauseOwningOpInterface>(*target);
+
+        // TODO Add as a method to MapClauseOwningOpInterface.
+        unsigned mapVarIdx = 0;
+        for (auto [idx, mapOp] : llvm::enumerate(mapClauseOwner.getMapVars())) {
+          if (mapOp == op) {
+            mapVarIdx = idx;
+            break;
+          }
+        }
+
+        auto argIface =
+            llvm::dyn_cast<mlir::omp::BlockArgOpenMPOpInterface>(*target);
+        mlir::BlockArgument opBlockArg = argIface.getMapBlockArgs()[mapVarIdx];
+        llvm::SetVector<mlir::Operation *> mapVarForwardSlice;
+        mlir::getForwardSlice(opBlockArg, &mapVarForwardSlice);
+
+        mapVarForwardSlice.remove_if([&](mlir::Operation *sliceOp) {
+          // TODO Support coordinate_of ops.
+          //
+          // TODO Support call ops by recursively examining the forward slice of
+          // the corresponding paramemter to the field.
+          return !mlir::isa<hlfir::DesignateOp>(sliceOp);
+        });
+
+        auto recordType = mlir::cast<fir::RecordType>(underlyingType);
+        llvm::SmallVector<mlir::Value> newMapOpsForFields;
+        llvm::SmallVector<int64_t> fieldIdices;
+
+        for (auto fieldMemTyPair : recordType.getTypeList()) {
+          auto &field = fieldMemTyPair.first;
+          auto memTy = fieldMemTyPair.second;
+
+          bool shouldMapField =
+              llvm::find_if(mapVarForwardSlice, [&](mlir::Operation *sliceOp) {
+                if (!fir::isAllocatableType(memTy))
+                  return false;
+
+                auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp);
+                if (!designateOp)
+                  return false;
+
+                return designateOp.getComponent() &&
+                       designateOp.getComponent()->strref() == field;
+              }) != mapVarForwardSlice.end();
+
+          // TODO Handle recursive record types.
+
+          if (!shouldMapField)
+            continue;
+
+          int64_t fieldIdx = recordType.getFieldIndex(field);
+          bool alreadyMapped = false;
+
+          if (op.getMembersIndexAttr())
+            for (auto indexList : op.getMembersIndexAttr()) {
+              auto indexListAttr = mlir::cast<mlir::ArrayAttr>(indexList);
+              if (indexListAttr.size() == 1 &&
+                  mlir::cast<mlir::IntegerAttr>(indexListAttr[0]).getInt() ==
+                      fieldIdx)
+                alreadyMapped = true;
+            }
+
+          if (alreadyMapped)
+            continue;
+
+          builder.setInsertionPoint(op);
+          mlir::Value fieldIdxVal = builder.createIntegerConstant(
+              op.getLoc(), mlir::IndexType::get(builder.getContext()),
+              fieldIdx);
+          auto fieldCoord = builder.create<fir::CoordinateOp>(
+              op.getLoc(), builder.getRefType(memTy), op.getVarPtr(),
+              fieldIdxVal);
+          Fortran::lower::AddrAndBoundsInfo info =
+              Fortran::lower::getDataOperandBaseAddr(
+                  builder, fieldCoord, /*isOptional=*/false, op.getLoc());
+          llvm::SmallVector<mlir::Value> bounds =
+              Fortran::lower::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
+                                                   mlir::omp::MapBoundsType>(
+                  builder, info,
+                  hlfir::translateToExtendedValue(op.getLoc(), builder,
+                                                  hlfir::Entity{fieldCoord})
+                      .first,
+                  /*dataExvIsAssumedSize=*/false, op.getLoc());
+
+          mlir::omp::MapInfoOp fieldMapOp =
+              builder.create<mlir::omp::MapInfoOp>(
+                  op.getLoc(), fieldCoord.getResult().getType(),
+                  fieldCoord.getResult(),
+                  mlir::TypeAttr::get(
+                      fir::unwrapRefType(fieldCoord.getResult().getType())),
+                  /*varPtrPtr=*/mlir::Value{},
+                  /*members=*/mlir::ValueRange{},
+                  /*members_index=*/mlir::ArrayAttr{},
+                  /*bounds=*/bounds, op.getMapTypeAttr(),
+                  builder.getAttr<mlir::omp::VariableCaptureKindAttr>(
+                      mlir::omp::VariableCaptureKind::ByRef),
+                  builder.getStringAttr(op.getNameAttr().strref() + "." +
+                                        field + ".implicit_map"),
+                  /*partial_map=*/builder.getBoolAttr(false));
+          newMapOpsForFields.emplace_back(fieldMapOp);
+          fieldIdices.emplace_back(fieldIdx);
+        }
+
+        if (!newMapOpsForFields.empty()) {
+          op.getMembersMutable().append(newMapOpsForFields);
+          llvm::SmallVector<llvm::SmallVector<int64_t>> newMemberIndices;
+          mlir::ArrayAttr oldMembersIdxAttr = op.getMembersIndexAttr();
+
+          if (oldMembersIdxAttr)
+            for (mlir::Attribute indexList : oldMembersIdxAttr) {
+              llvm::SmallVector<int64_t> listVec;
+
+              for (mlir::Attribute index :
+                   mlir::cast<mlir::ArrayAttr>(indexList))
+                listVec.push_back(
+                    mlir::cast<mlir::IntegerAttr>(index).getInt());
+
+              newMemberIndices.emplace_back(std::move(listVec));
+            }
+
+          for (int64_t newFieldIdx : fieldIdices) {
+            newMemberIndices.emplace_back(
+                llvm::SmallVector<int64_t>(1, newFieldIdx));
+          }
+
+          op.setMembersIndexAttr(
+              builder.create2DI64ArrayAttr(newMemberIndices));
+        }
+
+        return mlir::WalkResult::advance();
+      });
 
       func->walk([&](mlir::omp::MapInfoOp op) {
         // TODO: Currently only supports a single user for the MapInfoOp. This
diff --git a/offload/test/offloading/fortran/implicit-record-field-mapping.f90 b/offload/test/offloading/fortran/implicit-record-field-mapping.f90
new file mode 100644
index 00000000000000..8bde033d4ef4c3
--- /dev/null
+++ b/offload/test/offloading/fortran/implicit-record-field-mapping.f90
@@ -0,0 +1,63 @@
+! Test implicit mapping of alloctable record fields.
+
+! REQUIRES: flang, amdgpu
+
+! This fails only because it needs the Fortran runtime built for device. If this
+! is avaialbe, this test succeedds when run.
+! XFAIL: *
+
+! RUN: %libomptarget-compile-fortran-generic
+! RUN: env LIBOMPTARGET_INFO=16 %libomptarget-run-generic 2>&1 | %fcheck-generic
+module record_m
+  implicit none
+
+  private
+  public :: recored_t
+
+  type recored_t
+    real, allocatable :: not_to_implicitly_map(:)
+    real, allocatable :: to_implicitly_map(:)
+  end type
+
+  interface recored_t
+  end interface
+end module record_m
+
+program concurrent_inferences
+  use record_m, only : recored_t
+  implicit none
+
+  type(recored_t) :: dst_record
+  real :: src_array(10)
+  real :: dst_sum, src_sum
+  integer :: i
+
+  call random_number(src_array)
+  dst_sum = 0
+  src_sum = 0
+
+  do i=1,10
+    src_sum = src_sum + src_array(i)
+  end do
+  print *, "src_sum=", src_sum
+
+  !$omp target map(from: dst_sum)
+    dst_record%to_implicitly_map = src_array
+    dst_sum = 0
+
+    do i=1,10
+      dst_sum = dst_sum + dst_record%to_implicitly_map(i)
+    end do
+  !$omp end target
+
+  print *, "dst_sum=", dst_sum
+
+  if (src_sum == dst_sum) then
+    print *, "Test succeeded!"
+  else
+    print *, "Test failed!", " dst_sum=", dst_sum, "vs. src_sum=", src_sum
+  endif
+end program
+
+! CHECK: "PluginInterface" device {{[0-9]+}} info: Launching kernel {{.*}}
+! CHECK: Test succeeded!

@llvmbot
Copy link
Member

llvmbot commented Nov 27, 2024

@llvm/pr-subscribers-offload

Author: Kareem Ergawy (ergawy)

Changes

This is a proof-of-concept to implicitely map allocatable record fields. This PR contains the following changes:

  1. Re-purposes some of the utils used in Lower/OpenMP.cpp so that these utils work on the mlir::Value level rather than the semantics::Symbol level. This takes one step towards to enabling MLIR passes to more easily do some lowering themselves (e.g. creating omp.map.bounds ops for implicitely caputured data like this PR does).
  2. Adds support for implicitely capturing and mapping allocatable fields in record types.

If the taken approach by this PR is acceptable, I will convert it to a proper PR and add more extensive testing.


Full diff: https://github.com/llvm/llvm-project/pull/117867.diff

5 Files Affected:

  • (modified) flang/lib/Lower/DirectivesCommon.h (+33-17)
  • (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+7-14)
  • (modified) flang/lib/Optimizer/OpenMP/CMakeLists.txt (+2)
  • (modified) flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp (+154)
  • (added) offload/test/offloading/fortran/implicit-record-field-mapping.f90 (+63)
diff --git a/flang/lib/Lower/DirectivesCommon.h b/flang/lib/Lower/DirectivesCommon.h
index 88514b16743278..6e2c6ee4b1bcdb 100644
--- a/flang/lib/Lower/DirectivesCommon.h
+++ b/flang/lib/Lower/DirectivesCommon.h
@@ -609,11 +609,10 @@ void createEmptyRegionBlocks(
   }
 }
 
-inline AddrAndBoundsInfo
-getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
-                       fir::FirOpBuilder &builder,
-                       Fortran::lower::SymbolRef sym, mlir::Location loc) {
-  mlir::Value symAddr = converter.getSymbolAddress(sym);
+inline AddrAndBoundsInfo getDataOperandBaseAddr(fir::FirOpBuilder &builder,
+                                                mlir::Value symAddr,
+                                                bool isOptional,
+                                                mlir::Location loc) {
   mlir::Value rawInput = symAddr;
   if (auto declareOp =
           mlir::dyn_cast_or_null<hlfir::DeclareOp>(symAddr.getDefiningOp())) {
@@ -621,20 +620,11 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
     rawInput = declareOp.getResults()[1];
   }
 
-  // TODO: Might need revisiting to handle for non-shared clauses
-  if (!symAddr) {
-    if (const auto *details =
-            sym->detailsIf<Fortran::semantics::HostAssocDetails>()) {
-      symAddr = converter.getSymbolAddress(details->symbol());
-      rawInput = symAddr;
-    }
-  }
-
   if (!symAddr)
     llvm::report_fatal_error("could not retrieve symbol address");
 
   mlir::Value isPresent;
-  if (Fortran::semantics::IsOptional(sym))
+  if (isOptional)
     isPresent =
         builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), rawInput);
 
@@ -648,8 +638,7 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
     // all address/dimension retrievals. For Fortran optional though, leave
     // the load generation for later so it can be done in the appropriate
     // if branches.
-    if (mlir::isa<fir::ReferenceType>(symAddr.getType()) &&
-        !Fortran::semantics::IsOptional(sym)) {
+    if (mlir::isa<fir::ReferenceType>(symAddr.getType()) && !isOptional) {
       mlir::Value addr = builder.create<fir::LoadOp>(loc, symAddr);
       return AddrAndBoundsInfo(addr, rawInput, isPresent, boxTy);
     }
@@ -659,6 +648,14 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
   return AddrAndBoundsInfo(symAddr, rawInput, isPresent);
 }
 
+inline AddrAndBoundsInfo
+getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
+                       fir::FirOpBuilder &builder,
+                       Fortran::lower::SymbolRef sym, mlir::Location loc) {
+  return getDataOperandBaseAddr(builder, converter.getSymbolAddress(sym),
+                                Fortran::semantics::IsOptional(sym), loc);
+}
+
 template <typename BoundsOp, typename BoundsType>
 llvm::SmallVector<mlir::Value>
 gatherBoundsOrBoundValues(fir::FirOpBuilder &builder, mlir::Location loc,
@@ -1224,6 +1221,25 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
 
   return info;
 }
+
+template <typename BoundsOp, typename BoundsType>
+llvm::SmallVector<mlir::Value>
+genImplicitBoundsOps(fir::FirOpBuilder &builder, lower::AddrAndBoundsInfo &info,
+                     fir::ExtendedValue dataExv, bool dataExvIsAssumedSize,
+                     mlir::Location loc) {
+  llvm::SmallVector<mlir::Value> bounds;
+
+  mlir::Value baseOp = info.rawInput;
+  if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(baseOp.getType())))
+    bounds = lower::genBoundsOpsFromBox<BoundsOp, BoundsType>(builder, loc,
+                                                              dataExv, info);
+  if (mlir::isa<fir::SequenceType>(fir::unwrapRefType(baseOp.getType()))) {
+    bounds = lower::genBaseBoundsOps<BoundsOp, BoundsType>(
+        builder, loc, dataExv, dataExvIsAssumedSize);
+  }
+
+  return bounds;
+}
 } // namespace lower
 } // namespace Fortran
 
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index ea2d1eb66bfea5..ddf15ce7f3c3dd 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1817,32 +1817,25 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
       if (const auto *details =
               sym.template detailsIf<semantics::HostAssocDetails>())
         converter.copySymbolBinding(details->symbol(), sym);
-      llvm::SmallVector<mlir::Value> bounds;
       std::stringstream name;
       fir::ExtendedValue dataExv = converter.getSymbolExtendedValue(sym);
       name << sym.name().ToString();
 
       lower::AddrAndBoundsInfo info = getDataOperandBaseAddr(
           converter, firOpBuilder, sym, converter.getCurrentLocation());
-      mlir::Value baseOp = info.rawInput;
-      if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(baseOp.getType())))
-        bounds = lower::genBoundsOpsFromBox<mlir::omp::MapBoundsOp,
-                                            mlir::omp::MapBoundsType>(
-            firOpBuilder, converter.getCurrentLocation(), dataExv, info);
-      if (mlir::isa<fir::SequenceType>(fir::unwrapRefType(baseOp.getType()))) {
-        bool dataExvIsAssumedSize =
-            semantics::IsAssumedSizeArray(sym.GetUltimate());
-        bounds = lower::genBaseBoundsOps<mlir::omp::MapBoundsOp,
-                                         mlir::omp::MapBoundsType>(
-            firOpBuilder, converter.getCurrentLocation(), dataExv,
-            dataExvIsAssumedSize);
-      }
+      llvm::SmallVector<mlir::Value> bounds =
+          lower::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
+                                      mlir::omp::MapBoundsType>(
+              firOpBuilder, info, dataExv,
+              semantics::IsAssumedSizeArray(sym.GetUltimate()),
+              converter.getCurrentLocation());
 
       llvm::omp::OpenMPOffloadMappingFlags mapFlag =
           llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
       mlir::omp::VariableCaptureKind captureKind =
           mlir::omp::VariableCaptureKind::ByRef;
 
+      mlir::Value baseOp = info.rawInput;
       mlir::Type eleType = baseOp.getType();
       if (auto refType = mlir::dyn_cast<fir::ReferenceType>(baseOp.getType()))
         eleType = refType.getElementType();
diff --git a/flang/lib/Optimizer/OpenMP/CMakeLists.txt b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
index b1e0dbf6e707e5..0b78773a252cd2 100644
--- a/flang/lib/Optimizer/OpenMP/CMakeLists.txt
+++ b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
@@ -11,6 +11,7 @@ add_flang_library(FlangOpenMPTransforms
   FIRDialect
   HLFIROpsIncGen
   FlangOpenMPPassesIncGen
+  ${dialect_libs}
 
   LINK_LIBS
   FIRAnalysis
@@ -25,4 +26,5 @@ add_flang_library(FlangOpenMPTransforms
   HLFIRDialect
   MLIRIR
   MLIRPass
+  ${dialect_libs}
 )
diff --git a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
index 4575c90e34acdd..717d0dcd3fdfba 100644
--- a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
+++ b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
@@ -25,9 +25,12 @@
 //===----------------------------------------------------------------------===//
 
 #include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Builder/HLFIRTools.h"
 #include "flang/Optimizer/Dialect/FIRType.h"
 #include "flang/Optimizer/Dialect/Support/KindMapping.h"
+#include "flang/Optimizer/HLFIR/HLFIROps.h"
 #include "flang/Optimizer/OpenMP/Passes.h"
+#include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
 #include "mlir/IR/BuiltinDialect.h"
@@ -43,6 +46,10 @@
 #include <iterator>
 #include <numeric>
 
+// Move to `flang/include/flang/Common/OpenMP-utils.h` once
+// https://github.com/llvm/llvm-project/pull/115443 is merge.
+#include "../../Lower/DirectivesCommon.h"
+
 namespace flangomp {
 #define GEN_PASS_DEF_MAPINFOFINALIZATIONPASS
 #include "flang/Optimizer/OpenMP/Passes.h.inc"
@@ -485,6 +492,153 @@ class MapInfoFinalizationPass
       // clear all local allocations we made for any boxes in any prior
       // iterations from previous function scopes.
       localBoxAllocas.clear();
+      func->walk([&](mlir::omp::MapInfoOp op) {
+        mlir::Type underlyingType =
+            fir::unwrapRefType(op.getVarPtr().getType());
+
+        if (!fir::isRecordWithAllocatableMember(underlyingType))
+          return mlir::WalkResult::advance();
+
+        mlir::omp::TargetOp target =
+            mlir::dyn_cast_if_present<mlir::omp::TargetOp>(
+                getFirstTargetUser(op));
+
+        if (!target)
+          return mlir::WalkResult::advance();
+
+        auto mapClauseOwner =
+            llvm::dyn_cast<mlir::omp::MapClauseOwningOpInterface>(*target);
+
+        // TODO Add as a method to MapClauseOwningOpInterface.
+        unsigned mapVarIdx = 0;
+        for (auto [idx, mapOp] : llvm::enumerate(mapClauseOwner.getMapVars())) {
+          if (mapOp == op) {
+            mapVarIdx = idx;
+            break;
+          }
+        }
+
+        auto argIface =
+            llvm::dyn_cast<mlir::omp::BlockArgOpenMPOpInterface>(*target);
+        mlir::BlockArgument opBlockArg = argIface.getMapBlockArgs()[mapVarIdx];
+        llvm::SetVector<mlir::Operation *> mapVarForwardSlice;
+        mlir::getForwardSlice(opBlockArg, &mapVarForwardSlice);
+
+        mapVarForwardSlice.remove_if([&](mlir::Operation *sliceOp) {
+          // TODO Support coordinate_of ops.
+          //
+          // TODO Support call ops by recursively examining the forward slice of
+          // the corresponding paramemter to the field.
+          return !mlir::isa<hlfir::DesignateOp>(sliceOp);
+        });
+
+        auto recordType = mlir::cast<fir::RecordType>(underlyingType);
+        llvm::SmallVector<mlir::Value> newMapOpsForFields;
+        llvm::SmallVector<int64_t> fieldIdices;
+
+        for (auto fieldMemTyPair : recordType.getTypeList()) {
+          auto &field = fieldMemTyPair.first;
+          auto memTy = fieldMemTyPair.second;
+
+          bool shouldMapField =
+              llvm::find_if(mapVarForwardSlice, [&](mlir::Operation *sliceOp) {
+                if (!fir::isAllocatableType(memTy))
+                  return false;
+
+                auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp);
+                if (!designateOp)
+                  return false;
+
+                return designateOp.getComponent() &&
+                       designateOp.getComponent()->strref() == field;
+              }) != mapVarForwardSlice.end();
+
+          // TODO Handle recursive record types.
+
+          if (!shouldMapField)
+            continue;
+
+          int64_t fieldIdx = recordType.getFieldIndex(field);
+          bool alreadyMapped = false;
+
+          if (op.getMembersIndexAttr())
+            for (auto indexList : op.getMembersIndexAttr()) {
+              auto indexListAttr = mlir::cast<mlir::ArrayAttr>(indexList);
+              if (indexListAttr.size() == 1 &&
+                  mlir::cast<mlir::IntegerAttr>(indexListAttr[0]).getInt() ==
+                      fieldIdx)
+                alreadyMapped = true;
+            }
+
+          if (alreadyMapped)
+            continue;
+
+          builder.setInsertionPoint(op);
+          mlir::Value fieldIdxVal = builder.createIntegerConstant(
+              op.getLoc(), mlir::IndexType::get(builder.getContext()),
+              fieldIdx);
+          auto fieldCoord = builder.create<fir::CoordinateOp>(
+              op.getLoc(), builder.getRefType(memTy), op.getVarPtr(),
+              fieldIdxVal);
+          Fortran::lower::AddrAndBoundsInfo info =
+              Fortran::lower::getDataOperandBaseAddr(
+                  builder, fieldCoord, /*isOptional=*/false, op.getLoc());
+          llvm::SmallVector<mlir::Value> bounds =
+              Fortran::lower::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
+                                                   mlir::omp::MapBoundsType>(
+                  builder, info,
+                  hlfir::translateToExtendedValue(op.getLoc(), builder,
+                                                  hlfir::Entity{fieldCoord})
+                      .first,
+                  /*dataExvIsAssumedSize=*/false, op.getLoc());
+
+          mlir::omp::MapInfoOp fieldMapOp =
+              builder.create<mlir::omp::MapInfoOp>(
+                  op.getLoc(), fieldCoord.getResult().getType(),
+                  fieldCoord.getResult(),
+                  mlir::TypeAttr::get(
+                      fir::unwrapRefType(fieldCoord.getResult().getType())),
+                  /*varPtrPtr=*/mlir::Value{},
+                  /*members=*/mlir::ValueRange{},
+                  /*members_index=*/mlir::ArrayAttr{},
+                  /*bounds=*/bounds, op.getMapTypeAttr(),
+                  builder.getAttr<mlir::omp::VariableCaptureKindAttr>(
+                      mlir::omp::VariableCaptureKind::ByRef),
+                  builder.getStringAttr(op.getNameAttr().strref() + "." +
+                                        field + ".implicit_map"),
+                  /*partial_map=*/builder.getBoolAttr(false));
+          newMapOpsForFields.emplace_back(fieldMapOp);
+          fieldIdices.emplace_back(fieldIdx);
+        }
+
+        if (!newMapOpsForFields.empty()) {
+          op.getMembersMutable().append(newMapOpsForFields);
+          llvm::SmallVector<llvm::SmallVector<int64_t>> newMemberIndices;
+          mlir::ArrayAttr oldMembersIdxAttr = op.getMembersIndexAttr();
+
+          if (oldMembersIdxAttr)
+            for (mlir::Attribute indexList : oldMembersIdxAttr) {
+              llvm::SmallVector<int64_t> listVec;
+
+              for (mlir::Attribute index :
+                   mlir::cast<mlir::ArrayAttr>(indexList))
+                listVec.push_back(
+                    mlir::cast<mlir::IntegerAttr>(index).getInt());
+
+              newMemberIndices.emplace_back(std::move(listVec));
+            }
+
+          for (int64_t newFieldIdx : fieldIdices) {
+            newMemberIndices.emplace_back(
+                llvm::SmallVector<int64_t>(1, newFieldIdx));
+          }
+
+          op.setMembersIndexAttr(
+              builder.create2DI64ArrayAttr(newMemberIndices));
+        }
+
+        return mlir::WalkResult::advance();
+      });
 
       func->walk([&](mlir::omp::MapInfoOp op) {
         // TODO: Currently only supports a single user for the MapInfoOp. This
diff --git a/offload/test/offloading/fortran/implicit-record-field-mapping.f90 b/offload/test/offloading/fortran/implicit-record-field-mapping.f90
new file mode 100644
index 00000000000000..8bde033d4ef4c3
--- /dev/null
+++ b/offload/test/offloading/fortran/implicit-record-field-mapping.f90
@@ -0,0 +1,63 @@
+! Test implicit mapping of alloctable record fields.
+
+! REQUIRES: flang, amdgpu
+
+! This fails only because it needs the Fortran runtime built for device. If this
+! is avaialbe, this test succeedds when run.
+! XFAIL: *
+
+! RUN: %libomptarget-compile-fortran-generic
+! RUN: env LIBOMPTARGET_INFO=16 %libomptarget-run-generic 2>&1 | %fcheck-generic
+module record_m
+  implicit none
+
+  private
+  public :: recored_t
+
+  type recored_t
+    real, allocatable :: not_to_implicitly_map(:)
+    real, allocatable :: to_implicitly_map(:)
+  end type
+
+  interface recored_t
+  end interface
+end module record_m
+
+program concurrent_inferences
+  use record_m, only : recored_t
+  implicit none
+
+  type(recored_t) :: dst_record
+  real :: src_array(10)
+  real :: dst_sum, src_sum
+  integer :: i
+
+  call random_number(src_array)
+  dst_sum = 0
+  src_sum = 0
+
+  do i=1,10
+    src_sum = src_sum + src_array(i)
+  end do
+  print *, "src_sum=", src_sum
+
+  !$omp target map(from: dst_sum)
+    dst_record%to_implicitly_map = src_array
+    dst_sum = 0
+
+    do i=1,10
+      dst_sum = dst_sum + dst_record%to_implicitly_map(i)
+    end do
+  !$omp end target
+
+  print *, "dst_sum=", dst_sum
+
+  if (src_sum == dst_sum) then
+    print *, "Test succeeded!"
+  else
+    print *, "Test failed!", " dst_sum=", dst_sum, "vs. src_sum=", src_sum
+  endif
+end program
+
+! CHECK: "PluginInterface" device {{[0-9]+}} info: Launching kernel {{.*}}
+! CHECK: Test succeeded!

Comment on lines -624 to -632
// TODO: Might need revisiting to handle for non-shared clauses
if (!symAddr) {
if (const auto *details =
sym->detailsIf<Fortran::semantics::HostAssocDetails>()) {
symAddr = converter.getSymbolAddress(details->symbol());
rawInput = symAddr;
}
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a dead condition since we call symAddr.getDefiningOp() above. I think we should assert instread.

@ergawy ergawy marked this pull request as draft November 27, 2024 11:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category offload
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants