diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index cf38fc5f058f2..0dfdd9209b40e 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -742,39 +742,47 @@ static Value *foldSelectICmpLshrAshr(const ICmpInst *IC, Value *TrueVal, /// 1. The icmp predicate is inverted /// 2. The select operands are reversed /// 3. The magnitude of C2 and C1 are flipped -static Value *foldSelectICmpAndBinOp(const ICmpInst *IC, Value *TrueVal, - Value *FalseVal, - InstCombiner::BuilderTy &Builder) { +static Value *foldSelectICmpAndBinOp(Value *CondVal, Value *TrueVal, + Value *FalseVal, + InstCombiner::BuilderTy &Builder) { // Only handle integer compares. Also, if this is a vector select, we need a // vector compare. if (!TrueVal->getType()->isIntOrIntVectorTy() || - TrueVal->getType()->isVectorTy() != IC->getType()->isVectorTy()) + TrueVal->getType()->isVectorTy() != CondVal->getType()->isVectorTy()) return nullptr; - Value *CmpLHS = IC->getOperand(0); - Value *CmpRHS = IC->getOperand(1); - unsigned C1Log; bool NeedAnd = false; - CmpInst::Predicate Pred = IC->getPredicate(); - if (IC->isEquality()) { - if (!match(CmpRHS, m_Zero())) - return nullptr; + CmpPredicate Pred; + Value *CmpLHS, *CmpRHS; - const APInt *C1; - if (!match(CmpLHS, m_And(m_Value(), m_Power2(C1)))) - return nullptr; + if (match(CondVal, m_ICmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS)))) { + if (ICmpInst::isEquality(Pred)) { + if (!match(CmpRHS, m_Zero())) + return nullptr; - C1Log = C1->logBase2(); - } else { - auto Res = decomposeBitTestICmp(CmpLHS, CmpRHS, Pred); - if (!Res || !Res->Mask.isPowerOf2()) - return nullptr; + const APInt *C1; + if (!match(CmpLHS, m_And(m_Value(), m_Power2(C1)))) + return nullptr; - CmpLHS = Res->X; - Pred = Res->Pred; - C1Log = Res->Mask.logBase2(); - NeedAnd = true; + C1Log = C1->logBase2(); + } else { + auto Res = decomposeBitTestICmp(CmpLHS, CmpRHS, Pred); + if (!Res || !Res->Mask.isPowerOf2()) + return nullptr; + + CmpLHS = Res->X; + Pred = Res->Pred; + C1Log = Res->Mask.logBase2(); + NeedAnd = true; + } + } else if (auto *Trunc = dyn_cast(CondVal)) { + CmpLHS = Trunc->getOperand(0); + C1Log = 0; + Pred = ICmpInst::ICMP_NE; + NeedAnd = !Trunc->hasNoUnsignedWrap(); + } else { + return nullptr; } Value *Y, *V = CmpLHS; @@ -808,7 +816,7 @@ static Value *foldSelectICmpAndBinOp(const ICmpInst *IC, Value *TrueVal, // Make sure we don't create more instructions than we save. if ((NeedShift + NeedXor + NeedZExtTrunc + NeedAnd) > - (IC->hasOneUse() + BinOp->hasOneUse())) + (CondVal->hasOneUse() + BinOp->hasOneUse())) return nullptr; if (NeedAnd) { @@ -1986,9 +1994,6 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI, if (Instruction *V = foldSelectZeroOrOnes(ICI, TrueVal, FalseVal, Builder)) return V; - if (Value *V = foldSelectICmpAndBinOp(ICI, TrueVal, FalseVal, Builder)) - return replaceInstUsesWith(SI, V); - if (Value *V = foldSelectICmpLshrAshr(ICI, TrueVal, FalseVal, Builder)) return replaceInstUsesWith(SI, V); @@ -3946,6 +3951,9 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { if (Instruction *Result = foldSelectInstWithICmp(SI, ICI)) return Result; + if (Value *V = foldSelectICmpAndBinOp(CondVal, TrueVal, FalseVal, Builder)) + return replaceInstUsesWith(SI, V); + if (Instruction *Add = foldAddSubSelect(SI, Builder)) return Add; if (Instruction *Add = foldOverflowingAddSubSelect(SI, Builder)) diff --git a/llvm/test/Transforms/InstCombine/select-with-bitwise-ops.ll b/llvm/test/Transforms/InstCombine/select-with-bitwise-ops.ll index 67dec9178eeca..ca2e23c1d082e 100644 --- a/llvm/test/Transforms/InstCombine/select-with-bitwise-ops.ll +++ b/llvm/test/Transforms/InstCombine/select-with-bitwise-ops.ll @@ -1754,9 +1754,9 @@ define i8 @select_icmp_eq_and_1_0_lshr_tv(i8 %x, i8 %y) { define i8 @select_trunc_or_2(i8 %x, i8 %y) { ; CHECK-LABEL: @select_trunc_or_2( -; CHECK-NEXT: [[TRUNC:%.*]] = trunc i8 [[X:%.*]] to i1 -; CHECK-NEXT: [[OR:%.*]] = or i8 [[Y:%.*]], 2 -; CHECK-NEXT: [[SELECT:%.*]] = select i1 [[TRUNC]], i8 [[OR]], i8 [[Y]] +; CHECK-NEXT: [[TMP1:%.*]] = shl i8 [[X:%.*]], 1 +; CHECK-NEXT: [[TMP2:%.*]] = and i8 [[TMP1]], 2 +; CHECK-NEXT: [[SELECT:%.*]] = or i8 [[Y:%.*]], [[TMP2]] ; CHECK-NEXT: ret i8 [[SELECT]] ; %trunc = trunc i8 %x to i1 @@ -1767,9 +1767,9 @@ define i8 @select_trunc_or_2(i8 %x, i8 %y) { define i8 @select_not_trunc_or_2(i8 %x, i8 %y) { ; CHECK-LABEL: @select_not_trunc_or_2( -; CHECK-NEXT: [[TRUNC:%.*]] = trunc i8 [[X:%.*]] to i1 -; CHECK-NEXT: [[OR:%.*]] = or i8 [[Y:%.*]], 2 -; CHECK-NEXT: [[SELECT:%.*]] = select i1 [[TRUNC]], i8 [[OR]], i8 [[Y]] +; CHECK-NEXT: [[TMP1:%.*]] = shl i8 [[X:%.*]], 1 +; CHECK-NEXT: [[TMP2:%.*]] = and i8 [[TMP1]], 2 +; CHECK-NEXT: [[SELECT:%.*]] = or i8 [[Y:%.*]], [[TMP2]] ; CHECK-NEXT: ret i8 [[SELECT]] ; %trunc = trunc i8 %x to i1 @@ -1781,9 +1781,8 @@ define i8 @select_not_trunc_or_2(i8 %x, i8 %y) { define i8 @select_trunc_nuw_or_2(i8 %x, i8 %y) { ; CHECK-LABEL: @select_trunc_nuw_or_2( -; CHECK-NEXT: [[TRUNC:%.*]] = trunc nuw i8 [[X:%.*]] to i1 -; CHECK-NEXT: [[OR:%.*]] = or i8 [[Y:%.*]], 2 -; CHECK-NEXT: [[SELECT:%.*]] = select i1 [[TRUNC]], i8 [[OR]], i8 [[Y]] +; CHECK-NEXT: [[TMP1:%.*]] = shl i8 [[X:%.*]], 1 +; CHECK-NEXT: [[SELECT:%.*]] = or i8 [[Y:%.*]], [[TMP1]] ; CHECK-NEXT: ret i8 [[SELECT]] ; %trunc = trunc nuw i8 %x to i1 @@ -1794,9 +1793,9 @@ define i8 @select_trunc_nuw_or_2(i8 %x, i8 %y) { define i8 @select_trunc_nsw_or_2(i8 %x, i8 %y) { ; CHECK-LABEL: @select_trunc_nsw_or_2( -; CHECK-NEXT: [[TRUNC:%.*]] = trunc nsw i8 [[X:%.*]] to i1 -; CHECK-NEXT: [[OR:%.*]] = or i8 [[Y:%.*]], 2 -; CHECK-NEXT: [[SELECT:%.*]] = select i1 [[TRUNC]], i8 [[OR]], i8 [[Y]] +; CHECK-NEXT: [[TMP1:%.*]] = shl i8 [[X:%.*]], 1 +; CHECK-NEXT: [[TMP2:%.*]] = and i8 [[TMP1]], 2 +; CHECK-NEXT: [[SELECT:%.*]] = or i8 [[Y:%.*]], [[TMP2]] ; CHECK-NEXT: ret i8 [[SELECT]] ; %trunc = trunc nsw i8 %x to i1 @@ -1807,9 +1806,9 @@ define i8 @select_trunc_nsw_or_2(i8 %x, i8 %y) { define <2 x i8> @select_trunc_or_2_vec(<2 x i8> %x, <2 x i8> %y) { ; CHECK-LABEL: @select_trunc_or_2_vec( -; CHECK-NEXT: [[TRUNC:%.*]] = trunc <2 x i8> [[X:%.*]] to <2 x i1> -; CHECK-NEXT: [[OR:%.*]] = or <2 x i8> [[Y:%.*]], splat (i8 2) -; CHECK-NEXT: [[SELECT:%.*]] = select <2 x i1> [[TRUNC]], <2 x i8> [[OR]], <2 x i8> [[Y]] +; CHECK-NEXT: [[TMP1:%.*]] = shl <2 x i8> [[X:%.*]], splat (i8 1) +; CHECK-NEXT: [[TMP2:%.*]] = and <2 x i8> [[TMP1]], splat (i8 2) +; CHECK-NEXT: [[SELECT:%.*]] = or <2 x i8> [[Y:%.*]], [[TMP2]] ; CHECK-NEXT: ret <2 x i8> [[SELECT]] ; %trunc = trunc <2 x i8> %x to <2 x i1>