Skip to content

Commit 781e002

Browse files
This vectorizes TensorPrimitives.Log2 (#92897)
* Add a way to support operations that can't be vectorized on netstandard * Updating TensorPrimitives.Log2 to be vectorized on .NET Core * Update src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs Co-authored-by: Stephen Toub <[email protected]> * Ensure we do an arithmetic right shift in the Log2 vectorization * Ensure the code can compile on .NET 7 * Ensure that edge cases are properly handled and don't resolve to `x` * Ensure that Log2 special results are explicitly handled. --------- Co-authored-by: Stephen Toub <[email protected]>
1 parent f41715c commit 781e002

File tree

3 files changed

+302
-16
lines changed

3 files changed

+302
-16
lines changed

src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -598,20 +598,8 @@ public static void Log(ReadOnlySpan<float> x, Span<float> destination)
598598
/// operating systems or architectures.
599599
/// </para>
600600
/// </remarks>
601-
public static void Log2(ReadOnlySpan<float> x, Span<float> destination)
602-
{
603-
if (x.Length > destination.Length)
604-
{
605-
ThrowHelper.ThrowArgument_DestinationTooShort();
606-
}
607-
608-
ValidateInputOutputSpanNonOverlapping(x, destination);
609-
610-
for (int i = 0; i < x.Length; i++)
611-
{
612-
destination[i] = Log2(x[i]);
613-
}
614-
}
601+
public static void Log2(ReadOnlySpan<float> x, Span<float> destination) =>
602+
InvokeSpanIntoSpan<Log2Operator>(x, destination);
615603

616604
/// <summary>Searches for the largest single-precision floating-point number in the specified tensor.</summary>
617605
/// <param name="x">The tensor, represented as a span.</param>

src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs

Lines changed: 280 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2579,6 +2579,286 @@ public static Vector512<float> Invoke(Vector512<float> x, Vector512<float> y)
25792579
#endif
25802580
}
25812581

2582+
private readonly struct Log2Operator : IUnaryOperator
2583+
{
2584+
// This code is based on `vrs4_log2f` from amd/aocl-libm-ose
2585+
// Copyright (C) 2021-2022 Advanced Micro Devices, Inc. All rights reserved.
2586+
//
2587+
// Licensed under the BSD 3-Clause "New" or "Revised" License
2588+
// See THIRD-PARTY-NOTICES.TXT for the full license text
2589+
2590+
// Spec:
2591+
// log2f(x)
2592+
// = log2f(x) if x ∈ F and x > 0
2593+
// = x if x = qNaN
2594+
// = 0 if x = 1
2595+
// = -inf if x = (-0, 0}
2596+
// = NaN otherwise
2597+
//
2598+
// Assumptions/Expectations
2599+
// - Maximum ULP is observed to be at 4
2600+
// - Some FPU Exceptions may not be available
2601+
// - Performance is at least 3x
2602+
//
2603+
// Implementation Notes:
2604+
// 1. Range Reduction:
2605+
// x = 2^n*(1+f) .... (1)
2606+
// where n is exponent and is an integer
2607+
// (1+f) is mantissa ∈ [1,2). i.e., 1 ≤ 1+f < 2 .... (2)
2608+
//
2609+
// From (1), taking log on both sides
2610+
// log2(x) = log2(2^n * (1+f))
2611+
// = n + log2(1+f) .... (3)
2612+
//
2613+
// let z = 1 + f
2614+
// log2(z) = log2(k) + log2(z) - log2(k)
2615+
// log2(z) = log2(kz) - log2(k)
2616+
//
2617+
// From (2), range of z is [1, 2)
2618+
// by simply dividing range by 'k', z is in [1/k, 2/k) .... (4)
2619+
// Best choice of k is the one which gives equal and opposite values
2620+
// at extrema +- -+
2621+
// 1 | 2 |
2622+
// --- - 1 = - |--- - 1 |
2623+
// k | k | .... (5)
2624+
// +- -+
2625+
//
2626+
// Solving for k, k = 3/2,
2627+
// From (4), using 'k' value, range is therefore [-0.3333, 0.3333]
2628+
//
2629+
// 2. Polynomial Approximation:
2630+
// More information refer to tools/sollya/vrs4_logf.sollya
2631+
//
2632+
// 7th Deg - Error abs: 0x1.04c4ac98p-22 rel: 0x1.2216e6f8p-19
2633+
2634+
private const uint V_MIN = 0x00800000;
2635+
private const uint V_MAX = 0x7F800000;
2636+
private const uint V_MASK = 0x007FFFFF;
2637+
private const uint V_OFF = 0x3F2AAAAB;
2638+
2639+
private const float C0 = 0.0f;
2640+
private const float C1 = 1.4426951f;
2641+
private const float C2 = -0.72134554f;
2642+
private const float C3 = 0.48089063f;
2643+
private const float C4 = -0.36084408f;
2644+
private const float C5 = 0.2888971f;
2645+
private const float C6 = -0.23594281f;
2646+
private const float C7 = 0.19948183f;
2647+
private const float C8 = -0.22616665f;
2648+
private const float C9 = 0.21228963f;
2649+
2650+
public static float Invoke(float x) => MathF.Log2(x);
2651+
2652+
public static Vector128<float> Invoke(Vector128<float> x)
2653+
{
2654+
Vector128<float> specialResult = x;
2655+
2656+
// x is subnormal or infinity or NaN
2657+
Vector128<uint> specialMask = Vector128.GreaterThanOrEqual(x.AsUInt32() - Vector128.Create(V_MIN), Vector128.Create(V_MAX - V_MIN));
2658+
2659+
if (specialMask != Vector128<uint>.Zero)
2660+
{
2661+
// float.IsZero(x) ? float.NegativeInfinity : x
2662+
Vector128<float> zeroMask = Vector128.Equals(x, Vector128<float>.Zero);
2663+
2664+
specialResult = Vector128.ConditionalSelect(
2665+
zeroMask,
2666+
Vector128.Create(float.NegativeInfinity),
2667+
specialResult
2668+
);
2669+
2670+
// (x < 0) ? float.NaN : x
2671+
Vector128<float> lessThanZeroMask = Vector128.LessThan(x, Vector128<float>.Zero);
2672+
2673+
specialResult = Vector128.ConditionalSelect(
2674+
lessThanZeroMask,
2675+
Vector128.Create(float.NaN),
2676+
specialResult
2677+
);
2678+
2679+
// float.IsZero(x) | (x < 0) | float.IsNaN(x) | float.IsPositiveInfinity(x)
2680+
Vector128<float> temp = zeroMask
2681+
| lessThanZeroMask
2682+
| ~Vector128.Equals(x, x)
2683+
| Vector128.Equals(x, Vector128.Create(float.PositiveInfinity));
2684+
2685+
// subnormal
2686+
Vector128<float> subnormalMask = Vector128.AndNot(specialMask.AsSingle(), temp);
2687+
2688+
x = Vector128.ConditionalSelect(
2689+
subnormalMask,
2690+
((x * 8388608.0f).AsUInt32() - Vector128.Create(23u << 23)).AsSingle(),
2691+
x
2692+
);
2693+
2694+
specialMask = temp.AsUInt32();
2695+
}
2696+
2697+
Vector128<uint> vx = x.AsUInt32() - Vector128.Create(V_OFF);
2698+
Vector128<float> n = Vector128.ConvertToSingle(Vector128.ShiftRightArithmetic(vx.AsInt32(), 23));
2699+
2700+
vx = (vx & Vector128.Create(V_MASK)) + Vector128.Create(V_OFF);
2701+
2702+
Vector128<float> r = vx.AsSingle() - Vector128.Create(1.0f);
2703+
2704+
Vector128<float> r2 = r * r;
2705+
Vector128<float> r4 = r2 * r2;
2706+
Vector128<float> r8 = r4 * r4;
2707+
2708+
Vector128<float> poly = (Vector128.Create(C9) * r + Vector128.Create(C8)) * r8
2709+
+ (((Vector128.Create(C7) * r + Vector128.Create(C6)) * r2
2710+
+ (Vector128.Create(C5) * r + Vector128.Create(C4))) * r4
2711+
+ ((Vector128.Create(C3) * r + Vector128.Create(C2)) * r2
2712+
+ (Vector128.Create(C1) * r + Vector128.Create(C0))));
2713+
2714+
return Vector128.ConditionalSelect(
2715+
specialMask.AsSingle(),
2716+
specialResult,
2717+
n + poly
2718+
);
2719+
}
2720+
2721+
public static Vector256<float> Invoke(Vector256<float> x)
2722+
{
2723+
Vector256<float> specialResult = x;
2724+
2725+
// x is subnormal or infinity or NaN
2726+
Vector256<uint> specialMask = Vector256.GreaterThanOrEqual(x.AsUInt32() - Vector256.Create(V_MIN), Vector256.Create(V_MAX - V_MIN));
2727+
2728+
if (specialMask != Vector256<uint>.Zero)
2729+
{
2730+
// float.IsZero(x) ? float.NegativeInfinity : x
2731+
Vector256<float> zeroMask = Vector256.Equals(x, Vector256<float>.Zero);
2732+
2733+
specialResult = Vector256.ConditionalSelect(
2734+
zeroMask,
2735+
Vector256.Create(float.NegativeInfinity),
2736+
specialResult
2737+
);
2738+
2739+
// (x < 0) ? float.NaN : x
2740+
Vector256<float> lessThanZeroMask = Vector256.LessThan(x, Vector256<float>.Zero);
2741+
2742+
specialResult = Vector256.ConditionalSelect(
2743+
lessThanZeroMask,
2744+
Vector256.Create(float.NaN),
2745+
specialResult
2746+
);
2747+
2748+
// float.IsZero(x) | (x < 0) | float.IsNaN(x) | float.IsPositiveInfinity(x)
2749+
Vector256<float> temp = zeroMask
2750+
| lessThanZeroMask
2751+
| ~Vector256.Equals(x, x)
2752+
| Vector256.Equals(x, Vector256.Create(float.PositiveInfinity));
2753+
2754+
// subnormal
2755+
Vector256<float> subnormalMask = Vector256.AndNot(specialMask.AsSingle(), temp);
2756+
2757+
x = Vector256.ConditionalSelect(
2758+
subnormalMask,
2759+
((x * 8388608.0f).AsUInt32() - Vector256.Create(23u << 23)).AsSingle(),
2760+
x
2761+
);
2762+
2763+
specialMask = temp.AsUInt32();
2764+
}
2765+
2766+
Vector256<uint> vx = x.AsUInt32() - Vector256.Create(V_OFF);
2767+
Vector256<float> n = Vector256.ConvertToSingle(Vector256.ShiftRightArithmetic(vx.AsInt32(), 23));
2768+
2769+
vx = (vx & Vector256.Create(V_MASK)) + Vector256.Create(V_OFF);
2770+
2771+
Vector256<float> r = vx.AsSingle() - Vector256.Create(1.0f);
2772+
2773+
Vector256<float> r2 = r * r;
2774+
Vector256<float> r4 = r2 * r2;
2775+
Vector256<float> r8 = r4 * r4;
2776+
2777+
Vector256<float> poly = (Vector256.Create(C9) * r + Vector256.Create(C8)) * r8
2778+
+ (((Vector256.Create(C7) * r + Vector256.Create(C6)) * r2
2779+
+ (Vector256.Create(C5) * r + Vector256.Create(C4))) * r4
2780+
+ ((Vector256.Create(C3) * r + Vector256.Create(C2)) * r2
2781+
+ (Vector256.Create(C1) * r + Vector256.Create(C0))));
2782+
2783+
return Vector256.ConditionalSelect(
2784+
specialMask.AsSingle(),
2785+
specialResult,
2786+
n + poly
2787+
);
2788+
}
2789+
2790+
#if NET8_0_OR_GREATER
2791+
public static Vector512<float> Invoke(Vector512<float> x)
2792+
{
2793+
Vector512<float> specialResult = x;
2794+
2795+
// x is subnormal or infinity or NaN
2796+
Vector512<uint> specialMask = Vector512.GreaterThanOrEqual(x.AsUInt32() - Vector512.Create(V_MIN), Vector512.Create(V_MAX - V_MIN));
2797+
2798+
if (specialMask != Vector512<uint>.Zero)
2799+
{
2800+
// float.IsZero(x) ? float.NegativeInfinity : x
2801+
Vector512<float> zeroMask = Vector512.Equals(x, Vector512<float>.Zero);
2802+
2803+
specialResult = Vector512.ConditionalSelect(
2804+
zeroMask,
2805+
Vector512.Create(float.NegativeInfinity),
2806+
specialResult
2807+
);
2808+
2809+
// (x < 0) ? float.NaN : x
2810+
Vector512<float> lessThanZeroMask = Vector512.LessThan(x, Vector512<float>.Zero);
2811+
2812+
specialResult = Vector512.ConditionalSelect(
2813+
lessThanZeroMask,
2814+
Vector512.Create(float.NaN),
2815+
specialResult
2816+
);
2817+
2818+
// float.IsZero(x) | (x < 0) | float.IsNaN(x) | float.IsPositiveInfinity(x)
2819+
Vector512<float> temp = zeroMask
2820+
| lessThanZeroMask
2821+
| ~Vector512.Equals(x, x)
2822+
| Vector512.Equals(x, Vector512.Create(float.PositiveInfinity));
2823+
2824+
// subnormal
2825+
Vector512<float> subnormalMask = Vector512.AndNot(specialMask.AsSingle(), temp);
2826+
2827+
x = Vector512.ConditionalSelect(
2828+
subnormalMask,
2829+
((x * 8388608.0f).AsUInt32() - Vector512.Create(23u << 23)).AsSingle(),
2830+
x
2831+
);
2832+
2833+
specialMask = temp.AsUInt32();
2834+
}
2835+
2836+
Vector512<uint> vx = x.AsUInt32() - Vector512.Create(V_OFF);
2837+
Vector512<float> n = Vector512.ConvertToSingle(Vector512.ShiftRightArithmetic(vx.AsInt32(), 23));
2838+
2839+
vx = (vx & Vector512.Create(V_MASK)) + Vector512.Create(V_OFF);
2840+
2841+
Vector512<float> r = vx.AsSingle() - Vector512.Create(1.0f);
2842+
2843+
Vector512<float> r2 = r * r;
2844+
Vector512<float> r4 = r2 * r2;
2845+
Vector512<float> r8 = r4 * r4;
2846+
2847+
Vector512<float> poly = (Vector512.Create(C9) * r + Vector512.Create(C8)) * r8
2848+
+ (((Vector512.Create(C7) * r + Vector512.Create(C6)) * r2
2849+
+ (Vector512.Create(C5) * r + Vector512.Create(C4))) * r4
2850+
+ ((Vector512.Create(C3) * r + Vector512.Create(C2)) * r2
2851+
+ (Vector512.Create(C1) * r + Vector512.Create(C0))));
2852+
2853+
return Vector512.ConditionalSelect(
2854+
specialMask.AsSingle(),
2855+
specialResult,
2856+
n + poly
2857+
);
2858+
}
2859+
#endif
2860+
}
2861+
25822862
private interface IUnaryOperator
25832863
{
25842864
static abstract float Invoke(float x);

src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ private static float Aggregate<TLoad, TAggregate>(
9797

9898
float result;
9999

100-
if (Vector.IsHardwareAccelerated && x.Length >= Vector<float>.Count)
100+
if (Vector.IsHardwareAccelerated && load.CanVectorize && x.Length >= Vector<float>.Count)
101101
{
102102
ref float xRef = ref MemoryMarshal.GetReference(x);
103103

@@ -304,7 +304,7 @@ private static void InvokeSpanIntoSpan<TUnaryOperator>(
304304
ref float dRef = ref MemoryMarshal.GetReference(destination);
305305
int i = 0, oneVectorFromEnd;
306306

307-
if (Vector.IsHardwareAccelerated)
307+
if (Vector.IsHardwareAccelerated && op.CanVectorize)
308308
{
309309
oneVectorFromEnd = x.Length - Vector<float>.Count;
310310
if (oneVectorFromEnd >= 0)
@@ -885,6 +885,7 @@ public Vector<float> Invoke(Vector<float> x, Vector<float> y)
885885

886886
private readonly struct NegateOperator : IUnaryOperator
887887
{
888+
public bool CanVectorize => true;
888889
public float Invoke(float x) => -x;
889890
public Vector<float> Invoke(Vector<float> x) => -x;
890891
}
@@ -903,24 +904,41 @@ public Vector<float> Invoke(Vector<float> x, Vector<float> y)
903904

904905
private readonly struct IdentityOperator : IUnaryOperator
905906
{
907+
public bool CanVectorize => true;
906908
public float Invoke(float x) => x;
907909
public Vector<float> Invoke(Vector<float> x) => x;
908910
}
909911

910912
private readonly struct SquaredOperator : IUnaryOperator
911913
{
914+
public bool CanVectorize => true;
912915
public float Invoke(float x) => x * x;
913916
public Vector<float> Invoke(Vector<float> x) => x * x;
914917
}
915918

916919
private readonly struct AbsoluteOperator : IUnaryOperator
917920
{
921+
public bool CanVectorize => true;
918922
public float Invoke(float x) => MathF.Abs(x);
919923
public Vector<float> Invoke(Vector<float> x) => Vector.Abs(x);
920924
}
921925

926+
private readonly struct Log2Operator : IUnaryOperator
927+
{
928+
public bool CanVectorize => false;
929+
930+
public float Invoke(float x) => Log2(x);
931+
932+
public Vector<float> Invoke(Vector<float> x)
933+
{
934+
// Vectorizing requires shift right support, which is .NET 7 or later
935+
throw new NotImplementedException();
936+
}
937+
}
938+
922939
private interface IUnaryOperator
923940
{
941+
bool CanVectorize { get; }
924942
float Invoke(float x);
925943
Vector<float> Invoke(Vector<float> x);
926944
}

0 commit comments

Comments
 (0)