diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp index 0a274373ce8a..19a7b3a08402 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp @@ -12,7 +12,6 @@ #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h" #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.h" #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.h" -#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h" #include "iree/compiler/Codegen/Utils/VectorOpUtils.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLForwardCompat.h" @@ -45,169 +44,34 @@ #define GET_ATTRDEF_CLASSES #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp.inc" -using LayoutDimension = mlir::iree_compiler::IREE::VectorExt::LayoutDimension; -using LayoutDimensionAttr = - mlir::iree_compiler::IREE::VectorExt::LayoutDimensionAttr; -using VectorLayoutInterface = - mlir::iree_compiler::IREE::VectorExt::VectorLayoutInterface; -using PerDimLayoutAttr = mlir::iree_compiler::IREE::VectorExt::PerDimLayoutAttr; -using LayoutAttr = mlir::iree_compiler::IREE::VectorExt::LayoutAttr; -using NestedLayoutAttr = mlir::iree_compiler::IREE::VectorExt::NestedLayoutAttr; - namespace mlir::iree_compiler::IREE::GPU { -namespace { -// Struct containing abstract MMA shape and type information. -struct OpaqueMmaLayout { - int64_t mSize = 0; - int64_t nSize = 0; - int64_t kSize = 0; - Type aType; - Type bType; - Type cType; -}; - -// Struct containing concrete MMA shape, type, and layout information. -struct ConcreteMmaLayout { - OpaqueMmaLayout base; - PerDimLayoutAttr aMLayout; - PerDimLayoutAttr aKLayout; - PerDimLayoutAttr bKLayout; - PerDimLayoutAttr bNLayout; - PerDimLayoutAttr cMLayout; - PerDimLayoutAttr cNLayout; -}; -} // namespace - //===----------------------------------------------------------------------===// -// #iree_gpu.mma_vector_layout +// MMA intrinsics semantics: shapes, layouts, operand element types. //===----------------------------------------------------------------------===// -static PerDimLayoutAttr getBatchedPerDimLayoutAttr(LayoutDimensionAttr batchDim, - PerDimLayoutAttr baseLayout, - int64_t problemSize, - int64_t fragmentDimSize) { - assert(problemSize % fragmentDimSize == 0 && - "invalid layout fragment for problem size"); - - SmallVector dimAttrs(baseLayout.getLabels()); - dimAttrs.insert(dimAttrs.begin(), batchDim); - - SmallVector shapes(baseLayout.getShapes()); - shapes.insert(shapes.begin(), problemSize / fragmentDimSize); - auto layout = - PerDimLayoutAttr::get(baseLayout.getContext(), dimAttrs, shapes); - return layout; +static int getBlockSize(MMAIntrinsic /*intrinsic*/) { + // Not supporting any block size other than 1 at the moment. + return 1; } -// Get the batched layout attributes for the given fragment layouts, indexing -// map, and problem shape. The canonical fragment map is used to compare against -// the problem map |indexingMap|. For example, for mma fragment B (RHS): -// -// indexingMap = affine_map<(d0, d1, d2) -> (d1, d2) # Transposed B -// fragmentMap = affine_map<(d0, d1, d2) -> (d2, d1) -// problemShape = [32, 64] -// fragmentSize = [16, 8] -// fragmentLayouts = [kLayout, nLayout] -// -// Gives batched layout -// -// Dim0 Layout = [BATCHX, nLayoutLabels], [8, nLayoutShape] -// Dim1 Layout = [BATCHY, kLayoutLabels], [2, kLayoutShape] -static LayoutAttr -getBatchedLayoutAttr(AffineMap indexingMap, AffineMap fragmentMap, - ArrayRef problemShape, - ArrayRef fragmentSize, - ArrayRef fragmentLayouts) { - // Current distribution to MFMA operations does not support batched - // contractions so that is reflected here. - assert(indexingMap.getNumResults() == 2 && - "invalid indexing map to non-batched simple contraction"); - - LayoutDimensionAttr batchX = LayoutDimensionAttr::get( - indexingMap.getContext(), LayoutDimension::BATCHX); - LayoutDimensionAttr batchY = LayoutDimensionAttr::get( - indexingMap.getContext(), LayoutDimension::BATCHY); - - SmallVector perDimAttrs; - for (auto [expr, batchType] : - llvm::zip_equal(indexingMap.getResults(), - SmallVector{batchX, batchY})) { - auto maybeResultPosition = fragmentMap.getResultPosition(expr); - assert(maybeResultPosition && "fragment map and problem map mismatch"); - int64_t idx = *maybeResultPosition; - perDimAttrs.push_back(getBatchedPerDimLayoutAttr( - batchType, fragmentLayouts[idx], problemShape[idx], fragmentSize[idx])); - } - - return LayoutAttr::get(indexingMap.getContext(), perDimAttrs); +static uint32_t getArchID(MMAIntrinsic intrinsic) { + return static_cast(intrinsic) & 0xFF00; } -static FailureOr> -getContractionLayout(vector::ContractionOp contract, ConcreteMmaLayout layout) { - MLIRContext *context = contract.getContext(); - FailureOr maybeContractionDims = - linalg::inferContractionDims(contract.getIndexingMapsArray()); - if (failed(maybeContractionDims)) { - return failure(); - } - auto contractionDims = *maybeContractionDims; - // TODO: Relax this condition to strictly alignment requirements. - if (contractionDims.k.size() != 1 || contractionDims.m.size() != 1 || - contractionDims.n.size() != 1) { - return failure(); - } - // TODO: Support batched contractions. - if (contractionDims.batch.size() > 0) { - return failure(); - } - unsigned mDim = contractionDims.m[0]; - unsigned nDim = contractionDims.n[0]; - unsigned kDim = contractionDims.k[0]; - - SmallVector iterationBounds; - contract.getIterationBounds(iterationBounds); - - int64_t problemMSize = iterationBounds[mDim]; - int64_t problemNSize = iterationBounds[nDim]; - int64_t problemKSize = iterationBounds[kDim]; - - int64_t mSize = layout.base.mSize; - int64_t nSize = layout.base.nSize; - int64_t kSize = layout.base.kSize; - - // The problem size currently must be strictly aligned to the size of the mma. - // This is expected to succeed assuming the correct [masked] vector size was - // set at strategy configuration time (for this mma). - if (problemMSize % mSize != 0 || problemNSize % nSize || - problemKSize % kSize) { - return failure(); - } +static bool is_AMD_MFMA(MMAIntrinsic intrinsic) { + return getArchID(intrinsic) >= 0x1000 && getArchID(intrinsic) <= 0x17FF; +} - LayoutAttr aLayout = getBatchedLayoutAttr( - contract.getIndexingMapsArray()[0], - AffineMap::getMultiDimMapWithTargets(3, {mDim, kDim}, context), - {problemMSize, problemKSize}, {mSize, kSize}, - {layout.aMLayout, layout.aKLayout}); - LayoutAttr bLayout = getBatchedLayoutAttr( - contract.getIndexingMapsArray()[1], - AffineMap::getMultiDimMapWithTargets(3, {kDim, nDim}, context), - {problemKSize, problemNSize}, {kSize, nSize}, - {layout.bKLayout, layout.bNLayout}); - LayoutAttr cLayout = getBatchedLayoutAttr( - contract.getIndexingMapsArray()[2], - AffineMap::getMultiDimMapWithTargets(3, {mDim, nDim}, context), - {problemMSize, problemNSize}, {mSize, nSize}, - {layout.cMLayout, layout.cNLayout}); - - return std::make_tuple(aLayout, bLayout, cLayout); +static bool is_AMD_WMMA(MMAIntrinsic intrinsic) { + return getArchID(intrinsic) >= 0x1800 && getArchID(intrinsic) <= 0x1FFF; } -//===----------------------------------------------------------------------===// -// Layout Attribute Building Helpers -//===----------------------------------------------------------------------===// +static int64_t getIntrinsicSubgroupSize(MMAIntrinsic intrinsic) { + // Not using Wave64 at all at the moment, so the only place where the + // subgroup size is CDNA* architectures. + return is_AMD_MFMA(intrinsic) ? 64 : 32; +} static std::tuple getABCElementTypes(MLIRContext *context, MMAIntrinsic intrinsic) { @@ -263,233 +127,6 @@ static std::tuple getABCElementTypes(MLIRContext *context, return {}; } -template -static OpaqueMmaLayout getOpaqueMMALayout(MLIRContext *context, - MMAIntrinsicType intrinsic) { - OpaqueMmaLayout o; - std::tie(o.aType, o.bType, o.cType) = getABCElementTypes(context, intrinsic); - auto lhs = getSingleSubgroupLayout(intrinsic, MMAFragment::Lhs); - auto rhs = getSingleSubgroupLayout(intrinsic, MMAFragment::Rhs); - o.mSize = lhs.outer[0] * lhs.thread[0] * lhs.element[0]; - o.kSize = lhs.outer[1] * lhs.thread[1] * lhs.element[1]; - o.nSize = rhs.outer[1] * rhs.thread[1] * rhs.element[1]; - return o; -} - -static std::tuple -getPerDimLayoutAttrs(MLIRContext *context, TileSwizzle swizzle) { - // Step 1: obtain the swizzled tile shape, but keeping track of the source - // dimension indices. - struct SrcIndexAndSwizzleDim { - size_t srcIndex; - TileSwizzle::Dim dim; - }; - SmallVector swizzledShape; - for (auto [i, e] : llvm::enumerate(swizzle.expandShape)) { - for (TileSwizzle::Dim d : e) { - swizzledShape.push_back(SrcIndexAndSwizzleDim{i, d}); - } - } - applyPermutationToVector(swizzledShape, swizzle.permutation); - - // Step 2: collect the appropriate labels to use for the swizzled dims. - LayoutDimension internalLabels[] = {LayoutDimension::VECTORZ, - LayoutDimension::VECTORY, - LayoutDimension::VECTORX}; - LayoutDimension crossThreadLabels[] = { - LayoutDimension::LANEZ, LayoutDimension::LANEY, LayoutDimension::LANEX}; - auto internalLabelIter = std::end(internalLabels); - auto crossThreadLabelIter = std::end(crossThreadLabels); - for (SrcIndexAndSwizzleDim d : swizzledShape) { - if (d.dim.kind == TileSwizzle::Dim::Kind::Internal) { - assert(internalLabelIter != std::begin(internalLabels)); - --internalLabelIter; - } else if (d.dim.kind == TileSwizzle::Dim::Kind::CrossThread) { - assert(crossThreadLabelIter != std::begin(crossThreadLabels)); - --crossThreadLabelIter; - } else { - assert(false && "unexpected dimension kind in intrinsic swizzle"); - } - } - - // Step 3: put together the result PerDimLayoutAttr'd for the two source dims. - SmallVector labels[2]; - SmallVector shape[2]; - for (SrcIndexAndSwizzleDim d : swizzledShape) { - shape[d.srcIndex].push_back(d.dim.size); - auto &labelIterRef = (d.dim.kind == TileSwizzle::Dim::Kind::Internal) - ? internalLabelIter - : crossThreadLabelIter; - labels[d.srcIndex].push_back(LayoutDimensionAttr::get( - context, static_cast(*labelIterRef++))); - } - return {PerDimLayoutAttr::get(context, labels[0], shape[0]), - PerDimLayoutAttr::get(context, labels[1], shape[1])}; -}; - -static ConcreteMmaLayout getConcreteMMALayout(MLIRContext *context, - MMAIntrinsic intrinsic) { - auto opaque = getOpaqueMMALayout(context, intrinsic); - ConcreteMmaLayout concreteLayout; - concreteLayout.base = opaque; - auto lhsSwizzle = getIntrinsicSwizzle(intrinsic, MMAFragment::Lhs); - auto rhsSwizzle = getIntrinsicSwizzle(intrinsic, MMAFragment::Rhs); - auto accSwizzle = getIntrinsicSwizzle(intrinsic, MMAFragment::Acc); - std::tie(concreteLayout.aMLayout, concreteLayout.aKLayout) = - getPerDimLayoutAttrs(context, lhsSwizzle); - std::tie(concreteLayout.bNLayout, concreteLayout.bKLayout) = - getPerDimLayoutAttrs(context, rhsSwizzle); - std::tie(concreteLayout.cMLayout, concreteLayout.cNLayout) = - getPerDimLayoutAttrs(context, accSwizzle); - return concreteLayout; -} - -//===----------------------------------------------------------------------===// -// MmaInterface Attribute Helper Functions -//===----------------------------------------------------------------------===// - -MMASingleSubgroupLayout getASingleSubgroupLayout(MmaInterfaceAttr mmaKind) { - if (auto mmaAttr = dyn_cast(mmaKind)) { - return mmaAttr.getASingleSubgroupLayout(); - } else if (auto vmmaAttr = dyn_cast(mmaKind)) { - return vmmaAttr.getASingleSubgroupLayout(); - } else { - assert(false && "unhandled MMA Interface type."); - return {}; - } -} - -MMASingleSubgroupLayout getBSingleSubgroupLayout(MmaInterfaceAttr mmaKind) { - if (auto mmaAttr = dyn_cast(mmaKind)) { - return mmaAttr.getBSingleSubgroupLayout(); - } else if (auto vmmaAttr = dyn_cast(mmaKind)) { - return vmmaAttr.getBSingleSubgroupLayout(); - } else { - assert(false && "unhandled MMA Interface type."); - return {}; - } -} - -MMASingleSubgroupLayout getCSingleSubgroupLayout(MmaInterfaceAttr mmaKind) { - if (auto mmaAttr = dyn_cast(mmaKind)) { - return mmaAttr.getCSingleSubgroupLayout(); - } else if (auto vmmaAttr = dyn_cast(mmaKind)) { - return vmmaAttr.getCSingleSubgroupLayout(); - } else { - assert(false && "unhandled MMA Interface type."); - return {}; - } -} - -//===----------------------------------------------------------------------===// -// MFMA Attributes -//===----------------------------------------------------------------------===// - -Attribute MMAAttr::parse(AsmParser &p, Type type) { - if (failed(p.parseLess())) - return {}; - - FailureOr mmaIntrinsic = - FieldParser::parse(p); - if (failed(mmaIntrinsic)) { - p.emitError(p.getCurrentLocation(), "failed to parse mfma type identifier"); - return {}; - } - - if (failed(p.parseGreater())) - return {}; - - return get(p.getContext(), mmaIntrinsic->getValue()); -} - -void MMAAttr::print(AsmPrinter &p) const { - auto &os = p.getStream(); - os << "<"; - os << stringifyMMAIntrinsic(getIntrinsic().getValue()); - os << ">"; -} - -MMAAttr MMAAttr::get(MLIRContext *context, MMAIntrinsic type) { - auto layout = getOpaqueMMALayout(context, type); - return Base::get(context, MMAIntrinsicAttr::get(context, type), layout.mSize, - layout.nSize, layout.kSize, layout.aType, layout.bType, - layout.cType); -} - -std::tuple MMAAttr::getABCElementTypes() const { - return {getAType(), getBType(), getCType()}; -} - -std::tuple MMAAttr::getMNKShape() const { - return {getMSize(), getNSize(), getKSize()}; -} - -template -static VectorType getVectorType(MLIRContext *context, - MMAIntrinsicType intrinsic, - MMAFragment fragment) { - auto o = getOpaqueMMALayout(context, intrinsic); - auto s = getSingleSubgroupLayout(intrinsic, fragment); - Type elemType = (fragment == MMAFragment::Lhs) ? o.aType - : (fragment == MMAFragment::Rhs) ? o.bType - : o.cType; - return VectorType::get( - {s.element[0] * s.element[1] * s.outer[0] * s.outer[1]}, elemType); -} - -std::tuple -MMAAttr::getABCVectorTypes() const { - MLIRContext *context = getContext(); - MMAIntrinsic intrinsic = getIntrinsic().getValue(); - VectorType aVecType = getVectorType(context, intrinsic, MMAFragment::Lhs); - VectorType bVecType = getVectorType(context, intrinsic, MMAFragment::Rhs); - VectorType cVecType = getVectorType(context, intrinsic, MMAFragment::Acc); - return {aVecType, bVecType, cVecType}; -} - -FailureOr> -MMAAttr::getContractionLayout(vector::ContractionOp contract) const { - ConcreteMmaLayout layout = - getConcreteMMALayout(contract->getContext(), getIntrinsic().getValue()); - return IREE::GPU::getContractionLayout(contract, layout); -} - -static int getBlockSize(MMAIntrinsic /*intrinsic*/) { - // Not supporting any block size other than 1 at the moment. - return 1; -} - -int64_t MMAAttr::getBlockSize() const { - return IREE::GPU::getBlockSize(getIntrinsic().getValue()); -} - -static uint32_t getArchID(MMAIntrinsic intrinsic) { - return static_cast(intrinsic) & 0xFF00; -} - -static bool is_AMD_MFMA(MMAIntrinsic intrinsic) { - return getArchID(intrinsic) >= 0x1000 && getArchID(intrinsic) <= 0x17FF; -} - -static bool is_AMD_WMMA(MMAIntrinsic intrinsic) { - return getArchID(intrinsic) >= 0x1800 && getArchID(intrinsic) <= 0x1FFF; -} - -static int64_t getIntrinsicSubgroupSize(MMAIntrinsic intrinsic) { - // Not using Wave64 at all at the moment, so the only place where the - // subgroup size is CDNA* architectures. - return is_AMD_MFMA(intrinsic) ? 64 : 32; -} - -int64_t MMAAttr::getSubgroupSize() const { - return getIntrinsicSubgroupSize(getIntrinsic().getValue()); -} - -FailureOr MMAAttr::getMmaScope() const { - return IREE::GPU::MMAScope::Subgroup; -} - MMASingleSubgroupLayout getSingleSubgroupLayout(MMAIntrinsic intrinsic, MMAFragment fragment) { switch (intrinsic) { @@ -637,6 +274,139 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(MMAIntrinsic intrinsic, return {}; } +template +static OpaqueMmaLayout getOpaqueMMALayout(MLIRContext *context, + MMAIntrinsicType intrinsic) { + OpaqueMmaLayout o; + std::tie(o.aType, o.bType, o.cType) = getABCElementTypes(context, intrinsic); + auto lhs = getSingleSubgroupLayout(intrinsic, MMAFragment::Lhs); + auto rhs = getSingleSubgroupLayout(intrinsic, MMAFragment::Rhs); + o.mSize = lhs.outer[0] * lhs.thread[0] * lhs.element[0]; + o.kSize = lhs.outer[1] * lhs.thread[1] * lhs.element[1]; + o.nSize = rhs.outer[1] * rhs.thread[1] * rhs.element[1]; + return o; +} + +OpaqueMmaLayout getOpaqueMMALayout(MLIRContext *context, + IREE::GPU::MMAIntrinsic intrinsic) { + return getOpaqueMMALayout(context, intrinsic); +} + +//===----------------------------------------------------------------------===// +// MmaInterface Attribute Helper Functions +//===----------------------------------------------------------------------===// + +MMASingleSubgroupLayout getASingleSubgroupLayout(MmaInterfaceAttr mmaKind) { + if (auto mmaAttr = dyn_cast(mmaKind)) { + return mmaAttr.getASingleSubgroupLayout(); + } + if (auto vmmaAttr = dyn_cast(mmaKind)) { + return vmmaAttr.getASingleSubgroupLayout(); + } + assert(false && "unhandled MMA Interface type."); + return {}; +} + +MMASingleSubgroupLayout getBSingleSubgroupLayout(MmaInterfaceAttr mmaKind) { + if (auto mmaAttr = dyn_cast(mmaKind)) { + return mmaAttr.getBSingleSubgroupLayout(); + } + if (auto vmmaAttr = dyn_cast(mmaKind)) { + return vmmaAttr.getBSingleSubgroupLayout(); + } + assert(false && "unhandled MMA Interface type."); + return {}; +} + +MMASingleSubgroupLayout getCSingleSubgroupLayout(MmaInterfaceAttr mmaKind) { + if (auto mmaAttr = dyn_cast(mmaKind)) { + return mmaAttr.getCSingleSubgroupLayout(); + } + if (auto vmmaAttr = dyn_cast(mmaKind)) { + return vmmaAttr.getCSingleSubgroupLayout(); + } + assert(false && "unhandled MMA Interface type."); + return {}; +} + +//===----------------------------------------------------------------------===// +// MMA Attributes +//===----------------------------------------------------------------------===// + +Attribute MMAAttr::parse(AsmParser &p, Type type) { + if (failed(p.parseLess())) + return {}; + + FailureOr mmaIntrinsic = + FieldParser::parse(p); + if (failed(mmaIntrinsic)) { + p.emitError(p.getCurrentLocation(), "failed to parse mfma type identifier"); + return {}; + } + + if (failed(p.parseGreater())) + return {}; + + return get(p.getContext(), mmaIntrinsic->getValue()); +} + +void MMAAttr::print(AsmPrinter &p) const { + auto &os = p.getStream(); + os << "<"; + os << stringifyMMAIntrinsic(getIntrinsic().getValue()); + os << ">"; +} + +MMAAttr MMAAttr::get(MLIRContext *context, MMAIntrinsic type) { + auto layout = getOpaqueMMALayout(context, type); + return Base::get(context, MMAIntrinsicAttr::get(context, type), layout.mSize, + layout.nSize, layout.kSize, layout.aType, layout.bType, + layout.cType); +} + +std::tuple MMAAttr::getABCElementTypes() const { + return {getAType(), getBType(), getCType()}; +} + +std::tuple MMAAttr::getMNKShape() const { + return {getMSize(), getNSize(), getKSize()}; +} + +template +static VectorType getVectorType(MLIRContext *context, + MMAIntrinsicType intrinsic, + MMAFragment fragment) { + auto o = getOpaqueMMALayout(context, intrinsic); + auto s = getSingleSubgroupLayout(intrinsic, fragment); + Type elemType = (fragment == MMAFragment::Lhs) ? o.aType + : (fragment == MMAFragment::Rhs) ? o.bType + : o.cType; + return VectorType::get( + {s.element[0] * s.element[1] * s.outer[0] * s.outer[1]}, elemType); +} + +std::tuple +MMAAttr::getABCVectorTypes() const { + MLIRContext *context = getContext(); + MMAIntrinsic intrinsic = getIntrinsic().getValue(); + VectorType aVecType = getVectorType(context, intrinsic, MMAFragment::Lhs); + VectorType bVecType = getVectorType(context, intrinsic, MMAFragment::Rhs); + VectorType cVecType = getVectorType(context, intrinsic, MMAFragment::Acc); + return {aVecType, bVecType, cVecType}; +} + +int64_t MMAAttr::getBlockSize() const { + return IREE::GPU::getBlockSize(getIntrinsic().getValue()); +} + +int64_t MMAAttr::getSubgroupSize() const { + return getIntrinsicSubgroupSize(getIntrinsic().getValue()); +} + +FailureOr MMAAttr::getMmaScope() const { + return IREE::GPU::MMAScope::Subgroup; +} + MMASingleSubgroupLayout MMAAttr::getASingleSubgroupLayout() const { return getSingleSubgroupLayout(getIntrinsic().getValue(), MMAFragment::Lhs); } @@ -781,22 +551,8 @@ LogicalResult MMAAttr::populateOperandOffsetsSizesStrides( SmallVector &offsets, SmallVector &sizes, SmallVector &strides) const { - MMASingleSubgroupLayout subgroupLayout; - switch (fragment) { - case IREE::GPU::MMAFragment::Lhs: { - subgroupLayout = getASingleSubgroupLayout(); - break; - } - case IREE::GPU::MMAFragment::Rhs: { - subgroupLayout = getBSingleSubgroupLayout(); - break; - } - case IREE::GPU::MMAFragment::Acc: { - subgroupLayout = getCSingleSubgroupLayout(); - break; - } - } - + MMASingleSubgroupLayout subgroupLayout = + getSingleSubgroupLayout(getIntrinsic().getValue(), fragment); SmallVector canonicalOffsets; SmallVector canonicalSizes; if (failed(populateCanonicalOffsetsSizesAndStrides( @@ -810,82 +566,6 @@ LogicalResult MMAAttr::populateOperandOffsetsSizesStrides( return success(); } -LogicalResult MMAAttr::materializeOperandConcreteShape( - OpBuilder &builder, IREE::GPU::MMAFragment fragment, Value operand, - std::optional> permutation, - SmallVector &reassociations, - RankedTensorType &resultType) const { - - SmallVector outerSizes; - SmallVector opaqueSizes; - auto [m, n, k] = getMNKShape(); - switch (fragment) { - case IREE::GPU::MMAFragment::Lhs: { - outerSizes = getASingleSubgroupLayout().outer; - opaqueSizes.append({m, k}); - break; - } - case IREE::GPU::MMAFragment::Rhs: { - outerSizes = getBSingleSubgroupLayout().outer; - opaqueSizes.append({k, n}); - break; - } - case IREE::GPU::MMAFragment::Acc: { - outerSizes = getCSingleSubgroupLayout().outer; - opaqueSizes.append({m, n}); - break; - } - } - if (permutation.has_value()) { - if (permutation.value().size() != outerSizes.size()) { - return failure(); - } - applyPermutationToVector(opaqueSizes, permutation.value()); - applyPermutationToVector(outerSizes, permutation.value()); - } - - // Inner tile must have sizes matching the opaque layout. - auto operandType = llvm::cast(operand.getType()); - ArrayRef operandShape = operandType.getShape(); - SmallVector innerShape(operandShape.end() - opaqueSizes.size(), - operandShape.end()); - if (!llvm::equal(opaqueSizes, innerShape)) { - return failure(); - } - - // Expand the shape of the inner tile to reflect the MMA thread layout. - SmallVector resultShape(operandShape.begin(), - operandShape.end() - 2); - SmallVector reInds = - llvm::map_to_vector(llvm::seq(resultShape.size()), - [](int64_t idx) -> ReassociationIndices { - return ReassociationIndices({idx}); - }); - int idx = reInds.size(); - for (auto [outer, native] : llvm::zip_equal(outerSizes, opaqueSizes)) { - // Skip expansion if the outer dim is unit as the SingleSubgroupLayout gives - // a guarantee that the |element| counts are contiguous within the layout, - // and a unit outer implies a single offset and size for that dimension. - if (outer == 1) { - resultShape.push_back(native); - reInds.push_back(ReassociationIndices({idx++})); - continue; - } - - // Reshape to [outer, native / outer] == [outer, thread * element]. This - // corresponds to |outer| repetitions of the thread/element sublayout. - resultShape.push_back(outer); - assert(native % outer == 0 && "invalid mma layout"); - resultShape.push_back(native / outer); - reInds.push_back(ReassociationIndices{idx, idx + 1}); - idx += 2; - } - - reassociations = reInds; - resultType = operandType.clone(resultShape); - return success(); -} - //===----------------------------------------------------------------------===// // DataTiledMMA Attributes //===----------------------------------------------------------------------===// @@ -1265,22 +945,8 @@ LogicalResult VirtualMMAAttr::populateOperandOffsetsSizesStrides( SmallVector &offsets, SmallVector &sizes, SmallVector &strides) const { - MMASingleSubgroupLayout subgroupLayout; - switch (fragment) { - case IREE::GPU::MMAFragment::Lhs: { - subgroupLayout = getASingleSubgroupLayout(); - break; - } - case IREE::GPU::MMAFragment::Rhs: { - subgroupLayout = getBSingleSubgroupLayout(); - break; - } - case IREE::GPU::MMAFragment::Acc: { - subgroupLayout = getCSingleSubgroupLayout(); - break; - } - } - + MMASingleSubgroupLayout subgroupLayout = + getSingleSubgroupLayout(getIntrinsic().getValue(), fragment); SmallVector canonicalOffsets; SmallVector canonicalSizes; if (failed(populateCanonicalOffsetsSizesAndStrides( @@ -1444,348 +1110,6 @@ MMASingleSubgroupLayout VirtualMMAAttr::getCSingleSubgroupLayout() const { return getSingleSubgroupLayout(getIntrinsic().getValue(), MMAFragment::Acc); } -//===----------------------------------------------------------------------===// -// MMA Schedule Attributes -//===----------------------------------------------------------------------===// - -/// Gets a unit vector of the given rank, but fills in the given dimensions -/// from the 2 element array |counts|. |dim0| is the position in the returned -/// vector to put the first element of |counts|, and |dim1| is the position to -/// put the second element. For example, -/// -/// rank = 3, counts = [5, 7], dim0 = 2, dim1 = 1 -/// returns [1, 5, 7] -SmallVector getUnitOfRankWithDims(int64_t rank, - ArrayRef counts, - int64_t dim0, int64_t dim1) { - assert(counts.size() == 2 && - "Unexpected non-rank 2 single subgroup dimension counts"); - SmallVector res(rank, 1); - res[dim0] = counts[0]; - res[dim1] = counts[1]; - return res; -} - -SmallVector getIdentityPerm(int64_t rank) { - return llvm::to_vector(llvm::seq(static_cast(0), rank)); -} - -/// Constructs an identity permutation with the given rank, except it applies -/// the given rank-2 |perm| to the two dimensions |dim0| and |dim1|, and then -/// swaps the positions of dim0 and dim1 in the final permutation. For example, -/// -/// rank = 3, perm = [1, 0], dim0 = 1, dim1 = 2 -/// returns [0, 1, 2] -/// -/// This is essentially just applying two rank-2 permutations to two particular -/// dimensions. First it applies |perm|, which corresponds to a permutation -/// needed by the underlying intrinsic, then it does another permutation based -/// on the order of actual dimensions for the MMA fragment. For example, for the -/// B matrix, dim0 = K and dim1 = N, so for the element order of an MFMA -/// 16x16x16, perm would be `[1, 0]`, however if the actual contraction is a -/// matmul_transpose_b, then the element order needs to be [0, 1]. -SmallVector getIdentityPermWithSwap(int64_t rank, - ArrayRef perm, - int64_t dim0, int64_t dim1) { - assert(perm.size() == 2 && - "Unexpected non-rank 2 single subgroup dimension order"); - SmallVector res = getIdentityPerm(rank); - if (perm[0] > perm[1]) { - std::swap(dim0, dim1); - } - if (dim0 > dim1) { - res[dim0] = dim1; - res[dim1] = dim0; - } - return res; -} - -/// Constructs the nested layout given the layout for a single subgroup and the -/// subgroup/batch counts and orders, as well as the dimensions along which to -/// distribute the intrinsic's layout. -/// -/// |outerDim| and |innerDim| refer to which dimensions are the outermost and -/// innermost for a canonical MK_KN_MN matrix multiply, for a particular -/// fragment. For example, for the B matrix of an MK_NK_MN matrix multiply, -/// we would have: -/// outerDim = 1 for the K dim -/// innerDim = 0 for the N dim -/// -/// For something like MK_NKN_MN with multiple N dims, it would typically be: -/// outerDim = 1 for K -/// innerDim = 2 for the second N dim -/// -/// Importantly these two dimensions always refer to the actual dimension -/// positions in the undistributed vector. For each fragment, this means: -/// A: [outerDim, innerDim] = [innerMostMDim, innerMostKDim] -/// B: [outerDim, innerDim] = [innerMostKDim, innerMostNDim] -/// C: [outerDim, innerDim] = [innerMostMDim, innerMostNDim] -/// -/// And here inner most is referential to the iteration order, not the order -/// they appear per fragment (because there is no relationship between the -/// dimension order of M in A and in C, for example). -NestedLayoutAttr createNestedLayout(MLIRContext *context, int64_t rank, - int64_t outerDim, int64_t innerDim, - SmallVector subgroupSizes, - SmallVector subgroupStrides, - SmallVector batchCount, - MMASingleSubgroupLayout counts) { - - LLVM_DEBUG({ - llvm::errs() << "Creating Nested Layout for::"; - llvm::errs() << "\n outerDim = " << outerDim; - llvm::errs() << "\n innerDim = " << innerDim; - llvm::errs() << "\n subgroupSizes: "; - llvm::interleaveComma(subgroupSizes, llvm::errs()); - llvm::errs() << "\n subgroupStrides: "; - llvm::interleaveComma(subgroupStrides, llvm::errs()); - llvm::errs() << "\n batchCount: "; - llvm::interleaveComma(batchCount, llvm::errs()); - llvm::errs() << "\n counts.outer: "; - llvm::interleaveComma(counts.outer, llvm::errs()); - llvm::errs() << "\n counts.thread: "; - llvm::interleaveComma(counts.thread, llvm::errs()); - llvm::errs() << "\n counts.element: "; - llvm::interleaveComma(counts.element, llvm::errs()); - llvm::errs() << "\n counts.tstrides: "; - llvm::interleaveComma(counts.tstrides, llvm::errs()); - llvm::errs() << "\n"; - }); - - SmallVector outerCount = - getUnitOfRankWithDims(rank, counts.outer, outerDim, innerDim); - SmallVector threadCount = - getUnitOfRankWithDims(rank, counts.thread, outerDim, innerDim); - SmallVector threadStrides = - getUnitOfRankWithDims(rank, counts.tstrides, outerDim, innerDim); - SmallVector elementCount = - getUnitOfRankWithDims(rank, counts.element, outerDim, innerDim); - - auto layoutAttr = NestedLayoutAttr::get(context, subgroupSizes, batchCount, - outerCount, threadCount, elementCount, - subgroupStrides, threadStrides); - return layoutAttr; -} - -FailureOr> -MMAScheduleAttr::getContractionLayout(VectorContractOpInfo &opInfo, - linalg::LinalgOp contractOp) const { - LLVM_DEBUG({ - llvm::errs() << "Getting mma layouts for:\n" << contractOp << "\n"; - llvm::errs() << "For schedule: " << *this << "\n"; - }); - - int64_t rank = contractOp.getIteratorTypesArray().size(); - auto mmaAttr = llvm::cast(getIntrinsic()); - MLIRContext *context = getContext(); - - SmallVector bounds = contractOp.getStaticLoopRanges(); - if (llvm::any_of(bounds, - [](int64_t x) { return x == ShapedType::kDynamic; })) { - return failure(); - } - - if (!llvm::all_of(opInfo.getBatchDims(), - [&bounds](int64_t dim) { return bounds[dim] == 1; })) { - LLVM_DEBUG({ llvm::errs() << "non-unit batch dimension\n"; }); - return failure(); - } - - // Get the concrete nested layout for each matrix. Note that the struct - // MMASingleSubgroupLayout contains the partial layout for the - // canonical (M, K) x (K, N) -> (M, N) matmul form; while the specific - // contract op we are looking at right now may not be exactly in that form. - // So here we need to permute/transpose the canonical layout to match with - // the concrete contract op. - - // Note that no matter how we permute/transpose the input contraction - // problem, the way we view the hardware warps remain the same--that is, - // from the hardware's perspective, a single warp has the same warp ID no - // matter what part of the contraction it works on. Similarly here, we are - // delinearizing the linearized GPU hardware lane ID into a n-D concatenated - // logical warp+thread using the subgroup/thread basis, so the subgroup - // basis should remain the same for all A/B/C matrix. - - auto [intrinsicM, intrinsicN, intrinsicK] = mmaAttr.getMNKShape(); - - SmallVector subgroupMBasis; - SmallVector batchMSizes; - int64_t currMCount = getSubgroupMCount(); - - auto divideGreedily = [](int64_t availableSubgroups, int64_t dimSize, - int64_t minDimSize) -> std::pair { - int64_t dividableDim = dimSize / minDimSize; - int64_t subgroupsUsed = std::gcd(availableSubgroups, dividableDim); - dividableDim /= subgroupsUsed; - int64_t batchesUsed = dividableDim; - return {subgroupsUsed, batchesUsed}; - }; - - // Greedily break up the M subgroup and batch counts along the "M" iteration - // bounds. We distribute as many residual subgroups as possible per M dim, - // and then divide the remaining along batch dims. The inner most M dim is - // always the one used for the intrinsic, meaning for a valid schedule, the - // computed batch counts and subgroup basis will satisfy totalMSize / - // intrinsicM = product(batchMSizes) * product(subgroupMBasis) - for (auto dim : opInfo.getMDims()) { - // Get the number of subgroups and batches used for this dimension based - // on the intrinsic size and the bound size. - int64_t subgroupsUsed, batchesUsed; - if (dim == opInfo.getMDims().back()) { - std::tie(subgroupsUsed, batchesUsed) = - divideGreedily(currMCount, bounds[dim], intrinsicM); - } else { - std::tie(subgroupsUsed, batchesUsed) = - divideGreedily(currMCount, bounds[dim], 1); - } - subgroupMBasis.push_back(subgroupsUsed); - batchMSizes.push_back(batchesUsed); - // Update available subgroup count. - currMCount /= subgroupsUsed; - } - - SmallVector subgroupNBasis; - SmallVector batchNSizes; - int64_t currNCount = getSubgroupNCount(); - - // Do the same for N dims. - for (auto dim : opInfo.getNDims()) { - // Get the number of subgroups and batches used for this dimension based - // on the intrinsic size and the bound size. - int64_t subgroupsUsed, batchesUsed; - if (dim == opInfo.getNDims().back()) { - std::tie(subgroupsUsed, batchesUsed) = - divideGreedily(currNCount, bounds[dim], intrinsicN); - } else { - std::tie(subgroupsUsed, batchesUsed) = - divideGreedily(currNCount, bounds[dim], 1); - } - subgroupNBasis.push_back(subgroupsUsed); - batchNSizes.push_back(batchesUsed); - // Update available subgroup count. - currNCount /= subgroupsUsed; - } - - SmallVector subgroupMStrides(subgroupMBasis.size()); - SmallVector subgroupNStrides(subgroupNBasis.size()); - - auto mDimVec = opInfo.getMDims(); - llvm::SmallDenseSet mDims(mDimVec.begin(), mDimVec.end()); - auto nDimVec = opInfo.getNDims(); - llvm::SmallDenseSet nDims(nDimVec.begin(), nDimVec.end()); - // Because we currently require all batch dimensions to be unit, the - // subgroup basis can be constructed from the M and N bases. To keep things - // simple, the current heuristic is to distribute the loop dimensions from - // outer to inner. - int64_t currStride = 1; - int64_t currM = subgroupMStrides.size() - 1; - int64_t currN = subgroupNStrides.size() - 1; - for (int64_t dim : llvm::reverse(llvm::seq(rank))) { - if (mDims.contains(dim)) { - subgroupMStrides[currM] = currStride; - currStride *= subgroupMBasis[currM]; - currM--; - continue; - } - - if (nDims.contains(dim)) { - subgroupNStrides[currN] = currStride; - currStride *= subgroupNBasis[currN]; - currN--; - continue; - } - } - - // C matrix layout - auto [m, n] = opInfo.getResultMNIndex(); - int64_t cRank = opInfo.getCRank(); - - // Get the M and N dims w.r.t. the dimensions of the C matrix. cMDims and - // cNDims are the M and N dimensions of the C matrix in the order they are - // iterated over in the contraction. - SmallVector cMDims = opInfo.outMDims; - SmallVector cNDims = opInfo.outNDims; - SmallVector cBatchSizes(cRank, 1); - SmallVector cSubgroupSizes(cRank, 1); - SmallVector cSubgroupStrides(cRank, 0); - for (auto [i, dim] : llvm::enumerate(cMDims)) { - cBatchSizes[dim] = batchMSizes[i]; - cSubgroupSizes[dim] = subgroupMBasis[i]; - cSubgroupStrides[dim] = subgroupMStrides[i]; - } - for (auto [i, dim] : llvm::enumerate(cNDims)) { - cBatchSizes[dim] = batchNSizes[i]; - cSubgroupSizes[dim] = subgroupNBasis[i]; - cSubgroupStrides[dim] = subgroupNStrides[i]; - } - - auto cLayout = createNestedLayout(context, cRank, m, n, - /*subgroupCount=*/cSubgroupSizes, - /*subgroupStrides=*/cSubgroupStrides, - /*batchCount=*/cBatchSizes, - getCSingleSubgroupLayout(mmaAttr)); - LLVM_DEBUG({ llvm::errs() << "C layout: " << cLayout << "\n"; }); - - // A matrix layout - auto [afm, bfn] = opInfo.getOperandMNIndex(); - auto [afk, bfk] = opInfo.getOperandKIndex(); - - int64_t aRank = opInfo.getARank(); - - SmallVector aMDims = opInfo.lhsMDims; - SmallVector aBatchSizes(aRank, 1); - SmallVector aSubgroupSizes(aRank, 1); - SmallVector aSubgroupStrides(aRank, 0); - for (auto [i, dim] : llvm::enumerate(aMDims)) { - aBatchSizes[dim] = batchMSizes[i]; - aSubgroupSizes[dim] = subgroupMBasis[i]; - aSubgroupStrides[dim] = subgroupMStrides[i]; - } - for (auto [kDim, lhsKDim] : - llvm::zip_equal(opInfo.getKDims(), opInfo.lhsKDim)) { - aBatchSizes[lhsKDim] = bounds[kDim]; - } - aBatchSizes[afk] = bounds[opInfo.getKDims().back()] / intrinsicK; - - auto aLayout = createNestedLayout(context, aRank, afm, afk, - /*subgroupCount=*/aSubgroupSizes, - /*subgroupStrides=*/aSubgroupStrides, - /*batchCount=*/aBatchSizes, - getASingleSubgroupLayout(mmaAttr)); - LLVM_DEBUG({ llvm::errs() << "A layout: " << aLayout << "\n"; }); - - int64_t bRank = opInfo.getBRank(); - - SmallVector bNDims = opInfo.rhsNDims; - SmallVector bBatchSizes(bRank, 1); - SmallVector bSubgroupSizes(bRank, 1); - SmallVector bSubgroupStrides(bRank, 0); - for (auto [i, dim] : llvm::enumerate(bNDims)) { - bBatchSizes[dim] = batchNSizes[i]; - bSubgroupSizes[dim] = subgroupNBasis[i]; - bSubgroupStrides[dim] = subgroupNStrides[i]; - } - for (auto [kDim, rhsKDim] : - llvm::zip_equal(opInfo.getKDims(), opInfo.rhsKDim)) { - bBatchSizes[rhsKDim] = bounds[kDim]; - } - bBatchSizes[bfk] = bounds[opInfo.getKDims().back()] / intrinsicK; - - auto bLayout = createNestedLayout(context, bRank, bfk, bfn, - /*subgroupCount=*/bSubgroupSizes, - /*subgroupStrides=*/bSubgroupStrides, - /*batchCount=*/bBatchSizes, - getBSingleSubgroupLayout(mmaAttr)); - LLVM_DEBUG({ llvm::errs() << "B layout: " << bLayout << "\n"; }); - - std::tuple - result = {aLayout, bLayout, cLayout}; - return result; -} - //===----------------------------------------------------------------------===// // Target Attributes //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h index 794206673a68..92ce4f493380 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h @@ -22,12 +22,40 @@ namespace mlir::iree_compiler::IREE::GPU { -// Partial nested layout for an MMA intrinsic's matrix input/output inside -// a single subgroup. +// Struct describing the detailed subgroup-level layout of a MMA intrinsic. +// Together with element type information and subgroup size, it completes the +// full description of the semantics of a MMA intrinsic. +// +// Note: It is not possible to infer subgroup size from the information in this +// struct. The product of the `thread` sizes here is often, but not always equal +// to subgroup size. When the product of the `thread` sizes (call that product +// `P`) is smaller than subgroup size, it must be a divisor of it, and the +// semantics in that case are that threads within the subgroup whose thread-ids +// differ by a multiple of `P`, are accessing the same elements. +// +// Example observed in RDNA3 WMMA Wave64 intrinsics: +// If the subgroup size is 64 but the product `P` of `thread` sizes is 32, that +// means that each element is being accessed by 2 threads (2 = 64/32), and the +// threads accessing the same element are those whose tids are exactly 32 apart. struct MMASingleSubgroupLayout { + // Internal dimensions (as in TileSwizzle::Dim::Kind::Internal) that are + // outer-most in the layout. This happens when a MMA op, seen on a single + // thread, has an operand that consists of multiple elements, and these elems + // are NOT contiguous. + // This is not used by every MMA op; ops which don't use that simply have 1's. SmallVector outer; + // Cross-thread dimensions (as in TileSwizzle::Dim::Kind::CrossThread). + // This is the kind of dimension that is present in all GPU MMA ops, by + // definition of "SIMT". It is still possible for one of the `thread` dims to + // be 1, but not both. SmallVector thread; + // Strides corresponding to the cross-thread dimensions. SmallVector tstrides; + // Internal dimensions (as in TileSwizzle::Dim::Kind::Internal) that are + // inner-most in the layout. This happens when a MMA op, seen on a single + // thread, has an operand that consists of multiple elements, and these elems + // are NOT contiguous. + // This is not used by every MMA op; ops which don't use that simply have 1's. SmallVector element; }; @@ -43,6 +71,22 @@ MMASingleSubgroupLayout getBSingleSubgroupLayout(MmaInterfaceAttr mmaKind); MMASingleSubgroupLayout getCSingleSubgroupLayout(MmaInterfaceAttr mmaKind); +// Struct describing the shape of a MMA operation, but not the detailed layout. +// TODO(bjacob): the only user outside of IREEGPUAttrs.cpp is +// LLVMGPU/TransformExtensions, so maybe make that internal again if/when that +// goes away. +struct OpaqueMmaLayout { + int64_t mSize = 0; + int64_t nSize = 0; + int64_t kSize = 0; + Type aType; + Type bType; + Type cType; +}; + +OpaqueMmaLayout getOpaqueMMALayout(MLIRContext *context, + IREE::GPU::MMAIntrinsic intrinsic); + } // namespace mlir::iree_compiler::IREE::GPU // clang-format off diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td index bc6f2d3584f3..51674dbd1f2b 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td @@ -148,7 +148,6 @@ class IREEGPU_MmaVectorLayoutAttr : DeclareAttrInterfaceMethods : "getCSingleSubgroupLayout", "buildMmaOperation", "populateOperandOffsetsSizesStrides", - "materializeOperandConcreteShape", ]> ]> { let cppNamespace = "::mlir::iree_compiler::IREE::GPU"; @@ -259,19 +257,11 @@ def IREEGPU_DataTiledMMAAttr : |intrinsic| field specifies which particular MMA intrinsic is targeted by the data-tiling. - The tile swizzling already happens, so the attribute does not need to - implement materializeOperandConcreteShape interface method. E.g., if the - target intrinsic is MFMA_F32_16x16x4_F32: - - The inner tile shape of LHS is 4x16. - - The inner tile shape of RHS is 4x16. - - The inner tile shape of ACC is 4x16x4. - - Furthermore, the unrolling and interleaving can be represented with the - attribute. In the concept of data-tiling, we always unroll the parallel - dimensions (i.e., M, N dimensions) to be outermost, and interleave the - unrolled K dimension. I.e., the unrolled K dimension becomes the innermost - dimension. The constraint can be relaxed based on data-tiling needs. The - additional information can be added to `parameters`. + The other fields default to one, and that default results in a single + intrinsic equivalent to MMAAttr, while values greater than one result in + wider "kernels" consisting of multiple intrinsics, with the data layout + already swizzled into a tile layout that allows each intrinsic to access + data at an offset that's as simple as possible a mapping from the thread ID. }]; let assemblyFormat = "`<` struct(params) `>`"; @@ -369,15 +359,6 @@ def IREEGPU_MmaScheduleAttr : AttrDef { ); let assemblyFormat = "`<` struct(params) `>`"; - - let extraClassDeclaration = [{ - // Returns the A/B/C matrix concrete layout targeting |contractOp|. - ::mlir::FailureOr<::std::tuple> - getContractionLayout(::mlir::iree_compiler::VectorContractOpInfo &opInfo, - ::mlir::linalg::LinalgOp contractOp) const; - }]; } //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/ConcretizeMmaShapes.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/ConcretizeMmaShapes.cpp index 94bcf3dbe593..6b2e83451edc 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/ConcretizeMmaShapes.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/ConcretizeMmaShapes.cpp @@ -23,7 +23,80 @@ struct ConcretizeMmaShapesPass final using ConcretizeMmaShapesPassBase::ConcretizeMmaShapesPassBase; void runOnOperation() override; }; -} // namespace + +LogicalResult materializeOperandConcreteShape( + OpBuilder &builder, MMAAttr mma, IREE::GPU::MMAFragment fragment, + Value operand, std::optional> permutation, + SmallVector &reassociations, + RankedTensorType &resultType) { + + SmallVector outerSizes; + SmallVector opaqueSizes; + auto [m, n, k] = mma.getMNKShape(); + switch (fragment) { + case IREE::GPU::MMAFragment::Lhs: { + outerSizes = mma.getASingleSubgroupLayout().outer; + opaqueSizes.append({m, k}); + break; + } + case IREE::GPU::MMAFragment::Rhs: { + outerSizes = mma.getBSingleSubgroupLayout().outer; + opaqueSizes.append({k, n}); + break; + } + case IREE::GPU::MMAFragment::Acc: { + outerSizes = mma.getCSingleSubgroupLayout().outer; + opaqueSizes.append({m, n}); + break; + } + } + if (permutation.has_value()) { + if (permutation.value().size() != outerSizes.size()) { + return failure(); + } + applyPermutationToVector(opaqueSizes, permutation.value()); + applyPermutationToVector(outerSizes, permutation.value()); + } + + // Inner tile must have sizes matching the opaque layout. + auto operandType = llvm::cast(operand.getType()); + ArrayRef operandShape = operandType.getShape(); + if (opaqueSizes != operandShape.take_back(opaqueSizes.size())) { + return failure(); + } + + // Expand the shape of the inner tile to reflect the MMA thread layout. + SmallVector resultShape(operandShape.begin(), + operandShape.end() - 2); + SmallVector reInds = + llvm::map_to_vector(llvm::seq(resultShape.size()), + [](int64_t idx) -> ReassociationIndices { + return ReassociationIndices({idx}); + }); + int idx = reInds.size(); + for (auto [outer, native] : llvm::zip_equal(outerSizes, opaqueSizes)) { + // Skip expansion if the outer dim is unit as the SingleSubgroupLayout gives + // a guarantee that the |element| counts are contiguous within the layout, + // and a unit outer implies a single offset and size for that dimension. + if (outer == 1) { + resultShape.push_back(native); + reInds.push_back(ReassociationIndices({idx++})); + continue; + } + + // Reshape to [outer, native / outer] == [outer, thread * element]. This + // corresponds to |outer| repetitions of the thread/element sublayout. + resultShape.push_back(outer); + assert(native % outer == 0 && "invalid mma layout"); + resultShape.push_back(native / outer); + reInds.push_back(ReassociationIndices{idx, idx + 1}); + idx += 2; + } + + reassociations = reInds; + resultType = operandType.clone(resultShape); + return success(); +} struct ConcretizeMmaOperandShape final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -56,12 +129,15 @@ struct ConcretizeMmaOperandShape final : OpRewritePattern { } // Get the reassociation indices and result type of the expand_shape op. - MmaInterfaceAttr kind = mmaOp.getKind(); + MMAAttr kind = dyn_cast(mmaOp.getKind()); + if (!kind) { + return failure(); + } SmallVector reassociations; RankedTensorType concreteType; - if (failed(kind.materializeOperandConcreteShape(rewriter, fragment, operand, - permutation, reassociations, - concreteType))) { + if (failed(materializeOperandConcreteShape(rewriter, kind, fragment, + operand, permutation, + reassociations, concreteType))) { return failure(); } @@ -140,6 +216,8 @@ struct ConcretizeMmaOperandShape final : OpRewritePattern { MMAFragment fragment; }; +} // namespace + void ConcretizeMmaShapesPass::runOnOperation() { MLIRContext *context = &getContext(); auto funcOp = getOperation(); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp index 69333a1e2386..9df22c7e609c 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp @@ -8,6 +8,7 @@ #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h" #include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h" #include "iree/compiler/Codegen/LLVMGPU/Passes.h" +#include "llvm/Support/Debug.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinAttributes.h" @@ -71,6 +72,310 @@ static int64_t getSubgroupNCount(Operation *op) { return *config.getSubgroupNCount(); } +/// Gets a unit vector of the given rank, but fills in the given dimensions +/// from the 2 element array |counts|. |dim0| is the position in the returned +/// vector to put the first element of |counts|, and |dim1| is the position to +/// put the second element. For example, +/// +/// rank = 3, counts = [5, 7], dim0 = 2, dim1 = 1 +/// returns [1, 5, 7] +static SmallVector getUnitOfRankWithDims(int64_t rank, + ArrayRef counts, + int64_t dim0, int64_t dim1) { + assert(counts.size() == 2 && + "Unexpected non-rank 2 single subgroup dimension counts"); + SmallVector res(rank, 1); + res[dim0] = counts[0]; + res[dim1] = counts[1]; + return res; +} + +/// Constructs the nested layout given the layout for a single subgroup and the +/// subgroup/batch counts and orders, as well as the dimensions along which to +/// distribute the intrinsic's layout. +/// +/// |outerDim| and |innerDim| refer to which dimensions are the outermost and +/// innermost for a canonical MK_KN_MN matrix multiply, for a particular +/// fragment. For example, for the B matrix of an MK_NK_MN matrix multiply, +/// we would have: +/// outerDim = 1 for the K dim +/// innerDim = 0 for the N dim +/// +/// For something like MK_NKN_MN with multiple N dims, it would typically be: +/// outerDim = 1 for K +/// innerDim = 2 for the second N dim +/// +/// Importantly these two dimensions always refer to the actual dimension +/// positions in the undistributed vector. For each fragment, this means: +/// A: [outerDim, innerDim] = [innerMostMDim, innerMostKDim] +/// B: [outerDim, innerDim] = [innerMostKDim, innerMostNDim] +/// C: [outerDim, innerDim] = [innerMostMDim, innerMostNDim] +/// +/// And here inner most is referential to the iteration order, not the order +/// they appear per fragment (because there is no relationship between the +/// dimension order of M in A and in C, for example). +static NestedLayoutAttr createNestedLayout( + MLIRContext *context, int64_t rank, int64_t outerDim, int64_t innerDim, + ArrayRef subgroupSizes, ArrayRef subgroupStrides, + ArrayRef batchCount, IREE::GPU::MMASingleSubgroupLayout counts) { + + LLVM_DEBUG({ + llvm::dbgs() << "Creating Nested Layout for::"; + llvm::dbgs() << "\n outerDim = " << outerDim; + llvm::dbgs() << "\n innerDim = " << innerDim; + llvm::dbgs() << "\n subgroupSizes: "; + llvm::interleaveComma(subgroupSizes, llvm::dbgs()); + llvm::dbgs() << "\n subgroupStrides: "; + llvm::interleaveComma(subgroupStrides, llvm::dbgs()); + llvm::dbgs() << "\n batchCount: "; + llvm::interleaveComma(batchCount, llvm::dbgs()); + llvm::dbgs() << "\n counts.outer: "; + llvm::interleaveComma(counts.outer, llvm::dbgs()); + llvm::dbgs() << "\n counts.thread: "; + llvm::interleaveComma(counts.thread, llvm::dbgs()); + llvm::dbgs() << "\n counts.element: "; + llvm::interleaveComma(counts.element, llvm::dbgs()); + llvm::dbgs() << "\n counts.tstrides: "; + llvm::interleaveComma(counts.tstrides, llvm::dbgs()); + llvm::dbgs() << "\n"; + }); + + SmallVector outerCount = + getUnitOfRankWithDims(rank, counts.outer, outerDim, innerDim); + SmallVector threadCount = + getUnitOfRankWithDims(rank, counts.thread, outerDim, innerDim); + SmallVector threadStrides = + getUnitOfRankWithDims(rank, counts.tstrides, outerDim, innerDim); + SmallVector elementCount = + getUnitOfRankWithDims(rank, counts.element, outerDim, innerDim); + + auto layoutAttr = NestedLayoutAttr::get(context, subgroupSizes, batchCount, + outerCount, threadCount, elementCount, + subgroupStrides, threadStrides); + return layoutAttr; +} + +static FailureOr> +getContractionLayout(IREE::GPU::MMAScheduleAttr schedule, + VectorContractOpInfo &opInfo, + linalg::LinalgOp contractOp) { + LLVM_DEBUG({ + llvm::dbgs() << "Getting mma layouts for:\n" << contractOp << "\n"; + llvm::dbgs() << "For schedule: " << schedule << "\n"; + }); + + int64_t rank = contractOp.getIteratorTypesArray().size(); + auto mmaAttr = + llvm::cast(schedule.getIntrinsic()); + MLIRContext *context = schedule.getContext(); + + SmallVector bounds = contractOp.getStaticLoopRanges(); + if (llvm::any_of(bounds, + [](int64_t x) { return x == ShapedType::kDynamic; })) { + return failure(); + } + + if (!llvm::all_of(opInfo.getBatchDims(), + [&bounds](int64_t dim) { return bounds[dim] == 1; })) { + LLVM_DEBUG({ llvm::dbgs() << "non-unit batch dimension\n"; }); + return failure(); + } + + // Get the concrete nested layout for each matrix. Note that the struct + // MMASingleSubgroupLayout contains the partial layout for the + // canonical (M, K) x (K, N) -> (M, N) matmul form; while the specific + // contract op we are looking at right now may not be exactly in that form. + // So here we need to permute/transpose the canonical layout to match with + // the concrete contract op. + + // Note that no matter how we permute/transpose the input contraction + // problem, the way we view the hardware warps remain the same--that is, + // from the hardware's perspective, a single warp has the same warp ID no + // matter what part of the contraction it works on. Similarly here, we are + // delinearizing the linearized GPU hardware lane ID into a n-D concatenated + // logical warp+thread using the subgroup/thread basis, so the subgroup + // basis should remain the same for all A/B/C matrix. + + auto [intrinsicM, intrinsicN, intrinsicK] = mmaAttr.getMNKShape(); + + SmallVector subgroupMBasis; + SmallVector batchMSizes; + int64_t currMCount = schedule.getSubgroupMCount(); + + auto divideGreedily = [](int64_t availableSubgroups, int64_t dimSize, + int64_t minDimSize) -> std::pair { + int64_t dividableDim = dimSize / minDimSize; + int64_t subgroupsUsed = std::gcd(availableSubgroups, dividableDim); + dividableDim /= subgroupsUsed; + int64_t batchesUsed = dividableDim; + return {subgroupsUsed, batchesUsed}; + }; + + // Greedily break up the M subgroup and batch counts along the "M" iteration + // bounds. We distribute as many residual subgroups as possible per M dim, + // and then divide the remaining along batch dims. The inner most M dim is + // always the one used for the intrinsic, meaning for a valid schedule, the + // computed batch counts and subgroup basis will satisfy totalMSize / + // intrinsicM = product(batchMSizes) * product(subgroupMBasis) + for (auto dim : opInfo.getMDims()) { + // Get the number of subgroups and batches used for this dimension based + // on the intrinsic size and the bound size. + int64_t subgroupsUsed, batchesUsed; + if (dim == opInfo.getMDims().back()) { + std::tie(subgroupsUsed, batchesUsed) = + divideGreedily(currMCount, bounds[dim], intrinsicM); + } else { + std::tie(subgroupsUsed, batchesUsed) = + divideGreedily(currMCount, bounds[dim], 1); + } + subgroupMBasis.push_back(subgroupsUsed); + batchMSizes.push_back(batchesUsed); + // Update available subgroup count. + currMCount /= subgroupsUsed; + } + + SmallVector subgroupNBasis; + SmallVector batchNSizes; + int64_t currNCount = schedule.getSubgroupNCount(); + + // Do the same for N dims. + for (auto dim : opInfo.getNDims()) { + // Get the number of subgroups and batches used for this dimension based + // on the intrinsic size and the bound size. + int64_t subgroupsUsed, batchesUsed; + if (dim == opInfo.getNDims().back()) { + std::tie(subgroupsUsed, batchesUsed) = + divideGreedily(currNCount, bounds[dim], intrinsicN); + } else { + std::tie(subgroupsUsed, batchesUsed) = + divideGreedily(currNCount, bounds[dim], 1); + } + subgroupNBasis.push_back(subgroupsUsed); + batchNSizes.push_back(batchesUsed); + // Update available subgroup count. + currNCount /= subgroupsUsed; + } + + SmallVector subgroupMStrides(subgroupMBasis.size()); + SmallVector subgroupNStrides(subgroupNBasis.size()); + + auto mDimVec = opInfo.getMDims(); + llvm::SmallDenseSet mDims(mDimVec.begin(), mDimVec.end()); + auto nDimVec = opInfo.getNDims(); + llvm::SmallDenseSet nDims(nDimVec.begin(), nDimVec.end()); + // Because we currently require all batch dimensions to be unit, the + // subgroup basis can be constructed from the M and N bases. To keep things + // simple, the current heuristic is to distribute the loop dimensions from + // outer to inner. + int64_t currStride = 1; + int64_t currM = subgroupMStrides.size() - 1; + int64_t currN = subgroupNStrides.size() - 1; + for (int64_t dim : llvm::reverse(llvm::seq(rank))) { + if (mDims.contains(dim)) { + subgroupMStrides[currM] = currStride; + currStride *= subgroupMBasis[currM]; + currM--; + continue; + } + + if (nDims.contains(dim)) { + subgroupNStrides[currN] = currStride; + currStride *= subgroupNBasis[currN]; + currN--; + continue; + } + } + + // C matrix layout + auto [m, n] = opInfo.getResultMNIndex(); + int64_t cRank = opInfo.getCRank(); + + // Get the M and N dims w.r.t. the dimensions of the C matrix. cMDims and + // cNDims are the M and N dimensions of the C matrix in the order they are + // iterated over in the contraction. + SmallVector cMDims = opInfo.outMDims; + SmallVector cNDims = opInfo.outNDims; + SmallVector cBatchSizes(cRank, 1); + SmallVector cSubgroupSizes(cRank, 1); + SmallVector cSubgroupStrides(cRank, 0); + for (auto [i, dim] : llvm::enumerate(cMDims)) { + cBatchSizes[dim] = batchMSizes[i]; + cSubgroupSizes[dim] = subgroupMBasis[i]; + cSubgroupStrides[dim] = subgroupMStrides[i]; + } + for (auto [i, dim] : llvm::enumerate(cNDims)) { + cBatchSizes[dim] = batchNSizes[i]; + cSubgroupSizes[dim] = subgroupNBasis[i]; + cSubgroupStrides[dim] = subgroupNStrides[i]; + } + + auto cLayout = createNestedLayout(context, cRank, m, n, + /*subgroupCount=*/cSubgroupSizes, + /*subgroupStrides=*/cSubgroupStrides, + /*batchCount=*/cBatchSizes, + getCSingleSubgroupLayout(mmaAttr)); + LLVM_DEBUG({ llvm::dbgs() << "C layout: " << cLayout << "\n"; }); + + // A matrix layout + auto [afm, bfn] = opInfo.getOperandMNIndex(); + auto [afk, bfk] = opInfo.getOperandKIndex(); + + int64_t aRank = opInfo.getARank(); + + SmallVector aMDims = opInfo.lhsMDims; + SmallVector aBatchSizes(aRank, 1); + SmallVector aSubgroupSizes(aRank, 1); + SmallVector aSubgroupStrides(aRank, 0); + for (auto [i, dim] : llvm::enumerate(aMDims)) { + aBatchSizes[dim] = batchMSizes[i]; + aSubgroupSizes[dim] = subgroupMBasis[i]; + aSubgroupStrides[dim] = subgroupMStrides[i]; + } + for (auto [kDim, lhsKDim] : + llvm::zip_equal(opInfo.getKDims(), opInfo.lhsKDim)) { + aBatchSizes[lhsKDim] = bounds[kDim]; + } + aBatchSizes[afk] = bounds[opInfo.getKDims().back()] / intrinsicK; + + auto aLayout = createNestedLayout(context, aRank, afm, afk, + /*subgroupCount=*/aSubgroupSizes, + /*subgroupStrides=*/aSubgroupStrides, + /*batchCount=*/aBatchSizes, + getASingleSubgroupLayout(mmaAttr)); + LLVM_DEBUG({ llvm::dbgs() << "A layout: " << aLayout << "\n"; }); + + int64_t bRank = opInfo.getBRank(); + + SmallVector bNDims = opInfo.rhsNDims; + SmallVector bBatchSizes(bRank, 1); + SmallVector bSubgroupSizes(bRank, 1); + SmallVector bSubgroupStrides(bRank, 0); + for (auto [i, dim] : llvm::enumerate(bNDims)) { + bBatchSizes[dim] = batchNSizes[i]; + bSubgroupSizes[dim] = subgroupNBasis[i]; + bSubgroupStrides[dim] = subgroupNStrides[i]; + } + for (auto [kDim, rhsKDim] : + llvm::zip_equal(opInfo.getKDims(), opInfo.rhsKDim)) { + bBatchSizes[rhsKDim] = bounds[kDim]; + } + bBatchSizes[bfk] = bounds[opInfo.getKDims().back()] / intrinsicK; + + auto bLayout = createNestedLayout(context, bRank, bfk, bfn, + /*subgroupCount=*/bSubgroupSizes, + /*subgroupStrides=*/bSubgroupStrides, + /*batchCount=*/bBatchSizes, + getBSingleSubgroupLayout(mmaAttr)); + LLVM_DEBUG({ llvm::dbgs() << "B layout: " << bLayout << "\n"; }); + + std::tuple + result = {aLayout, bLayout, cLayout}; + return result; +} + static LogicalResult setContractionAnchor(IREE::GPU::MMAScheduleAttr schedule, SmallVector promotedOperands, RewriterBase &rewriter, @@ -89,7 +394,7 @@ static LogicalResult setContractionAnchor(IREE::GPU::MMAScheduleAttr schedule, contract.getIndexingMapsArray()); assert(succeeded(opInfo) && "contraction should have been inferred"); - auto layouts = schedule.getContractionLayout(opInfo.value(), contract); + auto layouts = getContractionLayout(schedule, opInfo.value(), contract); if (failed(layouts)) { return contract->emitError("cannot get concrete layout for contraction"); } @@ -176,7 +481,7 @@ static LogicalResult setConvolutionAnchor(IREE::GPU::MMAScheduleAttr schedule, assert(succeeded(opInfo) && "unit filter dim convolution should have been infered"); - auto layouts = schedule.getContractionLayout(opInfo.value(), conv); + auto layouts = getContractionLayout(schedule, opInfo.value(), conv); if (failed(layouts)) { return conv->emitError("cannot get concrete layout for convolution"); } diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp index 3dd0c128008e..a2c2187f3f14 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp @@ -11,6 +11,7 @@ #include "iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.h" #include "iree/compiler/Codegen/Common/GPU/Passes.h" #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" +#include "iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.h" #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.h" #include "iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.h" #include "iree/compiler/Codegen/Utils/GPUUtils.h" @@ -1599,6 +1600,219 @@ createReadLayout(MLIRContext *ctx, const VectorExt::LayoutAttr &layout) { return VectorExt::LayoutAttr::get(ctx, perDimLayouts); } +// Struct containing concrete MMA shape, type, and layout information. +struct ConcreteMmaLayout { + GPU::OpaqueMmaLayout base; + VectorExt::PerDimLayoutAttr aMLayout; + VectorExt::PerDimLayoutAttr aKLayout; + VectorExt::PerDimLayoutAttr bKLayout; + VectorExt::PerDimLayoutAttr bNLayout; + VectorExt::PerDimLayoutAttr cMLayout; + VectorExt::PerDimLayoutAttr cNLayout; +}; + +static std::tuple +getPerDimLayoutAttrs(MLIRContext *context, TileSwizzle swizzle) { + // Step 1: obtain the swizzled tile shape, but keeping track of the source + // dimension indices. + struct SrcIndexAndSwizzleDim { + size_t srcIndex; + TileSwizzle::Dim dim; + }; + SmallVector swizzledShape; + for (auto [i, e] : llvm::enumerate(swizzle.expandShape)) { + for (TileSwizzle::Dim d : e) { + swizzledShape.push_back(SrcIndexAndSwizzleDim{i, d}); + } + } + applyPermutationToVector(swizzledShape, swizzle.permutation); + + // Step 2: collect the appropriate labels to use for the swizzled dims. + VectorExt::LayoutDimension internalLabels[] = { + VectorExt::LayoutDimension::VECTORZ, VectorExt::LayoutDimension::VECTORY, + VectorExt::LayoutDimension::VECTORX}; + VectorExt::LayoutDimension crossThreadLabels[] = { + VectorExt::LayoutDimension::LANEZ, VectorExt::LayoutDimension::LANEY, + VectorExt::LayoutDimension::LANEX}; + auto internalLabelIter = std::end(internalLabels); + auto crossThreadLabelIter = std::end(crossThreadLabels); + for (SrcIndexAndSwizzleDim d : swizzledShape) { + if (d.dim.kind == TileSwizzle::Dim::Kind::Internal) { + assert(internalLabelIter != std::begin(internalLabels)); + --internalLabelIter; + } else if (d.dim.kind == TileSwizzle::Dim::Kind::CrossThread) { + assert(crossThreadLabelIter != std::begin(crossThreadLabels)); + --crossThreadLabelIter; + } else { + assert(false && "unexpected dimension kind in intrinsic swizzle"); + } + } + + // Step 3: put together the result PerDimLayoutAttr'd for the two source dims. + SmallVector labels[2]; + SmallVector shape[2]; + for (SrcIndexAndSwizzleDim d : swizzledShape) { + shape[d.srcIndex].push_back(d.dim.size); + auto &labelIterRef = (d.dim.kind == TileSwizzle::Dim::Kind::Internal) + ? internalLabelIter + : crossThreadLabelIter; + labels[d.srcIndex].push_back(VectorExt::LayoutDimensionAttr::get( + context, static_cast(*labelIterRef++))); + } + return {VectorExt::PerDimLayoutAttr::get(context, labels[0], shape[0]), + VectorExt::PerDimLayoutAttr::get(context, labels[1], shape[1])}; +}; + +static ConcreteMmaLayout getConcreteMMALayout(MLIRContext *context, + GPU::MMAIntrinsic intrinsic) { + auto opaque = GPU::getOpaqueMMALayout(context, intrinsic); + ConcreteMmaLayout concreteLayout; + concreteLayout.base = opaque; + auto lhsSwizzle = getIntrinsicSwizzle(intrinsic, GPU::MMAFragment::Lhs); + auto rhsSwizzle = getIntrinsicSwizzle(intrinsic, GPU::MMAFragment::Rhs); + auto accSwizzle = getIntrinsicSwizzle(intrinsic, GPU::MMAFragment::Acc); + std::tie(concreteLayout.aMLayout, concreteLayout.aKLayout) = + getPerDimLayoutAttrs(context, lhsSwizzle); + std::tie(concreteLayout.bNLayout, concreteLayout.bKLayout) = + getPerDimLayoutAttrs(context, rhsSwizzle); + std::tie(concreteLayout.cMLayout, concreteLayout.cNLayout) = + getPerDimLayoutAttrs(context, accSwizzle); + return concreteLayout; +} + +static VectorExt::PerDimLayoutAttr +getBatchedPerDimLayoutAttr(VectorExt::LayoutDimensionAttr batchDim, + VectorExt::PerDimLayoutAttr baseLayout, + int64_t problemSize, int64_t fragmentDimSize) { + assert(problemSize % fragmentDimSize == 0 && + "invalid layout fragment for problem size"); + + SmallVector dimAttrs( + baseLayout.getLabels()); + dimAttrs.insert(dimAttrs.begin(), batchDim); + + SmallVector shapes(baseLayout.getShapes()); + shapes.insert(shapes.begin(), problemSize / fragmentDimSize); + auto layout = VectorExt::PerDimLayoutAttr::get(baseLayout.getContext(), + dimAttrs, shapes); + return layout; +} + +// Get the batched layout attributes for the given fragment layouts, indexing +// map, and problem shape. The canonical fragment map is used to compare against +// the problem map |indexingMap|. For example, for mma fragment B (RHS): +// +// indexingMap = affine_map<(d0, d1, d2) -> (d1, d2) # Transposed B +// fragmentMap = affine_map<(d0, d1, d2) -> (d2, d1) +// problemShape = [32, 64] +// fragmentSize = [16, 8] +// fragmentLayouts = [kLayout, nLayout] +// +// Gives batched layout +// +// Dim0 Layout = [BATCHX, nLayoutLabels], [8, nLayoutShape] +// Dim1 Layout = [BATCHY, kLayoutLabels], [2, kLayoutShape] +static VectorExt::LayoutAttr +getBatchedLayoutAttr(AffineMap indexingMap, AffineMap fragmentMap, + ArrayRef problemShape, + ArrayRef fragmentSize, + ArrayRef fragmentLayouts) { + // Current distribution to MFMA operations does not support batched + // contractions so that is reflected here. + assert(indexingMap.getNumResults() == 2 && + "invalid indexing map to non-batched simple contraction"); + + VectorExt::LayoutDimensionAttr batchX = VectorExt::LayoutDimensionAttr::get( + indexingMap.getContext(), VectorExt::LayoutDimension::BATCHX); + VectorExt::LayoutDimensionAttr batchY = VectorExt::LayoutDimensionAttr::get( + indexingMap.getContext(), VectorExt::LayoutDimension::BATCHY); + + SmallVector perDimAttrs; + for (auto [expr, batchType] : llvm::zip_equal( + indexingMap.getResults(), + SmallVector{batchX, batchY})) { + auto maybeResultPosition = fragmentMap.getResultPosition(expr); + assert(maybeResultPosition && "fragment map and problem map mismatch"); + int64_t idx = *maybeResultPosition; + perDimAttrs.push_back(getBatchedPerDimLayoutAttr( + batchType, fragmentLayouts[idx], problemShape[idx], fragmentSize[idx])); + } + + return VectorExt::LayoutAttr::get(indexingMap.getContext(), perDimAttrs); +} + +static FailureOr> +getContractionLayout(vector::ContractionOp contract, ConcreteMmaLayout layout) { + MLIRContext *context = contract.getContext(); + FailureOr maybeContractionDims = + linalg::inferContractionDims(contract.getIndexingMapsArray()); + if (failed(maybeContractionDims)) { + return failure(); + } + auto contractionDims = *maybeContractionDims; + // TODO: Relax this condition to strictly alignment requirements. + if (contractionDims.k.size() != 1 || contractionDims.m.size() != 1 || + contractionDims.n.size() != 1) { + return failure(); + } + // TODO: Support batched contractions. + if (contractionDims.batch.size() > 0) { + return failure(); + } + unsigned mDim = contractionDims.m[0]; + unsigned nDim = contractionDims.n[0]; + unsigned kDim = contractionDims.k[0]; + + SmallVector iterationBounds; + contract.getIterationBounds(iterationBounds); + + int64_t problemMSize = iterationBounds[mDim]; + int64_t problemNSize = iterationBounds[nDim]; + int64_t problemKSize = iterationBounds[kDim]; + + int64_t mSize = layout.base.mSize; + int64_t nSize = layout.base.nSize; + int64_t kSize = layout.base.kSize; + + // The problem size currently must be strictly aligned to the size of the mma. + // This is expected to succeed assuming the correct [masked] vector size was + // set at strategy configuration time (for this mma). + if (problemMSize % mSize != 0 || problemNSize % nSize || + problemKSize % kSize) { + return failure(); + } + + VectorExt::LayoutAttr aLayout = getBatchedLayoutAttr( + contract.getIndexingMapsArray()[0], + AffineMap::getMultiDimMapWithTargets(3, {mDim, kDim}, context), + {problemMSize, problemKSize}, {mSize, kSize}, + {layout.aMLayout, layout.aKLayout}); + VectorExt::LayoutAttr bLayout = getBatchedLayoutAttr( + contract.getIndexingMapsArray()[1], + AffineMap::getMultiDimMapWithTargets(3, {kDim, nDim}, context), + {problemKSize, problemNSize}, {kSize, nSize}, + {layout.bKLayout, layout.bNLayout}); + VectorExt::LayoutAttr cLayout = getBatchedLayoutAttr( + contract.getIndexingMapsArray()[2], + AffineMap::getMultiDimMapWithTargets(3, {mDim, nDim}, context), + {problemMSize, problemNSize}, {mSize, nSize}, + {layout.cMLayout, layout.cNLayout}); + + return std::make_tuple(aLayout, bLayout, cLayout); +} + +FailureOr> static getContractionLayout(GPU::MMAAttr mma, + vector::ContractionOp + contract) { + ConcreteMmaLayout layout = getConcreteMMALayout( + contract->getContext(), mma.getIntrinsic().getValue()); + return getContractionLayout(contract, layout); +} + DiagnosedSilenceableFailure transform_dialect::SetContractionLayoutAttributes::apply( transform::TransformRewriter &rewriter, @@ -1609,7 +1823,7 @@ transform_dialect::SetContractionLayoutAttributes::apply( return emitDefiniteFailure() << "invalid more than one attribute for contraction annotation"; } - auto mmaType = llvm::dyn_cast(typeList.front()); + auto mmaType = llvm::dyn_cast(typeList.front()); if (!mmaType) { return emitDefiniteFailure() << "invalid non-mma attribute for contraction annotation " @@ -1623,7 +1837,7 @@ transform_dialect::SetContractionLayoutAttributes::apply( << "invalid non-contraction annotation " << payload; } - auto maybeLayouts = mmaType.getContractionLayout(contract); + auto maybeLayouts = getContractionLayout(mmaType, contract); if (failed(maybeLayouts)) { return emitDefiniteFailure() << "invalid opaque mma layout for annotation " << mmaType;