Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[InstCombine] Simplify nonnull pointers #128111

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

dtcxzyw
Copy link
Member

@dtcxzyw dtcxzyw commented Feb 21, 2025

This patch is the follow-up of #127979. It introduces a helper simplifyNonNullOperand to avoid duplicate logic. It also addresses the one-use issue in visitLoadInst, as discussed in #127979 (comment).
The nonnull attribute is also supported. Proof: https://alive2.llvm.org/ce/z/MCKgT9

Based on #128107.

@llvmbot
Copy link
Member

llvmbot commented Feb 21, 2025

@llvm/pr-subscribers-llvm-ir

Author: Yingwei Zheng (dtcxzyw)

Changes

This patch is the follow-up of #127979. It introduces a helper simplifyNonNullOperand to avoid duplicate logic. It also addresses the one-use issue in visitLoadInst, as discussed in #127979 (comment).
The nonnull attribute is also supported. Proof: https://alive2.llvm.org/ce/z/MCKgT9

Based on #128107.


Full diff: https://github.com/llvm/llvm-project/pull/128111.diff

9 Files Affected:

  • (modified) llvm/include/llvm/IR/Function.h (+5)
  • (modified) llvm/lib/IR/Function.cpp (+11)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp (+14-4)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineInternal.h (+4)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp (+21-25)
  • (modified) llvm/lib/Transforms/InstCombine/InstructionCombining.cpp (+9-1)
  • (modified) llvm/test/Transforms/InstCombine/nonnull-select.ll (+9-20)
  • (modified) llvm/test/Transforms/PhaseOrdering/load-store-sameval.ll (+2-5)
  • (added) llvm/test/Transforms/PhaseOrdering/memset-combine.ll (+20)
diff --git a/llvm/include/llvm/IR/Function.h b/llvm/include/llvm/IR/Function.h
index 29041688124bc..7ea8673bedad1 100644
--- a/llvm/include/llvm/IR/Function.h
+++ b/llvm/include/llvm/IR/Function.h
@@ -731,6 +731,11 @@ class LLVM_ABI Function : public GlobalObject, public ilist_node<Function> {
   /// create a Function) from the Function Src to this one.
   void copyAttributesFrom(const Function *Src);
 
+  /// Return true if the return value is known to be not null.
+  /// This may be because it has the nonnull attribute, or because at least
+  /// one byte is dereferenceable and the pointer is in addrspace(0).
+  bool isReturnNonNull() const;
+
   /// deleteBody - This method deletes the body of the function, and converts
   /// the linkage to external.
   ///
diff --git a/llvm/lib/IR/Function.cpp b/llvm/lib/IR/Function.cpp
index 5666f0a53866f..d22cf65769e26 100644
--- a/llvm/lib/IR/Function.cpp
+++ b/llvm/lib/IR/Function.cpp
@@ -873,6 +873,17 @@ void Function::copyAttributesFrom(const Function *Src) {
     setPrologueData(Src->getPrologueData());
 }
 
+bool Function::isReturnNonNull() const {
+  if (hasRetAttribute(Attribute::NonNull))
+    return true;
+
+  if (AttributeSets.getRetDereferenceableBytes() > 0 &&
+      !NullPointerIsDefined(this, getReturnType()->getPointerAddressSpace()))
+    return true;
+
+  return false;
+}
+
 MemoryEffects Function::getMemoryEffects() const {
   return getAttributes().getMemoryEffects();
 }
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 400ebcf493713..c8b3d29c3aa98 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -3993,10 +3993,20 @@ Instruction *InstCombinerImpl::visitCallBase(CallBase &Call) {
   unsigned ArgNo = 0;
 
   for (Value *V : Call.args()) {
-    if (V->getType()->isPointerTy() &&
-        !Call.paramHasAttr(ArgNo, Attribute::NonNull) &&
-        isKnownNonZero(V, getSimplifyQuery().getWithInstruction(&Call)))
-      ArgNos.push_back(ArgNo);
+    if (V->getType()->isPointerTy()) {
+      // Simplify the nonnull operand before nonnull inference to avoid
+      // unnecessary queries.
+      if (Call.paramHasNonNullAttr(ArgNo, /*AllowUndefOrPoison=*/true)) {
+        if (Value *Res = simplifyNonNullOperand(V)) {
+          replaceOperand(Call, ArgNo, Res);
+          Changed = true;
+        }
+      }
+
+      if (!Call.paramHasAttr(ArgNo, Attribute::NonNull) &&
+          isKnownNonZero(V, getSimplifyQuery().getWithInstruction(&Call)))
+        ArgNos.push_back(ArgNo);
+    }
     ArgNo++;
   }
 
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 83e1da98deeda..71c80d4c401f8 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -455,6 +455,10 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
 
   Instruction *hoistFNegAboveFMulFDiv(Value *FNegOp, Instruction &FMFSource);
 
+  /// Simplify \p V given that it is known to be non-null.
+  /// Returns the simplified value if possible, otherwise returns nullptr.
+  Value *simplifyNonNullOperand(Value *V);
+
 public:
   /// Create and insert the idiom we use to indicate a block is unreachable
   /// without having to rewrite the CFG from within InstCombine.
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
index d5534c15cca76..89fc1051b18dc 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
@@ -982,6 +982,19 @@ static bool canSimplifyNullLoadOrGEP(LoadInst &LI, Value *Op) {
   return false;
 }
 
+/// TODO: Recursively simplify nonnull value to handle one-use inbounds GEPs.
+Value *InstCombinerImpl::simplifyNonNullOperand(Value *V) {
+  if (auto *Sel = dyn_cast<SelectInst>(V)) {
+    if (isa<ConstantPointerNull>(Sel->getOperand(1)))
+      return Sel->getOperand(2);
+
+    if (isa<ConstantPointerNull>(Sel->getOperand(2)))
+      return Sel->getOperand(1);
+  }
+
+  return nullptr;
+}
+
 Instruction *InstCombinerImpl::visitLoadInst(LoadInst &LI) {
   Value *Op = LI.getOperand(0);
   if (Value *Res = simplifyLoadInst(&LI, Op, SQ.getWithInstruction(&LI)))
@@ -1059,20 +1072,13 @@ Instruction *InstCombinerImpl::visitLoadInst(LoadInst &LI) {
         V2->copyMetadata(LI, Metadata::PoisonGeneratingIDs);
         return SelectInst::Create(SI->getCondition(), V1, V2);
       }
-
-      // load (select (cond, null, P)) -> load P
-      if (isa<ConstantPointerNull>(SI->getOperand(1)) &&
-          !NullPointerIsDefined(SI->getFunction(),
-                                LI.getPointerAddressSpace()))
-        return replaceOperand(LI, 0, SI->getOperand(2));
-
-      // load (select (cond, P, null)) -> load P
-      if (isa<ConstantPointerNull>(SI->getOperand(2)) &&
-          !NullPointerIsDefined(SI->getFunction(),
-                                LI.getPointerAddressSpace()))
-        return replaceOperand(LI, 0, SI->getOperand(1));
     }
   }
+
+  if (!NullPointerIsDefined(LI.getFunction(), LI.getPointerAddressSpace()))
+    if (Value *V = simplifyNonNullOperand(Op))
+      return replaceOperand(LI, 0, V);
+
   return nullptr;
 }
 
@@ -1437,19 +1443,9 @@ Instruction *InstCombinerImpl::visitStoreInst(StoreInst &SI) {
   if (isa<UndefValue>(Val))
     return eraseInstFromFunction(SI);
 
-  // TODO: Add a helper to simplify the pointer operand for all memory
-  // instructions.
-  // store val, (select (cond, null, P)) -> store val, P
-  // store val, (select (cond, P, null)) -> store val, P
-  if (!NullPointerIsDefined(SI.getFunction(), SI.getPointerAddressSpace())) {
-    if (SelectInst *Sel = dyn_cast<SelectInst>(Ptr)) {
-      if (isa<ConstantPointerNull>(Sel->getOperand(1)))
-        return replaceOperand(SI, 1, Sel->getOperand(2));
-
-      if (isa<ConstantPointerNull>(Sel->getOperand(2)))
-        return replaceOperand(SI, 1, Sel->getOperand(1));
-    }
-  }
+  if (!NullPointerIsDefined(SI.getFunction(), SI.getPointerAddressSpace()))
+    if (Value *V = simplifyNonNullOperand(Ptr))
+      return replaceOperand(SI, 1, V);
 
   return nullptr;
 }
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 5621511570b58..d3af06f63fcd2 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -3587,7 +3587,15 @@ Instruction *InstCombinerImpl::visitFree(CallInst &FI, Value *Op) {
 
 Instruction *InstCombinerImpl::visitReturnInst(ReturnInst &RI) {
   Value *RetVal = RI.getReturnValue();
-  if (!RetVal || !AttributeFuncs::isNoFPClassCompatibleType(RetVal->getType()))
+  if (!RetVal)
+    return nullptr;
+
+  if (RetVal->getType()->isPointerTy() && RI.getFunction()->isReturnNonNull()) {
+    if (Value *V = simplifyNonNullOperand(RetVal))
+      return replaceOperand(RI, 0, V);
+  }
+
+  if (!AttributeFuncs::isNoFPClassCompatibleType(RetVal->getType()))
     return nullptr;
 
   Function *F = RI.getFunction();
diff --git a/llvm/test/Transforms/InstCombine/nonnull-select.ll b/llvm/test/Transforms/InstCombine/nonnull-select.ll
index 3fab2dfb41a42..cc000b4c88164 100644
--- a/llvm/test/Transforms/InstCombine/nonnull-select.ll
+++ b/llvm/test/Transforms/InstCombine/nonnull-select.ll
@@ -5,10 +5,7 @@
 
 define nonnull ptr @pr48975(ptr %.0) {
 ; CHECK-LABEL: @pr48975(
-; CHECK-NEXT:    [[DOT1:%.*]] = load ptr, ptr [[DOT0:%.*]], align 8
-; CHECK-NEXT:    [[DOT2:%.*]] = icmp eq ptr [[DOT1]], null
-; CHECK-NEXT:    [[DOT4:%.*]] = select i1 [[DOT2]], ptr null, ptr [[DOT0]]
-; CHECK-NEXT:    ret ptr [[DOT4]]
+; CHECK-NEXT:    ret ptr [[DOT4:%.*]]
 ;
   %.1 = load ptr, ptr %.0, align 8
   %.2 = icmp eq ptr %.1, null
@@ -18,8 +15,7 @@ define nonnull ptr @pr48975(ptr %.0) {
 
 define nonnull ptr @nonnull_ret(i1 %cond, ptr %p) {
 ; CHECK-LABEL: @nonnull_ret(
-; CHECK-NEXT:    [[RES:%.*]] = select i1 [[COND:%.*]], ptr [[P:%.*]], ptr null
-; CHECK-NEXT:    ret ptr [[RES]]
+; CHECK-NEXT:    ret ptr [[RES:%.*]]
 ;
   %res = select i1 %cond, ptr %p, ptr null
   ret ptr %res
@@ -27,8 +23,7 @@ define nonnull ptr @nonnull_ret(i1 %cond, ptr %p) {
 
 define nonnull ptr @nonnull_ret2(i1 %cond, ptr %p) {
 ; CHECK-LABEL: @nonnull_ret2(
-; CHECK-NEXT:    [[RES:%.*]] = select i1 [[COND:%.*]], ptr null, ptr [[P:%.*]]
-; CHECK-NEXT:    ret ptr [[RES]]
+; CHECK-NEXT:    ret ptr [[RES:%.*]]
 ;
   %res = select i1 %cond, ptr null, ptr %p
   ret ptr %res
@@ -36,8 +31,7 @@ define nonnull ptr @nonnull_ret2(i1 %cond, ptr %p) {
 
 define nonnull noundef ptr @nonnull_noundef_ret(i1 %cond, ptr %p) {
 ; CHECK-LABEL: @nonnull_noundef_ret(
-; CHECK-NEXT:    [[RES:%.*]] = select i1 [[COND:%.*]], ptr [[P:%.*]], ptr null
-; CHECK-NEXT:    ret ptr [[RES]]
+; CHECK-NEXT:    ret ptr [[RES:%.*]]
 ;
   %res = select i1 %cond, ptr %p, ptr null
   ret ptr %res
@@ -45,8 +39,7 @@ define nonnull noundef ptr @nonnull_noundef_ret(i1 %cond, ptr %p) {
 
 define nonnull noundef ptr @nonnull_noundef_ret2(i1 %cond, ptr %p) {
 ; CHECK-LABEL: @nonnull_noundef_ret2(
-; CHECK-NEXT:    [[RES:%.*]] = select i1 [[COND:%.*]], ptr null, ptr [[P:%.*]]
-; CHECK-NEXT:    ret ptr [[RES]]
+; CHECK-NEXT:    ret ptr [[RES:%.*]]
 ;
   %res = select i1 %cond, ptr null, ptr %p
   ret ptr %res
@@ -55,8 +48,7 @@ define nonnull noundef ptr @nonnull_noundef_ret2(i1 %cond, ptr %p) {
 
 define void @nonnull_call(i1 %cond, ptr %p) {
 ; CHECK-LABEL: @nonnull_call(
-; CHECK-NEXT:    [[RES:%.*]] = select i1 [[COND:%.*]], ptr [[P:%.*]], ptr null
-; CHECK-NEXT:    call void @f(ptr nonnull [[RES]])
+; CHECK-NEXT:    call void @f(ptr nonnull [[RES:%.*]])
 ; CHECK-NEXT:    ret void
 ;
   %res = select i1 %cond, ptr %p, ptr null
@@ -66,8 +58,7 @@ define void @nonnull_call(i1 %cond, ptr %p) {
 
 define void @nonnull_call2(i1 %cond, ptr %p) {
 ; CHECK-LABEL: @nonnull_call2(
-; CHECK-NEXT:    [[RES:%.*]] = select i1 [[COND:%.*]], ptr null, ptr [[P:%.*]]
-; CHECK-NEXT:    call void @f(ptr nonnull [[RES]])
+; CHECK-NEXT:    call void @f(ptr nonnull [[RES:%.*]])
 ; CHECK-NEXT:    ret void
 ;
   %res = select i1 %cond, ptr null, ptr %p
@@ -77,8 +68,7 @@ define void @nonnull_call2(i1 %cond, ptr %p) {
 
 define void @nonnull_noundef_call(i1 %cond, ptr %p) {
 ; CHECK-LABEL: @nonnull_noundef_call(
-; CHECK-NEXT:    [[RES:%.*]] = select i1 [[COND:%.*]], ptr [[P:%.*]], ptr null
-; CHECK-NEXT:    call void @f(ptr noundef nonnull [[RES]])
+; CHECK-NEXT:    call void @f(ptr noundef nonnull [[RES:%.*]])
 ; CHECK-NEXT:    ret void
 ;
   %res = select i1 %cond, ptr %p, ptr null
@@ -88,8 +78,7 @@ define void @nonnull_noundef_call(i1 %cond, ptr %p) {
 
 define void @nonnull_noundef_call2(i1 %cond, ptr %p) {
 ; CHECK-LABEL: @nonnull_noundef_call2(
-; CHECK-NEXT:    [[RES:%.*]] = select i1 [[COND:%.*]], ptr null, ptr [[P:%.*]]
-; CHECK-NEXT:    call void @f(ptr noundef nonnull [[RES]])
+; CHECK-NEXT:    call void @f(ptr noundef nonnull [[RES:%.*]])
 ; CHECK-NEXT:    ret void
 ;
   %res = select i1 %cond, ptr null, ptr %p
diff --git a/llvm/test/Transforms/PhaseOrdering/load-store-sameval.ll b/llvm/test/Transforms/PhaseOrdering/load-store-sameval.ll
index d8ef0723cf09e..f6bf57a678786 100644
--- a/llvm/test/Transforms/PhaseOrdering/load-store-sameval.ll
+++ b/llvm/test/Transforms/PhaseOrdering/load-store-sameval.ll
@@ -1,24 +1,21 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
 ; RUN: opt -passes='instcombine,early-cse<memssa>' -S %s | FileCheck %s
 
-; FIXME: We can remove the store instruction in the exit block
 define i32 @load_store_sameval(ptr %p, i1 %cond1, i1 %cond2) {
 ; CHECK-LABEL: define i32 @load_store_sameval(
 ; CHECK-SAME: ptr [[P:%.*]], i1 [[COND1:%.*]], i1 [[COND2:%.*]]) {
 ; CHECK-NEXT:  [[ENTRY:.*:]]
-; CHECK-NEXT:    [[SPEC_SELECT:%.*]] = select i1 [[COND1]], ptr null, ptr [[P]]
-; CHECK-NEXT:    [[PRE:%.*]] = load i32, ptr [[SPEC_SELECT]], align 4
+; CHECK-NEXT:    [[PRE:%.*]] = load i32, ptr [[P]], align 4
 ; CHECK-NEXT:    br label %[[BLOCK:.*]]
 ; CHECK:       [[BLOCK]]:
 ; CHECK-NEXT:    br label %[[BLOCK2:.*]]
 ; CHECK:       [[BLOCK2]]:
 ; CHECK-NEXT:    br i1 [[COND2]], label %[[BLOCK3:.*]], label %[[EXIT:.*]]
 ; CHECK:       [[BLOCK3]]:
-; CHECK-NEXT:    [[LOAD:%.*]] = load double, ptr [[SPEC_SELECT]], align 8
+; CHECK-NEXT:    [[LOAD:%.*]] = load double, ptr [[P]], align 8
 ; CHECK-NEXT:    [[CMP:%.*]] = fcmp une double [[LOAD]], 0.000000e+00
 ; CHECK-NEXT:    br i1 [[CMP]], label %[[BLOCK]], label %[[BLOCK2]]
 ; CHECK:       [[EXIT]]:
-; CHECK-NEXT:    store i32 [[PRE]], ptr [[P]], align 4
 ; CHECK-NEXT:    ret i32 0
 ;
 entry:
diff --git a/llvm/test/Transforms/PhaseOrdering/memset-combine.ll b/llvm/test/Transforms/PhaseOrdering/memset-combine.ll
new file mode 100644
index 0000000000000..d1de11258ed91
--- /dev/null
+++ b/llvm/test/Transforms/PhaseOrdering/memset-combine.ll
@@ -0,0 +1,20 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+
+; RUN: opt < %s -passes=instcombine,memcpyopt -S | FileCheck %s
+
+; FIXME: These two memset calls should be merged into a single one.
+define void @merge_memset(ptr %p, i1 %cond) {
+; CHECK-LABEL: define void @merge_memset(
+; CHECK-SAME: ptr [[P:%.*]], i1 [[COND:%.*]]) {
+; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[COND]], ptr null, ptr [[P]]
+; CHECK-NEXT:    tail call void @llvm.memset.p0.i64(ptr noundef nonnull align 1 dereferenceable(4096) [[P]], i8 0, i64 4096, i1 false)
+; CHECK-NEXT:    [[OFF:%.*]] = getelementptr inbounds nuw i8, ptr [[SEL]], i64 4096
+; CHECK-NEXT:    tail call void @llvm.memset.p0.i64(ptr noundef nonnull align 1 dereferenceable(768) [[OFF]], i8 0, i64 768, i1 false)
+; CHECK-NEXT:    ret void
+;
+  %sel = select i1 %cond, ptr null, ptr %p
+  tail call void @llvm.memset.p0.i64(ptr noundef nonnull %sel, i8 0, i64 4096, i1 false)
+  %off = getelementptr inbounds nuw i8, ptr %sel, i64 4096
+  tail call void @llvm.memset.p0.i64(ptr noundef nonnull %off, i8 0, i64 768, i1 false)
+  ret void
+}

@llvmbot
Copy link
Member

llvmbot commented Feb 21, 2025

@llvm/pr-subscribers-llvm-transforms

Author: Yingwei Zheng (dtcxzyw)

Changes

This patch is the follow-up of #127979. It introduces a helper simplifyNonNullOperand to avoid duplicate logic. It also addresses the one-use issue in visitLoadInst, as discussed in #127979 (comment).
The nonnull attribute is also supported. Proof: https://alive2.llvm.org/ce/z/MCKgT9

Based on #128107.


Full diff: https://github.com/llvm/llvm-project/pull/128111.diff

9 Files Affected:

  • (modified) llvm/include/llvm/IR/Function.h (+5)
  • (modified) llvm/lib/IR/Function.cpp (+11)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp (+14-4)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineInternal.h (+4)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp (+21-25)
  • (modified) llvm/lib/Transforms/InstCombine/InstructionCombining.cpp (+9-1)
  • (modified) llvm/test/Transforms/InstCombine/nonnull-select.ll (+9-20)
  • (modified) llvm/test/Transforms/PhaseOrdering/load-store-sameval.ll (+2-5)
  • (added) llvm/test/Transforms/PhaseOrdering/memset-combine.ll (+20)
diff --git a/llvm/include/llvm/IR/Function.h b/llvm/include/llvm/IR/Function.h
index 29041688124bc..7ea8673bedad1 100644
--- a/llvm/include/llvm/IR/Function.h
+++ b/llvm/include/llvm/IR/Function.h
@@ -731,6 +731,11 @@ class LLVM_ABI Function : public GlobalObject, public ilist_node<Function> {
   /// create a Function) from the Function Src to this one.
   void copyAttributesFrom(const Function *Src);
 
+  /// Return true if the return value is known to be not null.
+  /// This may be because it has the nonnull attribute, or because at least
+  /// one byte is dereferenceable and the pointer is in addrspace(0).
+  bool isReturnNonNull() const;
+
   /// deleteBody - This method deletes the body of the function, and converts
   /// the linkage to external.
   ///
diff --git a/llvm/lib/IR/Function.cpp b/llvm/lib/IR/Function.cpp
index 5666f0a53866f..d22cf65769e26 100644
--- a/llvm/lib/IR/Function.cpp
+++ b/llvm/lib/IR/Function.cpp
@@ -873,6 +873,17 @@ void Function::copyAttributesFrom(const Function *Src) {
     setPrologueData(Src->getPrologueData());
 }
 
+bool Function::isReturnNonNull() const {
+  if (hasRetAttribute(Attribute::NonNull))
+    return true;
+
+  if (AttributeSets.getRetDereferenceableBytes() > 0 &&
+      !NullPointerIsDefined(this, getReturnType()->getPointerAddressSpace()))
+    return true;
+
+  return false;
+}
+
 MemoryEffects Function::getMemoryEffects() const {
   return getAttributes().getMemoryEffects();
 }
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 400ebcf493713..c8b3d29c3aa98 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -3993,10 +3993,20 @@ Instruction *InstCombinerImpl::visitCallBase(CallBase &Call) {
   unsigned ArgNo = 0;
 
   for (Value *V : Call.args()) {
-    if (V->getType()->isPointerTy() &&
-        !Call.paramHasAttr(ArgNo, Attribute::NonNull) &&
-        isKnownNonZero(V, getSimplifyQuery().getWithInstruction(&Call)))
-      ArgNos.push_back(ArgNo);
+    if (V->getType()->isPointerTy()) {
+      // Simplify the nonnull operand before nonnull inference to avoid
+      // unnecessary queries.
+      if (Call.paramHasNonNullAttr(ArgNo, /*AllowUndefOrPoison=*/true)) {
+        if (Value *Res = simplifyNonNullOperand(V)) {
+          replaceOperand(Call, ArgNo, Res);
+          Changed = true;
+        }
+      }
+
+      if (!Call.paramHasAttr(ArgNo, Attribute::NonNull) &&
+          isKnownNonZero(V, getSimplifyQuery().getWithInstruction(&Call)))
+        ArgNos.push_back(ArgNo);
+    }
     ArgNo++;
   }
 
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 83e1da98deeda..71c80d4c401f8 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -455,6 +455,10 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
 
   Instruction *hoistFNegAboveFMulFDiv(Value *FNegOp, Instruction &FMFSource);
 
+  /// Simplify \p V given that it is known to be non-null.
+  /// Returns the simplified value if possible, otherwise returns nullptr.
+  Value *simplifyNonNullOperand(Value *V);
+
 public:
   /// Create and insert the idiom we use to indicate a block is unreachable
   /// without having to rewrite the CFG from within InstCombine.
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
index d5534c15cca76..89fc1051b18dc 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
@@ -982,6 +982,19 @@ static bool canSimplifyNullLoadOrGEP(LoadInst &LI, Value *Op) {
   return false;
 }
 
+/// TODO: Recursively simplify nonnull value to handle one-use inbounds GEPs.
+Value *InstCombinerImpl::simplifyNonNullOperand(Value *V) {
+  if (auto *Sel = dyn_cast<SelectInst>(V)) {
+    if (isa<ConstantPointerNull>(Sel->getOperand(1)))
+      return Sel->getOperand(2);
+
+    if (isa<ConstantPointerNull>(Sel->getOperand(2)))
+      return Sel->getOperand(1);
+  }
+
+  return nullptr;
+}
+
 Instruction *InstCombinerImpl::visitLoadInst(LoadInst &LI) {
   Value *Op = LI.getOperand(0);
   if (Value *Res = simplifyLoadInst(&LI, Op, SQ.getWithInstruction(&LI)))
@@ -1059,20 +1072,13 @@ Instruction *InstCombinerImpl::visitLoadInst(LoadInst &LI) {
         V2->copyMetadata(LI, Metadata::PoisonGeneratingIDs);
         return SelectInst::Create(SI->getCondition(), V1, V2);
       }
-
-      // load (select (cond, null, P)) -> load P
-      if (isa<ConstantPointerNull>(SI->getOperand(1)) &&
-          !NullPointerIsDefined(SI->getFunction(),
-                                LI.getPointerAddressSpace()))
-        return replaceOperand(LI, 0, SI->getOperand(2));
-
-      // load (select (cond, P, null)) -> load P
-      if (isa<ConstantPointerNull>(SI->getOperand(2)) &&
-          !NullPointerIsDefined(SI->getFunction(),
-                                LI.getPointerAddressSpace()))
-        return replaceOperand(LI, 0, SI->getOperand(1));
     }
   }
+
+  if (!NullPointerIsDefined(LI.getFunction(), LI.getPointerAddressSpace()))
+    if (Value *V = simplifyNonNullOperand(Op))
+      return replaceOperand(LI, 0, V);
+
   return nullptr;
 }
 
@@ -1437,19 +1443,9 @@ Instruction *InstCombinerImpl::visitStoreInst(StoreInst &SI) {
   if (isa<UndefValue>(Val))
     return eraseInstFromFunction(SI);
 
-  // TODO: Add a helper to simplify the pointer operand for all memory
-  // instructions.
-  // store val, (select (cond, null, P)) -> store val, P
-  // store val, (select (cond, P, null)) -> store val, P
-  if (!NullPointerIsDefined(SI.getFunction(), SI.getPointerAddressSpace())) {
-    if (SelectInst *Sel = dyn_cast<SelectInst>(Ptr)) {
-      if (isa<ConstantPointerNull>(Sel->getOperand(1)))
-        return replaceOperand(SI, 1, Sel->getOperand(2));
-
-      if (isa<ConstantPointerNull>(Sel->getOperand(2)))
-        return replaceOperand(SI, 1, Sel->getOperand(1));
-    }
-  }
+  if (!NullPointerIsDefined(SI.getFunction(), SI.getPointerAddressSpace()))
+    if (Value *V = simplifyNonNullOperand(Ptr))
+      return replaceOperand(SI, 1, V);
 
   return nullptr;
 }
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 5621511570b58..d3af06f63fcd2 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -3587,7 +3587,15 @@ Instruction *InstCombinerImpl::visitFree(CallInst &FI, Value *Op) {
 
 Instruction *InstCombinerImpl::visitReturnInst(ReturnInst &RI) {
   Value *RetVal = RI.getReturnValue();
-  if (!RetVal || !AttributeFuncs::isNoFPClassCompatibleType(RetVal->getType()))
+  if (!RetVal)
+    return nullptr;
+
+  if (RetVal->getType()->isPointerTy() && RI.getFunction()->isReturnNonNull()) {
+    if (Value *V = simplifyNonNullOperand(RetVal))
+      return replaceOperand(RI, 0, V);
+  }
+
+  if (!AttributeFuncs::isNoFPClassCompatibleType(RetVal->getType()))
     return nullptr;
 
   Function *F = RI.getFunction();
diff --git a/llvm/test/Transforms/InstCombine/nonnull-select.ll b/llvm/test/Transforms/InstCombine/nonnull-select.ll
index 3fab2dfb41a42..cc000b4c88164 100644
--- a/llvm/test/Transforms/InstCombine/nonnull-select.ll
+++ b/llvm/test/Transforms/InstCombine/nonnull-select.ll
@@ -5,10 +5,7 @@
 
 define nonnull ptr @pr48975(ptr %.0) {
 ; CHECK-LABEL: @pr48975(
-; CHECK-NEXT:    [[DOT1:%.*]] = load ptr, ptr [[DOT0:%.*]], align 8
-; CHECK-NEXT:    [[DOT2:%.*]] = icmp eq ptr [[DOT1]], null
-; CHECK-NEXT:    [[DOT4:%.*]] = select i1 [[DOT2]], ptr null, ptr [[DOT0]]
-; CHECK-NEXT:    ret ptr [[DOT4]]
+; CHECK-NEXT:    ret ptr [[DOT4:%.*]]
 ;
   %.1 = load ptr, ptr %.0, align 8
   %.2 = icmp eq ptr %.1, null
@@ -18,8 +15,7 @@ define nonnull ptr @pr48975(ptr %.0) {
 
 define nonnull ptr @nonnull_ret(i1 %cond, ptr %p) {
 ; CHECK-LABEL: @nonnull_ret(
-; CHECK-NEXT:    [[RES:%.*]] = select i1 [[COND:%.*]], ptr [[P:%.*]], ptr null
-; CHECK-NEXT:    ret ptr [[RES]]
+; CHECK-NEXT:    ret ptr [[RES:%.*]]
 ;
   %res = select i1 %cond, ptr %p, ptr null
   ret ptr %res
@@ -27,8 +23,7 @@ define nonnull ptr @nonnull_ret(i1 %cond, ptr %p) {
 
 define nonnull ptr @nonnull_ret2(i1 %cond, ptr %p) {
 ; CHECK-LABEL: @nonnull_ret2(
-; CHECK-NEXT:    [[RES:%.*]] = select i1 [[COND:%.*]], ptr null, ptr [[P:%.*]]
-; CHECK-NEXT:    ret ptr [[RES]]
+; CHECK-NEXT:    ret ptr [[RES:%.*]]
 ;
   %res = select i1 %cond, ptr null, ptr %p
   ret ptr %res
@@ -36,8 +31,7 @@ define nonnull ptr @nonnull_ret2(i1 %cond, ptr %p) {
 
 define nonnull noundef ptr @nonnull_noundef_ret(i1 %cond, ptr %p) {
 ; CHECK-LABEL: @nonnull_noundef_ret(
-; CHECK-NEXT:    [[RES:%.*]] = select i1 [[COND:%.*]], ptr [[P:%.*]], ptr null
-; CHECK-NEXT:    ret ptr [[RES]]
+; CHECK-NEXT:    ret ptr [[RES:%.*]]
 ;
   %res = select i1 %cond, ptr %p, ptr null
   ret ptr %res
@@ -45,8 +39,7 @@ define nonnull noundef ptr @nonnull_noundef_ret(i1 %cond, ptr %p) {
 
 define nonnull noundef ptr @nonnull_noundef_ret2(i1 %cond, ptr %p) {
 ; CHECK-LABEL: @nonnull_noundef_ret2(
-; CHECK-NEXT:    [[RES:%.*]] = select i1 [[COND:%.*]], ptr null, ptr [[P:%.*]]
-; CHECK-NEXT:    ret ptr [[RES]]
+; CHECK-NEXT:    ret ptr [[RES:%.*]]
 ;
   %res = select i1 %cond, ptr null, ptr %p
   ret ptr %res
@@ -55,8 +48,7 @@ define nonnull noundef ptr @nonnull_noundef_ret2(i1 %cond, ptr %p) {
 
 define void @nonnull_call(i1 %cond, ptr %p) {
 ; CHECK-LABEL: @nonnull_call(
-; CHECK-NEXT:    [[RES:%.*]] = select i1 [[COND:%.*]], ptr [[P:%.*]], ptr null
-; CHECK-NEXT:    call void @f(ptr nonnull [[RES]])
+; CHECK-NEXT:    call void @f(ptr nonnull [[RES:%.*]])
 ; CHECK-NEXT:    ret void
 ;
   %res = select i1 %cond, ptr %p, ptr null
@@ -66,8 +58,7 @@ define void @nonnull_call(i1 %cond, ptr %p) {
 
 define void @nonnull_call2(i1 %cond, ptr %p) {
 ; CHECK-LABEL: @nonnull_call2(
-; CHECK-NEXT:    [[RES:%.*]] = select i1 [[COND:%.*]], ptr null, ptr [[P:%.*]]
-; CHECK-NEXT:    call void @f(ptr nonnull [[RES]])
+; CHECK-NEXT:    call void @f(ptr nonnull [[RES:%.*]])
 ; CHECK-NEXT:    ret void
 ;
   %res = select i1 %cond, ptr null, ptr %p
@@ -77,8 +68,7 @@ define void @nonnull_call2(i1 %cond, ptr %p) {
 
 define void @nonnull_noundef_call(i1 %cond, ptr %p) {
 ; CHECK-LABEL: @nonnull_noundef_call(
-; CHECK-NEXT:    [[RES:%.*]] = select i1 [[COND:%.*]], ptr [[P:%.*]], ptr null
-; CHECK-NEXT:    call void @f(ptr noundef nonnull [[RES]])
+; CHECK-NEXT:    call void @f(ptr noundef nonnull [[RES:%.*]])
 ; CHECK-NEXT:    ret void
 ;
   %res = select i1 %cond, ptr %p, ptr null
@@ -88,8 +78,7 @@ define void @nonnull_noundef_call(i1 %cond, ptr %p) {
 
 define void @nonnull_noundef_call2(i1 %cond, ptr %p) {
 ; CHECK-LABEL: @nonnull_noundef_call2(
-; CHECK-NEXT:    [[RES:%.*]] = select i1 [[COND:%.*]], ptr null, ptr [[P:%.*]]
-; CHECK-NEXT:    call void @f(ptr noundef nonnull [[RES]])
+; CHECK-NEXT:    call void @f(ptr noundef nonnull [[RES:%.*]])
 ; CHECK-NEXT:    ret void
 ;
   %res = select i1 %cond, ptr null, ptr %p
diff --git a/llvm/test/Transforms/PhaseOrdering/load-store-sameval.ll b/llvm/test/Transforms/PhaseOrdering/load-store-sameval.ll
index d8ef0723cf09e..f6bf57a678786 100644
--- a/llvm/test/Transforms/PhaseOrdering/load-store-sameval.ll
+++ b/llvm/test/Transforms/PhaseOrdering/load-store-sameval.ll
@@ -1,24 +1,21 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
 ; RUN: opt -passes='instcombine,early-cse<memssa>' -S %s | FileCheck %s
 
-; FIXME: We can remove the store instruction in the exit block
 define i32 @load_store_sameval(ptr %p, i1 %cond1, i1 %cond2) {
 ; CHECK-LABEL: define i32 @load_store_sameval(
 ; CHECK-SAME: ptr [[P:%.*]], i1 [[COND1:%.*]], i1 [[COND2:%.*]]) {
 ; CHECK-NEXT:  [[ENTRY:.*:]]
-; CHECK-NEXT:    [[SPEC_SELECT:%.*]] = select i1 [[COND1]], ptr null, ptr [[P]]
-; CHECK-NEXT:    [[PRE:%.*]] = load i32, ptr [[SPEC_SELECT]], align 4
+; CHECK-NEXT:    [[PRE:%.*]] = load i32, ptr [[P]], align 4
 ; CHECK-NEXT:    br label %[[BLOCK:.*]]
 ; CHECK:       [[BLOCK]]:
 ; CHECK-NEXT:    br label %[[BLOCK2:.*]]
 ; CHECK:       [[BLOCK2]]:
 ; CHECK-NEXT:    br i1 [[COND2]], label %[[BLOCK3:.*]], label %[[EXIT:.*]]
 ; CHECK:       [[BLOCK3]]:
-; CHECK-NEXT:    [[LOAD:%.*]] = load double, ptr [[SPEC_SELECT]], align 8
+; CHECK-NEXT:    [[LOAD:%.*]] = load double, ptr [[P]], align 8
 ; CHECK-NEXT:    [[CMP:%.*]] = fcmp une double [[LOAD]], 0.000000e+00
 ; CHECK-NEXT:    br i1 [[CMP]], label %[[BLOCK]], label %[[BLOCK2]]
 ; CHECK:       [[EXIT]]:
-; CHECK-NEXT:    store i32 [[PRE]], ptr [[P]], align 4
 ; CHECK-NEXT:    ret i32 0
 ;
 entry:
diff --git a/llvm/test/Transforms/PhaseOrdering/memset-combine.ll b/llvm/test/Transforms/PhaseOrdering/memset-combine.ll
new file mode 100644
index 0000000000000..d1de11258ed91
--- /dev/null
+++ b/llvm/test/Transforms/PhaseOrdering/memset-combine.ll
@@ -0,0 +1,20 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+
+; RUN: opt < %s -passes=instcombine,memcpyopt -S | FileCheck %s
+
+; FIXME: These two memset calls should be merged into a single one.
+define void @merge_memset(ptr %p, i1 %cond) {
+; CHECK-LABEL: define void @merge_memset(
+; CHECK-SAME: ptr [[P:%.*]], i1 [[COND:%.*]]) {
+; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[COND]], ptr null, ptr [[P]]
+; CHECK-NEXT:    tail call void @llvm.memset.p0.i64(ptr noundef nonnull align 1 dereferenceable(4096) [[P]], i8 0, i64 4096, i1 false)
+; CHECK-NEXT:    [[OFF:%.*]] = getelementptr inbounds nuw i8, ptr [[SEL]], i64 4096
+; CHECK-NEXT:    tail call void @llvm.memset.p0.i64(ptr noundef nonnull align 1 dereferenceable(768) [[OFF]], i8 0, i64 768, i1 false)
+; CHECK-NEXT:    ret void
+;
+  %sel = select i1 %cond, ptr null, ptr %p
+  tail call void @llvm.memset.p0.i64(ptr noundef nonnull %sel, i8 0, i64 4096, i1 false)
+  %off = getelementptr inbounds nuw i8, ptr %sel, i64 4096
+  tail call void @llvm.memset.p0.i64(ptr noundef nonnull %off, i8 0, i64 768, i1 false)
+  ret void
+}

@nikic
Copy link
Contributor

nikic commented Feb 21, 2025

replaceOperand(Call, ArgNo, Res);
Changed = true;
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can else this instead of querying nonnull again? (Will no longer infer nonnull for dereferenceable, but we shouldn't need to ?)

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants