Skip to content

Commit 94c989d

Browse files
authored
[ComInterfaceGenerator] Warn if StringMarshalling doesn't match base and warn if base interface cannot be generated (#86467)
1 parent 164ba0c commit 94c989d

29 files changed

+1006
-178
lines changed
Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
// Licensed to the .NET Foundation under one or more agreements.
22
// The .NET Foundation licenses this file to you under the MIT license.
33

4+
using System;
45
using System.Collections.Generic;
56
using System.Collections.Immutable;
7+
using System.Diagnostics;
68
using System.Threading;
9+
using Microsoft.CodeAnalysis;
10+
using Microsoft.CodeAnalysis.CSharp.Syntax;
711

812
namespace Microsoft.Interop
913
{
@@ -12,44 +16,67 @@ internal sealed record ComInterfaceContext(ComInterfaceInfo Info, ComInterfaceCo
1216
/// <summary>
1317
/// Takes a list of ComInterfaceInfo, and creates a list of ComInterfaceContext.
1418
/// </summary>
15-
public static ImmutableArray<ComInterfaceContext> GetContexts(ImmutableArray<ComInterfaceInfo> data, CancellationToken _)
19+
public static ImmutableArray<(ComInterfaceContext? Context, Diagnostic? Diagnostic)> GetContexts(ImmutableArray<ComInterfaceInfo> data, CancellationToken _)
1620
{
17-
Dictionary<string, ComInterfaceInfo> symbolToInterfaceInfoMap = new();
18-
var accumulator = ImmutableArray.CreateBuilder<ComInterfaceContext>(data.Length);
21+
Dictionary<string, ComInterfaceInfo> nameToInterfaceInfoMap = new();
22+
var accumulator = ImmutableArray.CreateBuilder<(ComInterfaceContext? Context, Diagnostic? Diagnostic)>(data.Length);
1923
foreach (var iface in data)
2024
{
21-
symbolToInterfaceInfoMap.Add(iface.ThisInterfaceKey, iface);
25+
nameToInterfaceInfoMap.Add(iface.ThisInterfaceKey, iface);
2226
}
23-
Dictionary<string, ComInterfaceContext> symbolToContextMap = new();
27+
Dictionary<string, (ComInterfaceContext? Context, Diagnostic? Diagnostic)> nameToContextCache = new();
2428

2529
foreach (var iface in data)
2630
{
2731
accumulator.Add(AddContext(iface));
2832
}
2933
return accumulator.MoveToImmutable();
3034

31-
ComInterfaceContext AddContext(ComInterfaceInfo iface)
35+
(ComInterfaceContext? Context, Diagnostic? Diagnostic) AddContext(ComInterfaceInfo iface)
3236
{
33-
if (symbolToContextMap.TryGetValue(iface.ThisInterfaceKey, out var cachedValue))
37+
if (nameToContextCache.TryGetValue(iface.ThisInterfaceKey, out var cachedValue))
3438
{
3539
return cachedValue;
3640
}
3741

3842
if (iface.BaseInterfaceKey is null)
3943
{
4044
var baselessCtx = new ComInterfaceContext(iface, null);
41-
symbolToContextMap[iface.ThisInterfaceKey] = baselessCtx;
42-
return baselessCtx;
45+
nameToContextCache[iface.ThisInterfaceKey] = (baselessCtx, null);
46+
return (baselessCtx, null);
4347
}
4448

45-
if (!symbolToContextMap.TryGetValue(iface.BaseInterfaceKey, out var baseContext))
49+
if (
50+
// Cached base info has a diagnostic - failure
51+
(nameToContextCache.TryGetValue(iface.BaseInterfaceKey, out var basePair) && basePair.Diagnostic is not null)
52+
// Cannot find base ComInterfaceInfo - failure (failed ComInterfaceInfo creation)
53+
|| !nameToInterfaceInfoMap.TryGetValue(iface.BaseInterfaceKey, out var baseInfo)
54+
// Newly calculated base context pair has a diagnostic - failure
55+
|| (AddContext(baseInfo) is { } baseReturnPair && baseReturnPair.Diagnostic is not null))
4656
{
47-
baseContext = AddContext(symbolToInterfaceInfoMap[iface.BaseInterfaceKey]);
57+
// The base has failed generation at some point, so this interface cannot be generated
58+
(ComInterfaceContext, Diagnostic?) diagnosticPair = (null,
59+
Diagnostic.Create(
60+
GeneratorDiagnostics.BaseInterfaceIsNotGenerated,
61+
iface.DiagnosticLocation.AsLocation(), iface.ThisInterfaceKey, iface.BaseInterfaceKey));
62+
nameToContextCache[iface.ThisInterfaceKey] = diagnosticPair;
63+
return diagnosticPair;
4864
}
65+
var baseContext = basePair.Context ?? baseReturnPair.Context;
66+
Debug.Assert(baseContext != null);
4967
var ctx = new ComInterfaceContext(iface, baseContext);
50-
symbolToContextMap[iface.ThisInterfaceKey] = ctx;
51-
return ctx;
68+
(ComInterfaceContext, Diagnostic?) contextPair = (ctx, null);
69+
nameToContextCache[iface.ThisInterfaceKey] = contextPair;
70+
return contextPair;
5271
}
5372
}
73+
74+
internal ComInterfaceContext GetTopLevelBase()
75+
{
76+
var currBase = Base;
77+
while (currBase is not null)
78+
currBase = currBase.Base;
79+
return currBase;
80+
}
5481
}
5582
}

src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33

44
using System;
5+
using System.Collections.Generic;
56
using System.Collections.Immutable;
67
using System.IO;
78
using System.Linq;
@@ -52,13 +53,21 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
5253

5354
var interfaceSymbolsWithoutDiagnostics = interfaceSymbolAndDiagnostics
5455
.Where(data => data.Diagnostic is null)
55-
.Select((data, ct) =>
56-
(data.InterfaceInfo, data.Symbol));
56+
.Select((data, ct) => (data.InterfaceInfo, data.Symbol));
5757

58-
var interfaceContexts = interfaceSymbolsWithoutDiagnostics
58+
var interfaceContextsAndDiagnostics = interfaceSymbolsWithoutDiagnostics
5959
.Select((data, ct) => data.InterfaceInfo!)
6060
.Collect()
6161
.SelectMany(ComInterfaceContext.GetContexts);
62+
context.RegisterDiagnostics(interfaceContextsAndDiagnostics.Select((data, ct) => data.Diagnostic));
63+
var interfaceContexts = interfaceContextsAndDiagnostics
64+
.Where(data => data.Context is not null)
65+
.Select((data, ct) => data.Context!);
66+
// Filter down interface symbols to remove those with diagnostics from GetContexts
67+
interfaceSymbolsWithoutDiagnostics = interfaceSymbolsWithoutDiagnostics
68+
.Zip(interfaceContextsAndDiagnostics)
69+
.Where(data => data.Right.Diagnostic is null)
70+
.Select((data, ct) => data.Left);
6271

6372
var comMethodsAndSymbolsAndDiagnostics = interfaceSymbolsWithoutDiagnostics.Select(ComMethodInfo.GetMethodsFromInterface);
6473
context.RegisterDiagnostics(comMethodsAndSymbolsAndDiagnostics.SelectMany(static (methodList, ct) => methodList.Select(m => m.Diagnostic)));
@@ -77,6 +86,10 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
7786
.Zip(methodInfosGroupedByInterface)
7887
.Collect()
7988
.SelectMany(static (data, ct) =>
89+
{
90+
return data.GroupBy(data => data.Left.GetTopLevelBase());
91+
})
92+
.SelectMany(static (data, ct) =>
8093
{
8194
return ComMethodContext.CalculateAllMethods(data, ct);
8295
});
@@ -239,16 +252,6 @@ private static IncrementalMethodStubGenerationContext CalculateStubInformation(M
239252
}
240253
}
241254

242-
AttributeData? generatedComAttribute = null;
243-
foreach (var attr in symbol.ContainingType.GetAttributes())
244-
{
245-
if (generatedComAttribute is null
246-
&& attr.AttributeClass?.ToDisplayString() == TypeNames.GeneratedComInterfaceAttribute)
247-
{
248-
generatedComAttribute = attr;
249-
}
250-
}
251-
252255
var generatorDiagnostics = new GeneratorDiagnostics();
253256

254257
if (lcidConversionAttr is not null)
@@ -257,12 +260,8 @@ private static IncrementalMethodStubGenerationContext CalculateStubInformation(M
257260
generatorDiagnostics.ReportConfigurationNotSupported(lcidConversionAttr, nameof(TypeNames.LCIDConversionAttribute));
258261
}
259262

260-
var generatedComInterfaceAttributeData = new InteropAttributeCompilationData();
261-
if (generatedComAttribute is not null)
262-
{
263-
var args = generatedComAttribute.NamedArguments.ToImmutableDictionary();
264-
generatedComInterfaceAttributeData = generatedComInterfaceAttributeData.WithValuesFromNamedArguments(args);
265-
}
263+
GeneratedComInterfaceCompilationData.TryGetGeneratedComInterfaceAttributeFromInterface(symbol.ContainingType, out var generatedComAttribute);
264+
var generatedComInterfaceAttributeData = GeneratedComInterfaceCompilationData.GetDataFromAttribute(generatedComAttribute);
266265
// Create the stub.
267266
var signatureContext = SignatureContext.Create(
268267
symbol,

src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceInfo.cs

Lines changed: 60 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33

44
using System;
5+
using System.Diagnostics;
56
using System.Diagnostics.CodeAnalysis;
67
using System.Linq;
78
using Microsoft.CodeAnalysis;
@@ -20,7 +21,8 @@ internal sealed record ComInterfaceInfo(
2021
InterfaceDeclarationSyntax Declaration,
2122
ContainingSyntaxContext TypeDefinitionContext,
2223
ContainingSyntax ContainingSyntax,
23-
Guid InterfaceId)
24+
Guid InterfaceId,
25+
LocationInfo DiagnosticLocation)
2426
{
2527
public static (ComInterfaceInfo? Info, Diagnostic? Diagnostic) From(INamedTypeSymbol symbol, InterfaceDeclarationSyntax syntax)
2628
{
@@ -58,14 +60,63 @@ public static (ComInterfaceInfo? Info, Diagnostic? Diagnostic) From(INamedTypeSy
5860
if (!TryGetBaseComInterface(symbol, syntax, out INamedTypeSymbol? baseSymbol, out Diagnostic? baseDiagnostic))
5961
return (null, baseDiagnostic);
6062

61-
return (new ComInterfaceInfo(
62-
ManagedTypeInfo.CreateTypeInfoForTypeSymbol(symbol),
63-
symbol.ToDisplayString(),
64-
baseSymbol?.ToDisplayString(),
65-
syntax,
66-
new ContainingSyntaxContext(syntax),
67-
new ContainingSyntax(syntax.Modifiers, syntax.Kind(), syntax.Identifier, syntax.TypeParameterList),
68-
guid ?? Guid.Empty), null);
63+
if (!StringMarshallingIsValid(symbol, syntax, baseSymbol, out Diagnostic? stringMarshallingDiagnostic))
64+
return (null, stringMarshallingDiagnostic);
65+
66+
return (
67+
new ComInterfaceInfo(
68+
ManagedTypeInfo.CreateTypeInfoForTypeSymbol(symbol),
69+
symbol.ToDisplayString(),
70+
baseSymbol?.ToDisplayString(),
71+
syntax,
72+
new ContainingSyntaxContext(syntax),
73+
new ContainingSyntax(syntax.Modifiers, syntax.Kind(), syntax.Identifier, syntax.TypeParameterList),
74+
guid ?? Guid.Empty,
75+
LocationInfo.From(symbol)),
76+
null);
77+
}
78+
79+
private static bool StringMarshallingIsValid(INamedTypeSymbol symbol, InterfaceDeclarationSyntax syntax, INamedTypeSymbol? baseSymbol, [NotNullWhen(false)] out Diagnostic? stringMarshallingDiagnostic)
80+
{
81+
var attrInfo = GeneratedComInterfaceData.From(GeneratedComInterfaceCompilationData.GetAttributeDataFromInterfaceSymbol(symbol));
82+
if (attrInfo.IsUserDefined.HasFlag(InteropAttributeMember.StringMarshalling) || attrInfo.IsUserDefined.HasFlag(InteropAttributeMember.StringMarshallingCustomType))
83+
{
84+
if (attrInfo.StringMarshalling is StringMarshalling.Custom && attrInfo.StringMarshallingCustomType is null)
85+
{
86+
stringMarshallingDiagnostic = Diagnostic.Create(
87+
GeneratorDiagnostics.InvalidStringMarshallingConfigurationOnInterface,
88+
syntax.Identifier.GetLocation(),
89+
symbol.ToDisplayString(),
90+
SR.InvalidStringMarshallingConfigurationMissingCustomType);
91+
return false;
92+
}
93+
if (attrInfo.StringMarshalling is not StringMarshalling.Custom && attrInfo.StringMarshallingCustomType is not null)
94+
{
95+
stringMarshallingDiagnostic = Diagnostic.Create(
96+
GeneratorDiagnostics.InvalidStringMarshallingConfigurationOnInterface,
97+
syntax.Identifier.GetLocation(),
98+
symbol.ToDisplayString(),
99+
SR.InvalidStringMarshallingConfigurationNotCustom);
100+
return false;
101+
}
102+
}
103+
if (baseSymbol is not null)
104+
{
105+
var baseAttrInfo = GeneratedComInterfaceData.From(GeneratedComInterfaceCompilationData.GetAttributeDataFromInterfaceSymbol(baseSymbol));
106+
// The base can be undefined string marshalling
107+
if ((baseAttrInfo.IsUserDefined.HasFlag(InteropAttributeMember.StringMarshalling) || baseAttrInfo.IsUserDefined.HasFlag(InteropAttributeMember.StringMarshallingCustomType))
108+
&& baseAttrInfo != attrInfo)
109+
{
110+
stringMarshallingDiagnostic = Diagnostic.Create(
111+
GeneratorDiagnostics.InvalidStringMarshallingMismatchBetweenBaseAndDerived,
112+
syntax.Identifier.GetLocation(),
113+
symbol.ToDisplayString(),
114+
SR.GeneratedComInterfaceStringMarshallingMustMatchBase);
115+
return false;
116+
}
117+
}
118+
stringMarshallingDiagnostic = null;
119+
return true;
69120
}
70121

71122
/// <summary>

src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodContext.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ private MethodDeclarationSyntax CreateUnreachableExceptionStub()
109109
.WithAttributeLists(List<AttributeListSyntax>())
110110
.WithExplicitInterfaceSpecifier(ExplicitInterfaceSpecifier(
111111
ParseName(OriginalDeclaringInterface.Info.Type.FullTypeName)))
112+
.WithParameterList(ParameterList(SeparatedList(GenerationContext.SignatureContext.StubParameters)))
112113
.WithExpressionBody(ArrowExpressionClause(
113114
ThrowExpression(
114115
ObjectCreationExpression(
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
using System;
5+
using System.Collections.Generic;
6+
using System.Collections.Immutable;
7+
using System.Diagnostics;
8+
using System.Diagnostics.CodeAnalysis;
9+
using System.Text;
10+
using Microsoft.CodeAnalysis;
11+
12+
namespace Microsoft.Interop
13+
{
14+
/// <summary>
15+
/// Contains the data related to a GeneratedComInterfaceAttribute, without references to Roslyn symbols.
16+
/// See <seealso cref="GeneratedComInterfaceCompilationData"/> for a type with a reference to the StringMarshallingCustomType
17+
/// </summary>
18+
internal sealed record GeneratedComInterfaceData : InteropAttributeData
19+
{
20+
public static GeneratedComInterfaceData From(GeneratedComInterfaceCompilationData generatedComInterfaceAttr)
21+
=> new GeneratedComInterfaceData() with
22+
{
23+
IsUserDefined = generatedComInterfaceAttr.IsUserDefined,
24+
SetLastError = generatedComInterfaceAttr.SetLastError,
25+
StringMarshalling = generatedComInterfaceAttr.StringMarshalling,
26+
StringMarshallingCustomType = generatedComInterfaceAttr.StringMarshallingCustomType is not null
27+
? ManagedTypeInfo.CreateTypeInfoForTypeSymbol(generatedComInterfaceAttr.StringMarshallingCustomType)
28+
: null
29+
};
30+
}
31+
32+
/// <summary>
33+
/// Contains the data related to a GeneratedComInterfaceAttribute, with references to Roslyn symbols.
34+
/// Use <seealso cref="GeneratedComInterfaceData"/> instead when using for incremental compilation state to avoid keeping a compilation alive
35+
/// </summary>
36+
internal sealed record GeneratedComInterfaceCompilationData : InteropAttributeCompilationData
37+
{
38+
public static bool TryGetGeneratedComInterfaceAttributeFromInterface(INamedTypeSymbol interfaceSymbol, [NotNullWhen(true)] out AttributeData? generatedComInterfaceAttribute)
39+
{
40+
generatedComInterfaceAttribute = null;
41+
foreach (var attr in interfaceSymbol.GetAttributes())
42+
{
43+
if (generatedComInterfaceAttribute is null
44+
&& attr.AttributeClass?.ToDisplayString() == TypeNames.GeneratedComInterfaceAttribute)
45+
{
46+
generatedComInterfaceAttribute = attr;
47+
}
48+
}
49+
return generatedComInterfaceAttribute is not null;
50+
}
51+
52+
public static GeneratedComInterfaceCompilationData GetAttributeDataFromInterfaceSymbol(INamedTypeSymbol interfaceSymbol)
53+
{
54+
bool found = TryGetGeneratedComInterfaceAttributeFromInterface(interfaceSymbol, out var attr);
55+
Debug.Assert(found);
56+
return GetDataFromAttribute(attr);
57+
}
58+
59+
public static GeneratedComInterfaceCompilationData GetDataFromAttribute(AttributeData attr)
60+
{
61+
Debug.Assert(attr.AttributeClass.ToDisplayString() == TypeNames.GeneratedComInterfaceAttribute);
62+
var generatedComInterfaceAttributeData = new GeneratedComInterfaceCompilationData();
63+
var args = attr.NamedArguments.ToImmutableDictionary();
64+
generatedComInterfaceAttributeData = generatedComInterfaceAttributeData.WithValuesFromNamedArguments(args);
65+
return generatedComInterfaceAttributeData;
66+
}
67+
}
68+
}

0 commit comments

Comments
 (0)