Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
ee2c0b6
Optimize FMA codegen base on the overwritten
weilinwa Jul 20, 2021
46d0011
Improve function/var names
weilinwa Aug 27, 2021
cce4bda
Add assertions
weilinwa Aug 27, 2021
b825291
Get use of FMA with TryGetUse
weilinwa Sep 7, 2021
f615e39
Decide FMA form with two conditions, OverwrittenOpNum and isContained
weilinwa Sep 8, 2021
b698036
Fix op reg error in codegen
weilinwa Sep 10, 2021
7d9c0d6
Decide form using lastUse and isContained in no overwritten case
weilinwa Sep 15, 2021
1344d92
Clean up code
weilinwa Sep 18, 2021
029a9b5
Separate default case overwrittenOpNum==0
weilinwa Sep 20, 2021
f2a371f
Apply format patch
weilinwa Sep 29, 2021
9955389
Change variable and function names
weilinwa Oct 1, 2021
7c56653
Update regOptional for op1 and resolve some other comments
weilinwa Oct 5, 2021
1d51caa
Optimize FMA codegen base on the overwritten
weilinwa Jul 20, 2021
091133e
Improve function/var names
weilinwa Aug 27, 2021
9a6ae44
Add assertions
weilinwa Aug 27, 2021
ffcff76
Get use of FMA with TryGetUse
weilinwa Sep 7, 2021
5641f8f
Decide FMA form with two conditions, OverwrittenOpNum and isContained
weilinwa Sep 8, 2021
b7312ac
Fix op reg error in codegen
weilinwa Sep 10, 2021
a325fe3
Decide form using lastUse and isContained in no overwritten case
weilinwa Sep 15, 2021
0f950dd
Clean up code
weilinwa Sep 18, 2021
33a596d
Separate default case overwrittenOpNum==0
weilinwa Sep 20, 2021
5da9368
Apply format patch
weilinwa Sep 29, 2021
c3a9f07
Change variable and function names
weilinwa Oct 1, 2021
9e356aa
Update regOptional for op1 and resolve some other comments
weilinwa Oct 5, 2021
f8159bc
Change var names
weilinwa Oct 13, 2021
18bbe4d
Resolve merge conflicts.
weilinwa Oct 13, 2021
2ca2524
Fix jit format
weilinwa Oct 13, 2021
17bd967
Fix build node error for op1 is regOptional
weilinwa Oct 14, 2021
eed5912
Use targetReg instead of GetResultOpNumForFMA in codegen
weilinwa Oct 28, 2021
43c5034
Update variable names
weilinwa Nov 2, 2021
5ef70a5
Refactor lsra to solve lastUse status changed caused assertion failure
weilinwa Nov 7, 2021
bfa6924
Add check to prioritize contained op in lsra
weilinwa Nov 7, 2021
12f260b
Update for jit format
weilinwa Nov 7, 2021
5ca658e
Simplify code
weilinwa Nov 17, 2021
ec4ef66
Resolve comments
weilinwa Nov 17, 2021
aa93a85
Comment out assert because of lastUse change
weilinwa Nov 19, 2021
c66a018
Fix some copiesUpperBits related errors
weilinwa Nov 22, 2021
ff5a433
Merge branch 'main' into fma_opt
weilinwa Nov 22, 2021
a4657c7
Update src/coreclr/jit/lsraxarch.cpp
weilinwa Nov 30, 2021
75d7a37
Add link to the new issue
weilinwa Nov 30, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions src/coreclr/jit/gentree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19791,6 +19791,39 @@ uint16_t GenTreeLclVarCommon::GetLclOffs() const
}
}

#if defined(TARGET_XARCH) && defined(FEATURE_HW_INTRINSICS)
//------------------------------------------------------------------------
// GetOverwrittenOpNumForFMA: check if the result is written into one of the operands
//
// Return Value:
// The operand number or 0 if not overwritten.
//
unsigned GenTreeHWIntrinsic::GetOverwrittenOpNumForFMA(GenTree* use, GenTree* op1, GenTree* op2, GenTree* op3)
{
// only FMA intrinsic node should call into this function
assert(HWIntrinsicInfo::lookupIsa(gtHWIntrinsicId) == InstructionSet_FMA);
if (!use->OperIs(GT_STORE_LCL_VAR))
return 0;
GenTreeLclVarCommon* overwritten = use->AsLclVarCommon();
unsigned overwrittenLclNum = overwritten->GetLclNum();
unsigned overwrittenOpNum = 0; // 1->op1, 2->op2, 3->op3
if (op1->IsLocal() && op1->AsLclVarCommon()->GetLclNum() == overwrittenLclNum)
{
overwrittenOpNum = 1;
}
else if (op2->IsLocal() && op2->AsLclVarCommon()->GetLclNum() == overwrittenLclNum)
{
overwrittenOpNum = 2;
}
else if (op3->IsLocal() && op3->AsLclVarCommon()->GetLclNum() == overwrittenLclNum)
{
overwrittenOpNum = 3;
}

return overwrittenOpNum;
}
#endif // TARGET_XARCH && FEATURE_HW_INTRINSICS

#ifdef TARGET_ARM
//------------------------------------------------------------------------
// IsOffsetMisaligned: check if the field needs a special handling on arm.
Expand Down
2 changes: 2 additions & 0 deletions src/coreclr/jit/gentree.h
Original file line number Diff line number Diff line change
Expand Up @@ -5187,6 +5187,8 @@ struct GenTreeHWIntrinsic : public GenTreeJitIntrinsic
bool OperIsMemoryLoadOrStore() const; // Returns true for the HW Intrinsic instructions that have MemoryLoad or
// MemoryStore semantics, false otherwise

unsigned GetOverwrittenOpNumForFMA(GenTree* use, GenTree* op1, GenTree* op2, GenTree* op3);

#if DEBUGGABLE_GENTREE
GenTreeHWIntrinsic() : GenTreeJitIntrinsic()
{
Expand Down
87 changes: 57 additions & 30 deletions src/coreclr/jit/hwintrinsiccodegenxarch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2133,47 +2133,74 @@ void CodeGen::genFMAIntrinsic(GenTreeHWIntrinsic* node)
// Intrinsics with CopyUpperBits semantics cannot have op1 be contained
assert(!copiesUpperBits || !op1->isContained());

if (op2->isContained() || op2->isUsedFromSpillTemp())
unsigned overwrittenOpNum = 0;
LIR::Use use;
if (LIR::AsRange(compiler->compCurBB).TryGetUse(node, &use))
{
// 132 form: op1 = (op1 * op3) + [op2]
overwrittenOpNum = node->GetOverwrittenOpNumForFMA(use.User(), op1, op2, op3);
}

ins = (instruction)(ins - 1);
op1Reg = op1->GetRegNum();
op2Reg = op3->GetRegNum();
op3 = op2;
if (overwrittenOpNum == 1)
{
if (op2->isContained())
{
// op1 = (op1 * [op2]) + op3
// 132 form: XMM1 = (XMM1 * [XMM3]) + XMM2
ins = (instruction)(ins - 1);
op1Reg = op1->GetRegNum();
op2Reg = op3->GetRegNum();
op3 = op2;
}
else
{
// op1 = (op1 * op2) + [op3]
// 213 form: XMM1 = (XMM2 * XMM1) + [XMM3]
op1Reg = op1->GetRegNum();
op2Reg = op2->GetRegNum();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert(op3->isContained());?

}
}
else if (op1->isContained() || op1->isUsedFromSpillTemp())
else if (overwrittenOpNum == 3)
{
// 231 form: op3 = (op2 * op3) + [op1]

// 231 form: XMM1 = (XMM2 * [XMM3]) + XMM1
// One of the following:
// op3 = ([op1] * op2) + op3
// op3 = (op1 * [op2]) + op3
ins = (instruction)(ins + 1);
op1Reg = op3->GetRegNum();
op2Reg = op2->GetRegNum();
op3 = op1;
if (op1->isContained())
{
// op3 = ([op1] * op2) + op3
op2Reg = op2->GetRegNum();
op3 = op1;
}
else
{
// op3 = (op1 * [op2]) + op3
op2Reg = op1->GetRegNum();
op3 = op2;
}
}
else
{
// 213 form: op1 = (op2 * op1) + [op3]

op1Reg = op1->GetRegNum();
op2Reg = op2->GetRegNum();

isCommutative = !copiesUpperBits;
}

if (isCommutative && (op1Reg != targetReg) && (op2Reg == targetReg))
{
assert(node->isRMWHWIntrinsic(compiler));
assert(overwrittenOpNum == 2 || overwrittenOpNum == 0);
if (op1->isContained())
{
// op2 = ([op1] * op2) + op3
// 132 form: XMM1 = (XMM1 * [XMM3]) + XMM2
ins = (instruction)(ins - 1);
op1Reg = op2->GetRegNum();
op2Reg = op3->GetRegNum();
op3 = op1;

// We have "reg2 = (reg1 * reg2) +/- op3" where "reg1 != reg2" on a RMW intrinsic.
//
// For non-commutative intrinsics, we should have ensured that op2 was marked
// delay free in order to prevent it from getting assigned the same register
// as target. However, for commutative intrinsics, we can just swap the operands
// in order to have "reg2 = reg2 op reg1" which will end up producing the right code.
}
else
{
// op2 = (op1 * op2) + [op3]
// 213 form: XMM1 = (XMM2 * XMM1) + [XMM3]
op1Reg = op2->GetRegNum();
op2Reg = op1->GetRegNum();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also needs isCommutative = copiesUpperBits?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also assert(op3->isContained()); ?

}

op2Reg = op1Reg;
op1Reg = targetReg;
}

genHWIntrinsic_R_R_R_RM(ins, attr, targetReg, op1Reg, op2Reg, op3);
Expand Down
100 changes: 77 additions & 23 deletions src/coreclr/jit/lowerxarch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6314,38 +6314,92 @@ void Lowering::ContainCheckHWIntrinsic(GenTreeHWIntrinsic* node)
if ((intrinsicId >= NI_FMA_MultiplyAdd) && (intrinsicId <= NI_FMA_MultiplySubtractNegatedScalar))
{
bool supportsRegOptional = false;
unsigned overwrittenOpNum = 0;
LIR::Use use;

if (IsContainableHWIntrinsicOp(node, op3, &supportsRegOptional))
if (BlockRange().TryGetUse(node, &use))
{
// 213 form: op1 = (op2 * op1) + [op3]
MakeSrcContained(node, op3);
overwrittenOpNum = node->GetOverwrittenOpNumForFMA(use.User(), op1, op2, op3);
}
else if (IsContainableHWIntrinsicOp(node, op2, &supportsRegOptional))
{
// 132 form: op1 = (op1 * op3) + [op2]
MakeSrcContained(node, op2);
}
else if (IsContainableHWIntrinsicOp(node, op1, &supportsRegOptional))

switch (overwrittenOpNum)
{
// Intrinsics with CopyUpperBits semantics cannot have op1 be contained
case 1:
{
if (IsContainableHWIntrinsicOp(node, op2, &supportsRegOptional))
{
// op1 = (op1 * [op2]) + op3
// 132 form: XMM1 = (XMM1 * [XMM3]) + XMM2
MakeSrcContained(node, op2);
}
else if (IsContainableHWIntrinsicOp(node, op3, &supportsRegOptional))
{
// op1 = (op1 * op2) + [op3]
// 213 form: XMM1 = (XMM2 * XMM1) + [XMM3]
MakeSrcContained(node, op3);

if (!HWIntrinsicInfo::CopiesUpperBits(intrinsicId))
}
else
{
assert(supportsRegOptional);
// op1 = (op1 * op2) + [op3]
// 213 form: XMM1 = (XMM2 * XMM1) + [XMM3]
op3->SetRegOptional();
}
break;
}
case 3:
{
// 231 form: op3 = (op2 * op3) + [op1]
MakeSrcContained(node, op1);
// 231 form: XMM1 = (XMM2 * [XMM3]) + XMM1
// One of the following:
// op3 = ([op1] * op2) + op3
// op3 = (op1 * [op2]) + op3
if (IsContainableHWIntrinsicOp(node, op1, &supportsRegOptional) && !HWIntrinsicInfo::CopiesUpperBits(intrinsicId))
{
// Intrinsics with CopyUpperBits semantics cannot have op1 be contained

MakeSrcContained(node, op1);
}
else if (IsContainableHWIntrinsicOp(node, op2, &supportsRegOptional))
{
MakeSrcContained(node, op2);
}
else
{
assert(supportsRegOptional);
op2->SetRegOptional();
}
break;
}
}
else
{
assert(supportsRegOptional);
default:
{
assert(overwrittenOpNum == 2 || overwrittenOpNum == 0);
if (IsContainableHWIntrinsicOp(node, op1, &supportsRegOptional) && !HWIntrinsicInfo::CopiesUpperBits(intrinsicId))
{
// op2 = ([op1] * op2) + op3
// 132 form: XMM1 = (XMM1 * [XMM3]) + XMM2
MakeSrcContained(node, op1);
}
else if (IsContainableHWIntrinsicOp(node, op3, &supportsRegOptional))
{
// op2 = (op1 * op2) + [op3]
// 213 form: XMM1 = (XMM2 * XMM1) + [XMM3]
MakeSrcContained(node, op3);
}
else
{
assert(supportsRegOptional);

// TODO-XArch-CQ: Technically any one of the three operands can
// be reg-optional. With a limitation on op1 where
// it can only be so if CopyUpperBits is off.
// https://github.com/dotnet/runtime/issues/6358
// TODO-XArch-CQ: Technically any one of the three operands can
// be reg-optional. With a limitation on op1 where
// it can only be so if CopyUpperBits is off.
// https://github.com/dotnet/runtime/issues/6358

// 213 form: op1 = (op2 * op1) + op3
op3->SetRegOptional();
// op2 = (op1 * op2) + [op3]
// 213 form: XMM1 = (XMM2 * XMM1) + [XMM3]
op3->SetRegOptional();
}
}
}
}
else
Expand Down
80 changes: 59 additions & 21 deletions src/coreclr/jit/lsraxarch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2334,47 +2334,85 @@ int LinearScan::BuildHWIntrinsic(GenTreeHWIntrinsic* intrinsicTree)

const bool copiesUpperBits = HWIntrinsicInfo::CopiesUpperBits(intrinsicId);

unsigned overwrittenOpNum = 0;
LIR::Use use;
if (LIR::AsRange(blockSequence[curBBSeqNum]).TryGetUse(intrinsicTree, &use))
{
overwrittenOpNum = intrinsicTree->GetOverwrittenOpNumForFMA(use.User(), op1, op2, op3);
}

// Intrinsics with CopyUpperBits semantics cannot have op1 be contained
assert(!copiesUpperBits || !op1->isContained());

if (op2->isContained())
if (overwrittenOpNum == 1)
{
// 132 form: op1 = (op1 * op3) + [op2]
if (op2->isContained())
{
// op1 = (op1 * [op2]) + op3
// 132 form: XMM1 = (XMM1 * [XMM3]) + XMM2
tgtPrefUse = BuildUse(op1);

tgtPrefUse = BuildUse(op1);
srcCount += 1;
srcCount += BuildOperandUses(op2);
srcCount += BuildDelayFreeUses(op3, op1);
}
else
{
//assert(op3->isContained());
// op1 = (op1 * op2) + [op3]
// 213 form: XMM1 = (XMM2 * XMM1) + [XMM3]
tgtPrefUse = BuildUse(op1);

srcCount += 1;
srcCount += BuildOperandUses(op2);
srcCount += BuildDelayFreeUses(op3, op1);
srcCount += 1;
srcCount += op3->isContained() ? BuildOperandUses(op3) : BuildDelayFreeUses(op3, op1);
srcCount += BuildDelayFreeUses(op2, op1);

}
}
else if (op1->isContained())
else if (overwrittenOpNum == 3)
{
// 231 form: op3 = (op2 * op3) + [op1]

// 231 form: XMM1 = (XMM2 * [XMM3]) + XMM1
// One of the following:
// op3 = ([op1] * op2) + op3
// op3 = (op1 * [op2]) + op3
tgtPrefUse = BuildUse(op3);

srcCount += BuildOperandUses(op1);
srcCount += BuildDelayFreeUses(op2, op1);
srcCount += 1;
if (op1->isContained())
{
srcCount += BuildOperandUses(op1);
srcCount += BuildDelayFreeUses(op2, op3);
}
else
{
//assert(op2->isContained());
srcCount += op2->isContained() ? BuildOperandUses(op2) : BuildDelayFreeUses(op2, op3);
srcCount += BuildDelayFreeUses(op1, op3);
}

}
else
{
// 213 form: op1 = (op2 * op1) + [op3]
assert(overwrittenOpNum == 2 || overwrittenOpNum == 0);

tgtPrefUse = BuildUse(op1);
srcCount += 1;

if (copiesUpperBits)
if (op1->isContained())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment why we chose to check for the containment of op1 for cases where no operand is overwritten?

{
srcCount += BuildDelayFreeUses(op2, op1);
// op2 = ([op1] * op2) + op3
// 132 form: XMM1 = (XMM1 * [XMM3]) + XMM2
tgtPrefUse = BuildUse(op2);
srcCount += 1;
srcCount += BuildOperandUses(op1);
srcCount += BuildDelayFreeUses(op3, op2);
}
else
{
tgtPrefUse2 = BuildUse(op2);
// op2 = (op1 * op2) + [op3]
// 213 form: XMM1 = (XMM2 * XMM1) + [XMM3]

tgtPrefUse = BuildUse(op2);
srcCount += 1;
srcCount += op3->isContained() ? BuildOperandUses(op3) : BuildDelayFreeUses(op3, op1);
srcCount += BuildDelayFreeUses(op1, op2);
}

srcCount += op3->isContained() ? BuildOperandUses(op3) : BuildDelayFreeUses(op3, op1);
}

buildUses = false;
Expand Down