Skip to content

Commit 4088f05

Browse files
stephentoubmichaelgsharp
authored andcommitted
Vectorize TensorPrimitives.ConvertToHalf (dotnet#92715)
1 parent 50e3948 commit 4088f05

File tree

2 files changed

+344
-4
lines changed

2 files changed

+344
-4
lines changed

src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs

Lines changed: 293 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,301 @@ public static void ConvertToHalf(ReadOnlySpan<float> source, Span<Half> destinat
3131
ThrowHelper.ThrowArgument_DestinationTooShort();
3232
}
3333

34-
for (int i = 0; i < source.Length; i++)
34+
ref float sourceRef = ref MemoryMarshal.GetReference(source);
35+
ref ushort destinationRef = ref Unsafe.As<Half, ushort>(ref MemoryMarshal.GetReference(destination));
36+
int i = 0, twoVectorsFromEnd;
37+
38+
#if NET8_0_OR_GREATER
39+
if (Vector512.IsHardwareAccelerated)
3540
{
36-
destination[i] = (Half)source[i];
41+
twoVectorsFromEnd = source.Length - (Vector512<float>.Count * 2);
42+
if (i <= twoVectorsFromEnd)
43+
{
44+
// Loop handling two input vectors / one output vector at a time.
45+
do
46+
{
47+
Vector512<uint> lower = SingleToHalfAsWidenedUInt32_Vector512(Vector512.LoadUnsafe(ref sourceRef, (uint)i));
48+
Vector512<uint> upper = SingleToHalfAsWidenedUInt32_Vector512(Vector512.LoadUnsafe(ref sourceRef, (uint)(i + Vector512<float>.Count)));
49+
Vector512.Narrow(lower, upper).StoreUnsafe(ref destinationRef, (uint)i);
50+
51+
i += Vector512<float>.Count * 2;
52+
}
53+
while (i <= twoVectorsFromEnd);
54+
55+
// Handle any remaining elements with final vectors.
56+
if (i != source.Length)
57+
{
58+
i = source.Length - (Vector512<float>.Count * 2);
59+
60+
Vector512<uint> lower = SingleToHalfAsWidenedUInt32_Vector512(Vector512.LoadUnsafe(ref sourceRef, (uint)i));
61+
Vector512<uint> upper = SingleToHalfAsWidenedUInt32_Vector512(Vector512.LoadUnsafe(ref sourceRef, (uint)(i + Vector512<float>.Count)));
62+
Vector512.Narrow(lower, upper).StoreUnsafe(ref destinationRef, (uint)i);
63+
}
64+
65+
return;
66+
}
3767
}
68+
#endif
69+
70+
if (Vector256.IsHardwareAccelerated)
71+
{
72+
twoVectorsFromEnd = source.Length - (Vector256<float>.Count * 2);
73+
if (i <= twoVectorsFromEnd)
74+
{
75+
// Loop handling two input vectors / one output vector at a time.
76+
do
77+
{
78+
Vector256<uint> lower = SingleToHalfAsWidenedUInt32_Vector256(Vector256.LoadUnsafe(ref sourceRef, (uint)i));
79+
Vector256<uint> upper = SingleToHalfAsWidenedUInt32_Vector256(Vector256.LoadUnsafe(ref sourceRef, (uint)(i + Vector256<float>.Count)));
80+
Vector256<ushort> halfs = Vector256.Narrow(lower, upper);
81+
halfs.StoreUnsafe(ref destinationRef, (uint)i);
82+
83+
i += Vector256<float>.Count * 2;
84+
}
85+
while (i <= twoVectorsFromEnd);
86+
87+
// Handle any remaining elements with final vectors.
88+
if (i != source.Length)
89+
{
90+
i = source.Length - (Vector256<float>.Count * 2);
91+
92+
Vector256<uint> lower = SingleToHalfAsWidenedUInt32_Vector256(Vector256.LoadUnsafe(ref sourceRef, (uint)i));
93+
Vector256<uint> upper = SingleToHalfAsWidenedUInt32_Vector256(Vector256.LoadUnsafe(ref sourceRef, (uint)(i + Vector256<float>.Count)));
94+
Vector256.Narrow(lower, upper).StoreUnsafe(ref destinationRef, (uint)i);
95+
}
96+
97+
return;
98+
}
99+
}
100+
101+
if (Vector128.IsHardwareAccelerated)
102+
{
103+
twoVectorsFromEnd = source.Length - (Vector128<float>.Count * 2);
104+
if (i <= twoVectorsFromEnd)
105+
{
106+
// Loop handling two input vectors / one output vector at a time.
107+
do
108+
{
109+
Vector128<uint> lower = SingleToHalfAsWidenedUInt32_Vector128(Vector128.LoadUnsafe(ref sourceRef, (uint)i));
110+
Vector128<uint> upper = SingleToHalfAsWidenedUInt32_Vector128(Vector128.LoadUnsafe(ref sourceRef, (uint)(i + Vector128<float>.Count)));
111+
Vector128.Narrow(lower, upper).StoreUnsafe(ref destinationRef, (uint)i);
112+
113+
i += Vector128<float>.Count * 2;
114+
}
115+
while (i <= twoVectorsFromEnd);
116+
117+
// Handle any remaining elements with final vectors.
118+
if (i != source.Length)
119+
{
120+
i = source.Length - (Vector128<float>.Count * 2);
121+
122+
Vector128<uint> lower = SingleToHalfAsWidenedUInt32_Vector128(Vector128.LoadUnsafe(ref sourceRef, (uint)i));
123+
Vector128<uint> upper = SingleToHalfAsWidenedUInt32_Vector128(Vector128.LoadUnsafe(ref sourceRef, (uint)(i + Vector128<float>.Count)));
124+
Vector128.Narrow(lower, upper).StoreUnsafe(ref destinationRef, (uint)i);
125+
}
126+
127+
return;
128+
}
129+
}
130+
131+
while (i < source.Length)
132+
{
133+
Unsafe.Add(ref destinationRef, i) = BitConverter.HalfToUInt16Bits((Half)Unsafe.Add(ref sourceRef, i));
134+
i++;
135+
}
136+
137+
// This implements a vectorized version of the `explicit operator Half(float value) operator`.
138+
// See detailed description of the algorithm used here:
139+
// https://github.com/dotnet/runtime/blob/ca8d6f0420096831766ec11c7d400e4f7ccc7a34/src/libraries/System.Private.CoreLib/src/System/Half.cs#L606-L714
140+
// The cast operator converts a float to a Half represented as a UInt32, then narrows to a UInt16, and reinterpret casts to Half.
141+
// This does the same, with an input VectorXx<float> and an output VectorXx<uint>.
142+
// Loop handling two input vectors at a time; each input float is double the size of each output Half,
143+
// so we need two vectors of floats to produce one vector of Halfs. Half isn't supported in VectorXx<T>,
144+
// so we convert the VectorXx<float> to a VectorXx<uint>, and the caller then uses this twice, narrows the combination
145+
// into a VectorXx<ushort>, and then saves that out to the destination `ref Half` reinterpreted as `ref ushort`.
146+
147+
#pragma warning disable IDE0059 // https://github.com/dotnet/roslyn/issues/44948
148+
const uint MinExp = 0x3880_0000u; // Minimum exponent for rounding
149+
const uint Exponent126 = 0x3f00_0000u; // Exponent displacement #1
150+
const uint SingleBiasedExponentMask = 0x7F80_0000; // float.BiasedExponentMask; // Exponent mask
151+
const uint Exponent13 = 0x0680_0000u; // Exponent displacement #2
152+
const float MaxHalfValueBelowInfinity = 65520.0f; // Maximum value that is not Infinity in Half
153+
const uint ExponentMask = 0x7C00; // Mask for exponent bits in Half
154+
const uint SingleSignMask = 0x8000_0000u; // float.SignMask; // Mask for sign bit in float
155+
#pragma warning restore IDE0059
156+
157+
static Vector128<uint> SingleToHalfAsWidenedUInt32_Vector128(Vector128<float> value)
158+
{
159+
Vector128<uint> bitValue = value.AsUInt32();
160+
161+
// Extract sign bit
162+
Vector128<uint> sign = Vector128.ShiftRightLogical(bitValue & Vector128.Create(SingleSignMask), 16);
163+
164+
// Detecting NaN (0u if value is NaN; otherwise, ~0u)
165+
Vector128<uint> realMask = Vector128.Equals(value, value).AsUInt32();
166+
167+
// Clear sign bit
168+
value = Vector128.Abs(value);
169+
170+
// Rectify values that are Infinity in Half.
171+
value = Vector128.Min(Vector128.Create(MaxHalfValueBelowInfinity), value);
172+
173+
// Rectify lower exponent
174+
Vector128<uint> exponentOffset0 = Vector128.Max(value, Vector128.Create(MinExp).AsSingle()).AsUInt32();
175+
176+
// Extract exponent
177+
exponentOffset0 &= Vector128.Create(SingleBiasedExponentMask);
178+
179+
// Add exponent by 13
180+
exponentOffset0 += Vector128.Create(Exponent13);
181+
182+
// Round Single into Half's precision (NaN also gets modified here, just setting the MSB of fraction)
183+
value += exponentOffset0.AsSingle();
184+
bitValue = value.AsUInt32();
185+
186+
// Only exponent bits will be modified if NaN
187+
Vector128<uint> maskedHalfExponentForNaN = ~realMask & Vector128.Create(ExponentMask);
188+
189+
// Subtract exponent by 126
190+
bitValue -= Vector128.Create(Exponent126);
191+
192+
// Shift bitValue right by 13 bits to match the boundary of exponent part and fraction part.
193+
Vector128<uint> newExponent = Vector128.ShiftRightLogical(bitValue, 13);
194+
195+
// Clear the fraction parts if the value was NaN.
196+
bitValue &= realMask;
197+
198+
// Merge the exponent part with fraction part, and add the exponent part and fraction part's overflow.
199+
bitValue += newExponent;
200+
201+
// Clear exponents if value is NaN
202+
bitValue &= ~maskedHalfExponentForNaN;
203+
204+
// Merge sign bit with possible NaN exponent
205+
Vector128<uint> signAndMaskedExponent = maskedHalfExponentForNaN | sign;
206+
207+
// Merge sign bit and possible NaN exponent
208+
bitValue |= signAndMaskedExponent;
209+
210+
// The final result
211+
return bitValue;
212+
}
213+
214+
static Vector256<uint> SingleToHalfAsWidenedUInt32_Vector256(Vector256<float> value)
215+
{
216+
Vector256<uint> bitValue = value.AsUInt32();
217+
218+
// Extract sign bit
219+
Vector256<uint> sign = Vector256.ShiftRightLogical(bitValue & Vector256.Create(SingleSignMask), 16);
220+
221+
// Detecting NaN (0u if value is NaN; otherwise, ~0u)
222+
Vector256<uint> realMask = Vector256.Equals(value, value).AsUInt32();
223+
224+
// Clear sign bit
225+
value = Vector256.Abs(value);
226+
227+
// Rectify values that are Infinity in Half.
228+
value = Vector256.Min(Vector256.Create(MaxHalfValueBelowInfinity), value);
229+
230+
// Rectify lower exponent
231+
Vector256<uint> exponentOffset0 = Vector256.Max(value, Vector256.Create(MinExp).AsSingle()).AsUInt32();
232+
233+
// Extract exponent
234+
exponentOffset0 &= Vector256.Create(SingleBiasedExponentMask);
235+
236+
// Add exponent by 13
237+
exponentOffset0 += Vector256.Create(Exponent13);
238+
239+
// Round Single into Half's precision (NaN also gets modified here, just setting the MSB of fraction)
240+
value += exponentOffset0.AsSingle();
241+
bitValue = value.AsUInt32();
242+
243+
// Only exponent bits will be modified if NaN
244+
Vector256<uint> maskedHalfExponentForNaN = ~realMask & Vector256.Create(ExponentMask);
245+
246+
// Subtract exponent by 126
247+
bitValue -= Vector256.Create(Exponent126);
248+
249+
// Shift bitValue right by 13 bits to match the boundary of exponent part and fraction part.
250+
Vector256<uint> newExponent = Vector256.ShiftRightLogical(bitValue, 13);
251+
252+
// Clear the fraction parts if the value was NaN.
253+
bitValue &= realMask;
254+
255+
// Merge the exponent part with fraction part, and add the exponent part and fraction part's overflow.
256+
bitValue += newExponent;
257+
258+
// Clear exponents if value is NaN
259+
bitValue &= ~maskedHalfExponentForNaN;
260+
261+
// Merge sign bit with possible NaN exponent
262+
Vector256<uint> signAndMaskedExponent = maskedHalfExponentForNaN | sign;
263+
264+
// Merge sign bit and possible NaN exponent
265+
bitValue |= signAndMaskedExponent;
266+
267+
// The final result
268+
return bitValue;
269+
}
270+
271+
#if NET8_0_OR_GREATER
272+
static Vector512<uint> SingleToHalfAsWidenedUInt32_Vector512(Vector512<float> value)
273+
{
274+
Vector512<uint> bitValue = value.AsUInt32();
275+
276+
// Extract sign bit
277+
Vector512<uint> sign = Vector512.ShiftRightLogical(bitValue & Vector512.Create(SingleSignMask), 16);
278+
279+
// Detecting NaN (0u if value is NaN; otherwise, ~0u)
280+
Vector512<uint> realMask = Vector512.Equals(value, value).AsUInt32();
281+
282+
// Clear sign bit
283+
value = Vector512.Abs(value);
284+
285+
// Rectify values that are Infinity in Half.
286+
value = Vector512.Min(Vector512.Create(MaxHalfValueBelowInfinity), value);
287+
288+
// Rectify lower exponent
289+
Vector512<uint> exponentOffset0 = Vector512.Max(value, Vector512.Create(MinExp).AsSingle()).AsUInt32();
290+
291+
// Extract exponent
292+
exponentOffset0 &= Vector512.Create(SingleBiasedExponentMask);
293+
294+
// Add exponent by 13
295+
exponentOffset0 += Vector512.Create(Exponent13);
296+
297+
// Round Single into Half's precision (NaN also gets modified here, just setting the MSB of fraction)
298+
value += exponentOffset0.AsSingle();
299+
bitValue = value.AsUInt32();
300+
301+
// Only exponent bits will be modified if NaN
302+
Vector512<uint> maskedHalfExponentForNaN = ~realMask & Vector512.Create(ExponentMask);
303+
304+
// Subtract exponent by 126
305+
bitValue -= Vector512.Create(Exponent126);
306+
307+
// Shift bitValue right by 13 bits to match the boundary of exponent part and fraction part.
308+
Vector512<uint> newExponent = Vector512.ShiftRightLogical(bitValue, 13);
309+
310+
// Clear the fraction parts if the value was NaN.
311+
bitValue &= realMask;
312+
313+
// Merge the exponent part with fraction part, and add the exponent part and fraction part's overflow.
314+
bitValue += newExponent;
315+
316+
// Clear exponents if value is NaN
317+
bitValue &= ~maskedHalfExponentForNaN;
318+
319+
// Merge sign bit with possible NaN exponent
320+
Vector512<uint> signAndMaskedExponent = maskedHalfExponentForNaN | sign;
321+
322+
// Merge sign bit and possible NaN exponent
323+
bitValue |= signAndMaskedExponent;
324+
325+
// The final result
326+
return bitValue;
327+
}
328+
#endif
38329
}
39330

40331
/// <summary>

src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.netcore.cs

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ public static void ConvertToHalf(int tensorLength)
1616
using BoundedMemory<float> source = CreateAndFillTensor(tensorLength);
1717
foreach (int destLength in new[] { source.Length, source.Length + 1 })
1818
{
19-
Half[] destination = new Half[destLength];
19+
using BoundedMemory<Half> destination = BoundedMemory.Allocate<Half>(destLength);
20+
destination.Span.Fill(Half.Zero);
2021

2122
TensorPrimitives.ConvertToHalf(source, destination);
2223

@@ -35,6 +36,28 @@ public static void ConvertToHalf(int tensorLength)
3536
}
3637
}
3738

39+
[Theory]
40+
[MemberData(nameof(TensorLengths))]
41+
public static void ConvertToHalf_SpecialValues(int tensorLength)
42+
{
43+
using BoundedMemory<float> source = CreateAndFillTensor(tensorLength);
44+
using BoundedMemory<Half> destination = BoundedMemory.Allocate<Half>(tensorLength);
45+
46+
// NaN, infinities, and 0s
47+
source[s_random.Next(source.Length)] = float.NaN;
48+
source[s_random.Next(source.Length)] = float.PositiveInfinity;
49+
source[s_random.Next(source.Length)] = float.NegativeInfinity;
50+
source[s_random.Next(source.Length)] = 0;
51+
source[s_random.Next(source.Length)] = float.NegativeZero;
52+
53+
TensorPrimitives.ConvertToHalf(source, destination);
54+
55+
for (int i = 0; i < source.Length; i++)
56+
{
57+
Assert.Equal((Half)source[i], destination[i]);
58+
}
59+
}
60+
3861
[Theory]
3962
[MemberData(nameof(TensorLengths))]
4063
public static void ConvertToHalf_ThrowsForTooShortDestination(int tensorLength)
@@ -51,7 +74,7 @@ public static void ConvertToHalf_ThrowsForTooShortDestination(int tensorLength)
5174
[MemberData(nameof(TensorLengthsIncluding0))]
5275
public static void ConvertToSingle(int tensorLength)
5376
{
54-
Half[] source = new Half[tensorLength];
77+
using BoundedMemory<Half> source = BoundedMemory.Allocate<Half>(tensorLength);
5578
for (int i = 0; i < source.Length; i++)
5679
{
5780
source[i] = (Half)s_random.NextSingle();
@@ -78,6 +101,32 @@ public static void ConvertToSingle(int tensorLength)
78101
}
79102
}
80103
}
104+
[Theory]
105+
[MemberData(nameof(TensorLengths))]
106+
public static void ConvertToSingle_SpecialValues(int tensorLength)
107+
{
108+
using BoundedMemory<Half> source = BoundedMemory.Allocate<Half>(tensorLength);
109+
for (int i = 0; i < source.Length; i++)
110+
{
111+
source[i] = (Half)s_random.NextSingle();
112+
}
113+
114+
using BoundedMemory<float> destination = CreateTensor(tensorLength);
115+
116+
// NaN, infinities, and 0s
117+
source[s_random.Next(source.Length)] = Half.NaN;
118+
source[s_random.Next(source.Length)] = Half.PositiveInfinity;
119+
source[s_random.Next(source.Length)] = Half.NegativeInfinity;
120+
source[s_random.Next(source.Length)] = Half.Zero;
121+
source[s_random.Next(source.Length)] = Half.NegativeZero;
122+
123+
TensorPrimitives.ConvertToSingle(source, destination);
124+
125+
for (int i = 0; i < source.Length; i++)
126+
{
127+
Assert.Equal((float)source[i], destination[i]);
128+
}
129+
}
81130

82131
[Theory]
83132
[MemberData(nameof(TensorLengths))]

0 commit comments

Comments
 (0)