Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,46 @@ public static Tensor<T> Resize<T>(Tensor<T> input, ReadOnlySpan<nint> shape)
#endregion

#region Broadcast
/// <summary>
/// Broadcast the data from <paramref name="left"/> to the smallest broadcastable shape compatible with <paramref name="right"/>. Creates a new <see cref="Tensor{T}"/> and allocates new memory.
/// If the shapes are not compatible, <see cref="bool"/> is returned.
/// </summary>
/// <param name="left">Input <see cref="Tensor{T}"/>.</param>
/// <param name="right">Other <see cref="Tensor{T}"/> to make shapes broadcastable.</param>
/// <param name="destination">Destination <see cref="Tensor{T}"/>.</param>
/// <returns></returns>
public static bool TryBroadcastTo<T>(Tensor<T> left, Tensor<T> right, out Tensor<T> destination)
where T : IEquatable<T>, IEqualityOperators<T, T, bool>
{
try
{
nint[] newSize = TensorHelpers.GetSmallestBroadcastableSize(left.Lengths, right.Lengths);

Tensor<T> intermediate = BroadcastTo(left, newSize);
destination = Tensor.Create(intermediate.ToArray(), intermediate.Lengths);
return true;
}
catch
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no way to implement this without using exceptions for control flow? This largely defeats the intended purpose of the Try pattern, which is to avoid the expense of exceptions for the condition represented by the Try.

{
destination = Tensor<T>.Empty;
return false;
}
}

/// <summary>
/// Broadcast the data from <paramref name="input"/> to the new shape <paramref name="shape"/>. Creates a new <see cref="Tensor{T}"/> and allocates new memory.
/// If the shape of the <paramref name="input"/> is not compatible with the new shape, an exception is thrown.
/// </summary>
/// <param name="input">Input <see cref="Tensor{T}"/>.</param>
/// <param name="shape"><see cref="ReadOnlySpan{T}"/> of the desired new shape.</param>
/// <exception cref="ArgumentException">Thrown when the shapes are not broadcast compatible.</exception>
public static Tensor<T> TryBroadcastTo<T>(Tensor<T> input, ReadOnlySpan<nint> shape)
where T : IEquatable<T>, IEqualityOperators<T, T, bool>
{
Tensor<T> intermediate = BroadcastTo(input, shape);
return Tensor.Create(intermediate.ToArray(), intermediate.Lengths);
}

/// <summary>
/// Broadcast the data from <paramref name="left"/> to the smallest broadcastable shape compatible with <paramref name="right"/>. Creates a new <see cref="Tensor{T}"/> and allocates new memory.
/// </summary>
Expand Down Expand Up @@ -87,7 +127,7 @@ internal static Tensor<T> BroadcastTo<T>(Tensor<T> input, ReadOnlySpan<nint> sha
if (input.Lengths.SequenceEqual(shape))
return new Tensor<T>(input._values, shape, false);

if (!TensorHelpers.AreShapesBroadcastCompatible(input.Lengths, shape))
if (!TensorHelpers.IsBroadcastableTo(input.Lengths, shape))
ThrowHelper.ThrowArgument_ShapesNotBroadcastCompatible();

nint newSize = TensorSpanHelpers.CalculateTotalLength(shape);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,31 +28,31 @@ public static nint CountTrueElements(Tensor<bool> filter)
return count;
}

internal static bool AreShapesBroadcastCompatible<T>(Tensor<T> tensor1, Tensor<T> tensor2)
where T : IEquatable<T>, IEqualityOperators<T, T, bool> => AreShapesBroadcastCompatible(tensor1.Lengths, tensor2.Lengths);
internal static bool IsBroadcastableTo<T>(Tensor<T> tensor1, Tensor<T> tensor2)
where T : IEquatable<T>, IEqualityOperators<T, T, bool> => IsBroadcastableTo(tensor1.Lengths, tensor2.Lengths);

internal static bool AreShapesBroadcastCompatible(ReadOnlySpan<nint> shape1, ReadOnlySpan<nint> shape2)
internal static bool IsBroadcastableTo(ReadOnlySpan<nint> lengths1, ReadOnlySpan<nint> lengths2)
{
int shape1Index = shape1.Length - 1;
int shape2Index = shape2.Length - 1;
int lengths1Index = lengths1.Length - 1;
int lengths2Index = lengths2.Length - 1;

bool areCompatible = true;

nint s1;
nint s2;

while (shape1Index >= 0 || shape2Index >= 0)
while (lengths1Index >= 0 || lengths2Index >= 0)
{
// if a dimension is missing in one of the shapes, it is considered to be 1
if (shape1Index < 0)
if (lengths1Index < 0)
s1 = 1;
else
s1 = shape1[shape1Index--];
s1 = lengths1[lengths1Index--];

if (shape2Index < 0)
if (lengths2Index < 0)
s2 = 1;
else
s2 = shape2[shape2Index--];
s2 = lengths2[lengths2Index--];

if (s1 == s2 || (s1 == 1 && s2 != 1) || (s2 == 1 && s1 != 1)) { }
else
Expand All @@ -67,7 +67,7 @@ internal static bool AreShapesBroadcastCompatible(ReadOnlySpan<nint> shape1, Rea

internal static nint[] GetSmallestBroadcastableSize(ReadOnlySpan<nint> shape1, ReadOnlySpan<nint> shape2)
{
if (!AreShapesBroadcastCompatible(shape1, shape2))
if (!IsBroadcastableTo(shape1, shape2))
throw new Exception("Shapes are not broadcast compatible");

nint[] intermediateShape = GetIntermediateShape(shape1, shape2.Length);
Expand Down
Loading