Skip to content

Commit 3bf40a3

Browse files
authored
Update TensorPrimitives aggregations to vectorize handling of remaining elements (#92672)
* Update TensorPrimitives.CosineSimilarity to vectorize handling of remaining elements * Vectorize remainder handling for Aggregate helpers
1 parent dc1f86a commit 3bf40a3

File tree

3 files changed

+443
-205
lines changed

3 files changed

+443
-205
lines changed

src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ public static float Distance(ReadOnlySpan<float> x, ReadOnlySpan<float> y)
126126
ThrowHelper.ThrowArgument_SpansMustHaveSameLength();
127127
}
128128

129-
return MathF.Sqrt(Aggregate<SubtractSquaredOperator, AddOperator>(0f, x, y));
129+
return MathF.Sqrt(Aggregate<SubtractSquaredOperator, AddOperator>(x, y));
130130
}
131131

132132
/// <summary>Computes the element-wise result of: <c><paramref name="x" /> / <paramref name="y" /></c>.</summary>
@@ -162,7 +162,7 @@ public static float Dot(ReadOnlySpan<float> x, ReadOnlySpan<float> y) // BLAS1:
162162
ThrowHelper.ThrowArgument_SpansMustHaveSameLength();
163163
}
164164

165-
return Aggregate<MultiplyOperator, AddOperator>(0f, x, y);
165+
return Aggregate<MultiplyOperator, AddOperator>(x, y);
166166
}
167167

168168
/// <summary>Computes the element-wise result of: <c>pow(e, <paramref name="x" />)</c>.</summary>
@@ -545,7 +545,7 @@ public static void Negate(ReadOnlySpan<float> x, Span<float> destination) =>
545545
/// <param name="x">The first tensor, represented as a span.</param>
546546
/// <returns>The L2 norm.</returns>
547547
public static float Norm(ReadOnlySpan<float> x) => // BLAS1: nrm2
548-
MathF.Sqrt(Aggregate<SquaredOperator, AddOperator>(0f, x));
548+
MathF.Sqrt(Aggregate<SquaredOperator, AddOperator>(x));
549549

550550
/// <summary>Computes the product of all elements in <paramref name="x"/>.</summary>
551551
/// <param name="x">The tensor, represented as a span.</param>
@@ -558,7 +558,7 @@ public static float Product(ReadOnlySpan<float> x)
558558
ThrowHelper.ThrowArgument_SpansMustBeNonEmpty();
559559
}
560560

561-
return Aggregate<IdentityOperator, MultiplyOperator>(1.0f, x);
561+
return Aggregate<IdentityOperator, MultiplyOperator>(x);
562562
}
563563

564564
/// <summary>Computes the product of the element-wise result of: <c><paramref name="x" /> - <paramref name="y" /></c>.</summary>
@@ -580,7 +580,7 @@ public static float ProductOfDifferences(ReadOnlySpan<float> x, ReadOnlySpan<flo
580580
ThrowHelper.ThrowArgument_SpansMustHaveSameLength();
581581
}
582582

583-
return Aggregate<SubtractOperator, MultiplyOperator>(1.0f, x, y);
583+
return Aggregate<SubtractOperator, MultiplyOperator>(x, y);
584584
}
585585

586586
/// <summary>Computes the product of the element-wise result of: <c><paramref name="x" /> + <paramref name="y" /></c>.</summary>
@@ -602,7 +602,7 @@ public static float ProductOfSums(ReadOnlySpan<float> x, ReadOnlySpan<float> y)
602602
ThrowHelper.ThrowArgument_SpansMustHaveSameLength();
603603
}
604604

605-
return Aggregate<AddOperator, MultiplyOperator>(1.0f, x, y);
605+
return Aggregate<AddOperator, MultiplyOperator>(x, y);
606606
}
607607

608608
/// <summary>
@@ -703,7 +703,7 @@ public static void Subtract(ReadOnlySpan<float> x, float y, Span<float> destinat
703703
/// <param name="x">The tensor, represented as a span.</param>
704704
/// <returns>The result of adding all elements in <paramref name="x"/>, or zero if <paramref name="x"/> is empty.</returns>
705705
public static float Sum(ReadOnlySpan<float> x) =>
706-
Aggregate<IdentityOperator, AddOperator>(0f, x);
706+
Aggregate<IdentityOperator, AddOperator>(x);
707707

708708
/// <summary>Computes the sum of the absolute values of every element in <paramref name="x"/>.</summary>
709709
/// <param name="x">The tensor, represented as a span.</param>
@@ -713,14 +713,14 @@ public static float Sum(ReadOnlySpan<float> x) =>
713713
/// <para>This method corresponds to the <c>asum</c> method defined by <c>BLAS1</c>.</para>
714714
/// </remarks>
715715
public static float SumOfMagnitudes(ReadOnlySpan<float> x) =>
716-
Aggregate<AbsoluteOperator, AddOperator>(0f, x);
716+
Aggregate<AbsoluteOperator, AddOperator>(x);
717717

718718
/// <summary>Computes the sum of the squares of every element in <paramref name="x"/>.</summary>
719719
/// <param name="x">The tensor, represented as a span.</param>
720720
/// <returns>The result of adding every element in <paramref name="x"/> multiplied by itself, or zero if <paramref name="x"/> is empty.</returns>
721721
/// <remarks>This method effectively does <c><see cref="TensorPrimitives" />.Sum(<see cref="TensorPrimitives" />.Multiply(<paramref name="x" />, <paramref name="x" />))</c>.</remarks>
722722
public static float SumOfSquares(ReadOnlySpan<float> x) =>
723-
Aggregate<SquaredOperator, AddOperator>(0f, x);
723+
Aggregate<SquaredOperator, AddOperator>(x);
724724

725725
/// <summary>Computes the element-wise result of: <c>tanh(<paramref name="x" />)</c>.</summary>
726726
/// <param name="x">The tensor, represented as a span.</param>
@@ -739,5 +739,31 @@ public static void Tanh(ReadOnlySpan<float> x, Span<float> destination)
739739
destination[i] = MathF.Tanh(x[i]);
740740
}
741741
}
742+
743+
/// <summary>Mask used to handle remaining elements after vectorized handling of the input.</summary>
744+
/// <remarks>
745+
/// Logically 16 rows of 16 uints. The Nth row should be used to handle N remaining elements at the
746+
/// end of the input, where elements in the vector prior to that will be zero'd.
747+
/// </remarks>
748+
private static ReadOnlySpan<uint> RemainderUInt32Mask_16x16 => new uint[]
749+
{
750+
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
751+
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF,
752+
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF,
753+
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
754+
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
755+
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
756+
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
757+
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
758+
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
759+
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
760+
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
761+
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
762+
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
763+
0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
764+
0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
765+
0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
766+
0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
767+
};
742768
}
743769
}

0 commit comments

Comments
 (0)