Skip to content

Commit 284aeaf

Browse files
Optimize VectorX<T>.ConditionalSelect for constant masks (#104092)
* Optimize ConditionalSelect for const mask This adds a check in the JIT for constant masks (`GT_CNS_VEC`, everything else gets lowered to it) and enables optimization to `BlendVariable` (`(v)pblendvb` instruction). This currently does not work for masks loaded from an array in a field/variable. Also this optimization is not triggered for platforms supporting AVX512F(/VL?) since it gets optimized earlier to `vpternlogd` instruction. * Cleanup code and separate it into functions * fix build * Misc fixes * Final build fixes * Address review comments * Address review comments again * address the rest of the comments * Remove scalar assertion Co-authored-by: Tanner Gooding <[email protected]> --------- Co-authored-by: Tanner Gooding <[email protected]>
1 parent 8432f0d commit 284aeaf

File tree

4 files changed

+102
-4
lines changed

4 files changed

+102
-4
lines changed

src/coreclr/jit/gentree.cpp

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29707,6 +29707,92 @@ bool GenTree::IsInvariant() const
2970729707
return OperIsConst() || OperIs(GT_LCL_ADDR) || OperIs(GT_FTN_ADDR);
2970829708
}
2970929709

29710+
//-------------------------------------------------------------------
29711+
// IsVectorPerElementMask: returns true if this node is a vector constant per-element mask
29712+
// (every element has either all bits set or none of them).
29713+
//
29714+
// Arguments:
29715+
// simdBaseType - the base type of the constant being checked.
29716+
// simdSize - the size of the SIMD type of the intrinsic.
29717+
//
29718+
// Returns:
29719+
// True if this node is a vector constant per-element mask.
29720+
//
29721+
bool GenTree::IsVectorPerElementMask(var_types simdBaseType, unsigned simdSize) const
29722+
{
29723+
#ifdef FEATURE_SIMD
29724+
if (IsCnsVec())
29725+
{
29726+
const GenTreeVecCon* vecCon = AsVecCon();
29727+
29728+
int elementCount = vecCon->ElementCount(simdSize, simdBaseType);
29729+
29730+
switch (simdBaseType)
29731+
{
29732+
case TYP_BYTE:
29733+
case TYP_UBYTE:
29734+
return ElementsAreAllBitsSetOrZero(&vecCon->gtSimdVal.u8[0], elementCount);
29735+
case TYP_SHORT:
29736+
case TYP_USHORT:
29737+
return ElementsAreAllBitsSetOrZero(&vecCon->gtSimdVal.u16[0], elementCount);
29738+
case TYP_INT:
29739+
case TYP_UINT:
29740+
case TYP_FLOAT:
29741+
return ElementsAreAllBitsSetOrZero(&vecCon->gtSimdVal.u32[0], elementCount);
29742+
case TYP_LONG:
29743+
case TYP_ULONG:
29744+
case TYP_DOUBLE:
29745+
return ElementsAreAllBitsSetOrZero(&vecCon->gtSimdVal.u64[0], elementCount);
29746+
default:
29747+
unreached();
29748+
}
29749+
}
29750+
else if (OperIsHWIntrinsic())
29751+
{
29752+
const GenTreeHWIntrinsic* intrinsic = AsHWIntrinsic();
29753+
const NamedIntrinsic intrinsicId = intrinsic->GetHWIntrinsicId();
29754+
29755+
if (HWIntrinsicInfo::ReturnsPerElementMask(intrinsicId))
29756+
{
29757+
// We directly return a per-element mask
29758+
return true;
29759+
}
29760+
29761+
bool isScalar = false;
29762+
genTreeOps oper = intrinsic->HWOperGet(&isScalar);
29763+
29764+
switch (oper)
29765+
{
29766+
case GT_AND:
29767+
case GT_AND_NOT:
29768+
case GT_OR:
29769+
case GT_XOR:
29770+
{
29771+
// We are a binary bitwise operation where both inputs are per-element masks
29772+
return intrinsic->Op(1)->IsVectorPerElementMask(simdBaseType, simdSize) &&
29773+
intrinsic->Op(2)->IsVectorPerElementMask(simdBaseType, simdSize);
29774+
}
29775+
29776+
case GT_NOT:
29777+
{
29778+
// We are an unary bitwise operation where the input is a per-element mask
29779+
return intrinsic->Op(1)->IsVectorPerElementMask(simdBaseType, simdSize);
29780+
}
29781+
29782+
default:
29783+
{
29784+
assert(!GenTreeHWIntrinsic::OperIsBitwiseHWIntrinsic(oper));
29785+
break;
29786+
}
29787+
}
29788+
29789+
return false;
29790+
}
29791+
#endif // FEATURE_SIMD
29792+
29793+
return false;
29794+
}
29795+
2971029796
//------------------------------------------------------------------------
2971129797
// IsNeverNegative: returns true if the given tree is known to be never
2971229798
// negative, i. e. the upper bit will always be zero.

src/coreclr/jit/gentree.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2318,6 +2318,7 @@ struct GenTree
23182318
bool Precedes(GenTree* other);
23192319

23202320
bool IsInvariant() const;
2321+
bool IsVectorPerElementMask(var_types simdBaseType, unsigned simdSize) const;
23212322

23222323
bool IsNeverNegative(Compiler* comp) const;
23232324
bool IsNeverNegativeOne(Compiler* comp) const;

src/coreclr/jit/lowerxarch.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3034,12 +3034,12 @@ GenTree* Lowering::LowerHWIntrinsicCndSel(GenTreeHWIntrinsic* node)
30343034
GenTree* op3 = node->Op(3);
30353035

30363036
// If the condition vector comes from a hardware intrinsic that
3037-
// returns a per-element mask (marked with HW_Flag_ReturnsPerElementMask),
3038-
// we can optimize the entire conditional select to
3039-
// a single BlendVariable instruction (if supported by the architecture)
3037+
// returns a per-element mask, we can optimize the entire
3038+
// conditional select to a single BlendVariable instruction
3039+
// (if supported by the architecture)
30403040

30413041
// First, determine if the condition is a per-element mask
3042-
if (op1->OperIsHWIntrinsic() && HWIntrinsicInfo::ReturnsPerElementMask(op1->AsHWIntrinsic()->GetHWIntrinsicId()))
3042+
if (op1->IsVectorPerElementMask(simdBaseType, simdSize))
30433043
{
30443044
// Next, determine if the target architecture supports BlendVariable
30453045
NamedIntrinsic blendVariableId = NI_Illegal;

src/coreclr/jit/simd.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,17 @@ static bool ElementsAreSame(T* array, size_t size)
1515
return true;
1616
}
1717

18+
template <typename T>
19+
static bool ElementsAreAllBitsSetOrZero(T* array, size_t size)
20+
{
21+
for (size_t i = 0; i < size; i++)
22+
{
23+
if (array[i] != static_cast<T>(0) && array[i] != static_cast<T>(~0))
24+
return false;
25+
}
26+
return true;
27+
}
28+
1829
struct simd8_t
1930
{
2031
union

0 commit comments

Comments
 (0)