Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AMDGPU: Define v_mfma_f32_{16x16x128|32x32x64}_f8f6f4 instructions #116723

Merged
merged 1 commit into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion clang/include/clang/Basic/BuiltinsAMDGPU.def
Original file line number Diff line number Diff line change
Expand Up @@ -434,9 +434,11 @@ TARGET_BUILTIN(__builtin_amdgcn_cvt_sr_fp8_f32, "ifiiIi", "nc", "fp8-conversion-
//===----------------------------------------------------------------------===//
// GFX950 only builtins.
//===----------------------------------------------------------------------===//
TARGET_BUILTIN(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4, "V4fV8ZiV8ZiV4fIiIiIiiIii", "nc", "gfx950-insts")
TARGET_BUILTIN(__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4, "V16fV8ZiV8ZiV16fIiIiIiiIii", "nc", "gfx950-insts")

TARGET_BUILTIN(__builtin_amdgcn_mfma_f32_16x16x32_f16, "V4fV8hV8hV4fIiIiIi", "nc", "gfx950-insts")
TARGET_BUILTIN(__builtin_amdgcn_mfma_f32_32x32x16_f16, "V16fV8hV8hV16fIiIiIi", "nc", "gfx950-insts")

TARGET_BUILTIN(__builtin_amdgcn_mfma_f32_32x32x16_bf16, "V16fV8yV8yV16fIiIiIi", "nc", "gfx950-insts")

//===----------------------------------------------------------------------===//
Expand Down
15 changes: 14 additions & 1 deletion clang/lib/CodeGen/CGBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19729,7 +19729,20 @@ Value *CodeGenFunction::EmitAMDGPUBuiltinExpr(unsigned BuiltinID,
(uint64_t)0);
return Builder.CreateInsertElement(I0, A, 1);
}

case AMDGPU::BI__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4:
case AMDGPU::BI__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4: {
llvm::FixedVectorType *VT = FixedVectorType::get(Builder.getInt32Ty(), 8);
Function *F = CGM.getIntrinsic(
BuiltinID == AMDGPU::BI__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4
? Intrinsic::amdgcn_mfma_scale_f32_32x32x64_f8f6f4
: Intrinsic::amdgcn_mfma_scale_f32_16x16x128_f8f6f4,
{VT, VT});

SmallVector<Value *, 9> Args;
for (unsigned I = 0, N = E->getNumArgs(); I != N; ++I)
Args.push_back(EmitScalarExpr(E->getArg(I)));
return Builder.CreateCall(F, Args);
}
case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32:
case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_tied_w32:
case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64:
Expand Down
15 changes: 15 additions & 0 deletions clang/test/CodeGenOpenCL/builtins-amdgcn-mfma.cl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ typedef half v16h __attribute__((ext_vector_type(16)));
typedef half v32h __attribute__((ext_vector_type(32)));
typedef int v2i __attribute__((ext_vector_type(2)));
typedef int v4i __attribute__((ext_vector_type(4)));
typedef int v8i __attribute__((ext_vector_type(8)));
typedef int v16i __attribute__((ext_vector_type(16)));
typedef int v32i __attribute__((ext_vector_type(32)));
typedef short v2s __attribute__((ext_vector_type(2)));
Expand Down Expand Up @@ -431,4 +432,18 @@ v16f test_mfma_f32_32x32x16_bf16(v8bf16 a, v8bf16 b, v16f c) {
return __builtin_amdgcn_mfma_f32_32x32x16_bf16(a, b, c, 1, 2, 3);
}

// CHECK-GFX950-LABEL: @test_mfma_scale_f32_16x16x128_f8f6f4
// CHECK-GFX950: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v8i32(<8 x i32> %a, <8 x i32> %b, <4 x float> %c, i32 3, i32 1, i32 2, i32 %scale_a, i32 3, i32 %scale_b)
void test_mfma_scale_f32_16x16x128_f8f6f4(global v4f* out, v8i a, v8i b, v4f c, int scale_a, int scale_b)
{
*out = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(a, b, c, 3, 1, 2, scale_a, 3, scale_b);
}

// CHECK-GFX950-LABEL: @test_mfma_scale_f32_32x32x64_f8f6f4
// CHECK-GFX950: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v8i32(<8 x i32> %a, <8 x i32> %b, <16 x float> %c, i32 3, i32 1, i32 2, i32 %scale_a, i32 3, i32 %scale_b)
void test_mfma_scale_f32_32x32x64_f8f6f4(global v16f* out, v8i a, v8i b, v16f c, int scale_a, int scale_b)
{
*out = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(a, b, c, 3, 1, 2, scale_a, 3, scale_b);
}

#endif
15 changes: 15 additions & 0 deletions clang/test/SemaOpenCL/builtins-amdgcn-error-gfx950-param.cl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ typedef float float4 __attribute__((ext_vector_type(4)));
typedef float float16 __attribute__((ext_vector_type(16)));
typedef half half8 __attribute__((ext_vector_type(8)));
typedef __bf16 bfloat8 __attribute__((ext_vector_type(8)));
typedef int int8 __attribute__((ext_vector_type(8)));


void test_mfma_f32_16x16x32_f16(__global float4* out, half8 a, half8 b, float4 c, int X) {
Expand All @@ -26,3 +27,17 @@ void test_mfma_f32_32x32x16_bf16(__global float16* out, bfloat8 a, bfloat8 b, fl
*out = __builtin_amdgcn_mfma_f32_32x32x16_bf16(a, b, c, 0, X, 0); // expected-error{{argument to '__builtin_amdgcn_mfma_f32_32x32x16_bf16' must be a constant integer}}
*out = __builtin_amdgcn_mfma_f32_32x32x16_bf16(a, b, c, 0, 0, X); // expected-error{{argument to '__builtin_amdgcn_mfma_f32_32x32x16_bf16' must be a constant integer}}
}

void test_mfma_scale_f32_16x16x128_f8f6f4(__global float4* out, int8 a, int8 b, float4 c, int X, int Y) {
*out = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(a, b, c, X, 0, 1, Y, 2, Y); // expected-error{{argument to '__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4' must be a constant integer}}
*out = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(a, b, c, 0, X, 1, Y, 2, Y); // expected-error{{argument to '__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4' must be a constant integer}}
*out = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(a, b, c, 0, 0, X, Y, 2, Y); // expected-error{{argument to '__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4' must be a constant integer}}
*out = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(a, b, c, 0, 0, 0, Y, X, Y); // expected-error{{argument to '__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4' must be a constant integer}}
}

void test_mfma_scale_f32_32x32x64_f8f6f4(__global float16* out, int8 a, int8 b, float16 c, int X, int Y) {
*out = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(a, b, c, X, 0, 1, Y, 2, Y); // expected-error{{argument to '__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4' must be a constant integer}}
*out = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(a, b, c, 0, X, 1, Y, 2, Y); // expected-error{{argument to '__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4' must be a constant integer}}
*out = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(a, b, c, 0, 0, X, Y, 2, Y); // expected-error{{argument to '__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4' must be a constant integer}}
*out = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(a, b, c, 0, 0, 0, Y, X, Y); // expected-error{{argument to '__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4' must be a constant integer}}
}
23 changes: 22 additions & 1 deletion clang/test/SemaOpenCL/builtins-amdgcn-error-gfx950.cl
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,33 @@
typedef float float4 __attribute__((ext_vector_type(4)));
typedef float float16 __attribute__((ext_vector_type(16)));
typedef half half8 __attribute__((ext_vector_type(8)));
typedef half half16 __attribute__((ext_vector_type(16)));
typedef __bf16 bfloat8 __attribute__((ext_vector_type(8)));
typedef __bf16 bfloat16 __attribute__((ext_vector_type(16)));
typedef unsigned int uint2 __attribute__((ext_vector_type(2)));
typedef int int4 __attribute__((ext_vector_type(4)));
typedef int int8 __attribute__((ext_vector_type(8)));
typedef int int16 __attribute__((ext_vector_type(16)));

void test(__global float4* out0, half8 a0, half8 b0, float4 c0,
__global float16* out1, half8 a1, half8 b1, float16 c1,
__global float16* out2, bfloat8 a2, bfloat8 b2, float16 c2) {
__global float16* out2, bfloat8 a2, bfloat8 b2, float16 c2,
__global int4* out3, int4 a3, int4 b3, int4 c3,
__global int16* out4, int4 a4, int4 b4, int16 c4,
__global float4* out5, bfloat8 a5, bfloat8 b5, float4 c5,
__global float4* out6, half8 a6, half16 b6, float4 c6,
__global float16* out7, half8 a7, half16 b7, float16 c7,
__global float4* out8, bfloat8 a8, bfloat16 b8, float4 c8,
__global float16* out9, bfloat8 a9, bfloat16 b9, float16 c9,
__global int4* out10, int4 a10, int8 b10, int4 c10,
__global int16* out11, int4 a11, int8 b11, int16 c11,
__global float4* out12, int4 a12, int8 b12, float4 c12,
__global float16* out13, int4 a13, int8 b13, float16 c13,
__global float4* out14, int8 a14, int8 b14, float4 c14, int d14, int e14,
__global float16* out15, int8 a15, int8 b15, float16 c15, int d15, int e15) {
*out0 = __builtin_amdgcn_mfma_f32_16x16x32_f16(a0, b0, c0, 0, 0, 0); // expected-error{{'__builtin_amdgcn_mfma_f32_16x16x32_f16' needs target feature gfx950-insts}}
*out1 = __builtin_amdgcn_mfma_f32_32x32x16_f16(a1, b1, c1, 0, 0, 0); // expected-error{{'__builtin_amdgcn_mfma_f32_32x32x16_f16' needs target feature gfx950-insts}}
*out2 = __builtin_amdgcn_mfma_f32_32x32x16_bf16(a2, b2, c2, 0, 0, 0); // expected-error{{'__builtin_amdgcn_mfma_f32_32x32x16_bf16' needs target feature gfx950-insts}}
*out14 = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(a14, b14, c14, 0, 0, 0, d14, 0, e14); // expected-error{{'__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4' needs target feature gfx950-insts}}
*out15 = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(a15, b15, c15, 0, 0, 0, d15, 0, e15); // expected-error{{'__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4' needs target feature gfx950-insts}}
}
10 changes: 10 additions & 0 deletions llvm/docs/AMDGPUUsage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1397,6 +1397,16 @@ The AMDGPU backend implements the following LLVM IR intrinsics.
used by hardware to control active lanes when used in EXEC register.
For example, ballot(i1 true) return EXEC mask.

llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4 Emit `v_mfma_scale_f32_16x16x128_f8f6f4` to set the scale factor. The
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This reminds me that we probably didn't add other gfx950 intrinsics to the document.

last 4 operands correspond to the scale inputs.

- 2-bit byte index to use for each lane for matrix A
- Matrix A scale values
- 2-bit byte index to use for each lane for matrix B
- Matrix B scale values

llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4 Emit `v_mfma_scale_f32_32x32x64_f8f6f4`

============================================== ==========================================================

.. TODO::
Expand Down
31 changes: 31 additions & 0 deletions llvm/include/llvm/IR/IntrinsicsAMDGPU.td
Original file line number Diff line number Diff line change
Expand Up @@ -2968,6 +2968,35 @@ class AMDGPUMfmaIntrinsic<LLVMType DestTy, LLVMType SrcABTy> :
[IntrConvergent, IntrNoMem,
ImmArg<ArgIndex<3>>, ImmArg<ArgIndex<4>>, ImmArg<ArgIndex<5>>]>;


// srcA's format is determined by cbsz. srcB's format is determined by
// blgp.
//
// These should be <8 x i32> for f8 formats, <6 x i32> for f6 formats,
// and <4 x i32> for f4 formats. If the format control bits imply a
// smaller type than used, the high elements will be truncated.
//
// If the format control bits imply a larger type than used, the high
// elements are padded with undef.

class AMDGPUMfmaScaleIntrinsic<LLVMType DestTy> :
DefaultAttrsIntrinsic<[DestTy],
[llvm_anyvector_ty, llvm_anyvector_ty, DestTy,
llvm_i32_ty, // cbsz
llvm_i32_ty, // blgp
// llvm_i1_ty, // TODO: neg_src2
// llvm_i1_ty, // TODO: abs_src2
// llvm_i1_ty, // TODO: clamp
llvm_i32_ty, // op_sel (A matrix scale, 2-bits) // TODO: Make i2?
llvm_i32_ty, // v_mfma_ld_scale_b32 src0 (A matrix scale)
llvm_i32_ty, // op_sel (B matrix scale, 2-bits) // TODO: Make i2?
llvm_i32_ty // v_mfma_ld_scale_b32 src1 (B matrix scale)
],
[IntrConvergent, IntrNoMem,
ImmArg<ArgIndex<3>>, ImmArg<ArgIndex<4>>,
ImmArg<ArgIndex<5>>, ImmArg<ArgIndex<7>>
]>;

defset list<Intrinsic> AMDGPUMFMAIntrinsics908 = {
def int_amdgcn_mfma_f32_32x32x1f32 : AMDGPUMfmaIntrinsic<llvm_v32f32_ty, llvm_float_ty>;
def int_amdgcn_mfma_f32_16x16x1f32 : AMDGPUMfmaIntrinsic<llvm_v16f32_ty, llvm_float_ty>;
Expand Down Expand Up @@ -3119,6 +3148,8 @@ def int_amdgcn_mfma_f32_16x16x32_f16 : AMDGPUMfmaIntrinsic<llvm_v4f32_ty, llvm_v
def int_amdgcn_mfma_f32_32x32x16_f16 : AMDGPUMfmaIntrinsic<llvm_v16f32_ty, llvm_v8f16_ty>;

def int_amdgcn_mfma_f32_32x32x16_bf16 : AMDGPUMfmaIntrinsic<llvm_v16f32_ty, llvm_v8bf16_ty>;
def int_amdgcn_mfma_scale_f32_16x16x128_f8f6f4 : AMDGPUMfmaScaleIntrinsic<llvm_v4f32_ty>;
def int_amdgcn_mfma_scale_f32_32x32x64_f8f6f4 : AMDGPUMfmaScaleIntrinsic<llvm_v16f32_ty>;
}

//===----------------------------------------------------------------------===//
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/AMDGPU/AMDGPUGISel.td
Original file line number Diff line number Diff line change
Expand Up @@ -423,3 +423,6 @@ def gi_fp_pow2_to_exponent : GICustomOperandRenderer<"renderFPPow2ToExponent">,

def gi_as_hw_round_mode : GICustomOperandRenderer<"renderRoundMode">,
GISDNodeXFormEquiv<as_hw_round_mode>;

def gi_MFMALdScaleModifierOp : GICustomOperandRenderer<"renderScaledMAIIntrinsicOperand">,
GISDNodeXFormEquiv<MFMALdScaleXForm>;
1 change: 1 addition & 0 deletions llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1258,6 +1258,7 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
if (isa<UndefValue>(Src)) {
return IC.replaceInstUsesWith(II, Src);
}
return std::nullopt;
}
}
if (const AMDGPU::ImageDimIntrinsicInfo *ImageDimIntr =
Expand Down
12 changes: 12 additions & 0 deletions llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5742,6 +5742,18 @@ void AMDGPUInstructionSelector::renderRoundMode(MachineInstrBuilder &MIB,
MIB.addImm((MI.getOperand(OpIdx).getImm() + 3) % 4);
}

/// Convert from 2-bit value to enum values used for op_sel* source modifiers.
void AMDGPUInstructionSelector::renderScaledMAIIntrinsicOperand(
MachineInstrBuilder &MIB, const MachineInstr &MI, int OpIdx) const {
unsigned Val = MI.getOperand(OpIdx).getImm();
unsigned New = 0;
if (Val & 0x1)
New |= SISrcMods::OP_SEL_0;
if (Val & 0x2)
New |= SISrcMods::OP_SEL_1;
MIB.addImm(New);
}

bool AMDGPUInstructionSelector::isInlineImmediate(const APInt &Imm) const {
return TII.isInlineConstant(Imm);
}
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.h
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,8 @@ class AMDGPUInstructionSelector final : public InstructionSelector {

void renderRoundMode(MachineInstrBuilder &MIB, const MachineInstr &MI,
int OpIdx) const;
void renderScaledMAIIntrinsicOperand(MachineInstrBuilder &MIB,
const MachineInstr &MI, int OpIdx) const;

bool isInlineImmediate(const APInt &Imm) const;
bool isInlineImmediate(const APFloat &Imm) const;
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/AMDGPU/AMDGPUInstructions.td
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class AMDGPUInst <dag outs, dag ins, string asm = "",
// instructions to not match without killing the whole decode process. It is
// mainly used for ARM, but Tablegen expects this field to exist or it fails
// to build the decode table.
field bits<96> SoftFail = 0;
field bits<128> SoftFail = 0; // FIXME: If this is smaller than largest instruction, DecodeEmitter crashes

let DecoderNamespace = Namespace;

Expand Down
19 changes: 19 additions & 0 deletions llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4769,6 +4769,25 @@ AMDGPURegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
: getVGPROpMapping(MI.getOperand(4).getReg(), MRI, *TRI);
break;
}
case Intrinsic::amdgcn_mfma_scale_f32_16x16x128_f8f6f4:
case Intrinsic::amdgcn_mfma_scale_f32_32x32x64_f8f6f4: {
const SIMachineFunctionInfo *Info = MF.getInfo<SIMachineFunctionInfo>();
OpdsMapping[0] =
Info->mayNeedAGPRs()
? getAGPROpMapping(MI.getOperand(0).getReg(), MRI, *TRI)
: getVGPROpMapping(MI.getOperand(0).getReg(), MRI, *TRI);

OpdsMapping[2] = getVGPROpMapping(MI.getOperand(2).getReg(), MRI, *TRI);
OpdsMapping[3] = getVGPROpMapping(MI.getOperand(3).getReg(), MRI, *TRI);
OpdsMapping[4] =
Info->mayNeedAGPRs()
? getAGPROpMapping(MI.getOperand(4).getReg(), MRI, *TRI)
: getVGPROpMapping(MI.getOperand(4).getReg(), MRI, *TRI);

OpdsMapping[8] = getVGPROpMapping(MI.getOperand(8).getReg(), MRI, *TRI);
OpdsMapping[10] = getVGPROpMapping(MI.getOperand(10).getReg(), MRI, *TRI);
break;
}
case Intrinsic::amdgcn_smfmac_f32_16x16x32_f16:
case Intrinsic::amdgcn_smfmac_f32_32x32x16_f16:
case Intrinsic::amdgcn_smfmac_f32_16x16x32_bf16:
Expand Down
75 changes: 75 additions & 0 deletions llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,17 @@ static inline DecoderUInt128 eat12Bytes(ArrayRef<uint8_t> &Bytes) {
return DecoderUInt128(Lo, Hi);
}

static inline DecoderUInt128 eat16Bytes(ArrayRef<uint8_t> &Bytes) {
assert(Bytes.size() >= 16);
uint64_t Lo =
support::endian::read<uint64_t, llvm::endianness::little>(Bytes.data());
Bytes = Bytes.slice(8);
uint64_t Hi =
support::endian::read<uint64_t, llvm::endianness::little>(Bytes.data());
Bytes = Bytes.slice(8);
return DecoderUInt128(Lo, Hi);
}

DecodeStatus AMDGPUDisassembler::getInstruction(MCInst &MI, uint64_t &Size,
ArrayRef<uint8_t> Bytes_,
uint64_t Address,
Expand Down Expand Up @@ -548,6 +559,15 @@ DecodeStatus AMDGPUDisassembler::getInstruction(MCInst &MI, uint64_t &Size,

// Reinitialize Bytes
Bytes = Bytes_.slice(0, MaxInstBytesNum);

} else if (Bytes.size() >= 16 &&
STI.hasFeature(AMDGPU::FeatureGFX950Insts)) {
DecoderUInt128 DecW = eat16Bytes(Bytes);
if (tryDecodeInst(DecoderTableGFX940128, MI, DecW, Address, CS))
break;

// Reinitialize Bytes
Bytes = Bytes_.slice(0, MaxInstBytesNum);
}

if (Bytes.size() >= 8) {
Expand Down Expand Up @@ -759,6 +779,9 @@ DecodeStatus AMDGPUDisassembler::getInstruction(MCInst &MI, uint64_t &Size,
if (MCII->get(MI.getOpcode()).TSFlags & SIInstrFlags::SDWA)
convertSDWAInst(MI);

if (MCII->get(MI.getOpcode()).TSFlags & SIInstrFlags::IsMAI)
convertMAIInst(MI);

int VDstIn_Idx = AMDGPU::getNamedOperandIdx(MI.getOpcode(),
AMDGPU::OpName::vdst_in);
if (VDstIn_Idx != -1) {
Expand Down Expand Up @@ -837,6 +860,58 @@ void AMDGPUDisassembler::convertSDWAInst(MCInst &MI) const {
}
}

/// Adjust the register values used by V_MFMA_F8F6F4_f8_f8 instructions to the
/// appropriate subregister for the used format width.
static void adjustMFMA_F8F6F4OpRegClass(const MCRegisterInfo &MRI,
MCOperand &MO, uint8_t NumRegs) {
switch (NumRegs) {
case 4:
return MO.setReg(MRI.getSubReg(MO.getReg(), AMDGPU::sub0_sub1_sub2_sub3));
case 6:
return MO.setReg(
MRI.getSubReg(MO.getReg(), AMDGPU::sub0_sub1_sub2_sub3_sub4_sub5));
case 8:
// No-op in cases where one operand is still f8/bf8.
return;
default:
llvm_unreachable("Unexpected size for mfma f8f6f4 operand");
}
}

/// f8f6f4 instructions have different pseudos depending on the used formats. In
/// the disassembler table, we only have the variants with the largest register
/// classes which assume using an fp8/bf8 format for both operands. The actual
/// register class depends on the format in blgp and cbsz operands. Adjust the
/// register classes depending on the used format.
void AMDGPUDisassembler::convertMAIInst(MCInst &MI) const {
int BlgpIdx =
AMDGPU::getNamedOperandIdx(MI.getOpcode(), AMDGPU::OpName::blgp);
if (BlgpIdx == -1)
return;

int CbszIdx =
AMDGPU::getNamedOperandIdx(MI.getOpcode(), AMDGPU::OpName::cbsz);

unsigned CBSZ = MI.getOperand(CbszIdx).getImm();
unsigned BLGP = MI.getOperand(BlgpIdx).getImm();

const AMDGPU::MFMA_F8F6F4_Info *AdjustedRegClassOpcode =
AMDGPU::getMFMA_F8F6F4_WithFormatArgs(CBSZ, BLGP, MI.getOpcode());
if (!AdjustedRegClassOpcode ||
AdjustedRegClassOpcode->Opcode == MI.getOpcode())
return;

MI.setOpcode(AdjustedRegClassOpcode->Opcode);
int Src0Idx =
AMDGPU::getNamedOperandIdx(MI.getOpcode(), AMDGPU::OpName::src0);
int Src1Idx =
AMDGPU::getNamedOperandIdx(MI.getOpcode(), AMDGPU::OpName::src1);
adjustMFMA_F8F6F4OpRegClass(MRI, MI.getOperand(Src0Idx),
AdjustedRegClassOpcode->NumRegsSrcA);
adjustMFMA_F8F6F4OpRegClass(MRI, MI.getOperand(Src1Idx),
AdjustedRegClassOpcode->NumRegsSrcB);
}

struct VOPModifiers {
unsigned OpSel = 0;
unsigned OpSelHi = 0;
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.h
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ class AMDGPUDisassembler : public MCDisassembler {
void convertVINTERPInst(MCInst &MI) const;
void convertFMAanyK(MCInst &MI, int ImmLitIdx) const;
void convertSDWAInst(MCInst &MI) const;
void convertMAIInst(MCInst &MI) const;
void convertDPP8Inst(MCInst &MI) const;
void convertMIMGInst(MCInst &MI) const;
void convertVOP3DPPInst(MCInst &MI) const;
Expand Down
Loading
Loading