Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ private static async Task<Document> RefactorAsync(

IMethodSymbol methodSymbol = semanticModel.GetDeclaredSymbol(methodDeclaration, cancellationToken);

UseAsyncAwaitRewriter rewriter = UseAsyncAwaitRewriter.Create(methodSymbol);
UseAsyncAwaitRewriter rewriter = UseAsyncAwaitRewriter.Create(methodSymbol, semanticModel, node.SpanStart);
var newBody = (BlockSyntax)rewriter.VisitBlock(newNode.Body);

newNode = newNode
Expand All @@ -138,7 +138,7 @@ private static async Task<Document> RefactorAsync(

IMethodSymbol methodSymbol = semanticModel.GetDeclaredSymbol(localFunction, cancellationToken);

UseAsyncAwaitRewriter rewriter = UseAsyncAwaitRewriter.Create(methodSymbol);
UseAsyncAwaitRewriter rewriter = UseAsyncAwaitRewriter.Create(methodSymbol, semanticModel, node.SpanStart);
var newBody = (BlockSyntax)rewriter.VisitBlock(newNode.Body);

newNode = newNode
Expand All @@ -156,7 +156,7 @@ private static async Task<Document> RefactorAsync(

var methodSymbol = (IMethodSymbol)semanticModel.GetSymbol(lambdaExpression, cancellationToken);

UseAsyncAwaitRewriter rewriter = UseAsyncAwaitRewriter.Create(methodSymbol);
UseAsyncAwaitRewriter rewriter = UseAsyncAwaitRewriter.Create(methodSymbol, semanticModel, node.SpanStart);
var newBody = (BlockSyntax)rewriter.VisitBlock((BlockSyntax)newNode.Body);

newNode = newNode
Expand All @@ -174,7 +174,7 @@ private static async Task<Document> RefactorAsync(

var methodSymbol = (IMethodSymbol)semanticModel.GetSymbol(anonymousMethod, cancellationToken);

UseAsyncAwaitRewriter rewriter = UseAsyncAwaitRewriter.Create(methodSymbol);
UseAsyncAwaitRewriter rewriter = UseAsyncAwaitRewriter.Create(methodSymbol, semanticModel, node.SpanStart);
var newBody = (BlockSyntax)rewriter.VisitBlock((BlockSyntax)newNode.Body);

newNode = newNode
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ private static async Task<Document> RefactorAsync(
{
IMethodSymbol methodSymbol = semanticModel.GetDeclaredSymbol(methodDeclaration, cancellationToken);

UseAsyncAwaitRewriter rewriter = UseAsyncAwaitRewriter.Create(methodSymbol);
UseAsyncAwaitRewriter rewriter = UseAsyncAwaitRewriter.Create(methodSymbol, semanticModel, node.SpanStart);

var newNode = (MethodDeclarationSyntax)rewriter.VisitMethodDeclaration(methodDeclaration);

Expand All @@ -78,7 +78,7 @@ private static async Task<Document> RefactorAsync(
{
IMethodSymbol methodSymbol = semanticModel.GetDeclaredSymbol(localFunction, cancellationToken);

UseAsyncAwaitRewriter rewriter = UseAsyncAwaitRewriter.Create(methodSymbol);
UseAsyncAwaitRewriter rewriter = UseAsyncAwaitRewriter.Create(methodSymbol, semanticModel, node.SpanStart);

var newBody = (BlockSyntax)rewriter.VisitBlock(localFunction.Body);

Expand All @@ -92,7 +92,7 @@ private static async Task<Document> RefactorAsync(
{
var methodSymbol = (IMethodSymbol)semanticModel.GetSymbol(lambda, cancellationToken);

UseAsyncAwaitRewriter rewriter = UseAsyncAwaitRewriter.Create(methodSymbol);
UseAsyncAwaitRewriter rewriter = UseAsyncAwaitRewriter.Create(methodSymbol, semanticModel, node.SpanStart);

var newBody = (BlockSyntax)rewriter.VisitBlock((BlockSyntax)lambda.Body);

Expand All @@ -106,7 +106,7 @@ private static async Task<Document> RefactorAsync(
{
var methodSymbol = (IMethodSymbol)semanticModel.GetSymbol(lambda, cancellationToken);

UseAsyncAwaitRewriter rewriter = UseAsyncAwaitRewriter.Create(methodSymbol);
UseAsyncAwaitRewriter rewriter = UseAsyncAwaitRewriter.Create(methodSymbol, semanticModel, node.SpanStart);

var newBody = (BlockSyntax)rewriter.VisitBlock((BlockSyntax)lambda.Body);

Expand All @@ -120,7 +120,7 @@ private static async Task<Document> RefactorAsync(
{
var methodSymbol = (IMethodSymbol)semanticModel.GetSymbol(anonymousMethod, cancellationToken);

UseAsyncAwaitRewriter rewriter = UseAsyncAwaitRewriter.Create(methodSymbol);
UseAsyncAwaitRewriter rewriter = UseAsyncAwaitRewriter.Create(methodSymbol, semanticModel, node.SpanStart);

var newBody = (BlockSyntax)rewriter.VisitBlock((BlockSyntax)anonymousMethod.Body);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@ private UseAsyncAwaitRewriter(bool keepReturnStatement)

public bool KeepReturnStatement { get; }

public static UseAsyncAwaitRewriter Create(IMethodSymbol methodSymbol)
public static UseAsyncAwaitRewriter Create(IMethodSymbol methodSymbol, SemanticModel semanticModel, int position)
{
ITypeSymbol returnType = methodSymbol.ReturnType.OriginalDefinition;

var keepReturnStatement = false;

if (returnType.EqualsOrInheritsFrom(MetadataNames.System_Threading_Tasks_ValueTask_T)
|| returnType.EqualsOrInheritsFrom(MetadataNames.System_Threading_Tasks_Task_T))
if (returnType is INamedTypeSymbol { Arity: 1 }
&& returnType.IsAwaitable(semanticModel, position))
{
keepReturnStatement = true;
}
Expand Down
66 changes: 41 additions & 25 deletions src/Analyzers/CSharp/Analysis/ConfigureAwaitAnalyzer.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) .NET Foundation and Contributors. Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System.Collections.Immutable;
using System.Linq;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
Expand Down Expand Up @@ -65,7 +66,10 @@ private static void AddCallToConfigureAwait(SyntaxNodeAnalysisContext context)
if (typeSymbol is null)
return;

if (!SymbolUtility.IsAwaitable(typeSymbol))
if (!typeSymbol.IsAwaitable(context.SemanticModel, expression.SpanStart))
return;

if (!IsConfigureAwaitable(typeSymbol, context.SemanticModel, expression.SpanStart))
return;

DiagnosticHelpers.ReportDiagnostic(context, DiagnosticRules.ConfigureAwait, awaitExpression.Expression, "Add");
Expand All @@ -75,39 +79,43 @@ private static void RemoveCallToConfigureAwait(SyntaxNodeAnalysisContext context
{
var awaitExpression = (AwaitExpressionSyntax)context.Node;

// await (expr).ConfigureAwait(false);
// ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ExpressionSyntax expression = awaitExpression.Expression;

// await (expr).ConfigureAwait(false);
// ^^^^^^^^^^^^^^^^^^^^^^
SimpleMemberInvocationExpressionInfo invocationInfo = SyntaxInfo.SimpleMemberInvocationExpressionInfo(expression);

if (!IsConfigureAwait(expression))
if (!IsConfigureAwait(invocationInfo))
return;

ITypeSymbol typeSymbol = context.SemanticModel.GetTypeSymbol(expression, context.CancellationToken);
ITypeSymbol awaitedType = context.SemanticModel.GetTypeSymbol(expression, context.CancellationToken);

if (typeSymbol is null)
if (awaitedType is null)
return;

switch (typeSymbol.MetadataName)
{
case "ConfiguredTaskAwaitable":
case "ConfiguredTaskAwaitable`1":
case "ConfiguredValueTaskAwaitable":
case "ConfiguredValueTaskAwaitable`1":
{
if (typeSymbol.ContainingNamespace.HasMetadataName(MetadataNames.System_Runtime_CompilerServices))
{
DiagnosticHelpers.ReportDiagnostic(
context,
DiagnosticRules.ConfigureAwait,
Location.Create(
awaitExpression.SyntaxTree,
TextSpan.FromBounds(invocationInfo.OperatorToken.SpanStart, expression.Span.End)),
"Remove");
}

break;
}
}
if (!awaitedType.IsAwaitable(context.SemanticModel, expression.SpanStart))
return;

// await (expr).ConfigureAwait(false);
// ^^^^
// This expression may not be awaitable, in which case removing ConfigureAwait is not possible.
ITypeSymbol configuredType = context.SemanticModel.GetTypeSymbol(invocationInfo.Expression, context.CancellationToken);

if (configuredType is null)
return;

if (!configuredType.IsAwaitable(context.SemanticModel, invocationInfo.Expression.SpanStart))
return;

DiagnosticHelpers.ReportDiagnostic(
context,
DiagnosticRules.ConfigureAwait,
Location.Create(
awaitExpression.SyntaxTree,
TextSpan.FromBounds(invocationInfo.OperatorToken.SpanStart, expression.Span.End)),
"Remove");
}

public static bool IsConfigureAwait(ExpressionSyntax expression)
Expand All @@ -124,4 +132,12 @@ private static bool IsConfigureAwait(SimpleMemberInvocationExpressionInfo invoca
&& string.Equals(invocationInfo.NameText, "ConfigureAwait")
&& invocationInfo.Arguments.Count == 1;
}

private static bool IsConfigureAwaitable(ITypeSymbol typeSymbol, SemanticModel semanticModel, int position)
{
return semanticModel.LookupSymbols(position, typeSymbol, "ConfigureAwait", includeReducedExtensionMethods: true)
.OfType<IMethodSymbol>()
.Any(method => method.ReturnType.IsAwaitable(semanticModel, position)
&& method.HasSingleParameter(SpecialType.System_Boolean));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ private static void Analyze(
: context.SemanticModel.GetSymbol(containingMethod, context.CancellationToken)) as IMethodSymbol;

if (methodSymbol?.IsErrorType() == false
&& SymbolUtility.IsAwaitable(methodSymbol.ReturnType))
&& methodSymbol.ReturnType.IsTaskType()
&& methodSymbol.ReturnType.IsAwaitable(context.SemanticModel, context.Node.SpanStart))
{
ReportDiagnostic(context, usingKeyword);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,6 @@ public override void Initialize(AnalysisContext context)

context.RegisterCompilationStartAction(startContext =>
{
INamedTypeSymbol asyncAction = startContext.Compilation.GetTypeByMetadataName("Windows.Foundation.IAsyncAction");

bool shouldCheckWindowsRuntimeTypes = asyncAction is not null;

startContext.RegisterSyntaxNodeAction(
c =>
{
Expand All @@ -50,14 +46,14 @@ public override void Initialize(AnalysisContext context)
DiagnosticRules.AsynchronousMethodNameShouldEndWithAsync,
DiagnosticRules.NonAsynchronousMethodNameShouldNotEndWithAsync))
{
AnalyzeMethodDeclaration(c, shouldCheckWindowsRuntimeTypes);
AnalyzeMethodDeclaration(c);
}
},
SyntaxKind.MethodDeclaration);
});
}

private static void AnalyzeMethodDeclaration(SyntaxNodeAnalysisContext context, bool shouldCheckWindowsRuntimeTypes)
private static void AnalyzeMethodDeclaration(SyntaxNodeAnalysisContext context)
{
var methodDeclaration = (MethodDeclarationSyntax)context.Node;

Expand All @@ -74,7 +70,7 @@ private static void AnalyzeMethodDeclaration(SyntaxNodeAnalysisContext context,
if (!methodSymbol.Name.EndsWith("Async", StringComparison.Ordinal))
return;

if (SymbolUtility.IsAwaitable(methodSymbol.ReturnType, shouldCheckWindowsRuntimeTypes)
if (methodSymbol.ReturnType.IsAwaitable(context.SemanticModel, methodDeclaration.SpanStart)
|| IsAsyncEnumerableLike(methodSymbol.ReturnType.OriginalDefinition))
{
return;
Expand Down Expand Up @@ -105,7 +101,7 @@ private static void AnalyzeMethodDeclaration(SyntaxNodeAnalysisContext context,
if (methodSymbol.ImplementsInterfaceMember(allInterfaces: true))
return;

if (!SymbolUtility.IsAwaitable(methodSymbol.ReturnType, shouldCheckWindowsRuntimeTypes)
if (!methodSymbol.ReturnType.IsAwaitable(context.SemanticModel, methodDeclaration.SpanStart)
&& !methodSymbol.ReturnType.OriginalDefinition.HasMetadataName(in MetadataNames.System_Collections_Generic_IAsyncEnumerable_T))
{
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ void ReportAwaitAndConfigureAwait(AwaitExpressionSyntax awaitExpression)

ITypeSymbol typeSymbol = context.SemanticModel.GetTypeSymbol(expression, context.CancellationToken);

if (typeSymbol?.OriginalDefinition.HasMetadataName(MetadataNames.System_Runtime_CompilerServices_ConfiguredTaskAwaitable_T) == true
if (typeSymbol?.OriginalDefinition.IsAwaitable(context.SemanticModel, expression.SpanStart) == true
&& (expression is InvocationExpressionSyntax invocation))
{
var memberAccess = invocation.Expression as MemberAccessExpressionSyntax;
Expand Down
10 changes: 5 additions & 5 deletions src/Analyzers/CSharp/Analysis/UseAsyncAwaitAnalyzer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ private static void AnalyzeMethodDeclaration(SyntaxNodeAnalysisContext context)

IMethodSymbol methodSymbol = context.SemanticModel.GetDeclaredSymbol(methodDeclaration, context.CancellationToken);

if (!SymbolUtility.IsAwaitable(methodSymbol.ReturnType))
if (!methodSymbol.ReturnType.IsTaskType() || !methodSymbol.ReturnType.IsAwaitable(context.SemanticModel, body.SpanStart))
return;

if (IsFixable(body, context))
Expand All @@ -81,7 +81,7 @@ private static void AnalyzeLocalFunctionStatement(SyntaxNodeAnalysisContext cont

IMethodSymbol methodSymbol = context.SemanticModel.GetDeclaredSymbol(localFunction, context.CancellationToken);

if (!SymbolUtility.IsAwaitable(methodSymbol.ReturnType))
if (!methodSymbol.ReturnType.IsTaskType() || !methodSymbol.ReturnType.IsAwaitable(context.SemanticModel, body.SpanStart))
return;

if (IsFixable(body, context))
Expand All @@ -101,7 +101,7 @@ private static void AnalyzeSimpleLambdaExpression(SyntaxNodeAnalysisContext cont
if (context.SemanticModel.GetSymbol(simpleLambda, context.CancellationToken) is not IMethodSymbol methodSymbol)
return;

if (!SymbolUtility.IsAwaitable(methodSymbol.ReturnType))
if (!methodSymbol.ReturnType.IsTaskType() || !methodSymbol.ReturnType.IsAwaitable(context.SemanticModel, body.SpanStart))
return;

if (IsFixable(body, context))
Expand All @@ -121,7 +121,7 @@ private static void AnalyzeParenthesizedLambdaExpression(SyntaxNodeAnalysisConte
if (context.SemanticModel.GetSymbol(parenthesizedLambda, context.CancellationToken) is not IMethodSymbol methodSymbol)
return;

if (!SymbolUtility.IsAwaitable(methodSymbol.ReturnType))
if (!methodSymbol.ReturnType.IsTaskType() || !methodSymbol.ReturnType.IsAwaitable(context.SemanticModel, body.SpanStart))
return;

if (IsFixable(body, context))
Expand All @@ -143,7 +143,7 @@ private static void AnalyzeAnonymousMethodExpression(SyntaxNodeAnalysisContext c
if (context.SemanticModel.GetSymbol(anonymousMethod, context.CancellationToken) is not IMethodSymbol methodSymbol)
return;

if (!SymbolUtility.IsAwaitable(methodSymbol.ReturnType))
if (!methodSymbol.ReturnType.IsTaskType() || !methodSymbol.ReturnType.IsAwaitable(context.SemanticModel, body.SpanStart))
return;

if (IsFixable(body, context))
Expand Down
8 changes: 4 additions & 4 deletions src/Common/CSharp/Analysis/RemoveAsyncAwaitAnalysis.cs
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ private static bool VerifyTypes(

ITypeSymbol returnType = methodSymbol.ReturnType;

if (returnType?.OriginalDefinition.EqualsOrInheritsFrom(MetadataNames.System_Threading_Tasks_Task_T) != true)
if (returnType?.OriginalDefinition.IsAwaitable(semanticModel, node.SpanStart) != true)
return false;

ITypeSymbol typeArgument = ((INamedTypeSymbol)returnType).TypeArguments.SingleOrDefault(shouldThrow: false);
Expand Down Expand Up @@ -394,7 +394,7 @@ private static bool VerifyTypes(

ITypeSymbol returnType = methodSymbol.ReturnType;

if (returnType?.OriginalDefinition.EqualsOrInheritsFrom(MetadataNames.System_Threading_Tasks_Task_T) != true)
if (returnType?.OriginalDefinition.IsAwaitable(semanticModel, node.SpanStart) != true)
return false;

ITypeSymbol typeArgument = ((INamedTypeSymbol)returnType).TypeArguments.SingleOrDefault(shouldThrow: false);
Expand All @@ -417,15 +417,15 @@ private static bool VerifyAwaitType(AwaitExpressionSyntax awaitExpression, IType
if (expressionTypeSymbol is null)
return false;

if (expressionTypeSymbol.OriginalDefinition.EqualsOrInheritsFrom(MetadataNames.System_Threading_Tasks_Task_T))
if (expressionTypeSymbol.OriginalDefinition.IsAwaitable(semanticModel, expression.SpanStart))
return true;

SimpleMemberInvocationExpressionInfo invocationInfo = SyntaxInfo.SimpleMemberInvocationExpressionInfo(expression);

return invocationInfo.Success
&& invocationInfo.Arguments.Count == 1
&& invocationInfo.NameText == "ConfigureAwait"
&& expressionTypeSymbol.OriginalDefinition.HasMetadataName(MetadataNames.System_Runtime_CompilerServices_ConfiguredTaskAwaitable_T);
&& expressionTypeSymbol.OriginalDefinition.IsAwaitable(semanticModel, expression.SpanStart);
}

private static IMethodSymbol GetMethodSymbol(
Expand Down
2 changes: 2 additions & 0 deletions src/Core/MetadataNames.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,11 @@ internal static class MetadataNames
public static readonly MetadataName System_ReadOnlySpan_T = MetadataName.Parse("System.ReadOnlySpan`1");
public static readonly MetadataName System_Reflection = MetadataName.Parse("System.Reflection");
public static readonly MetadataName System_Runtime_CompilerServices = MetadataName.Parse("System.Runtime.CompilerServices");
public static readonly MetadataName System_Runtime_CompilerServices_AsyncMethodBuilderAttribute = MetadataName.Parse("System.Runtime.CompilerServices.AsyncMethodBuilderAttribute");
public static readonly MetadataName System_Runtime_CompilerServices_CollectionBuilderAttribute = MetadataName.Parse("System.Runtime.CompilerServices.CollectionBuilderAttribute");
public static readonly MetadataName System_Runtime_CompilerServices_ConfiguredTaskAwaitable = MetadataName.Parse("System.Runtime.CompilerServices.ConfiguredTaskAwaitable");
public static readonly MetadataName System_Runtime_CompilerServices_ConfiguredTaskAwaitable_T = MetadataName.Parse("System.Runtime.CompilerServices.ConfiguredTaskAwaitable`1");
public static readonly MetadataName System_Runtime_CompilerServices_INotifyCompletion = MetadataName.Parse("System.Runtime.CompilerServices.INotifyCompletion");
public static readonly MetadataName System_Runtime_InteropServices_LayoutKind = MetadataName.Parse("System.Runtime.InteropServices.LayoutKind");
public static readonly MetadataName System_Runtime_InteropServices_StructLayoutAttribute = MetadataName.Parse("System.Runtime.InteropServices.StructLayoutAttribute");
public static readonly MetadataName System_Runtime_Serialization_DataMemberAttribute = MetadataName.Parse("System.Runtime.Serialization.DataMemberAttribute");
Expand Down
Loading