Skip to content

Commit fdff01f

Browse files
stephentoubmichaelgsharp
authored andcommitted
Vectorize TensorPrimitives.ConvertToSingle (dotnet#92779)
* Vectorize TensorPrimitives.ConvertToSingle * Address PR feedback
1 parent 02416c2 commit fdff01f

File tree

2 files changed

+250
-4
lines changed

2 files changed

+250
-4
lines changed

src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs

Lines changed: 248 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -353,10 +353,256 @@ public static void ConvertToSingle(ReadOnlySpan<Half> source, Span<float> destin
353353
ThrowHelper.ThrowArgument_DestinationTooShort();
354354
}
355355

356-
for (int i = 0; i < source.Length; i++)
356+
ref short sourceRef = ref Unsafe.As<Half, short>(ref MemoryMarshal.GetReference(source));
357+
ref float destinationRef = ref MemoryMarshal.GetReference(destination);
358+
int i = 0, oneVectorFromEnd;
359+
360+
#if NET8_0_OR_GREATER
361+
if (Vector512.IsHardwareAccelerated)
362+
{
363+
oneVectorFromEnd = source.Length - Vector512<short>.Count;
364+
if (i <= oneVectorFromEnd)
365+
{
366+
// Loop handling one input vector / two output vectors at a time.
367+
do
368+
{
369+
(Vector512<int> lower, Vector512<int> upper) = Vector512.Widen(Vector512.LoadUnsafe(ref sourceRef, (uint)i));
370+
HalfAsWidenedUInt32ToSingle_Vector512(lower.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)i);
371+
HalfAsWidenedUInt32ToSingle_Vector512(upper.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)(i + Vector512<float>.Count));
372+
373+
i += Vector512<short>.Count;
374+
}
375+
while (i <= oneVectorFromEnd);
376+
377+
// Handle any remaining elements with a final input vector.
378+
if (i != source.Length)
379+
{
380+
i = source.Length - Vector512<short>.Count;
381+
382+
(Vector512<int> lower, Vector512<int> upper) = Vector512.Widen(Vector512.LoadUnsafe(ref sourceRef, (uint)i));
383+
HalfAsWidenedUInt32ToSingle_Vector512(lower.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)i);
384+
HalfAsWidenedUInt32ToSingle_Vector512(upper.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)(i + Vector512<float>.Count));
385+
}
386+
387+
return;
388+
}
389+
}
390+
#endif
391+
392+
if (Vector256.IsHardwareAccelerated)
393+
{
394+
oneVectorFromEnd = source.Length - Vector256<short>.Count;
395+
if (i <= oneVectorFromEnd)
396+
{
397+
// Loop handling one input vector / two output vectors at a time.
398+
do
399+
{
400+
(Vector256<int> lower, Vector256<int> upper) = Vector256.Widen(Vector256.LoadUnsafe(ref sourceRef, (uint)i));
401+
HalfAsWidenedUInt32ToSingle_Vector256(lower.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)i);
402+
HalfAsWidenedUInt32ToSingle_Vector256(upper.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)(i + Vector256<float>.Count));
403+
404+
i += Vector256<short>.Count;
405+
}
406+
while (i <= oneVectorFromEnd);
407+
408+
// Handle any remaining elements with a final input vector.
409+
if (i != source.Length)
410+
{
411+
i = source.Length - Vector256<short>.Count;
412+
413+
(Vector256<int> lower, Vector256<int> upper) = Vector256.Widen(Vector256.LoadUnsafe(ref sourceRef, (uint)i));
414+
HalfAsWidenedUInt32ToSingle_Vector256(lower.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)i);
415+
HalfAsWidenedUInt32ToSingle_Vector256(upper.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)(i + Vector256<float>.Count));
416+
}
417+
418+
return;
419+
}
420+
}
421+
422+
if (Vector128.IsHardwareAccelerated)
357423
{
358-
destination[i] = (float)source[i];
424+
oneVectorFromEnd = source.Length - Vector128<short>.Count;
425+
if (i <= oneVectorFromEnd)
426+
{
427+
// Loop handling one input vector / two output vectors at a time.
428+
do
429+
{
430+
(Vector128<int> lower, Vector128<int> upper) = Vector128.Widen(Vector128.LoadUnsafe(ref sourceRef, (uint)i));
431+
HalfAsWidenedUInt32ToSingle_Vector128(lower.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)i);
432+
HalfAsWidenedUInt32ToSingle_Vector128(upper.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)(i + Vector128<float>.Count));
433+
434+
i += Vector128<short>.Count;
435+
}
436+
while (i <= oneVectorFromEnd);
437+
438+
// Handle any remaining elements with a final input vector.
439+
if (i != source.Length)
440+
{
441+
i = source.Length - Vector128<short>.Count;
442+
443+
(Vector128<int> lower, Vector128<int> upper) = Vector128.Widen(Vector128.LoadUnsafe(ref sourceRef, (uint)i));
444+
HalfAsWidenedUInt32ToSingle_Vector128(lower.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)i);
445+
HalfAsWidenedUInt32ToSingle_Vector128(upper.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)(i + Vector128<float>.Count));
446+
}
447+
448+
return;
449+
}
359450
}
451+
452+
while (i < source.Length)
453+
{
454+
Unsafe.Add(ref destinationRef, i) = (float)Unsafe.As<short, Half>(ref Unsafe.Add(ref sourceRef, i));
455+
i++;
456+
}
457+
458+
// This implements a vectorized version of the `explicit operator float(Half value) operator`.
459+
// See detailed description of the algorithm used here:
460+
// https://github.com/dotnet/runtime/blob/3bf40a378f00cb5bf18ff62796bc7097719b974c/src/libraries/System.Private.CoreLib/src/System/Half.cs#L1010-L1040
461+
// The cast operator converts a Half represented as uint to a float. This does the same, with an input VectorXx<uint> and an output VectorXx<float>.
462+
// The VectorXx<uint> is created by reading a vector of Halfs as a VectorXx<short> then widened to two VectorXx<int>s and cast to VectorXx<uint>s.
463+
// We loop handling one input vector at a time, producing two output float vectors.
464+
465+
#pragma warning disable IDE0059 // https://github.com/dotnet/roslyn/issues/44948
466+
const uint ExponentLowerBound = 0x3880_0000u; // The smallest positive normal number in Half, converted to Single
467+
const uint ExponentOffset = 0x3800_0000u; // BitConverter.SingleToUInt32Bits(1.0f) - ((uint)BitConverter.HalfToUInt16Bits((Half)1.0f) << 13)
468+
const uint SingleSignMask = 0x8000_0000; // float.SignMask; // Mask for sign bit in Single
469+
const uint HalfExponentMask = 0x7C00; // Mask for exponent bits in Half
470+
const uint HalfToSingleBitsMask = 0x0FFF_E000; // Mask for bits in Single converted from Half
471+
#pragma warning restore IDE0059
472+
473+
static Vector128<float> HalfAsWidenedUInt32ToSingle_Vector128(Vector128<uint> value)
474+
{
475+
// Extract sign bit of value
476+
Vector128<uint> sign = value & Vector128.Create(SingleSignMask);
477+
478+
// Copy sign bit to upper bits
479+
Vector128<uint> bitValueInProcess = value;
480+
481+
// Extract exponent bits of value (BiasedExponent is not for here as it performs unnecessary shift)
482+
Vector128<uint> offsetExponent = bitValueInProcess & Vector128.Create(HalfExponentMask);
483+
484+
// ~0u when value is subnormal, 0 otherwise
485+
Vector128<uint> subnormalMask = Vector128.Equals(offsetExponent, Vector128<uint>.Zero);
486+
487+
// ~0u when value is either Infinity or NaN, 0 otherwise
488+
Vector128<uint> infinityOrNaNMask = Vector128.Equals(offsetExponent, Vector128.Create(HalfExponentMask));
489+
490+
// 0x3880_0000u if value is subnormal, 0 otherwise
491+
Vector128<uint> maskedExponentLowerBound = subnormalMask & Vector128.Create(ExponentLowerBound);
492+
493+
// 0x3880_0000u if value is subnormal, 0x3800_0000u otherwise
494+
Vector128<uint> offsetMaskedExponentLowerBound = Vector128.Create(ExponentOffset) | maskedExponentLowerBound;
495+
496+
// Match the position of the boundary of exponent bits and fraction bits with IEEE 754 Binary32(Single)
497+
bitValueInProcess = Vector128.ShiftLeft(bitValueInProcess, 13);
498+
499+
// Double the offsetMaskedExponentLowerBound if value is either Infinity or NaN
500+
offsetMaskedExponentLowerBound = Vector128.ConditionalSelect(Vector128.Equals(infinityOrNaNMask, Vector128<uint>.Zero),
501+
offsetMaskedExponentLowerBound,
502+
Vector128.ShiftLeft(offsetMaskedExponentLowerBound, 1));
503+
504+
// Extract exponent bits and fraction bits of value
505+
bitValueInProcess &= Vector128.Create(HalfToSingleBitsMask);
506+
507+
// Adjust exponent to match the range of exponent
508+
bitValueInProcess += offsetMaskedExponentLowerBound;
509+
510+
// If value is subnormal, remove unnecessary 1 on top of fraction bits.
511+
Vector128<uint> absoluteValue = (bitValueInProcess.AsSingle() - maskedExponentLowerBound.AsSingle()).AsUInt32();
512+
513+
// Merge sign bit with rest
514+
return (absoluteValue | sign).AsSingle();
515+
}
516+
517+
static Vector256<float> HalfAsWidenedUInt32ToSingle_Vector256(Vector256<uint> value)
518+
{
519+
// Extract sign bit of value
520+
Vector256<uint> sign = value & Vector256.Create(SingleSignMask);
521+
522+
// Copy sign bit to upper bits
523+
Vector256<uint> bitValueInProcess = value;
524+
525+
// Extract exponent bits of value (BiasedExponent is not for here as it performs unnecessary shift)
526+
Vector256<uint> offsetExponent = bitValueInProcess & Vector256.Create(HalfExponentMask);
527+
528+
// ~0u when value is subnormal, 0 otherwise
529+
Vector256<uint> subnormalMask = Vector256.Equals(offsetExponent, Vector256<uint>.Zero);
530+
531+
// ~0u when value is either Infinity or NaN, 0 otherwise
532+
Vector256<uint> infinityOrNaNMask = Vector256.Equals(offsetExponent, Vector256.Create(HalfExponentMask));
533+
534+
// 0x3880_0000u if value is subnormal, 0 otherwise
535+
Vector256<uint> maskedExponentLowerBound = subnormalMask & Vector256.Create(ExponentLowerBound);
536+
537+
// 0x3880_0000u if value is subnormal, 0x3800_0000u otherwise
538+
Vector256<uint> offsetMaskedExponentLowerBound = Vector256.Create(ExponentOffset) | maskedExponentLowerBound;
539+
540+
// Match the position of the boundary of exponent bits and fraction bits with IEEE 754 Binary32(Single)
541+
bitValueInProcess = Vector256.ShiftLeft(bitValueInProcess, 13);
542+
543+
// Double the offsetMaskedExponentLowerBound if value is either Infinity or NaN
544+
offsetMaskedExponentLowerBound = Vector256.ConditionalSelect(Vector256.Equals(infinityOrNaNMask, Vector256<uint>.Zero),
545+
offsetMaskedExponentLowerBound,
546+
Vector256.ShiftLeft(offsetMaskedExponentLowerBound, 1));
547+
548+
// Extract exponent bits and fraction bits of value
549+
bitValueInProcess &= Vector256.Create(HalfToSingleBitsMask);
550+
551+
// Adjust exponent to match the range of exponent
552+
bitValueInProcess += offsetMaskedExponentLowerBound;
553+
554+
// If value is subnormal, remove unnecessary 1 on top of fraction bits.
555+
Vector256<uint> absoluteValue = (bitValueInProcess.AsSingle() - maskedExponentLowerBound.AsSingle()).AsUInt32();
556+
557+
// Merge sign bit with rest
558+
return (absoluteValue | sign).AsSingle();
559+
}
560+
561+
#if NET8_0_OR_GREATER
562+
static Vector512<float> HalfAsWidenedUInt32ToSingle_Vector512(Vector512<uint> value)
563+
{
564+
// Extract sign bit of value
565+
Vector512<uint> sign = value & Vector512.Create(SingleSignMask);
566+
567+
// Copy sign bit to upper bits
568+
Vector512<uint> bitValueInProcess = value;
569+
570+
// Extract exponent bits of value (BiasedExponent is not for here as it performs unnecessary shift)
571+
Vector512<uint> offsetExponent = bitValueInProcess & Vector512.Create(HalfExponentMask);
572+
573+
// ~0u when value is subnormal, 0 otherwise
574+
Vector512<uint> subnormalMask = Vector512.Equals(offsetExponent, Vector512<uint>.Zero);
575+
576+
// ~0u when value is either Infinity or NaN, 0 otherwise
577+
Vector512<uint> infinityOrNaNMask = Vector512.Equals(offsetExponent, Vector512.Create(HalfExponentMask));
578+
579+
// 0x3880_0000u if value is subnormal, 0 otherwise
580+
Vector512<uint> maskedExponentLowerBound = subnormalMask & Vector512.Create(ExponentLowerBound);
581+
582+
// 0x3880_0000u if value is subnormal, 0x3800_0000u otherwise
583+
Vector512<uint> offsetMaskedExponentLowerBound = Vector512.Create(ExponentOffset) | maskedExponentLowerBound;
584+
585+
// Match the position of the boundary of exponent bits and fraction bits with IEEE 754 Binary32(Single)
586+
bitValueInProcess = Vector512.ShiftLeft(bitValueInProcess, 13);
587+
588+
// Double the offsetMaskedExponentLowerBound if value is either Infinity or NaN
589+
offsetMaskedExponentLowerBound = Vector512.ConditionalSelect(Vector512.Equals(infinityOrNaNMask, Vector512<uint>.Zero),
590+
offsetMaskedExponentLowerBound,
591+
Vector512.ShiftLeft(offsetMaskedExponentLowerBound, 1));
592+
593+
// Extract exponent bits and fraction bits of value
594+
bitValueInProcess &= Vector512.Create(HalfToSingleBitsMask);
595+
596+
// Adjust exponent to match the range of exponent
597+
bitValueInProcess += offsetMaskedExponentLowerBound;
598+
599+
// If value is subnormal, remove unnecessary 1 on top of fraction bits.
600+
Vector512<uint> absoluteValue = (bitValueInProcess.AsSingle() - maskedExponentLowerBound.AsSingle()).AsUInt32();
601+
602+
// Merge sign bit with rest
603+
return (absoluteValue | sign).AsSingle();
604+
}
605+
#endif
360606
}
361607

362608
private static float CosineSimilarityCore(ReadOnlySpan<float> x, ReadOnlySpan<float> y)

src/libraries/System.Private.CoreLib/src/System/Half.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1044,15 +1044,15 @@ public static explicit operator float(Half value)
10441044
// BitConverter.SingleToUInt32Bits(1.0f) - ((uint)BitConverter.HalfToUInt16Bits((Half)1.0f) << 13)
10451045
const uint ExponentOffset = 0x3800_0000u;
10461046
// Mask for sign bit in Single
1047-
const uint FloatSignMask = float.SignMask;
1047+
const uint SingleSignMask = float.SignMask;
10481048
// Mask for exponent bits in Half
10491049
const uint HalfExponentMask = BiasedExponentMask;
10501050
// Mask for bits in Single converted from Half
10511051
const int HalfToSingleBitsMask = 0x0FFF_E000;
10521052
// Extract the internal representation of value
10531053
short valueInInt16Bits = BitConverter.HalfToInt16Bits(value);
10541054
// Extract sign bit of value
1055-
uint sign = (uint)(int)valueInInt16Bits & FloatSignMask;
1055+
uint sign = (uint)(int)valueInInt16Bits & SingleSignMask;
10561056
// Copy sign bit to upper bits
10571057
uint bitValueInProcess = (uint)valueInInt16Bits;
10581058
// Extract exponent bits of value (BiasedExponent is not for here as it performs unnecessary shift)

0 commit comments

Comments
 (0)