Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -739,31 +739,5 @@ public static void Tanh(ReadOnlySpan<float> x, Span<float> destination)
destination[i] = MathF.Tanh(x[i]);
}
}

/// <summary>Mask used to handle remaining elements after vectorized handling of the input.</summary>
/// <remarks>
/// Logically 16 rows of 16 uints. The Nth row should be used to handle N remaining elements at the
/// end of the input, where elements in the vector prior to that will be zero'd.
/// </remarks>
private static ReadOnlySpan<uint> RemainderUInt32Mask_16x16 => new uint[]
{
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF,
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF,
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1404,22 +1404,25 @@ private static float GetFirstNaN(Vector512<float> vector) =>

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe Vector128<float> LoadRemainderMaskSingleVector128(int validItems) =>
Vector128.LoadUnsafe(
ref Unsafe.As<uint, float>(ref MemoryMarshal.GetReference(RemainderUInt32Mask_16x16)),
(uint)((validItems * 16) + 12)); // last four floats in the row
Vector128.ConditionalSelect(
Vector128.LessThan(Vector128.Create(3, 2, 1, 0), Vector128.Create(validItems)).AsSingle(),
Vector128<float>.AllBitsSet,
Vector128<float>.Zero);
Comment on lines +1407 to +1410
Copy link
Member

@tannergooding tannergooding Sep 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is going to be slower and result in a 128-bit constant (or more for other sizes) emitted at runtime per method this gets inlined into.

I think the table is ultimately better and we can get rid of it longer term using a proper JIT intrinsic.


If we really don't want the table, then achieving this with just a broadcast + comparison should be sufficient, since LessThan already produces a per-element mask of AllBitsSet (true) and Zero (false). You can use GreaterThanOrEqual if you need the mask inverted

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I'll stick with the table. I was on the fence anyway. Just don't like its bulkiness.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is, since TrailingMask wants to handle only trailing elements, it will skip processing the first n items that have already been processed.

So if 2 items remain, you have 3 < 2, 2 < 2, 1 < 2, 0 < 2 which produces Zero, Zero, AllBitsSet, AllBitsSet already


[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe Vector256<float> LoadRemainderMaskSingleVector256(int validItems) =>
Vector256.LoadUnsafe(
ref Unsafe.As<uint, float>(ref MemoryMarshal.GetReference(RemainderUInt32Mask_16x16)),
(uint)((validItems * 16) + 8)); // last eight floats in the row
Vector256.ConditionalSelect(
Vector256.LessThan(Vector256.Create(7, 6, 5, 4, 3, 2, 1, 0), Vector256.Create(validItems)).AsSingle(),
Vector256<float>.AllBitsSet,
Vector256<float>.Zero);

#if NET8_0_OR_GREATER
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe Vector512<float> LoadRemainderMaskSingleVector512(int validItems) =>
Vector512.LoadUnsafe(
ref Unsafe.As<uint, float>(ref MemoryMarshal.GetReference(RemainderUInt32Mask_16x16)),
(uint)(validItems * 16)); // all sixteen floats in the row
Vector512.ConditionalSelect(
Copy link
Member

@EgorBo EgorBo Sep 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So for AVX512 it'd looked like

var mmask16 mask = (1 << n) - 1;
var maskedAnd = Avx512.And(vec1, vec2, mask);

but we decided not to expose mask registers so if we expose these as a public API, we can intrinsify at leas the AVX512 version to be cheap 🙂 if it matters, presumably, performance of handling of trailing elements is not that much important esp for large data.

Vector512.LessThan(Vector512.Create(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0), Vector512.Create(validItems)).AsSingle(),
Vector512<float>.AllBitsSet,
Vector512<float>.Zero);
#endif

private readonly struct AddOperator : IAggregationOperator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -625,8 +625,8 @@ private static void InvokeSpanScalarSpanIntoSpan<TTernaryOperator>(
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static ref Vector<float> AsVector(ref float start, int offset) =>
ref Unsafe.As<float, Vector<float>>(
private static ref Vector<T> AsVector<T>(ref T start, int offset = 0) where T : unmanaged =>
ref Unsafe.As<T, Vector<T>>(
ref Unsafe.Add(ref start, offset));

private static unsafe bool IsNegative(float f) => *(int*)&f < 0;
Expand All @@ -640,11 +640,14 @@ private static unsafe Vector<float> LoadRemainderMaskSingleVector(int validItems
{
Debug.Assert(Vector<float>.Count is 4 or 8 or 16);

return AsVector(
ref Unsafe.As<uint, float>(ref MemoryMarshal.GetReference(RemainderUInt32Mask_16x16)),
(validItems * 16) + (16 - Vector<float>.Count));
return Vector.ConditionalSelect(
(Vector<float>)Vector.GreaterThan(AsVector(ref MemoryMarshal.GetReference<int>(s_0through15)), new Vector<int>(Vector<int>.Count - 1 - validItems)),
~Vector<float>.Zero,
Vector<float>.Zero);
}

private static readonly int[] s_0through15 = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15];

private readonly struct AddOperator : IAggregationOperator
{
public float Invoke(float x, float y) => x + y;
Expand Down