Skip to content

WIP: AMDGPU: Always select the VGPR version of MFMAs #145025

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: users/arsenm/amdgpu-add-pass-rewrite-vgpr-mfma-to-agpr
Choose a base branch
from

Conversation

arsenm
Copy link
Contributor

@arsenm arsenm commented Jun 20, 2025

We do not want to use AGPRs unless absolutely required due
to register pressure. Rely on a post-regalloc pass to replace
VGPR MFMAs with the AGPR version if it avoids the copies introduced
due to live range splitting.

We do not want to use AGPRs unless absolutely required due
to register pressure. Rely on a post-regalloc pass to replace
VGPR MFMAs with the AGPR version if it avoids the copies introduced
due to live range splitting.
Copy link
Contributor Author

arsenm commented Jun 20, 2025

Warning

This pull request is not mergeable via GitHub because a downstack PR is open. Once all requirements are satisfied, merge this PR as a stack on Graphite.
Learn more

This stack of pull requests is managed by Graphite. Learn more about stacking.

@llvmbot
Copy link
Member

llvmbot commented Jun 20, 2025

@llvm/pr-subscribers-backend-amdgpu

Author: Matt Arsenault (arsenm)

Changes

We do not want to use AGPRs unless absolutely required due
to register pressure. Rely on a post-regalloc pass to replace
VGPR MFMAs with the AGPR version if it avoids the copies introduced
due to live range splitting.


Full diff: https://github.com/llvm/llvm-project/pull/145025.diff

5 Files Affected:

  • (modified) llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp (+4-6)
  • (modified) llvm/lib/Target/AMDGPU/SIISelLowering.cpp (+1-19)
  • (modified) llvm/lib/Target/AMDGPU/SIMachineFunctionInfo.cpp (-6)
  • (modified) llvm/lib/Target/AMDGPU/SIMachineFunctionInfo.h (-6)
  • (modified) llvm/lib/Target/AMDGPU/VOP3PInstructions.td (+30-25)
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
index dca55dafcc5e3..8331fe333e637 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
@@ -4865,31 +4865,29 @@ AMDGPURegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
       // for srcA/srcB?
       //
       // vdst, srcA, srcB, srcC
-      const SIMachineFunctionInfo *Info = MF.getInfo<SIMachineFunctionInfo>();
       OpdsMapping[0] =
-          Info->mayNeedAGPRs()
+          !Subtarget.hasGFX90AInsts()
               ? getAGPROpMapping(MI.getOperand(0).getReg(), MRI, *TRI)
               : getVGPROpMapping(MI.getOperand(0).getReg(), MRI, *TRI);
       OpdsMapping[2] = getVGPROpMapping(MI.getOperand(2).getReg(), MRI, *TRI);
       OpdsMapping[3] = getVGPROpMapping(MI.getOperand(3).getReg(), MRI, *TRI);
       OpdsMapping[4] =
-          Info->mayNeedAGPRs()
+          !Subtarget.hasGFX90AInsts()
               ? getAGPROpMapping(MI.getOperand(4).getReg(), MRI, *TRI)
               : getVGPROpMapping(MI.getOperand(4).getReg(), MRI, *TRI);
       break;
     }
     case Intrinsic::amdgcn_mfma_scale_f32_16x16x128_f8f6f4:
     case Intrinsic::amdgcn_mfma_scale_f32_32x32x64_f8f6f4: {
-      const SIMachineFunctionInfo *Info = MF.getInfo<SIMachineFunctionInfo>();
       OpdsMapping[0] =
-          Info->mayNeedAGPRs()
+          !Subtarget.hasGFX90AInsts()
               ? getAGPROpMapping(MI.getOperand(0).getReg(), MRI, *TRI)
               : getVGPROpMapping(MI.getOperand(0).getReg(), MRI, *TRI);
 
       OpdsMapping[2] = getVGPROpMapping(MI.getOperand(2).getReg(), MRI, *TRI);
       OpdsMapping[3] = getVGPROpMapping(MI.getOperand(3).getReg(), MRI, *TRI);
       OpdsMapping[4] =
-          Info->mayNeedAGPRs()
+          !Subtarget.hasGFX90AInsts()
               ? getAGPROpMapping(MI.getOperand(4).getReg(), MRI, *TRI)
               : getVGPROpMapping(MI.getOperand(4).getReg(), MRI, *TRI);
 
diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index 07d79d677104a..11c9adb3371d5 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -16076,7 +16076,6 @@ void SITargetLowering::AdjustInstrPostInstrSelection(MachineInstr &MI,
 
   MachineFunction *MF = MI.getParent()->getParent();
   MachineRegisterInfo &MRI = MF->getRegInfo();
-  SIMachineFunctionInfo *Info = MF->getInfo<SIMachineFunctionInfo>();
 
   if (TII->isVOP3(MI.getOpcode())) {
     // Make sure constant bus requirements are respected.
@@ -16087,7 +16086,6 @@ void SITargetLowering::AdjustInstrPostInstrSelection(MachineInstr &MI,
     // use between vgpr and agpr as agpr tuples tend to be big.
     if (!MI.getDesc().operands().empty()) {
       unsigned Opc = MI.getOpcode();
-      bool HasAGPRs = Info->mayNeedAGPRs();
       const SIRegisterInfo *TRI = Subtarget->getRegisterInfo();
       int16_t Src2Idx = AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::src2);
       for (auto I :
@@ -16095,7 +16093,7 @@ void SITargetLowering::AdjustInstrPostInstrSelection(MachineInstr &MI,
             AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::src1), Src2Idx}) {
         if (I == -1)
           break;
-        if ((I == Src2Idx) && (HasAGPRs))
+        if (I == Src2Idx)
           break;
         MachineOperand &Op = MI.getOperand(I);
         if (!Op.isReg() || !Op.getReg().isVirtual())
@@ -16129,22 +16127,6 @@ void SITargetLowering::AdjustInstrPostInstrSelection(MachineInstr &MI,
             TII->legalizeOpWithMove(MI, Src1Idx);
         }
       }
-
-      if (!HasAGPRs)
-        return;
-
-      // Resolve the rest of AV operands to AGPRs.
-      if (auto *Src2 = TII->getNamedOperand(MI, AMDGPU::OpName::src2)) {
-        if (Src2->isReg() && Src2->getReg().isVirtual()) {
-          auto *RC = TRI->getRegClassForReg(MRI, Src2->getReg());
-          if (TRI->isVectorSuperClass(RC)) {
-            auto *NewRC = TRI->getEquivalentAGPRClass(RC);
-            MRI.setRegClass(Src2->getReg(), NewRC);
-            if (Src2->isTied())
-              MRI.setRegClass(MI.getOperand(0).getReg(), NewRC);
-          }
-        }
-      }
     }
 
     return;
diff --git a/llvm/lib/Target/AMDGPU/SIMachineFunctionInfo.cpp b/llvm/lib/Target/AMDGPU/SIMachineFunctionInfo.cpp
index 1673bfa152674..7a279d7bede7d 100644
--- a/llvm/lib/Target/AMDGPU/SIMachineFunctionInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/SIMachineFunctionInfo.cpp
@@ -63,12 +63,6 @@ SIMachineFunctionInfo::SIMachineFunctionInfo(const Function &F,
     PSInputAddr = AMDGPU::getInitialPSInputAddr(F);
   }
 
-  MayNeedAGPRs = ST.hasMAIInsts();
-  if (ST.hasGFX90AInsts() &&
-      ST.getMaxNumVGPRs(F) <= AMDGPU::VGPR_32RegClass.getNumRegs() &&
-      !mayUseAGPRs(F))
-    MayNeedAGPRs = false; // We will select all MAI with VGPR operands.
-
   if (AMDGPU::isChainCC(CC)) {
     // Chain functions don't receive an SP from their caller, but are free to
     // set one up. For now, we can use s32 to match what amdgpu_gfx functions
diff --git a/llvm/lib/Target/AMDGPU/SIMachineFunctionInfo.h b/llvm/lib/Target/AMDGPU/SIMachineFunctionInfo.h
index 0e7635a045588..b9157b9a8c7e6 100644
--- a/llvm/lib/Target/AMDGPU/SIMachineFunctionInfo.h
+++ b/llvm/lib/Target/AMDGPU/SIMachineFunctionInfo.h
@@ -493,8 +493,6 @@ class SIMachineFunctionInfo final : public AMDGPUMachineFunction,
   // user arguments. This is an offset from the KernargSegmentPtr.
   bool ImplicitArgPtr : 1;
 
-  bool MayNeedAGPRs : 1;
-
   // The hard-wired high half of the address of the global information table
   // for AMDPAL OS type. 0xffffffff represents no hard-wired high half, since
   // current hardware only allows a 16 bit value.
@@ -1165,10 +1163,6 @@ class SIMachineFunctionInfo final : public AMDGPUMachineFunction,
 
   unsigned getMaxMemoryClusterDWords() const { return MaxMemoryClusterDWords; }
 
-  bool mayNeedAGPRs() const {
-    return MayNeedAGPRs;
-  }
-
   // \returns true if a function has a use of AGPRs via inline asm or
   // has a call which may use it.
   bool mayUseAGPRs(const Function &F) const;
diff --git a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
index e8db879ca5077..6b6b74234cfef 100644
--- a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
+++ b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
@@ -856,17 +856,11 @@ defvar MayNotNeedAGPRs_gisel = [{
   return !MF.getInfo<SIMachineFunctionInfo>()->mayNeedAGPRs();
 }];
 
-class AgprMAIFrag<SDPatternOperator Op, bit HasAbid = true,
-                  bit Scaled = false> :
-  MAIFrag<Op, MayNeedAGPRs, HasAbid, Scaled> {
-  let GISelPredicateCode = MayNeedAGPRs_gisel;
-}
+class AgprMAIFrag<SDPatternOperator Op, bit HasAbid = true, bit Scaled = false>
+    : MAIFrag<Op, [{}], HasAbid, Scaled> {}
 
-class VgprMAIFrag<SDPatternOperator Op, bit HasAbid = true,
-                   bit Scaled = false> :
-  MAIFrag<Op, MayNotNeedAGPRs, HasAbid, Scaled> {
-  let GISelPredicateCode = MayNotNeedAGPRs_gisel;
-}
+class VgprMAIFrag<SDPatternOperator Op, bit HasAbid = true, bit Scaled = false>
+    : MAIFrag<Op, [{}], HasAbid, Scaled> {}
 
 let isAsCheapAsAMove = 1, isReMaterializable = 1 in {
   defm V_ACCVGPR_READ_B32  : VOP3Inst<"v_accvgpr_read_b32",  VOPProfileAccRead>;
@@ -917,10 +911,14 @@ multiclass MAIInst<string OpName, string P, SDPatternOperator node = null_frag,
                          !if(!or(NoDstOverlap, !eq(node, null_frag)), null_frag, AgprMAIFrag<node, HasAbid, Scaled>), Scaled>,
                  MFMATable<0, "AGPR", NAME # "_e64">;
 
-      let OtherPredicates = [isGFX90APlus], Mnemonic = OpName in
-      def _vgprcd_e64 : MAIInst<OpName # "_vgprcd", !cast<VOPProfileMAI>("VOPProfileMAI_" # P # "_VCD"),
-                                !if(!or(NoDstOverlap, !eq(node, null_frag)), null_frag, VgprMAIFrag<node, HasAbid, Scaled>), Scaled>,
-                        MFMATable<0, "VGPR", NAME # "_vgprcd_e64", NAME # "_e64">;
+      let OtherPredicates = [isGFX90APlus], Mnemonic = OpName,
+          AddedComplexity = 10 in def _vgprcd_e64
+          : MAIInst<OpName#"_vgprcd",
+                    !cast<VOPProfileMAI>("VOPProfileMAI_"#P#"_VCD"),
+                    !if(!or(NoDstOverlap, !eq(node, null_frag)), null_frag,
+                        VgprMAIFrag<node, HasAbid, Scaled>),
+                    Scaled>,
+          MFMATable<0, "VGPR", NAME#"_vgprcd_e64", NAME#"_e64">;
     }
 
     if NoDstOverlap then {
@@ -931,16 +929,22 @@ multiclass MAIInst<string OpName, string P, SDPatternOperator node = null_frag,
                                  !if(!eq(node, null_frag), null_frag, AgprMAIFrag<node, HasAbid, Scaled>), Scaled>,
                          MFMATable<1, "AGPR", NAME # "_e64", NAME # "_mac_e64">;
 
-        let OtherPredicates = [isGFX90APlus] in
-        def _mac_vgprcd_e64 : MAIInst<OpName # "_mac_vgprcd", !cast<VOPProfileMAI>("VOPProfileMAI_" # P # "_VCD"),
-                                      !if(!eq(node, null_frag), null_frag, VgprMAIFrag<node, HasAbid, Scaled>), Scaled>,
-                              MFMATable<1, "VGPR", NAME # "_vgprcd_e64", NAME # "_mac_e64">;
+        let OtherPredicates = [isGFX90APlus],
+            AddedComplexity = 10 in def _mac_vgprcd_e64
+            : MAIInst<OpName#"_mac_vgprcd",
+                      !cast<VOPProfileMAI>("VOPProfileMAI_"#P#"_VCD"),
+                      !if(!eq(node, null_frag), null_frag,
+                          VgprMAIFrag<node, HasAbid, Scaled>),
+                      Scaled>,
+            MFMATable<1, "VGPR", NAME#"_vgprcd_e64", NAME#"_mac_e64">;
       }
     }
   } // End isConvergent = 1, mayRaiseFPException = 0, ReadsModeReg = 1
 }
 
-// Provide a wrapper around MAIInst that provides the appended operands from V_MFMA_LD_SCALE_B32
+// Provide a wrapper around MAIInst that provides the appended operands from
+// V_MFMA_LD_SCALE_B32 AGPR variants are never selected; VGPR is selected and
+// may later be rewritten to AGPR.
 multiclass ScaledMAIInst_mc<string OpName, string UnscaledOpName_, SDPatternOperator node> {
   defvar VariantSuffix = !subst(!toupper(OpName), "", NAME); // Drop the main opcode name prefix to get the "_fN_fM" suffix.
   defvar UnscaledOpName = UnscaledOpName_#VariantSuffix;
@@ -949,9 +953,9 @@ multiclass ScaledMAIInst_mc<string OpName, string UnscaledOpName_, SDPatternOper
 
   defvar NoDstOverlap = !cast<VOPProfileMAI>(!cast<MAIInst>(UnscaledOpName#"_e64").Pfl).NoDstOverlap;
 
-  def _e64 : ScaledMAIInst<OpName,
-        !cast<MAIInst>(UnscaledOpName#"_e64"), !if(NoDstOverlap, null_frag, AgprMAIFrag<node, HasAbid, true>)>,
-      MFMATable<0, "AGPR", NAME # "_e64">;
+  def _e64
+      : ScaledMAIInst<OpName, !cast<MAIInst>(UnscaledOpName#"_e64"), null_frag>,
+        MFMATable<0, "AGPR", NAME#"_e64">;
 
   def _vgprcd_e64 : ScaledMAIInst<OpName # "_vgprcd",
           !cast<MAIInst>(UnscaledOpName#"_vgprcd_e64"), !if(NoDstOverlap, null_frag, VgprMAIFrag<node, HasAbid, true>)>,
@@ -961,9 +965,10 @@ multiclass ScaledMAIInst_mc<string OpName, string UnscaledOpName_, SDPatternOper
    let Constraints = !if(NoDstOverlap, "$vdst = $src2", ""),
        isConvertibleToThreeAddress = NoDstOverlap,
        Mnemonic = UnscaledOpName_ in {
-     def _mac_e64 : ScaledMAIInst<OpName # "_mac",
-          !cast<MAIInst>(UnscaledOpName # "_mac_e64"), AgprMAIFrag<node, HasAbid, true>>,
-        MFMATable<1, "AGPR", NAME # "_e64">;
+     def _mac_e64
+         : ScaledMAIInst<OpName#"_mac",
+                         !cast<MAIInst>(UnscaledOpName#"_mac_e64"), null_frag>,
+           MFMATable<1, "AGPR", NAME#"_e64">;
 
      def _mac_vgprcd_e64 : ScaledMAIInst<OpName # " _mac_vgprcd",
           !cast<MAIInst>(UnscaledOpName # "_mac_vgprcd_e64"), VgprMAIFrag<node, HasAbid, true>>,

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants