Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions src/coreclr/jit/emitarm64sve.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4560,14 +4560,13 @@ void emitter::emitInsSve_R_R_R_I(instruction ins,
case INS_sve_cdot:
assert(insScalableOptsNone(sopt));
assert(insOptsScalableWords(opt));
assert(isVectorRegister(reg1)); // ddddd
assert(isVectorRegister(reg2)); // nnnnn
assert(isVectorRegister(reg3)); // mmmmm
assert(isValidRot(imm)); // rr
assert(isValidVectorElemsize(optGetSveElemsize(opt))); // xx
assert(isVectorRegister(reg1)); // ddddd
assert(isVectorRegister(reg2)); // nnnnn
assert(isVectorRegister(reg3)); // mmmmm
assert(isValidRot(emitDecodeRotationImm0_to_270(imm))); // rr
assert(isValidVectorElemsize(optGetSveElemsize(opt))); // xx

// Convert rot to bitwise representation
imm = emitEncodeRotationImm0_to_270(imm);
fmt = IF_SVE_EJ_3A;
break;

Expand Down Expand Up @@ -5764,12 +5763,12 @@ void emitter::emitInsSve_R_R_R_I_I(instruction ins,
switch (ins)
{
case INS_sve_cdot:
assert(isVectorRegister(reg1)); // ddddd
assert(isVectorRegister(reg2)); // nnnnn
assert(isLowVectorRegister(reg3)); // mmmm
assert(isValidRot(imm2)); // rr
// Convert imm2 from rotation value (0-270) to bitwise representation (0-3)
imm = (imm1 << 2) | emitEncodeRotationImm0_to_270(imm2);
assert(isVectorRegister(reg1)); // ddddd
assert(isVectorRegister(reg2)); // nnnnn
assert(isLowVectorRegister(reg3)); // mmmm
assert(isValidRot(emitDecodeRotationImm0_to_270(imm2))); // rr

imm = (imm1 << 2) | imm2;

if (opt == INS_OPTS_SCALABLE_B)
{
Expand Down
1 change: 1 addition & 0 deletions src/coreclr/jit/hwintrinsic.h
Original file line number Diff line number Diff line change
Expand Up @@ -1275,6 +1275,7 @@ struct HWIntrinsicInfo
}

case NI_Sve_MultiplyAddRotateComplexBySelectedScalar:
case NI_Sve2_DotProductRotateComplexBySelectedIndex:
{
assert(sig->numArgs == 5);
*imm1Pos = 0;
Expand Down
19 changes: 19 additions & 0 deletions src/coreclr/jit/hwintrinsicarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,7 @@ void HWIntrinsicInfo::lookupImmBounds(
break;

case NI_Sve_MultiplyAddRotateComplex:
case NI_Sve2_DotProductRotateComplex:
immLowerBound = 0;
immUpperBound = 3;
break;
Expand All @@ -510,6 +511,23 @@ void HWIntrinsicInfo::lookupImmBounds(
}
break;

case NI_Sve2_DotProductRotateComplexBySelectedIndex:
if (immNumber == 1)
{
// Bounds for rotation
immLowerBound = 0;
immUpperBound = 3;
}
else
{
// Bounds for index
assert(immNumber == 2);
assert(baseType == TYP_BYTE || baseType == TYP_SHORT);
immLowerBound = 0;
immUpperBound = (baseType == TYP_BYTE) ? 3 : 1;
}
break;

case NI_Sve_TrigonometricMultiplyAddCoefficient:
immLowerBound = 0;
immUpperBound = 7;
Expand Down Expand Up @@ -3197,6 +3215,7 @@ GenTree* Compiler::impSpecialIntrinsic(NamedIntrinsic intrinsic,
}

case NI_Sve_MultiplyAddRotateComplexBySelectedScalar:
case NI_Sve2_DotProductRotateComplexBySelectedIndex:
{
assert(sig->numArgs == 5);
assert(!isScalar);
Expand Down
110 changes: 110 additions & 0 deletions src/coreclr/jit/hwintrinsiccodegenarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2737,6 +2737,116 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
GetEmitter()->emitInsSve_R_R_R(ins, emitSize, targetReg, op3Reg, op1Reg, INS_OPTS_SCALABLE_D);
break;

case NI_Sve2_DotProductRotateComplex:
{
assert(isRMW);
assert(hasImmediateOperand);

HWIntrinsicImmOpHelper helper(this, intrin.op4, node, (targetReg != op1Reg) ? 2 : 1);

for (helper.EmitBegin(); !helper.Done(); helper.EmitCaseEnd())
{
if (targetReg != op1Reg)
{
assert(targetReg != op2Reg);

GetEmitter()->emitInsSve_R_R(INS_sve_movprfx, EA_SCALABLE, targetReg, op1Reg);
}

GetEmitter()->emitInsSve_R_R_R_I(ins, emitSize, targetReg, op2Reg, op3Reg, helper.ImmValue(), opt);
}
break;
}

case NI_Sve2_DotProductRotateComplexBySelectedIndex:
{
assert(isRMW);
assert(hasImmediateOperand);

// If both immediates are constant, we don't need a jump table
if (intrin.op4->IsCnsIntOrI() && intrin.op5->IsCnsIntOrI())
{
if (targetReg != op1Reg)
{
assert(targetReg != op2Reg);
assert(targetReg != op3Reg);
GetEmitter()->emitInsSve_R_R(INS_sve_movprfx, EA_SCALABLE, targetReg, op1Reg);
}

assert(intrin.op4->isContainedIntOrIImmed() && intrin.op5->isContainedIntOrIImmed());
GetEmitter()->emitInsSve_R_R_R_I_I(ins, emitSize, targetReg, op2Reg, op3Reg,
intrin.op4->AsIntCon()->gtIconVal,
intrin.op5->AsIntCon()->gtIconVal, opt);
}
else
{
// Use the helper to generate a table. The table can only use a single lookup value, therefore
// the two immediates index and rotation must be combined to a single value
assert(!intrin.op4->isContainedIntOrIImmed() && !intrin.op5->isContainedIntOrIImmed());
emitAttr scalarSize = emitActualTypeSize(node->GetSimdBaseType());

var_types baseType = node->GetSimdBaseType();

if (baseType == TYP_BYTE)
{
GetEmitter()->emitIns_R_R_I(INS_lsl, scalarSize, op5Reg, op5Reg, 2);
GetEmitter()->emitIns_R_R_R(INS_orr, scalarSize, op4Reg, op4Reg, op5Reg);

// index and rotation both take values 0 to 3 so must be
// combined to a single value (0 to 15)
HWIntrinsicImmOpHelper helper(this, op4Reg, 0, 15, node, (targetReg != op1Reg) ? 2 : 1);
for (helper.EmitBegin(); !helper.Done(); helper.EmitCaseEnd())
{
if (targetReg != op1Reg)
{
assert(targetReg != op2Reg);
assert(targetReg != op3Reg);
GetEmitter()->emitInsSve_R_R(INS_sve_movprfx, EA_SCALABLE, targetReg, op1Reg);
}

const int value = helper.ImmValue();
const ssize_t index = value & 3;
const ssize_t rotation = (value >> 2) & 3;
GetEmitter()->emitInsSve_R_R_R_I_I(ins, emitSize, targetReg, op2Reg, op3Reg, index,
rotation, opt);
}

GetEmitter()->emitIns_R_R_I(INS_and, scalarSize, op4Reg, op4Reg, 3);
GetEmitter()->emitIns_R_R_I(INS_lsr, scalarSize, op5Reg, op5Reg, 2);
}
else
{
assert(baseType == TYP_SHORT);
GetEmitter()->emitIns_R_R_I(INS_lsl, scalarSize, op5Reg, op5Reg, 1);
GetEmitter()->emitIns_R_R_R(INS_orr, scalarSize, op4Reg, op4Reg, op5Reg);

// index (0 to 1, in op4Reg) and rotation (0 to 3, in op5Reg) must be
// combined to a single value (0 to 7)
HWIntrinsicImmOpHelper helper(this, op4Reg, 0, 7, node, (targetReg != op1Reg) ? 2 : 1);
for (helper.EmitBegin(); !helper.Done(); helper.EmitCaseEnd())
{
if (targetReg != op1Reg)
{
assert(targetReg != op2Reg);
assert(targetReg != op3Reg);
GetEmitter()->emitInsSve_R_R(INS_sve_movprfx, EA_SCALABLE, targetReg, op1Reg);
}

const int value = helper.ImmValue();
const ssize_t index = value & 1;
const ssize_t rotation = value >> 1;
GetEmitter()->emitInsSve_R_R_R_I_I(ins, emitSize, targetReg, op2Reg, op3Reg, index,
rotation, opt);
}

GetEmitter()->emitIns_R_R_I(INS_and, scalarSize, op4Reg, op4Reg, 1);
GetEmitter()->emitIns_R_R_I(INS_lsr, scalarSize, op5Reg, op5Reg, 1);
}
}

break;
}

case NI_Sve2_SubtractWideningEven:
{
var_types returnType = node->AsHWIntrinsic()->GetSimdBaseType();
Expand Down
2 changes: 2 additions & 0 deletions src/coreclr/jit/hwintrinsiclistarm64sve.h
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,8 @@ HARDWARE_INTRINSIC(Sve2, BitwiseClearXor,
HARDWARE_INTRINSIC(Sve2, BitwiseSelect, -1, 3, {INS_sve_bsl, INS_sve_bsl, INS_sve_bsl, INS_sve_bsl, INS_sve_bsl, INS_sve_bsl, INS_sve_bsl, INS_sve_bsl, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_SpecialCodeGen|HW_Flag_HasRMWSemantics)
HARDWARE_INTRINSIC(Sve2, BitwiseSelectLeftInverted, -1, 3, {INS_sve_bsl1n, INS_sve_bsl1n, INS_sve_bsl1n, INS_sve_bsl1n, INS_sve_bsl1n, INS_sve_bsl1n, INS_sve_bsl1n, INS_sve_bsl1n, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_SpecialCodeGen|HW_Flag_HasRMWSemantics)
HARDWARE_INTRINSIC(Sve2, BitwiseSelectRightInverted, -1, 3, {INS_sve_bsl2n, INS_sve_bsl2n, INS_sve_bsl2n, INS_sve_bsl2n, INS_sve_bsl2n, INS_sve_bsl2n, INS_sve_bsl2n, INS_sve_bsl2n, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_SpecialCodeGen|HW_Flag_HasRMWSemantics)
HARDWARE_INTRINSIC(Sve2, DotProductRotateComplex, -1, 4, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_cdot, INS_invalid, INS_sve_cdot, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_HasRMWSemantics|HW_Flag_SpecialCodeGen|HW_Flag_HasImmediateOperand)
HARDWARE_INTRINSIC(Sve2, DotProductRotateComplexBySelectedIndex, -1, 5, {INS_sve_cdot, INS_invalid, INS_sve_cdot, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_HasRMWSemantics|HW_Flag_SpecialCodeGen|HW_Flag_HasImmediateOperand|HW_Flag_LowVectorOperation|HW_Flag_SpecialImport|HW_Flag_BaseTypeFromSecondArg)
HARDWARE_INTRINSIC(Sve2, FusedAddHalving, -1, -1, {INS_sve_shadd, INS_sve_uhadd, INS_sve_shadd, INS_sve_uhadd, INS_sve_shadd, INS_sve_uhadd, INS_sve_shadd, INS_sve_uhadd, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_EmbeddedMaskedOperation|HW_Flag_HasRMWSemantics|HW_Flag_LowMaskedOperation)
HARDWARE_INTRINSIC(Sve2, FusedAddRoundedHalving, -1, -1, {INS_sve_srhadd, INS_sve_urhadd, INS_sve_srhadd, INS_sve_urhadd, INS_sve_srhadd, INS_sve_urhadd, INS_sve_srhadd, INS_sve_urhadd, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_EmbeddedMaskedOperation|HW_Flag_HasRMWSemantics|HW_Flag_LowMaskedOperation)
HARDWARE_INTRINSIC(Sve2, FusedSubtractHalving, -1, -1, {INS_sve_shsub, INS_sve_uhsub, INS_sve_shsub, INS_sve_uhsub, INS_sve_shsub, INS_sve_uhsub, INS_sve_shsub, INS_sve_uhsub, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_EmbeddedMaskedOperation|HW_Flag_HasRMWSemantics|HW_Flag_LowMaskedOperation)
Expand Down
2 changes: 2 additions & 0 deletions src/coreclr/jit/lowerarmarch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4091,6 +4091,7 @@ void Lowering::ContainCheckHWIntrinsic(GenTreeHWIntrinsic* node)
case NI_Sve_FusedMultiplyAddBySelectedScalar:
case NI_Sve_FusedMultiplySubtractBySelectedScalar:
case NI_Sve_MultiplyAddRotateComplex:
case NI_Sve2_DotProductRotateComplex:
assert(hasImmediateOperand);
assert(varTypeIsIntegral(intrin.op4));
if (intrin.op4->IsCnsIntOrI())
Expand Down Expand Up @@ -4148,6 +4149,7 @@ void Lowering::ContainCheckHWIntrinsic(GenTreeHWIntrinsic* node)
break;

case NI_Sve_MultiplyAddRotateComplexBySelectedScalar:
case NI_Sve2_DotProductRotateComplexBySelectedIndex:
assert(hasImmediateOperand);
assert(varTypeIsIntegral(intrin.op4));
assert(varTypeIsIntegral(intrin.op5));
Expand Down
7 changes: 7 additions & 0 deletions src/coreclr/jit/lsraarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1710,6 +1710,7 @@ void LinearScan::BuildHWIntrinsicImmediate(GenTreeHWIntrinsic* intrinsicTree, co
break;

case NI_Sve_MultiplyAddRotateComplexBySelectedScalar:
case NI_Sve2_DotProductRotateComplexBySelectedIndex:
// This API has two immediates, one of which is used to index pairs of floats in a vector.
// For a vector width of 128 bits, this means the index's range is [0, 1],
// which means we will skip the above jump table register check,
Expand All @@ -1734,6 +1735,7 @@ void LinearScan::BuildHWIntrinsicImmediate(GenTreeHWIntrinsic* intrinsicTree, co
break;

case NI_Sve_MultiplyAddRotateComplex:
case NI_Sve2_DotProductRotateComplex:
needBranchTargetReg = !intrin.op4->isContainedIntOrIImmed();
break;

Expand Down Expand Up @@ -2164,6 +2166,7 @@ SingleTypeRegSet LinearScan::getOperandCandidates(GenTreeHWIntrinsic* intrinsicT
case NI_Sve_FusedMultiplyAddBySelectedScalar:
case NI_Sve_FusedMultiplySubtractBySelectedScalar:
case NI_Sve_MultiplyAddRotateComplexBySelectedScalar:
case NI_Sve2_DotProductRotateComplexBySelectedIndex:
case NI_Sve2_MultiplyAddBySelectedScalar:
case NI_Sve2_MultiplyBySelectedScalarWideningEvenAndAdd:
case NI_Sve2_MultiplyBySelectedScalarWideningOddAndAdd:
Expand All @@ -2185,6 +2188,10 @@ SingleTypeRegSet LinearScan::getOperandCandidates(GenTreeHWIntrinsic* intrinsicT
if (isLowVectorOpNum)
{
unsigned baseElementSize = genTypeSize(intrin.baseType);
if (intrin.id == NI_Sve2_DotProductRotateComplexBySelectedIndex)
{
baseElementSize = intrin.baseType == TYP_BYTE ? 4 : 8;
}

if (baseElementSize == 8)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1081,6 +1081,33 @@ internal Arm64() { }
/// </summary>
public static Vector<ulong> BitwiseSelectRightInverted(Vector<ulong> select, Vector<ulong> left, Vector<ulong> right) { throw new PlatformNotSupportedException(); }

// Complex dot product

/// <summary>
/// svint32_t svcdot[_s32](svint32_t op1, svint8_t op2, svint8_t op3, uint64_t imm_rotation)
/// CDOT Ztied1.S, Zop2.B, Zop3.B, #imm_rotation
/// </summary>
public static Vector<int> DotProductRotateComplex(Vector<int> op1, Vector<sbyte> op2, Vector<sbyte> op3, [ConstantExpected(Min = 0, Max = (byte)(3))] byte rotation) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svint64_t svcdot[_s64](svint64_t op1, svint16_t op2, svint16_t op3, uint64_t imm_rotation)
/// CDOT Ztied1.D, Zop2.H, Zop3.H, #imm_rotation
/// </summary>
public static Vector<long> DotProductRotateComplex(Vector<long> op1, Vector<short> op2, Vector<short> op3, [ConstantExpected(Min = 0, Max = (byte)(3))] byte rotation) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svint32_t svcdot_lane[_s32](svint32_t op1, svint8_t op2, svint8_t op3, uint64_t imm_index, uint64_t imm_rotation)
/// CDOT Ztied1.S, Zop2.B, Zop3.B[imm_index], #imm_rotation
/// </summary>
public static Vector<int> DotProductRotateComplexBySelectedIndex(Vector<int> op1, Vector<sbyte> op2, Vector<sbyte> op3, [ConstantExpected(Min = 0, Max = (byte)(3))] byte imm_index, [ConstantExpected(Min = 0, Max = (byte)(3))] byte rotation) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svint64_t svcdot_lane[_s64](svint64_t op1, svint16_t op2, svint16_t op3, uint64_t imm_index, uint64_t imm_rotation)
/// CDOT Ztied1.D, Zop2.H, Zop3.H[imm_index], #imm_rotation
/// </summary>
public static Vector<long> DotProductRotateComplexBySelectedIndex(Vector<long> op1, Vector<short> op2, Vector<short> op3, [ConstantExpected(Min = 0, Max = (byte)(1))] byte imm_index, [ConstantExpected(Min = 0, Max = (byte)(3))] byte rotation) { throw new PlatformNotSupportedException(); }


// Halving add

/// <summary>
Expand Down
Loading
Loading