Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 103 additions & 69 deletions src/libraries/System.Private.CoreLib/src/System/Text/Ascii.Equality.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public static bool Equals(ReadOnlySpan<char> left, ReadOnlySpan<char> right)
=> left.Length == right.Length
&& Equals<ushort, ushort, PlainLoader<ushort>>(ref Unsafe.As<char, ushort>(ref MemoryMarshal.GetReference(left)), ref Unsafe.As<char, ushort>(ref MemoryMarshal.GetReference(right)), (uint)right.Length);

private static bool Equals<TLeft, TRight, TLoader>(ref TLeft left, ref TRight right, nuint length)
private static unsafe bool Equals<TLeft, TRight, TLoader>(ref TLeft left, ref TRight right, nuint length)
where TLeft : unmanaged, INumberBase<TLeft>
where TRight : unmanaged, INumberBase<TRight>
where TLoader : struct, ILoader<TLeft, TRight>
Expand All @@ -50,15 +50,35 @@ private static bool Equals<TLeft, TRight, TLoader>(ref TLeft left, ref TRight ri

if (!Vector128.IsHardwareAccelerated || length < (uint)Vector128<TLeft>.Count)
{
for (nuint i = 0; i < length; ++i)
uint elementsPerLong = (uint)(sizeof(ulong) / sizeof(TLeft));
if (IntPtr.Size == 8 && Vector128.IsHardwareAccelerated && length >= elementsPerLong)
{
uint valueA = uint.CreateTruncating(Unsafe.Add(ref left, i));
uint valueB = uint.CreateTruncating(Unsafe.Add(ref right, i));

if (valueA != valueB || !UnicodeUtility.IsAsciiCodePoint(valueA))
// First 4 short or 8 byte elements
if (!TLoader.EqualAndAscii64(ref left, ref right))
{
return false;
}

if (length % elementsPerLong != 0)
{
// Last 4 short or 8 byte elements from the end
ref TLeft oneAwayFromLeftEnd = ref Unsafe.Add(ref left, length - elementsPerLong);
ref TRight oneAwayFromRightEnd = ref Unsafe.Add(ref right, length - elementsPerLong);
return TLoader.EqualAndAscii64(ref oneAwayFromLeftEnd, ref oneAwayFromRightEnd);
}
}
else
{
for (nuint i = 0; i < length; ++i)
{
uint valueA = uint.CreateTruncating(Unsafe.Add(ref left, i));
uint valueB = uint.CreateTruncating(Unsafe.Add(ref right, i));

if (valueA != valueB || !UnicodeUtility.IsAsciiCodePoint(valueA))
{
return false;
}
}
}
}
else if (Vector512.IsHardwareAccelerated && length >= (uint)Vector512<TLeft>.Count)
Expand Down Expand Up @@ -124,40 +144,31 @@ private static bool Equals<TLeft, TRight, TLoader>(ref TLeft left, ref TRight ri
else
{
ref TLeft currentLeftSearchSpace = ref left;
ref TLeft oneVectorAwayFromLeftEnd = ref Unsafe.Add(ref currentLeftSearchSpace, length - TLoader.Count128);
ref TRight currentRightSearchSpace = ref right;
ref TRight oneVectorAwayFromRightEnd = ref Unsafe.Add(ref currentRightSearchSpace, length - (uint)Vector128<TRight>.Count);

Vector128<TRight> leftValues;
Vector128<TRight> rightValues;
// Add Vector128<TLeft>.Count because TLeft == TRight
// Or we are in the Widen case where we iterate 2 * TRight.Count which is the same as TLeft.Count
Debug.Assert(Vector128<TLeft>.Count == Vector128<TRight>.Count
|| (typeof(TLoader) == typeof(WideningLoader) && Vector128<TLeft>.Count == Vector128<TRight>.Count * 2));
ref TRight oneVectorAwayFromRightEnd = ref Unsafe.Add(ref currentRightSearchSpace, length - (uint)Vector128<TLeft>.Count);

// Loop until either we've finished all elements or there's less than a vector's-worth remaining.
do
{
// it's OK to widen the bytes, it's NOT OK to narrow the chars (we could loose some information)
leftValues = TLoader.Load128(ref currentLeftSearchSpace);
rightValues = Vector128.LoadUnsafe(ref currentRightSearchSpace);

if (leftValues != rightValues || !AllCharsInVectorAreAscii(leftValues))
if (!TLoader.EqualAndAscii128(ref currentLeftSearchSpace, ref currentRightSearchSpace))
{
return false;
}

currentRightSearchSpace = ref Unsafe.Add(ref currentRightSearchSpace, (uint)Vector128<TRight>.Count);
currentLeftSearchSpace = ref Unsafe.Add(ref currentLeftSearchSpace, TLoader.Count128);
currentRightSearchSpace = ref Unsafe.Add(ref currentRightSearchSpace, Vector128<TLeft>.Count);
currentLeftSearchSpace = ref Unsafe.Add(ref currentLeftSearchSpace, Vector128<TLeft>.Count);
}
while (!Unsafe.IsAddressGreaterThan(ref currentRightSearchSpace, ref oneVectorAwayFromRightEnd));

// If any elements remain, process the last vector in the search space.
if (length % (uint)Vector128<TRight>.Count != 0)
if (length % (uint)Vector128<TLeft>.Count != 0)
{
leftValues = TLoader.Load128(ref oneVectorAwayFromLeftEnd);
rightValues = Vector128.LoadUnsafe(ref oneVectorAwayFromRightEnd);

if (leftValues != rightValues || !AllCharsInVectorAreAscii(leftValues))
{
return false;
}
ref TLeft oneVectorAwayFromLeftEnd = ref Unsafe.Add(ref left, length - (uint)Vector128<TLeft>.Count);
return TLoader.EqualAndAscii128(ref oneVectorAwayFromLeftEnd, ref oneVectorAwayFromRightEnd);
}
}

Expand Down Expand Up @@ -458,6 +469,8 @@ private interface ILoader<TLeft, TRight>
static abstract Vector128<TRight> Load128(ref TLeft ptr);
static abstract Vector256<TRight> Load256(ref TLeft ptr);
static abstract Vector512<TRight> Load512(ref TLeft ptr);
static abstract bool EqualAndAscii64(ref TLeft ptr, ref TRight right);
static abstract bool EqualAndAscii128(ref TLeft ptr, ref TRight right);
static abstract bool EqualAndAscii256(ref TLeft left, ref TRight right);
static abstract bool EqualAndAscii512(ref TLeft left, ref TRight right);
}
Expand All @@ -471,19 +484,32 @@ private interface ILoader<TLeft, TRight>
public static Vector256<T> Load256(ref T ptr) => Vector256.LoadUnsafe(ref ptr);
public static Vector512<T> Load512(ref T ptr) => Vector512.LoadUnsafe(ref ptr);

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static bool EqualAndAscii64(ref T left, ref T right)
{
ulong leftValues = Unsafe.ReadUnaligned<ulong>(ref Unsafe.As<T, byte>(ref left));
ulong rightValues = Unsafe.ReadUnaligned<ulong>(ref Unsafe.As<T, byte>(ref right));

return leftValues == rightValues && AllCharsInUInt64AreAscii<T>(leftValues);
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static bool EqualAndAscii128(ref T left, ref T right)
{
Vector128<T> leftValues = Vector128.LoadUnsafe(ref left);
Vector128<T> rightValues = Vector128.LoadUnsafe(ref right);

return leftValues == rightValues && AllCharsInVectorAreAscii(leftValues);
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
[CompExactlyDependsOn(typeof(Avx))]
public static bool EqualAndAscii256(ref T left, ref T right)
{
Vector256<T> leftValues = Vector256.LoadUnsafe(ref left);
Vector256<T> rightValues = Vector256.LoadUnsafe(ref right);

if (leftValues != rightValues || !AllCharsInVectorAreAscii(leftValues))
{
return false;
}

return true;
return leftValues == rightValues && AllCharsInVectorAreAscii(leftValues);
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
Expand All @@ -492,12 +518,7 @@ public static bool EqualAndAscii512(ref T left, ref T right)
Vector512<T> leftValues = Vector512.LoadUnsafe(ref left);
Vector512<T> rightValues = Vector512.LoadUnsafe(ref right);

if (leftValues != rightValues || !AllCharsInVectorAreAscii(leftValues))
{
return false;
}

return true;
return leftValues == rightValues && AllCharsInVectorAreAscii(leftValues);
}
}

Expand All @@ -514,15 +535,10 @@ public static Vector128<ushort> Load128(ref byte ptr)
{
return AdvSimd.ZeroExtendWideningLower(Vector64.LoadUnsafe(ref ptr));
}
else if (Sse2.IsSupported)
{
Vector128<byte> vec = Vector128.CreateScalarUnsafe(Unsafe.ReadUnaligned<long>(ref ptr)).AsByte();
return Sse2.UnpackLow(vec, Vector128<byte>.Zero).AsUInt16();
}
else
{
(Vector64<ushort> lower, Vector64<ushort> upper) = Vector64.Widen(Vector64.LoadUnsafe(ref ptr));
return Vector128.Create(lower, upper);
Vector128<byte> vec = Vector128.CreateScalarUnsafe(Unsafe.ReadUnaligned<long>(ref ptr)).AsByte();
return Vector128.WidenLower(vec).AsUInt16();
}
}

Expand All @@ -539,6 +555,42 @@ public static Vector512<ushort> Load512(ref byte ptr)
return Vector512.WidenLower(Vector256.LoadUnsafe(ref ptr).ToVector512());
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static bool EqualAndAscii64(ref byte utf8, ref ushort utf16)
{
// We widen the utf8 param so we can compare it to utf16, this doubles how much of the utf16 vector we search
ulong leftValues = Unsafe.ReadUnaligned<ulong>(ref utf8);
if (!AllBytesInUInt64AreAscii(leftValues))
{
return false;
}

Vector128<byte> vecNarrow = Vector128.CreateScalarUnsafe(leftValues).AsByte();
Vector128<ulong> vecWide = Vector128.WidenLower(vecNarrow).AsUInt64();

ulong right = Unsafe.ReadUnaligned<ulong>(ref Unsafe.As<ushort, byte>(ref utf16));
ulong rightNext = Unsafe.ReadUnaligned<ulong>(ref Unsafe.As<ushort, byte>(ref Unsafe.Add(ref utf16, sizeof(ulong) / 2)));

// A branchless version of "leftLower != right || leftUpper != rightNext"
return ((vecWide[0] ^ right) | (vecWide[1] ^ rightNext)) == 0;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static bool EqualAndAscii128(ref byte utf8, ref ushort utf16)
{
// We widen the utf8 param so we can compare it to utf16, this doubles how much of the utf16 vector we search
Debug.Assert(Vector128<byte>.Count == Vector128<ushort>.Count * 2);

Vector128<byte> leftNotWidened = Vector128.LoadUnsafe(ref utf8);
(Vector128<ushort> leftLower, Vector128<ushort> leftUpper) = Vector128.Widen(leftNotWidened);
Vector128<ushort> right = Vector128.LoadUnsafe(ref utf16);
Vector128<ushort> rightNext = Vector128.LoadUnsafe(ref utf16, (uint)Vector128<ushort>.Count);

Vector128<ushort> notAsciiCharMask = (leftNotWidened & Vector128.Create((byte)0x80)).AsUInt16();
// A branchless version of "leftLower != right || leftUpper != rightNext || !AllCharsInVectorAreAscii"
return (notAsciiCharMask | (leftLower ^ right) | (leftUpper ^ rightNext)) == Vector128<ushort>.Zero;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
[CompExactlyDependsOn(typeof(Avx))]
public static bool EqualAndAscii256(ref byte utf8, ref ushort utf16)
Expand All @@ -547,22 +599,13 @@ public static bool EqualAndAscii256(ref byte utf8, ref ushort utf16)
Debug.Assert(Vector256<byte>.Count == Vector256<ushort>.Count * 2);

Vector256<byte> leftNotWidened = Vector256.LoadUnsafe(ref utf8);
if (!AllCharsInVectorAreAscii(leftNotWidened))
{
return false;
}

(Vector256<ushort> leftLower, Vector256<ushort> leftUpper) = Vector256.Widen(leftNotWidened);
Vector256<ushort> right = Vector256.LoadUnsafe(ref utf16);
Vector256<ushort> rightNext = Vector256.LoadUnsafe(ref utf16, (uint)Vector256<ushort>.Count);

// A branchless version of "leftLower != right || leftUpper != rightNext"
if (((leftLower ^ right) | (leftUpper ^ rightNext)) != Vector256<ushort>.Zero)
{
return false;
}

return true;
Vector256<ushort> notAsciiCharMask = (leftNotWidened & Vector256.Create((byte)0x80)).AsUInt16();
// A branchless version of "leftLower != right || leftUpper != rightNext || !AllCharsInVectorAreAscii"
return (notAsciiCharMask | (leftLower ^ right) | (leftUpper ^ rightNext)) == Vector256<ushort>.Zero;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
Expand All @@ -572,22 +615,13 @@ public static bool EqualAndAscii512(ref byte utf8, ref ushort utf16)
Debug.Assert(Vector512<byte>.Count == Vector512<ushort>.Count * 2);

Vector512<byte> leftNotWidened = Vector512.LoadUnsafe(ref utf8);
if (!AllCharsInVectorAreAscii(leftNotWidened))
{
return false;
}

(Vector512<ushort> leftLower, Vector512<ushort> leftUpper) = Vector512.Widen(leftNotWidened);
Vector512<ushort> right = Vector512.LoadUnsafe(ref utf16);
Vector512<ushort> rightNext = Vector512.LoadUnsafe(ref utf16, (uint)Vector512<ushort>.Count);

// A branchless version of "leftLower != right || leftUpper != rightNext"
if (((leftLower ^ right) | (leftUpper ^ rightNext)) != Vector512<ushort>.Zero)
{
return false;
}

return true;
Vector512<ushort> notAsciiCharMask = (leftNotWidened & Vector512.Create((byte)0x80)).AsUInt16();
// A branchless version of "leftLower != right || leftUpper != rightNext || !AllCharsInVectorAreAscii"
return (notAsciiCharMask | (leftLower ^ right) | (leftUpper ^ rightNext)) == Vector512<ushort>.Zero;
}
}
}
Expand Down
52 changes: 39 additions & 13 deletions src/libraries/System.Text.Encoding/tests/Ascii/EqualsTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,35 @@ public abstract class AsciiEqualityTests<TLeft, TRight>
protected abstract bool Equals(ReadOnlySpan<TLeft> left, ReadOnlySpan<TRight> right);
protected abstract bool EqualsIgnoreCase(ReadOnlySpan<TLeft> left, ReadOnlySpan<TRight> right);

private static readonly int[] BufferLengths =
[
1,
(sizeof(long) / sizeof(short)) - 1,
(sizeof(long) / sizeof(short)),
(sizeof(long) / sizeof(short)) + 1,
Vector128<short>.Count - 1,
Vector128<short>.Count,
Vector128<short>.Count + 1,
Vector256<short>.Count - 1,
Vector256<short>.Count,
Vector256<short>.Count + 1,
Vector512<short>.Count - 1,
Vector512<short>.Count,
Vector512<short>.Count + 1
];

public static IEnumerable<object[]> ValidAsciiInputs
{
get
{
yield return new object[] { "test" };

for (char textLength = (char)0; textLength <= 127; textLength++)
foreach (int textLength in BufferLengths)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is a lot of test cases, I am not sure how long it's going to take to run all of them with debug builds.

Before your change, this test was covering strings that were 0 to 127 chars long. Now the max length is reduced to 33. I am not convinced that this is the right thing to do, please revert this particular change or convince me that it's the right thing to do.

{
yield return new object[] { new string(textLength, textLength) };
for (char chr = (char)0; chr <= 127; chr++)
{
yield return new object[] { new string(chr, textLength) };
}
}
}
}
Expand All @@ -55,15 +75,18 @@ public static IEnumerable<object[]> DifferentInputs
{
yield return new object[] { "tak", "nie" };

for (char i = (char)1; i <= 127; i++)
foreach (int textLength in BufferLengths)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above, I don't see benefits of this change.

{
if (i != '?') // ASCIIEncoding maps invalid ASCII to ?
for (char chr = (char)0; chr <= 127; chr++)
{
yield return new object[] { new string(i, i), string.Create(i, i, (destination, iteration) =>
if (chr != '?')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this comment was valuable, please restore it

Suggested change
if (chr != '?')
if (chr != '?') // ASCIIEncoding maps invalid ASCII to ?

{
destination.Fill(iteration);
destination[iteration / 2] = (char)128;
})};
yield return new object[] { new string(chr, textLength), string.Create(textLength, chr, (destination, character) =>
{
destination.Fill(character);
destination[destination.Length / 2] = (char)128;
})};
}
}
}
}
Expand Down Expand Up @@ -91,11 +114,14 @@ public static IEnumerable<object[]> EqualIgnoringCaseConsiderations
{
yield return new object[] { "aBc", "AbC" };

for (char i = (char)0; i <= 127; i++)
foreach (int textLength in BufferLengths)
{
char left = i;
char right = char.IsAsciiLetterUpper(left) ? char.ToLower(left) : char.IsAsciiLetterLower(left) ? char.ToUpper(left) : left;
yield return new object[] { new string(left, i), new string(right, i) };
for (char chr = (char)0; chr <= 127; chr++)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see the value of this change, but we should decrease the number of test cases. How about just focusing on Ascii letters? Because for other characters we would be duplicating other tests work.

{
char left = chr;
char right = char.IsAsciiLetterUpper(left) ? char.ToLower(left) : char.IsAsciiLetterLower(left) ? char.ToUpper(left) : left;
yield return new object[] { new string(left, textLength), new string(right, textLength) };
}
}
}
}
Expand All @@ -112,7 +138,7 @@ public static IEnumerable<object[]> ContainingNonAsciiCharactersBuffers
{
get
{
foreach (int length in new[] { 1, Vector128<byte>.Count - 1, Vector128<byte>.Count, Vector256<byte>.Count + 1 })
foreach (int length in BufferLengths)
{
for (int index = 0; index < length; index++)
{
Expand Down