Skip to content

Commit bd57689

Browse files
stephentoubmichaelgsharp
authored andcommitted
Use FMA in TensorPrimitives (dotnet#92205)
1 parent 86c9493 commit bd57689

File tree

2 files changed

+79
-36
lines changed

2 files changed

+79
-36
lines changed

src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs

Lines changed: 59 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
// Licensed to the .NET Foundation under one or more agreements.
22
// The .NET Foundation licenses this file to you under the MIT license.
33

4+
using System.Diagnostics;
45
using System.Runtime.CompilerServices;
56
using System.Runtime.InteropServices;
67
using System.Runtime.Intrinsics;
8+
using System.Runtime.Intrinsics.Arm;
9+
using System.Runtime.Intrinsics.X86;
710

811
namespace System.Numerics.Tensors
912
{
@@ -86,9 +89,9 @@ private static float CosineSimilarityCore(ReadOnlySpan<float> x, ReadOnlySpan<fl
8689
Vector512<float> xVec = Vector512.LoadUnsafe(ref xRef, (uint)i);
8790
Vector512<float> yVec = Vector512.LoadUnsafe(ref yRef, (uint)i);
8891

89-
dotProductVector += xVec * yVec;
90-
xSumOfSquaresVector += xVec * xVec;
91-
ySumOfSquaresVector += yVec * yVec;
92+
dotProductVector = FusedMultiplyAdd(xVec, yVec, dotProductVector);
93+
xSumOfSquaresVector = FusedMultiplyAdd(xVec, xVec, xSumOfSquaresVector);
94+
ySumOfSquaresVector = FusedMultiplyAdd(yVec, yVec, ySumOfSquaresVector);
9295

9396
i += Vector512<float>.Count;
9497
}
@@ -117,9 +120,9 @@ private static float CosineSimilarityCore(ReadOnlySpan<float> x, ReadOnlySpan<fl
117120
Vector256<float> xVec = Vector256.LoadUnsafe(ref xRef, (uint)i);
118121
Vector256<float> yVec = Vector256.LoadUnsafe(ref yRef, (uint)i);
119122

120-
dotProductVector += xVec * yVec;
121-
xSumOfSquaresVector += xVec * xVec;
122-
ySumOfSquaresVector += yVec * yVec;
123+
dotProductVector = FusedMultiplyAdd(xVec, yVec, dotProductVector);
124+
xSumOfSquaresVector = FusedMultiplyAdd(xVec, xVec, xSumOfSquaresVector);
125+
ySumOfSquaresVector = FusedMultiplyAdd(yVec, yVec, ySumOfSquaresVector);
123126

124127
i += Vector256<float>.Count;
125128
}
@@ -146,9 +149,9 @@ private static float CosineSimilarityCore(ReadOnlySpan<float> x, ReadOnlySpan<fl
146149
Vector128<float> xVec = Vector128.LoadUnsafe(ref xRef, (uint)i);
147150
Vector128<float> yVec = Vector128.LoadUnsafe(ref yRef, (uint)i);
148151

149-
dotProductVector += xVec * yVec;
150-
xSumOfSquaresVector += xVec * xVec;
151-
ySumOfSquaresVector += yVec * yVec;
152+
dotProductVector = FusedMultiplyAdd(xVec, yVec, dotProductVector);
153+
xSumOfSquaresVector = FusedMultiplyAdd(xVec, xVec, xSumOfSquaresVector);
154+
ySumOfSquaresVector = FusedMultiplyAdd(yVec, yVec, ySumOfSquaresVector);
152155

153156
i += Vector128<float>.Count;
154157
}
@@ -163,9 +166,9 @@ private static float CosineSimilarityCore(ReadOnlySpan<float> x, ReadOnlySpan<fl
163166
// Process any remaining elements past the last vector.
164167
for (; (uint)i < (uint)x.Length; i++)
165168
{
166-
dotProduct += x[i] * y[i];
167-
xSumOfSquares += x[i] * x[i];
168-
ySumOfSquares += y[i] * y[i];
169+
dotProduct = MathF.FusedMultiplyAdd(x[i], y[i], dotProduct);
170+
xSumOfSquares = MathF.FusedMultiplyAdd(x[i], x[i], xSumOfSquares);
171+
ySumOfSquares = MathF.FusedMultiplyAdd(y[i], y[i], ySumOfSquares);
169172
}
170173

171174
// Sum(X * Y) / (|X| * |Y|)
@@ -1032,6 +1035,46 @@ private static unsafe void InvokeSpanScalarSpanIntoSpan<TTernaryOperator>(
10321035
}
10331036
}
10341037

1038+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
1039+
private static Vector128<float> FusedMultiplyAdd(Vector128<float> x, Vector128<float> y, Vector128<float> addend)
1040+
{
1041+
if (Fma.IsSupported)
1042+
{
1043+
return Fma.MultiplyAdd(x, y, addend);
1044+
}
1045+
1046+
if (AdvSimd.IsSupported)
1047+
{
1048+
return AdvSimd.FusedMultiplyAdd(addend, x, y);
1049+
}
1050+
1051+
return (x * y) + addend;
1052+
}
1053+
1054+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
1055+
private static Vector256<float> FusedMultiplyAdd(Vector256<float> x, Vector256<float> y, Vector256<float> addend)
1056+
{
1057+
if (Fma.IsSupported)
1058+
{
1059+
return Fma.MultiplyAdd(x, y, addend);
1060+
}
1061+
1062+
return (x * y) + addend;
1063+
}
1064+
1065+
#if NET8_0_OR_GREATER
1066+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
1067+
private static Vector512<float> FusedMultiplyAdd(Vector512<float> x, Vector512<float> y, Vector512<float> addend)
1068+
{
1069+
if (Avx512F.IsSupported)
1070+
{
1071+
return Avx512F.FusedMultiplyAdd(x, y, addend);
1072+
}
1073+
1074+
return (x * y) + addend;
1075+
}
1076+
#endif
1077+
10351078
private readonly struct AddOperator : IBinaryOperator
10361079
{
10371080
public static float Invoke(float x, float y) => x + y;
@@ -1182,11 +1225,11 @@ public static float Invoke(Vector512<float> x)
11821225

11831226
private readonly struct MultiplyAddOperator : ITernaryOperator
11841227
{
1185-
public static float Invoke(float x, float y, float z) => (x * y) + z;
1186-
public static Vector128<float> Invoke(Vector128<float> x, Vector128<float> y, Vector128<float> z) => (x * y) + z;
1187-
public static Vector256<float> Invoke(Vector256<float> x, Vector256<float> y, Vector256<float> z) => (x * y) + z;
1228+
public static float Invoke(float x, float y, float z) => MathF.FusedMultiplyAdd(x, y, z);
1229+
public static Vector128<float> Invoke(Vector128<float> x, Vector128<float> y, Vector128<float> z) => FusedMultiplyAdd(x, y, z);
1230+
public static Vector256<float> Invoke(Vector256<float> x, Vector256<float> y, Vector256<float> z) => FusedMultiplyAdd(x, y, z);
11881231
#if NET8_0_OR_GREATER
1189-
public static Vector512<float> Invoke(Vector512<float> x, Vector512<float> y, Vector512<float> z) => (x * y) + z;
1232+
public static Vector512<float> Invoke(Vector512<float> x, Vector512<float> y, Vector512<float> z) => FusedMultiplyAdd(x, y, z);
11901233
#endif
11911234
}
11921235

src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ public static void AddTwoTensors(int tensorLength)
5959

6060
for (int i = 0; i < tensorLength; i++)
6161
{
62-
Assert.Equal(x[i] + y[i], destination[i]);
62+
Assert.Equal(x[i] + y[i], destination[i], Tolerance);
6363
}
6464
}
6565

@@ -97,7 +97,7 @@ public static void AddTensorAndScalar(int tensorLength)
9797

9898
for (int i = 0; i < tensorLength; i++)
9999
{
100-
Assert.Equal(x[i] + y, destination[i]);
100+
Assert.Equal(x[i] + y, destination[i], Tolerance);
101101
}
102102
}
103103

@@ -124,7 +124,7 @@ public static void SubtractTwoTensors(int tensorLength)
124124

125125
for (int i = 0; i < tensorLength; i++)
126126
{
127-
Assert.Equal(x[i] - y[i], destination[i]);
127+
Assert.Equal(x[i] - y[i], destination[i], Tolerance);
128128
}
129129
}
130130

@@ -162,7 +162,7 @@ public static void SubtractTensorAndScalar(int tensorLength)
162162

163163
for (int i = 0; i < tensorLength; i++)
164164
{
165-
Assert.Equal(x[i] - y, destination[i]);
165+
Assert.Equal(x[i] - y, destination[i], Tolerance);
166166
}
167167
}
168168

@@ -189,7 +189,7 @@ public static void MultiplyTwoTensors(int tensorLength)
189189

190190
for (int i = 0; i < tensorLength; i++)
191191
{
192-
Assert.Equal(x[i] * y[i], destination[i]);
192+
Assert.Equal(x[i] * y[i], destination[i], Tolerance);
193193
}
194194
}
195195

@@ -227,7 +227,7 @@ public static void MultiplyTensorAndScalar(int tensorLength)
227227

228228
for (int i = 0; i < tensorLength; i++)
229229
{
230-
Assert.Equal(x[i] * y, destination[i]);
230+
Assert.Equal(x[i] * y, destination[i], Tolerance);
231231
}
232232
}
233233

@@ -254,7 +254,7 @@ public static void DivideTwoTensors(int tensorLength)
254254

255255
for (int i = 0; i < tensorLength; i++)
256256
{
257-
Assert.Equal(x[i] / y[i], destination[i]);
257+
Assert.Equal(x[i] / y[i], destination[i], Tolerance);
258258
}
259259
}
260260

@@ -292,7 +292,7 @@ public static void DivideTensorAndScalar(int tensorLength)
292292

293293
for (int i = 0; i < tensorLength; i++)
294294
{
295-
Assert.Equal(x[i] / y, destination[i]);
295+
Assert.Equal(x[i] / y, destination[i], Tolerance);
296296
}
297297
}
298298

@@ -318,7 +318,7 @@ public static void NegateTensor(int tensorLength)
318318

319319
for (int i = 0; i < tensorLength; i++)
320320
{
321-
Assert.Equal(-x[i], destination[i]);
321+
Assert.Equal(-x[i], destination[i], Tolerance);
322322
}
323323
}
324324

@@ -345,7 +345,7 @@ public static void AddTwoTensorsAndMultiplyWithThirdTensor(int tensorLength)
345345

346346
for (int i = 0; i < tensorLength; i++)
347347
{
348-
Assert.Equal((x[i] + y[i]) * multiplier[i], destination[i]);
348+
Assert.Equal((x[i] + y[i]) * multiplier[i], destination[i], Tolerance);
349349
}
350350
}
351351

@@ -398,7 +398,7 @@ public static void AddTwoTensorsAndMultiplyWithScalar(int tensorLength)
398398

399399
for (int i = 0; i < tensorLength; i++)
400400
{
401-
Assert.Equal((x[i] + y[i]) * multiplier, destination[i]);
401+
Assert.Equal((x[i] + y[i]) * multiplier, destination[i], Tolerance);
402402
}
403403
}
404404

@@ -439,7 +439,7 @@ public static void AddTensorAndScalarAndMultiplyWithTensor(int tensorLength)
439439

440440
for (int i = 0; i < tensorLength; i++)
441441
{
442-
Assert.Equal((x[i] + y) * multiplier[i], destination[i]);
442+
Assert.Equal((x[i] + y) * multiplier[i], destination[i], Tolerance);
443443
}
444444
}
445445

@@ -480,7 +480,7 @@ public static void MultiplyTwoTensorsAndAddWithThirdTensor(int tensorLength)
480480

481481
for (int i = 0; i < tensorLength; i++)
482482
{
483-
Assert.Equal((x[i] * y[i]) + addend[i], destination[i]);
483+
Assert.Equal((x[i] * y[i]) + addend[i], destination[i], Tolerance);
484484
}
485485
}
486486

@@ -533,7 +533,7 @@ public static void MultiplyTwoTensorsAndAddWithScalar(int tensorLength)
533533

534534
for (int i = 0; i < tensorLength; i++)
535535
{
536-
Assert.Equal((x[i] * y[i]) + addend, destination[i]);
536+
Assert.Equal((x[i] * y[i]) + addend, destination[i], Tolerance);
537537
}
538538
}
539539

@@ -562,7 +562,7 @@ public static void MultiplyTensorAndScalarAndAddWithTensor(int tensorLength)
562562

563563
for (int i = 0; i < tensorLength; i++)
564564
{
565-
Assert.Equal((x[i] * y) + addend[i], destination[i]);
565+
Assert.Equal((x[i] * y) + addend[i], destination[i], Tolerance);
566566
}
567567
}
568568

@@ -589,7 +589,7 @@ public static void ExpTensor(int tensorLength)
589589

590590
for (int i = 0; i < tensorLength; i++)
591591
{
592-
Assert.Equal(MathF.Exp(x[i]), destination[i]);
592+
Assert.Equal(MathF.Exp(x[i]), destination[i], Tolerance);
593593
}
594594
}
595595

@@ -614,7 +614,7 @@ public static void LogTensor(int tensorLength)
614614

615615
for (int i = 0; i < tensorLength; i++)
616616
{
617-
Assert.Equal(MathF.Log(x[i]), destination[i]);
617+
Assert.Equal(MathF.Log(x[i]), destination[i], Tolerance);
618618
}
619619
}
620620

@@ -664,7 +664,7 @@ public static void CoshTensor(int tensorLength)
664664

665665
for (int i = 0; i < tensorLength; i++)
666666
{
667-
Assert.Equal(MathF.Cosh(x[i]), destination[i]);
667+
Assert.Equal(MathF.Cosh(x[i]), destination[i], Tolerance);
668668
}
669669
}
670670

@@ -689,7 +689,7 @@ public static void SinhTensor(int tensorLength)
689689

690690
for (int i = 0; i < tensorLength; i++)
691691
{
692-
Assert.Equal(MathF.Sinh(x[i]), destination[i]);
692+
Assert.Equal(MathF.Sinh(x[i]), destination[i], Tolerance);
693693
}
694694
}
695695

@@ -714,7 +714,7 @@ public static void TanhTensor(int tensorLength)
714714

715715
for (int i = 0; i < tensorLength; i++)
716716
{
717-
Assert.Equal(MathF.Tanh(x[i]), destination[i]);
717+
Assert.Equal(MathF.Tanh(x[i]), destination[i], Tolerance);
718718
}
719719
}
720720

0 commit comments

Comments
 (0)