Skip to content

Commit ec9762c

Browse files
tannergoodingmichaelgsharp
authored andcommitted
Adding a vectorized implementation of TensorPrimitives.Log (dotnet#92960)
* Adding a vectorized implementation of TensorPrimitives.Log * Make sure to hit Ctrl+S
1 parent 2091662 commit ec9762c

File tree

4 files changed

+335
-14
lines changed

4 files changed

+335
-14
lines changed

src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -563,20 +563,8 @@ public static unsafe int IndexOfMinMagnitude(ReadOnlySpan<float> x)
563563
/// operating systems or architectures.
564564
/// </para>
565565
/// </remarks>
566-
public static void Log(ReadOnlySpan<float> x, Span<float> destination)
567-
{
568-
if (x.Length > destination.Length)
569-
{
570-
ThrowHelper.ThrowArgument_DestinationTooShort();
571-
}
572-
573-
ValidateInputOutputSpanNonOverlapping(x, destination);
574-
575-
for (int i = 0; i < x.Length; i++)
576-
{
577-
destination[i] = MathF.Log(x[i]);
578-
}
579-
}
566+
public static void Log(ReadOnlySpan<float> x, Span<float> destination) =>
567+
InvokeSpanIntoSpan<LogOperator>(x, destination);
580568

581569
/// <summary>Computes the element-wise base 2 logarithm of single-precision floating-point numbers in the specified tensor.</summary>
582570
/// <param name="x">The tensor, represented as a span.</param>

src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs

Lines changed: 285 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2579,6 +2579,291 @@ public static Vector512<float> Invoke(Vector512<float> x, Vector512<float> y)
25792579
#endif
25802580
}
25812581

2582+
private readonly struct LogOperator : IUnaryOperator
2583+
{
2584+
// This code is based on `vrs4_logf` from amd/aocl-libm-ose
2585+
// Copyright (C) 2018-2019 Advanced Micro Devices, Inc. All rights reserved.
2586+
//
2587+
// Licensed under the BSD 3-Clause "New" or "Revised" License
2588+
// See THIRD-PARTY-NOTICES.TXT for the full license text
2589+
2590+
// Spec:
2591+
// logf(x)
2592+
// = logf(x) if x ∈ F and x > 0
2593+
// = x if x = qNaN
2594+
// = 0 if x = 1
2595+
// = -inf if x = (-0, 0}
2596+
// = NaN otherwise
2597+
//
2598+
// Assumptions/Expectations
2599+
// - ULP is derived to be << 4 (always)
2600+
// - Some FPU Exceptions may not be available
2601+
// - Performance is at least 3x
2602+
//
2603+
// Implementation Notes:
2604+
// 1. Range Reduction:
2605+
// x = 2^n*(1+f) .... (1)
2606+
// where n is exponent and is an integer
2607+
// (1+f) is mantissa ∈ [1,2). i.e., 1 ≤ 1+f < 2 .... (2)
2608+
//
2609+
// From (1), taking log on both sides
2610+
// log(x) = log(2^n * (1+f))
2611+
// = log(2^n) + log(1+f)
2612+
// = n*log(2) + log(1+f) .... (3)
2613+
//
2614+
// let z = 1 + f
2615+
// log(z) = log(k) + log(z) - log(k)
2616+
// log(z) = log(kz) - log(k)
2617+
//
2618+
// From (2), range of z is [1, 2)
2619+
// by simply dividing range by 'k', z is in [1/k, 2/k) .... (4)
2620+
// Best choice of k is the one which gives equal and opposite values
2621+
// at extrema +- -+
2622+
// 1 | 2 |
2623+
// --- - 1 = - |--- - 1 |
2624+
// k | k | .... (5)
2625+
// +- -+
2626+
//
2627+
// Solving for k, k = 3/2,
2628+
// From (4), using 'k' value, range is therefore [-0.3333, 0.3333]
2629+
//
2630+
// 2. Polynomial Approximation:
2631+
// More information refer to tools/sollya/vrs4_logf.sollya
2632+
//
2633+
// 7th Deg - Error abs: 0x1.04c4ac98p-22 rel: 0x1.2216e6f8p-19
2634+
// 6th Deg - Error abs: 0x1.179e97d8p-19 rel: 0x1.db676c1p-17
2635+
2636+
private const uint V_MIN = 0x00800000;
2637+
private const uint V_MAX = 0x7F800000;
2638+
private const uint V_MASK = 0x007FFFFF;
2639+
private const uint V_OFF = 0x3F2AAAAB;
2640+
2641+
private const float V_LN2 = 0.6931472f;
2642+
2643+
private const float C0 = 0.0f;
2644+
private const float C1 = 1.0f;
2645+
private const float C2 = -0.5000001f;
2646+
private const float C3 = 0.33332965f;
2647+
private const float C4 = -0.24999046f;
2648+
private const float C5 = 0.20018855f;
2649+
private const float C6 = -0.16700386f;
2650+
private const float C7 = 0.13902695f;
2651+
private const float C8 = -0.1197452f;
2652+
private const float C9 = 0.14401625f;
2653+
private const float C10 = -0.13657966f;
2654+
2655+
public static float Invoke(float x) => MathF.Log(x);
2656+
2657+
public static Vector128<float> Invoke(Vector128<float> x)
2658+
{
2659+
Vector128<float> specialResult = x;
2660+
2661+
// x is subnormal or infinity or NaN
2662+
Vector128<uint> specialMask = Vector128.GreaterThanOrEqual(x.AsUInt32() - Vector128.Create(V_MIN), Vector128.Create(V_MAX - V_MIN));
2663+
2664+
if (specialMask != Vector128<uint>.Zero)
2665+
{
2666+
// float.IsZero(x) ? float.NegativeInfinity : x
2667+
Vector128<float> zeroMask = Vector128.Equals(x, Vector128<float>.Zero);
2668+
2669+
specialResult = Vector128.ConditionalSelect(
2670+
zeroMask,
2671+
Vector128.Create(float.NegativeInfinity),
2672+
specialResult
2673+
);
2674+
2675+
// (x < 0) ? float.NaN : x
2676+
Vector128<float> lessThanZeroMask = Vector128.LessThan(x, Vector128<float>.Zero);
2677+
2678+
specialResult = Vector128.ConditionalSelect(
2679+
lessThanZeroMask,
2680+
Vector128.Create(float.NaN),
2681+
specialResult
2682+
);
2683+
2684+
// float.IsZero(x) | (x < 0) | float.IsNaN(x) | float.IsPositiveInfinity(x)
2685+
Vector128<float> temp = zeroMask
2686+
| lessThanZeroMask
2687+
| ~Vector128.Equals(x, x)
2688+
| Vector128.Equals(x, Vector128.Create(float.PositiveInfinity));
2689+
2690+
// subnormal
2691+
Vector128<float> subnormalMask = Vector128.AndNot(specialMask.AsSingle(), temp);
2692+
2693+
x = Vector128.ConditionalSelect(
2694+
subnormalMask,
2695+
((x * 8388608.0f).AsUInt32() - Vector128.Create(23u << 23)).AsSingle(),
2696+
x
2697+
);
2698+
2699+
specialMask = temp.AsUInt32();
2700+
}
2701+
2702+
Vector128<uint> vx = x.AsUInt32() - Vector128.Create(V_OFF);
2703+
Vector128<float> n = Vector128.ConvertToSingle(Vector128.ShiftRightArithmetic(vx.AsInt32(), 23));
2704+
2705+
vx = (vx & Vector128.Create(V_MASK)) + Vector128.Create(V_OFF);
2706+
2707+
Vector128<float> r = vx.AsSingle() - Vector128.Create(1.0f);
2708+
2709+
Vector128<float> r2 = r * r;
2710+
Vector128<float> r4 = r2 * r2;
2711+
Vector128<float> r8 = r4 * r4;
2712+
2713+
Vector128<float> q = (Vector128.Create(C10) * r2 + (Vector128.Create(C9) * r + Vector128.Create(C8)))
2714+
* r8 + (((Vector128.Create(C7) * r + Vector128.Create(C6))
2715+
* r2 + (Vector128.Create(C5) * r + Vector128.Create(C4)))
2716+
* r4 + ((Vector128.Create(C3) * r + Vector128.Create(C2))
2717+
* r2 + (Vector128.Create(C1) * r + Vector128.Create(C0))));
2718+
2719+
return Vector128.ConditionalSelect(
2720+
specialMask.AsSingle(),
2721+
specialResult,
2722+
n * Vector128.Create(V_LN2) + q
2723+
);
2724+
}
2725+
2726+
public static Vector256<float> Invoke(Vector256<float> x)
2727+
{
2728+
Vector256<float> specialResult = x;
2729+
2730+
// x is subnormal or infinity or NaN
2731+
Vector256<uint> specialMask = Vector256.GreaterThanOrEqual(x.AsUInt32() - Vector256.Create(V_MIN), Vector256.Create(V_MAX - V_MIN));
2732+
2733+
if (specialMask != Vector256<uint>.Zero)
2734+
{
2735+
// float.IsZero(x) ? float.NegativeInfinity : x
2736+
Vector256<float> zeroMask = Vector256.Equals(x, Vector256<float>.Zero);
2737+
2738+
specialResult = Vector256.ConditionalSelect(
2739+
zeroMask,
2740+
Vector256.Create(float.NegativeInfinity),
2741+
specialResult
2742+
);
2743+
2744+
// (x < 0) ? float.NaN : x
2745+
Vector256<float> lessThanZeroMask = Vector256.LessThan(x, Vector256<float>.Zero);
2746+
2747+
specialResult = Vector256.ConditionalSelect(
2748+
lessThanZeroMask,
2749+
Vector256.Create(float.NaN),
2750+
specialResult
2751+
);
2752+
2753+
// float.IsZero(x) | (x < 0) | float.IsNaN(x) | float.IsPositiveInfinity(x)
2754+
Vector256<float> temp = zeroMask
2755+
| lessThanZeroMask
2756+
| ~Vector256.Equals(x, x)
2757+
| Vector256.Equals(x, Vector256.Create(float.PositiveInfinity));
2758+
2759+
// subnormal
2760+
Vector256<float> subnormalMask = Vector256.AndNot(specialMask.AsSingle(), temp);
2761+
2762+
x = Vector256.ConditionalSelect(
2763+
subnormalMask,
2764+
((x * 8388608.0f).AsUInt32() - Vector256.Create(23u << 23)).AsSingle(),
2765+
x
2766+
);
2767+
2768+
specialMask = temp.AsUInt32();
2769+
}
2770+
2771+
Vector256<uint> vx = x.AsUInt32() - Vector256.Create(V_OFF);
2772+
Vector256<float> n = Vector256.ConvertToSingle(Vector256.ShiftRightArithmetic(vx.AsInt32(), 23));
2773+
2774+
vx = (vx & Vector256.Create(V_MASK)) + Vector256.Create(V_OFF);
2775+
2776+
Vector256<float> r = vx.AsSingle() - Vector256.Create(1.0f);
2777+
2778+
Vector256<float> r2 = r * r;
2779+
Vector256<float> r4 = r2 * r2;
2780+
Vector256<float> r8 = r4 * r4;
2781+
2782+
Vector256<float> q = (Vector256.Create(C10) * r2 + (Vector256.Create(C9) * r + Vector256.Create(C8)))
2783+
* r8 + (((Vector256.Create(C7) * r + Vector256.Create(C6))
2784+
* r2 + (Vector256.Create(C5) * r + Vector256.Create(C4)))
2785+
* r4 + ((Vector256.Create(C3) * r + Vector256.Create(C2))
2786+
* r2 + (Vector256.Create(C1) * r + Vector256.Create(C0))));
2787+
2788+
return Vector256.ConditionalSelect(
2789+
specialMask.AsSingle(),
2790+
specialResult,
2791+
n * Vector256.Create(V_LN2) + q
2792+
);
2793+
}
2794+
2795+
#if NET8_0_OR_GREATER
2796+
public static Vector512<float> Invoke(Vector512<float> x)
2797+
{
2798+
Vector512<float> specialResult = x;
2799+
2800+
// x is subnormal or infinity or NaN
2801+
Vector512<uint> specialMask = Vector512.GreaterThanOrEqual(x.AsUInt32() - Vector512.Create(V_MIN), Vector512.Create(V_MAX - V_MIN));
2802+
2803+
if (specialMask != Vector512<uint>.Zero)
2804+
{
2805+
// float.IsZero(x) ? float.NegativeInfinity : x
2806+
Vector512<float> zeroMask = Vector512.Equals(x, Vector512<float>.Zero);
2807+
2808+
specialResult = Vector512.ConditionalSelect(
2809+
zeroMask,
2810+
Vector512.Create(float.NegativeInfinity),
2811+
specialResult
2812+
);
2813+
2814+
// (x < 0) ? float.NaN : x
2815+
Vector512<float> lessThanZeroMask = Vector512.LessThan(x, Vector512<float>.Zero);
2816+
2817+
specialResult = Vector512.ConditionalSelect(
2818+
lessThanZeroMask,
2819+
Vector512.Create(float.NaN),
2820+
specialResult
2821+
);
2822+
2823+
// float.IsZero(x) | (x < 0) | float.IsNaN(x) | float.IsPositiveInfinity(x)
2824+
Vector512<float> temp = zeroMask
2825+
| lessThanZeroMask
2826+
| ~Vector512.Equals(x, x)
2827+
| Vector512.Equals(x, Vector512.Create(float.PositiveInfinity));
2828+
2829+
// subnormal
2830+
Vector512<float> subnormalMask = Vector512.AndNot(specialMask.AsSingle(), temp);
2831+
2832+
x = Vector512.ConditionalSelect(
2833+
subnormalMask,
2834+
((x * 8388608.0f).AsUInt32() - Vector512.Create(23u << 23)).AsSingle(),
2835+
x
2836+
);
2837+
2838+
specialMask = temp.AsUInt32();
2839+
}
2840+
2841+
Vector512<uint> vx = x.AsUInt32() - Vector512.Create(V_OFF);
2842+
Vector512<float> n = Vector512.ConvertToSingle(Vector512.ShiftRightArithmetic(vx.AsInt32(), 23));
2843+
2844+
vx = (vx & Vector512.Create(V_MASK)) + Vector512.Create(V_OFF);
2845+
2846+
Vector512<float> r = vx.AsSingle() - Vector512.Create(1.0f);
2847+
2848+
Vector512<float> r2 = r * r;
2849+
Vector512<float> r4 = r2 * r2;
2850+
Vector512<float> r8 = r4 * r4;
2851+
2852+
Vector512<float> q = (Vector512.Create(C10) * r2 + (Vector512.Create(C9) * r + Vector512.Create(C8)))
2853+
* r8 + (((Vector512.Create(C7) * r + Vector512.Create(C6))
2854+
* r2 + (Vector512.Create(C5) * r + Vector512.Create(C4)))
2855+
* r4 + ((Vector512.Create(C3) * r + Vector512.Create(C2))
2856+
* r2 + (Vector512.Create(C1) * r + Vector512.Create(C0))));
2857+
2858+
return Vector512.ConditionalSelect(
2859+
specialMask.AsSingle(),
2860+
specialResult,
2861+
n * Vector512.Create(V_LN2) + q
2862+
);
2863+
}
2864+
#endif
2865+
}
2866+
25822867
private readonly struct Log2Operator : IUnaryOperator
25832868
{
25842869
// This code is based on `vrs4_log2f` from amd/aocl-libm-ose

src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -923,6 +923,19 @@ public Vector<float> Invoke(Vector<float> x, Vector<float> y)
923923
public Vector<float> Invoke(Vector<float> x) => Vector.Abs(x);
924924
}
925925

926+
private readonly struct LogOperator : IUnaryOperator
927+
{
928+
public bool CanVectorize => false;
929+
930+
public float Invoke(float x) => MathF.Log(x);
931+
932+
public Vector<float> Invoke(Vector<float> x)
933+
{
934+
// Vectorizing requires shift right support, which is .NET 7 or later
935+
throw new NotImplementedException();
936+
}
937+
}
938+
926939
private readonly struct Log2Operator : IUnaryOperator
927940
{
928941
public bool CanVectorize => false;

src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1006,6 +1006,41 @@ public static void Log_InPlace(int tensorLength)
10061006
}
10071007
}
10081008

1009+
[Theory]
1010+
[MemberData(nameof(TensorLengths))]
1011+
public static void Log_SpecialValues(int tensorLength)
1012+
{
1013+
using BoundedMemory<float> x = CreateAndFillTensor(tensorLength);
1014+
using BoundedMemory<float> destination = CreateTensor(tensorLength);
1015+
1016+
// NaN
1017+
x[s_random.Next(x.Length)] = float.NaN;
1018+
1019+
// +Infinity
1020+
x[s_random.Next(x.Length)] = float.PositiveInfinity;
1021+
1022+
// -Infinity
1023+
x[s_random.Next(x.Length)] = float.NegativeInfinity;
1024+
1025+
// +Zero
1026+
x[s_random.Next(x.Length)] = +0.0f;
1027+
1028+
// -Zero
1029+
x[s_random.Next(x.Length)] = -0.0f;
1030+
1031+
// +Epsilon
1032+
x[s_random.Next(x.Length)] = +float.Epsilon;
1033+
1034+
// -Epsilon
1035+
x[s_random.Next(x.Length)] = -float.Epsilon;
1036+
1037+
TensorPrimitives.Log(x, destination);
1038+
for (int i = 0; i < tensorLength; i++)
1039+
{
1040+
Assert.Equal(MathF.Log(x[i]), destination[i], Tolerance);
1041+
}
1042+
}
1043+
10091044
[Theory]
10101045
[MemberData(nameof(TensorLengths))]
10111046
public static void Log_ThrowsForTooShortDestination(int tensorLength)

0 commit comments

Comments
 (0)