Skip to content

Implement ToImmutable*Async extension methods #1545

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Aug 28, 2024
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -52,28 +52,21 @@ private static IEnumerable<AsyncMethodGrouping> GetMethodsGroupedBySyntaxTree(Ge

private static string GenerateOverloads(AsyncMethodGrouping grouping, GenerationOptions options)
{
var usings = grouping.SyntaxTree.GetRoot() is CompilationUnitSyntax compilationUnit
? compilationUnit.Usings.ToString()
: string.Empty;

var overloads = new StringBuilder();
overloads.AppendLine("#nullable enable");
overloads.AppendLine(usings);
overloads.AppendLine("namespace System.Linq");
overloads.AppendLine("{");
overloads.AppendLine(" partial class AsyncEnumerable");
overloads.AppendLine(" {");

foreach (var method in grouping.Methods)
overloads.AppendLine(GenerateOverload(method, options));

overloads.AppendLine(" }");
overloads.AppendLine("}");

return overloads.ToString();
var compilationRoot = grouping.SyntaxTree.GetCompilationUnitRoot();
var namespaceDeclaration = compilationRoot.ChildNodes().OfType<NamespaceDeclarationSyntax>().Single();
var classDeclaration = namespaceDeclaration.ChildNodes().OfType<ClassDeclarationSyntax>().Single();

return CompilationUnit()
.WithUsings(List(compilationRoot.Usings.Select(@using => @using.WithoutTrivia())))
.AddMembers(NamespaceDeclaration(namespaceDeclaration.Name)
.AddMembers(ClassDeclaration(classDeclaration.Identifier)
.AddModifiers(Token(SyntaxKind.PartialKeyword))
.WithMembers(List(grouping.Methods.Select(method => GenerateOverload(method, options))))))
.NormalizeWhitespace()
.ToFullString();
}

private static string GenerateOverload(AsyncMethod method, GenerationOptions options)
private static MemberDeclarationSyntax GenerateOverload(AsyncMethod method, GenerationOptions options)
=> MethodDeclaration(method.Syntax.ReturnType, GetMethodName(method.Symbol, options))
.WithModifiers(TokenList(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.StaticKeyword)))
.WithTypeParameterList(method.Syntax.TypeParameterList)
Expand All @@ -87,9 +80,7 @@ private static string GenerateOverload(AsyncMethod method, GenerationOptions opt
method.Syntax.ParameterList.Parameters
.Select(p => Argument(IdentifierName(p.Identifier))))))))
.WithSemicolonToken(Token(SyntaxKind.SemicolonToken))
.WithLeadingTrivia(method.Syntax.GetLeadingTrivia().Where(t => t.GetStructure() is not DirectiveTriviaSyntax))
.NormalizeWhitespace()
.ToFullString();
.WithLeadingTrivia(method.Syntax.GetLeadingTrivia().Where(t => t.GetStructure() is not DirectiveTriviaSyntax));

private static INamedTypeSymbol GetAsyncOverloadAttributeSymbol(GeneratorExecutionContext context)
=> context.Compilation.GetTypeByMetadataName("System.Linq.GenerateAsyncOverloadAttribute") ?? throw new InvalidOperationException();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT License.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Xunit;

namespace Tests
{
public class ToImmutableArray : AsyncEnumerableTests
{
[Fact]
public async Task ToImmutableArray_Null()
{
await Assert.ThrowsAsync<ArgumentNullException>(() => ImmutableArrayAsyncEnumerableExtensions.ToImmutableArrayAsync<int>(default).AsTask());
await Assert.ThrowsAsync<ArgumentNullException>(() => ImmutableArrayAsyncEnumerableExtensions.ToImmutableArrayAsync<int>(default, CancellationToken.None).AsTask());
}

[Fact]
public async Task ToImmutableArray_IAsyncIListProvider_Simple()
{
var xs = new[] { 42, 25, 39 };
var res = xs.ToAsyncEnumerable().ToImmutableArrayAsync();
Assert.True((await res).SequenceEqual(xs));
}

[Fact]
public async Task ToImmutableArray_IAsyncIListProvider_Empty1()
{
var xs = new int[0];
var res = xs.ToAsyncEnumerable().ToImmutableArrayAsync();
Assert.True((await res).SequenceEqual(xs));
}

[Fact]
public async Task ToImmutableArray_IAsyncIListProvider_Empty2()
{
var xs = new HashSet<int>();
var res = xs.ToAsyncEnumerable().ToImmutableArrayAsync();
Assert.True((await res).SequenceEqual(xs));
}

[Fact]
public async Task ToImmutableArray_Empty()
{
var xs = AsyncEnumerable.Empty<int>();
var res = xs.ToImmutableArrayAsync();
Assert.True((await res).Length == 0);
}

[Fact]
public async Task ToImmutableArray_Throw()
{
var ex = new Exception("Bang!");
var res = Throw<int>(ex).ToImmutableArrayAsync();
await AssertThrowsAsync(res, ex);
}

[Fact]
public async Task ToImmutableArray_Query()
{
var xs = await AsyncEnumerable.Range(5, 50).Take(10).ToImmutableArrayAsync();
var ex = new[] { 5, 6, 7, 8, 9, 10, 11, 12, 13, 14 };

Assert.True(ex.SequenceEqual(xs));
}

[Fact]
public async Task ToImmutableArray_Set()
{
var res = new[] { 5, 6, 7, 8, 9, 10, 11, 12, 13, 14 };
var xs = new HashSet<int>(res);

var arr = await xs.ToAsyncEnumerable().ToImmutableArrayAsync();

Assert.True(res.SequenceEqual(arr));
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT License.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Xunit;

namespace Tests
{
public class ToImmutableDictionary : AsyncEnumerableTests
{
[Fact]
public async Task ToImmutableDictionary_Null()
{
await Assert.ThrowsAsync<ArgumentNullException>(() => ImmutableDictionaryAsyncEnumerableExtensions.ToImmutableDictionaryAsync<int, int>(default, x => 0).AsTask());
await Assert.ThrowsAsync<ArgumentNullException>(() => ImmutableDictionaryAsyncEnumerableExtensions.ToImmutableDictionaryAsync(Return42, default(Func<int, int>)).AsTask());

await Assert.ThrowsAsync<ArgumentNullException>(() => ImmutableDictionaryAsyncEnumerableExtensions.ToImmutableDictionaryAsync<int, int>(default, x => 0, EqualityComparer<int>.Default).AsTask());
await Assert.ThrowsAsync<ArgumentNullException>(() => ImmutableDictionaryAsyncEnumerableExtensions.ToImmutableDictionaryAsync(Return42, default, EqualityComparer<int>.Default).AsTask());

await Assert.ThrowsAsync<ArgumentNullException>(() => ImmutableDictionaryAsyncEnumerableExtensions.ToImmutableDictionaryAsync<int, int, int>(default, x => 0, x => 0).AsTask());
await Assert.ThrowsAsync<ArgumentNullException>(() => ImmutableDictionaryAsyncEnumerableExtensions.ToImmutableDictionaryAsync<int, int, int>(Return42, default, x => 0).AsTask());
await Assert.ThrowsAsync<ArgumentNullException>(() => ImmutableDictionaryAsyncEnumerableExtensions.ToImmutableDictionaryAsync<int, int, int>(Return42, x => 0, default).AsTask());

await Assert.ThrowsAsync<ArgumentNullException>(() => ImmutableDictionaryAsyncEnumerableExtensions.ToImmutableDictionaryAsync<int, int, int>(default, x => 0, x => 0, EqualityComparer<int>.Default).AsTask());
await Assert.ThrowsAsync<ArgumentNullException>(() => ImmutableDictionaryAsyncEnumerableExtensions.ToImmutableDictionaryAsync(Return42, default, x => 0, EqualityComparer<int>.Default).AsTask());
await Assert.ThrowsAsync<ArgumentNullException>(() => ImmutableDictionaryAsyncEnumerableExtensions.ToImmutableDictionaryAsync<int, int, int>(Return42, x => 0, default, EqualityComparer<int>.Default).AsTask());

await Assert.ThrowsAsync<ArgumentNullException>(() => ImmutableDictionaryAsyncEnumerableExtensions.ToImmutableDictionaryAsync<int, int>(default, x => 0, CancellationToken.None).AsTask());
await Assert.ThrowsAsync<ArgumentNullException>(() => ImmutableDictionaryAsyncEnumerableExtensions.ToImmutableDictionaryAsync(Return42, default(Func<int, int>), CancellationToken.None).AsTask());

await Assert.ThrowsAsync<ArgumentNullException>(() => ImmutableDictionaryAsyncEnumerableExtensions.ToImmutableDictionaryAsync<int, int>(default, x => 0, EqualityComparer<int>.Default, CancellationToken.None).AsTask());
await Assert.ThrowsAsync<ArgumentNullException>(() => ImmutableDictionaryAsyncEnumerableExtensions.ToImmutableDictionaryAsync(Return42, default, EqualityComparer<int>.Default, CancellationToken.None).AsTask());

await Assert.ThrowsAsync<ArgumentNullException>(() => ImmutableDictionaryAsyncEnumerableExtensions.ToImmutableDictionaryAsync<int, int, int>(default, x => 0, x => 0, CancellationToken.None).AsTask());
await Assert.ThrowsAsync<ArgumentNullException>(() => ImmutableDictionaryAsyncEnumerableExtensions.ToImmutableDictionaryAsync<int, int, int>(Return42, default, x => 0, CancellationToken.None).AsTask());
await Assert.ThrowsAsync<ArgumentNullException>(() => ImmutableDictionaryAsyncEnumerableExtensions.ToImmutableDictionaryAsync<int, int, int>(Return42, x => 0, default, CancellationToken.None).AsTask());

await Assert.ThrowsAsync<ArgumentNullException>(() => ImmutableDictionaryAsyncEnumerableExtensions.ToImmutableDictionaryAsync<int, int, int>(default, x => 0, x => 0, EqualityComparer<int>.Default, CancellationToken.None).AsTask());
await Assert.ThrowsAsync<ArgumentNullException>(() => ImmutableDictionaryAsyncEnumerableExtensions.ToImmutableDictionaryAsync(Return42, default, x => 0, EqualityComparer<int>.Default, CancellationToken.None).AsTask());
await Assert.ThrowsAsync<ArgumentNullException>(() => ImmutableDictionaryAsyncEnumerableExtensions.ToImmutableDictionaryAsync<int, int, int>(Return42, x => 0, default, EqualityComparer<int>.Default, CancellationToken.None).AsTask());
}

[Fact]
public async Task ToImmutableDictionary1Async()
{
var xs = new[] { 1, 4 }.ToAsyncEnumerable();
var res = await xs.ToImmutableDictionaryAsync(x => x % 2);
Assert.True(res[0] == 4);
Assert.True(res[1] == 1);
}

[Fact]
public async Task ToImmutableDictionary2Async()
{
var xs = new[] { 1, 4, 2 }.ToAsyncEnumerable();
await AssertThrowsAsync<ArgumentException>(xs.ToImmutableDictionaryAsync(x => x % 2).AsTask());
}

[Fact]
public async Task ToImmutableDictionary3Async()
{
var xs = new[] { 1, 4 }.ToAsyncEnumerable();
var res = await xs.ToImmutableDictionaryAsync(x => x % 2, x => x + 1);
Assert.True(res[0] == 5);
Assert.True(res[1] == 2);
}

[Fact]
public async Task ToImmutableDictionary4Async()
{
var xs = new[] { 1, 4, 2 }.ToAsyncEnumerable();
await AssertThrowsAsync<ArgumentException>(xs.ToImmutableDictionaryAsync(x => x % 2, x => x + 1).AsTask());
}

[Fact]
public async Task ToImmutableDictionary5Async()
{
var xs = new[] { 1, 4 }.ToAsyncEnumerable();
var res = await xs.ToImmutableDictionaryAsync(x => x % 2, new Eq());
Assert.True(res[0] == 4);
Assert.True(res[1] == 1);
}

[Fact]
public async Task ToImmutableDictionary6Async()
{
var xs = new[] { 1, 4, 2 }.ToAsyncEnumerable();
await AssertThrowsAsync<ArgumentException>(xs.ToImmutableDictionaryAsync(x => x % 2, new Eq()).AsTask());
}

[Fact]
public async Task ToImmutableDictionary7Async()
{
var xs = new[] { 1, 4 }.ToAsyncEnumerable();
var res = await xs.ToImmutableDictionaryAsync(x => x % 2, x => x, new Eq());
Assert.True(res[0] == 4);
Assert.True(res[1] == 1);
}

private sealed class Eq : IEqualityComparer<int>
{
public bool Equals(int x, int y) => EqualityComparer<int>.Default.Equals(Math.Abs(x), Math.Abs(y));

public int GetHashCode(int obj) => EqualityComparer<int>.Default.GetHashCode(Math.Abs(obj));
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT License.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Xunit;

namespace Tests
{
public class ToImmutableHashSet : AsyncEnumerableTests
{
[Fact]
public async Task ToImmutableHashSet_Null()
{
await Assert.ThrowsAsync<ArgumentNullException>(() => ImmutableHashSetAsyncEnumerableExtensions.ToImmutableHashSetAsync<int>(default).AsTask());
await Assert.ThrowsAsync<ArgumentNullException>(() => ImmutableHashSetAsyncEnumerableExtensions.ToImmutableHashSetAsync<int>(default, CancellationToken.None).AsTask());

await Assert.ThrowsAsync<ArgumentNullException>(() => ImmutableHashSetAsyncEnumerableExtensions.ToImmutableHashSetAsync(default, EqualityComparer<int>.Default, CancellationToken.None).AsTask());
}

[Fact]
public async Task ToImmutableHashSet_Simple()
{
var xs = new[] { 1, 2, 1, 2, 3, 4, 1, 2, 3, 4 };
var res = xs.ToAsyncEnumerable().ToImmutableHashSetAsync();
Assert.True((await res).OrderBy(x => x).SequenceEqual(new[] { 1, 2, 3, 4 }));
}

[Fact]
public async Task ToImmutableHashSet_Comparer()
{
var xs = new[] { 1, 12, 11, 2, 3, 14, 1, 12, 13, 4 };
var res = xs.ToAsyncEnumerable().ToImmutableHashSetAsync(new Eq());
Assert.True((await res).OrderBy(x => x).SequenceEqual(new[] { 1, 3, 12, 14 }));
}

private class Eq : IEqualityComparer<int>
{
public bool Equals(int x, int y) => x % 10 == y % 10;

public int GetHashCode(int obj) => obj % 10;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT License.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Immutable;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Xunit;

namespace Tests
{
public class ToImmutableList : AsyncEnumerableTests
{
[Fact]
public async Task ToImmutableList_Null()
{
await Assert.ThrowsAsync<ArgumentNullException>(() => ImmutableListAsyncEnumerableExtensions.ToImmutableListAsync<int>(default).AsTask());
await Assert.ThrowsAsync<ArgumentNullException>(() => ImmutableListAsyncEnumerableExtensions.ToImmutableListAsync<int>(default, CancellationToken.None).AsTask());
}

[Fact]
public async Task ToImmutableList_Simple()
{
var xs = new[] { 42, 25, 39 };
var res = xs.ToAsyncEnumerable().ToImmutableListAsync();
Assert.True((await res).SequenceEqual(xs));
}

[Fact]
public async Task ToImmutableList_Empty()
{
var xs = AsyncEnumerable.Empty<int>();
var res = xs.ToImmutableListAsync();
Assert.True((await res).Count == 0);
}

[Fact]
public async Task ToImmutableList_Throw()
{
var ex = new Exception("Bang!");
var res = Throw<int>(ex).ToImmutableListAsync();
await AssertThrowsAsync(res, ex);
}
}
}
Loading