Skip to content

Commit f9b7826

Browse files
[release/10.0] Fix codegen for SVE Scatter*With*Offsets* and GatherVector*With*Offsets* (#119959)
* Fix SVE Scatter*With*Offsets* codegen * Fix SVE GatherVector*With*Offsets* codegen
1 parent 1e004d9 commit f9b7826

File tree

6 files changed

+365
-248
lines changed

6 files changed

+365
-248
lines changed

src/coreclr/jit/hwintrinsiccodegenarm64.cpp

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2238,11 +2238,20 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
22382238
// GatherVector...(Vector<T> mask, T* address, Vector<T2> indices)
22392239

22402240
emitAttr baseSize = emitActualTypeSize(intrin.baseType);
2241-
bool isLoadingBytes =
2242-
((ins == INS_sve_ld1b) || (ins == INS_sve_ld1sb) || (ins == INS_sve_ldff1b) ||
2243-
(ins == INS_sve_ldff1sb) || (intrin.id == NI_Sve_GatherVectorWithByteOffsetFirstFaulting) ||
2241+
bool isLoadingFromOffsets =
2242+
((intrin.id == NI_Sve_GatherVectorByteZeroExtend) ||
2243+
(intrin.id == NI_Sve_GatherVectorByteZeroExtendFirstFaulting) ||
2244+
(intrin.id == NI_Sve_GatherVectorInt16WithByteOffsetsSignExtend) ||
2245+
(intrin.id == NI_Sve_GatherVectorInt16WithByteOffsetsSignExtendFirstFaulting) ||
2246+
(intrin.id == NI_Sve_GatherVectorInt32WithByteOffsetsSignExtend) ||
2247+
(intrin.id == NI_Sve_GatherVectorInt32WithByteOffsetsSignExtendFirstFaulting) ||
2248+
(intrin.id == NI_Sve_GatherVectorSByteSignExtend) ||
2249+
(intrin.id == NI_Sve_GatherVectorSByteSignExtendFirstFaulting) ||
2250+
(intrin.id == NI_Sve_GatherVectorUInt16WithByteOffsetsZeroExtend) ||
2251+
(intrin.id == NI_Sve_GatherVectorUInt16WithByteOffsetsZeroExtendFirstFaulting) ||
2252+
(intrin.id == NI_Sve_GatherVectorUInt32WithByteOffsetsZeroExtend) ||
22442253
(intrin.id == NI_Sve_GatherVectorUInt32WithByteOffsetsZeroExtendFirstFaulting) ||
2245-
(intrin.id == NI_Sve_GatherVectorUInt16WithByteOffsetsZeroExtendFirstFaulting));
2254+
(intrin.id == NI_Sve_GatherVectorWithByteOffsetFirstFaulting));
22462255
insScalableOpts sopt = INS_SCALABLE_OPTS_NONE;
22472256

22482257
if (baseSize == EA_4BYTE)
@@ -2251,13 +2260,13 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
22512260
opt = varTypeIsUnsigned(node->GetAuxiliaryType()) ? INS_OPTS_SCALABLE_S_UXTW
22522261
: INS_OPTS_SCALABLE_S_SXTW;
22532262

2254-
sopt = isLoadingBytes ? INS_SCALABLE_OPTS_NONE : INS_SCALABLE_OPTS_MOD_N;
2263+
sopt = isLoadingFromOffsets ? INS_SCALABLE_OPTS_NONE : INS_SCALABLE_OPTS_MOD_N;
22552264
}
22562265
else
22572266
{
22582267
// Index is multiplied.
22592268
assert(baseSize == EA_8BYTE);
2260-
sopt = isLoadingBytes ? INS_SCALABLE_OPTS_NONE : INS_SCALABLE_OPTS_LSL_N;
2269+
sopt = isLoadingFromOffsets ? INS_SCALABLE_OPTS_NONE : INS_SCALABLE_OPTS_LSL_N;
22612270
}
22622271

22632272
GetEmitter()->emitIns_R_R_R_R(ins, emitSize, targetReg, op1Reg, op2Reg, op3Reg, opt, sopt);
@@ -2300,12 +2309,8 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
23002309

23012310
case NI_Sve_Scatter:
23022311
case NI_Sve_Scatter16BitNarrowing:
2303-
case NI_Sve_Scatter16BitWithByteOffsetsNarrowing:
23042312
case NI_Sve_Scatter32BitNarrowing:
2305-
case NI_Sve_Scatter32BitWithByteOffsetsNarrowing:
23062313
case NI_Sve_Scatter8BitNarrowing:
2307-
case NI_Sve_Scatter8BitWithByteOffsetsNarrowing:
2308-
case NI_Sve_ScatterWithByteOffsets:
23092314
{
23102315
if (!varTypeIsSIMD(intrin.op2->gtType))
23112316
{
@@ -2340,6 +2345,23 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
23402345
break;
23412346
}
23422347

2348+
case NI_Sve_Scatter16BitWithByteOffsetsNarrowing:
2349+
case NI_Sve_Scatter32BitWithByteOffsetsNarrowing:
2350+
case NI_Sve_Scatter8BitWithByteOffsetsNarrowing:
2351+
case NI_Sve_ScatterWithByteOffsets:
2352+
{
2353+
emitAttr baseSize = emitActualTypeSize(intrin.baseType);
2354+
2355+
if (baseSize == EA_4BYTE)
2356+
{
2357+
opt = varTypeIsUnsigned(node->GetAuxiliaryType()) ? INS_OPTS_SCALABLE_S_UXTW
2358+
: INS_OPTS_SCALABLE_S_SXTW;
2359+
}
2360+
2361+
GetEmitter()->emitIns_R_R_R_R(ins, emitSize, op4Reg, op1Reg, op2Reg, op3Reg, opt);
2362+
break;
2363+
}
2364+
23432365
case NI_Sve_StoreNarrowing:
23442366
opt = emitter::optGetSveInsOpt(emitTypeSize(intrin.baseType));
23452367
GetEmitter()->emitIns_R_R_R_I(ins, emitSize, op3Reg, op1Reg, op2Reg, 0, opt);

src/libraries/System.Private.CoreLib/src/System/Runtime/Intrinsics/Arm/Sve.cs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9074,73 +9074,73 @@ internal Arm64() { }
90749074
/// void svst1_scatter_[s64]offset[_f64](svbool_t pg, float64_t *base, svint64_t offsets, svfloat64_t data)
90759075
/// ST1D Zdata.D, Pg, [Xbase, Zoffsets.D]
90769076
/// </summary>
9077-
public static unsafe void ScatterWithByteOffsets(Vector<double> mask, double* address, Vector<long> offsets, Vector<double> data) => Scatter(mask, address, offsets, data);
9077+
public static unsafe void ScatterWithByteOffsets(Vector<double> mask, double* address, Vector<long> offsets, Vector<double> data) => ScatterWithByteOffsets(mask, address, offsets, data);
90789078

90799079
/// <summary>
90809080
/// void svst1_scatter_[u64]offset[_f64](svbool_t pg, float64_t *base, svuint64_t offsets, svfloat64_t data)
90819081
/// ST1D Zdata.D, Pg, [Xbase, Zoffsets.D]
90829082
/// </summary>
9083-
public static unsafe void ScatterWithByteOffsets(Vector<double> mask, double* address, Vector<ulong> offsets, Vector<double> data) => Scatter(mask, address, offsets, data);
9083+
public static unsafe void ScatterWithByteOffsets(Vector<double> mask, double* address, Vector<ulong> offsets, Vector<double> data) => ScatterWithByteOffsets(mask, address, offsets, data);
90849084

90859085
/// <summary>
90869086
/// void svst1_scatter_[s32]offset[_s32](svbool_t pg, int32_t *base, svint32_t offsets, svint32_t data)
90879087
/// ST1W Zdata.S, Pg, [Xbase, Zoffsets.S, SXTW]
90889088
/// </summary>
9089-
public static unsafe void ScatterWithByteOffsets(Vector<int> mask, int* address, Vector<int> offsets, Vector<int> data) => Scatter(mask, address, offsets, data);
9089+
public static unsafe void ScatterWithByteOffsets(Vector<int> mask, int* address, Vector<int> offsets, Vector<int> data) => ScatterWithByteOffsets(mask, address, offsets, data);
90909090

90919091
/// <summary>
90929092
/// void svst1_scatter_[u32]offset[_s32](svbool_t pg, int32_t *base, svuint32_t offsets, svint32_t data)
90939093
/// ST1W Zdata.S, Pg, [Xbase, Zoffsets.S, UXTW]
90949094
/// </summary>
9095-
public static unsafe void ScatterWithByteOffsets(Vector<int> mask, int* address, Vector<uint> offsets, Vector<int> data) => Scatter(mask, address, offsets, data);
9095+
public static unsafe void ScatterWithByteOffsets(Vector<int> mask, int* address, Vector<uint> offsets, Vector<int> data) => ScatterWithByteOffsets(mask, address, offsets, data);
90969096

90979097
/// <summary>
90989098
/// void svst1_scatter_[s64]offset[_s64](svbool_t pg, int64_t *base, svint64_t offsets, svint64_t data)
90999099
/// ST1D Zdata.D, Pg, [Xbase, Zoffsets.D]
91009100
/// </summary>
9101-
public static unsafe void ScatterWithByteOffsets(Vector<long> mask, long* address, Vector<long> offsets, Vector<long> data) => Scatter(mask, address, offsets, data);
9101+
public static unsafe void ScatterWithByteOffsets(Vector<long> mask, long* address, Vector<long> offsets, Vector<long> data) => ScatterWithByteOffsets(mask, address, offsets, data);
91029102

91039103
/// <summary>
91049104
/// void svst1_scatter_[u64]offset[_s64](svbool_t pg, int64_t *base, svuint64_t offsets, svint64_t data)
91059105
/// ST1D Zdata.D, Pg, [Xbase, Zoffsets.D]
91069106
/// </summary>
9107-
public static unsafe void ScatterWithByteOffsets(Vector<long> mask, long* address, Vector<ulong> offsets, Vector<long> data) => Scatter(mask, address, offsets, data);
9107+
public static unsafe void ScatterWithByteOffsets(Vector<long> mask, long* address, Vector<ulong> offsets, Vector<long> data) => ScatterWithByteOffsets(mask, address, offsets, data);
91089108

91099109
/// <summary>
91109110
/// void svst1_scatter_[s32]offset[_f32](svbool_t pg, float32_t *base, svint32_t offsets, svfloat32_t data)
91119111
/// ST1W Zdata.S, Pg, [Xbase, Zoffsets.S, SXTW]
91129112
/// </summary>
9113-
public static unsafe void ScatterWithByteOffsets(Vector<float> mask, float* address, Vector<int> offsets, Vector<float> data) => Scatter(mask, address, offsets, data);
9113+
public static unsafe void ScatterWithByteOffsets(Vector<float> mask, float* address, Vector<int> offsets, Vector<float> data) => ScatterWithByteOffsets(mask, address, offsets, data);
91149114

91159115
/// <summary>
91169116
/// void svst1_scatter_[u32]offset[_f32](svbool_t pg, float32_t *base, svuint32_t offsets, svfloat32_t data)
91179117
/// ST1W Zdata.S, Pg, [Xbase, Zoffsets.S, UXTW]
91189118
/// </summary>
9119-
public static unsafe void ScatterWithByteOffsets(Vector<float> mask, float* address, Vector<uint> offsets, Vector<float> data) => Scatter(mask, address, offsets, data);
9119+
public static unsafe void ScatterWithByteOffsets(Vector<float> mask, float* address, Vector<uint> offsets, Vector<float> data) => ScatterWithByteOffsets(mask, address, offsets, data);
91209120

91219121
/// <summary>
91229122
/// void svst1_scatter_[s32]offset[_u32](svbool_t pg, uint32_t *base, svint32_t offsets, svuint32_t data)
91239123
/// ST1W Zdata.S, Pg, [Xbase, Zoffsets.S, SXTW]
91249124
/// </summary>
9125-
public static unsafe void ScatterWithByteOffsets(Vector<uint> mask, uint* address, Vector<int> offsets, Vector<uint> data) => Scatter(mask, address, offsets, data);
9125+
public static unsafe void ScatterWithByteOffsets(Vector<uint> mask, uint* address, Vector<int> offsets, Vector<uint> data) => ScatterWithByteOffsets(mask, address, offsets, data);
91269126

91279127
/// <summary>
91289128
/// void svst1_scatter_[u32]offset[_u32](svbool_t pg, uint32_t *base, svuint32_t offsets, svuint32_t data)
91299129
/// ST1W Zdata.S, Pg, [Xbase, Zoffsets.S, UXTW]
91309130
/// </summary>
9131-
public static unsafe void ScatterWithByteOffsets(Vector<uint> mask, uint* address, Vector<uint> offsets, Vector<uint> data) => Scatter(mask, address, offsets, data);
9131+
public static unsafe void ScatterWithByteOffsets(Vector<uint> mask, uint* address, Vector<uint> offsets, Vector<uint> data) => ScatterWithByteOffsets(mask, address, offsets, data);
91329132

91339133
/// <summary>
91349134
/// void svst1_scatter_[s64]offset[_u64](svbool_t pg, uint64_t *base, svint64_t offsets, svuint64_t data)
91359135
/// ST1D Zdata.D, Pg, [Xbase, Zoffsets.D]
91369136
/// </summary>
9137-
public static unsafe void ScatterWithByteOffsets(Vector<ulong> mask, ulong* address, Vector<long> offsets, Vector<ulong> data) => Scatter(mask, address, offsets, data);
9137+
public static unsafe void ScatterWithByteOffsets(Vector<ulong> mask, ulong* address, Vector<long> offsets, Vector<ulong> data) => ScatterWithByteOffsets(mask, address, offsets, data);
91389138

91399139
/// <summary>
91409140
/// void svst1_scatter_[u64]offset[_u64](svbool_t pg, uint64_t *base, svuint64_t offsets, svuint64_t data)
91419141
/// ST1D Zdata.D, Pg, [Xbase, Zoffsets.D]
91429142
/// </summary>
9143-
public static unsafe void ScatterWithByteOffsets(Vector<ulong> mask, ulong* address, Vector<ulong> offsets, Vector<ulong> data) => Scatter(mask, address, offsets, data);
9143+
public static unsafe void ScatterWithByteOffsets(Vector<ulong> mask, ulong* address, Vector<ulong> offsets, Vector<ulong> data) => ScatterWithByteOffsets(mask, address, offsets, data);
91449144

91459145

91469146
// Write to the first-fault register

0 commit comments

Comments
 (0)