Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using System.Collections.Immutable;
using System.IO;
using System.Linq;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
using Microsoft.CodeAnalysis.CSharp;

namespace Microsoft.Interop
{
[Generator]
public class ComClassGenerator : IIncrementalGenerator
{
private sealed record ComClassInfo(string ClassName, ContainingSyntaxContext ContainingSyntaxContext, ContainingSyntax ClassSyntax, SequenceEqualImmutableArray<string> ImplementedInterfacesNames);
public void Initialize(IncrementalGeneratorInitializationContext context)
{
// Get all types with the [GeneratedComClassAttribute] attribute.
var attributedClasses = context.SyntaxProvider
.ForAttributeWithMetadataName(
TypeNames.GeneratedComClassAttribute,
static (node, ct) => node is ClassDeclarationSyntax,
static (context, ct) =>
{
var type = (INamedTypeSymbol)context.TargetSymbol;
var syntax = (ClassDeclarationSyntax)context.TargetNode;
ImmutableArray<string>.Builder names = ImmutableArray.CreateBuilder<string>();
foreach (INamedTypeSymbol iface in type.AllInterfaces)
{
if (iface.GetAttributes().Any(attr => attr.AttributeClass?.ToDisplayString() == TypeNames.GeneratedComInterfaceAttribute))
{
names.Add(iface.ToDisplayString());
}
}
return new ComClassInfo(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to warn / bail if there are no interfaces with GeneratedComInterface?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's still a valid scenario, but I could see us adding a warning for it. I'll file a follow-up issue for that.

type.ToDisplayString(),
new ContainingSyntaxContext(syntax),
new ContainingSyntax(syntax.Modifiers, syntax.Kind(), syntax.Identifier, syntax.TypeParameterList),
new(names.ToImmutable()));
});

var classInfoType = attributedClasses.Select(static (info, ct) => GenerateClassInfoType(info)).SelectNormalized();

var attribute = attributedClasses.Select(static (info, ct) => GenerateClassInfoAttributeOnUserType(info)).SelectNormalized();

context.RegisterSourceOutput(attributedClasses.Zip(classInfoType).Zip(attribute), static (context, classInfo) =>
{
var ((comClassInfo, classInfoType), attribute) = classInfo;
StringWriter writer = new();
writer.WriteLine(classInfoType.ToFullString());
writer.WriteLine();
writer.WriteLine(attribute);
context.AddSource(comClassInfo.ClassName, writer.ToString());
});
}

private const string ClassInfoTypeName = "ComClassInformation";

private static readonly AttributeSyntax s_comExposedClassAttributeTemplate =
Attribute(
GenericName(TypeNames.ComExposedClassAttribute)
.AddTypeArgumentListArguments(
IdentifierName(ClassInfoTypeName)));
private static MemberDeclarationSyntax GenerateClassInfoAttributeOnUserType(ComClassInfo info) =>
info.ContainingSyntaxContext.WrapMemberInContainingSyntaxWithUnsafeModifier(
TypeDeclaration(info.ClassSyntax.TypeKind, info.ClassSyntax.Identifier)
.WithModifiers(info.ClassSyntax.Modifiers)
.WithTypeParameterList(info.ClassSyntax.TypeParameters)
.AddAttributeLists(AttributeList(SingletonSeparatedList(s_comExposedClassAttributeTemplate))));
private static ClassDeclarationSyntax GenerateClassInfoType(ComClassInfo info)
{
const string vtablesField = "s_vtables";
const string vtablesLocal = "vtables";
const string detailsTempLocal = "details";
const string countIdentifier = "count";
var typeDeclaration = ClassDeclaration(ClassInfoTypeName)
.AddModifiers(
Token(SyntaxKind.FileKeyword),
Token(SyntaxKind.SealedKeyword),
Token(SyntaxKind.UnsafeKeyword))
.AddBaseListTypes(SimpleBaseType(ParseTypeName(TypeNames.IComExposedClass)))
.AddMembers(
FieldDeclaration(
VariableDeclaration(
PointerType(
ParseTypeName(TypeNames.System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry)),
SingletonSeparatedList(VariableDeclarator(vtablesField))))
.AddModifiers(
Token(SyntaxKind.PrivateKeyword),
Token(SyntaxKind.StaticKeyword),
Token(SyntaxKind.VolatileKeyword)));
List<StatementSyntax> vtableInitializationBlock = new()
{
// ComInterfaceEntry* vtables = (ComInterfaceEntry*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(<className>), sizeof(ComInterfaceEntry) * <numInterfaces>);
LocalDeclarationStatement(
VariableDeclaration(
PointerType(
ParseTypeName(TypeNames.System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry)),
SingletonSeparatedList(
VariableDeclarator(vtablesLocal)
.WithInitializer(EqualsValueClause(
CastExpression(
PointerType(
ParseTypeName(TypeNames.System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry)),
InvocationExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
ParseTypeName(TypeNames.System_Runtime_CompilerServices_RuntimeHelpers),
IdentifierName("AllocateTypeAssociatedMemory")))
.AddArgumentListArguments(
Argument(TypeOfExpression(ParseTypeName(info.ClassName))),
Argument(
BinaryExpression(
SyntaxKind.MultiplyExpression,
SizeOfExpression(ParseTypeName(TypeNames.System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry)),
LiteralExpression(
SyntaxKind.NumericLiteralExpression,
Literal(info.ImplementedInterfacesNames.Array.Length))))))))))),
// IIUnknownDerivedDetails details;
LocalDeclarationStatement(
VariableDeclaration(
ParseTypeName(TypeNames.IIUnknownDerivedDetails),
SingletonSeparatedList(
VariableDeclarator(detailsTempLocal))))
};
for (int i = 0; i < info.ImplementedInterfacesNames.Array.Length; i++)
{
string ifaceName = info.ImplementedInterfacesNames.Array[i];

// details = StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy.GetIUnknownDerivedDetails(typeof(<ifaceName>).TypeHandle);
vtableInitializationBlock.Add(
ExpressionStatement(
AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
IdentifierName(detailsTempLocal),
InvocationExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
ParseTypeName(TypeNames.StrategyBasedComWrappers),
IdentifierName("DefaultIUnknownInterfaceDetailsStrategy")),
IdentifierName("GetIUnknownDerivedDetails")),
ArgumentList(
SingletonSeparatedList(
Argument(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
TypeOfExpression(ParseName(ifaceName)),
IdentifierName("TypeHandle")))))))));
// vtable[i] = new() { IID = details.Iid, Vtable = details.ManagedVirtualMethodTable };
vtableInitializationBlock.Add(
ExpressionStatement(
AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
ElementAccessExpression(
IdentifierName(vtablesLocal),
BracketedArgumentList(
SingletonSeparatedList(
Argument(
LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(i)))))),
ImplicitObjectCreationExpression(
ArgumentList(),
InitializerExpression(SyntaxKind.ObjectInitializerExpression,
SeparatedList(
new ExpressionSyntax[]
{
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
IdentifierName("IID"),
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
IdentifierName("details"),
IdentifierName("Iid"))),
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
IdentifierName("Vtable"),
CastExpression(
IdentifierName("nint"),
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
IdentifierName("details"),
IdentifierName("ManagedVirtualMethodTable"))))
}))))));
}

// s_vtable = vtable;
vtableInitializationBlock.Add(
ExpressionStatement(
AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
IdentifierName(vtablesField),
IdentifierName(vtablesLocal))));

BlockSyntax getComInterfaceEntriesMethodBody = Block(
// count = <count>;
ExpressionStatement(
AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
IdentifierName(countIdentifier),
LiteralExpression(SyntaxKind.NumericLiteralExpression,
Literal(info.ImplementedInterfacesNames.Array.Length)))),
// if (s_vtable == null)
// { initializer block }
IfStatement(
BinaryExpression(SyntaxKind.EqualsExpression,
IdentifierName(vtablesField),
LiteralExpression(SyntaxKind.NullLiteralExpression)),
Block(vtableInitializationBlock)),
// return s_vtable;
ReturnStatement(IdentifierName(vtablesField)));

typeDeclaration = typeDeclaration.AddMembers(
// public static unsafe ComWrappers.ComInterfaceDispatch* GetComInterfaceEntries(out int count)
// { body }
MethodDeclaration(
PointerType(
ParseTypeName(TypeNames.System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry)),
"GetComInterfaceEntries")
.AddParameterListParameters(
Parameter(Identifier(countIdentifier))
.WithType(PredefinedType(Token(SyntaxKind.IntKeyword)))
.AddModifiers(Token(SyntaxKind.OutKeyword)))
.WithBody(getComInterfaceEntriesMethodBody)
.AddModifiers(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.StaticKeyword)));

return typeDeclaration;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public static class StepNames

public void Initialize(IncrementalGeneratorInitializationContext context)
{
// Get all methods with the [GeneratedComInterface] attribute.
// Get all types with the [GeneratedComInterface] attribute.
var attributedInterfaces = context.SyntaxProvider
.ForAttributeWithMetadataName(
TypeNames.GeneratedComInterfaceAttribute,
Expand All @@ -62,7 +62,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
return new { data.Syntax, data.Symbol, Diagnostic = diagnostic };
});

// Split the methods we want to generate and the ones we don't into two separate groups.
// Split the types we want to generate and the ones we don't into two separate groups.
var interfacesToGenerate = interfacesWithDiagnostics.Where(static data => data.Diagnostic is null);
var invalidTypeDiagnostics = interfacesWithDiagnostics.Where(static data => data.Diagnostic is not null);

Expand Down Expand Up @@ -726,7 +726,7 @@ private static ClassDeclarationSyntax GenerateInterfaceInformation(ComInterfaceC
.WithExpressionBody(
ArrowExpressionClause(
ConditionalExpression(
BinaryExpression(SyntaxKind.EqualsExpression,
BinaryExpression(SyntaxKind.NotEqualsExpression,
IdentifierName(vtableFieldName),
LiteralExpression(SyntaxKind.NullLiteralExpression)),
IdentifierName(vtableFieldName),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,22 @@ public static string MarshalEx(InteropGenerationOptions options)

public const string System_Runtime_InteropServices_ComWrappers_ComInterfaceDispatch = "System.Runtime.InteropServices.ComWrappers.ComInterfaceDispatch";

public const string System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry = "System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry";

public const string StrategyBasedComWrappers = "System.Runtime.InteropServices.Marshalling.StrategyBasedComWrappers";

public const string IIUnknownInterfaceType = "System.Runtime.InteropServices.Marshalling.IIUnknownInterfaceType";
public const string IUnknownDerivedAttribute = "System.Runtime.InteropServices.Marshalling.IUnknownDerivedAttribute";
public const string IIUnknownDerivedDetails = "System.Runtime.InteropServices.Marshalling.IIUnknownDerivedDetails";

public const string ComWrappersUnwrapper = "System.Runtime.InteropServices.Marshalling.ComWrappersUnwrapper";
public const string UnmanagedObjectUnwrapperAttribute = "System.Runtime.InteropServices.Marshalling.UnmanagedObjectUnwrapperAttribute`1";

public const string IUnmanagedObjectUnwrapper = "System.Runtime.InteropServices.Marshalling.IUnmanagedObjectUnwrapper";
public const string UnmanagedObjectUnwrapper = "System.Runtime.InteropServices.Marshalling.UnmanagedObjectUnwrapper";

public const string GeneratedComClassAttribute = "System.Runtime.InteropServices.Marshalling.GeneratedComClassAttribute";
public const string ComExposedClassAttribute = "System.Runtime.InteropServices.Marshalling.ComExposedClassAttribute";
public const string IComExposedClass = "System.Runtime.InteropServices.Marshalling.IComExposedClass";
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

namespace System.Runtime.InteropServices.Marshalling
{
[AttributeUsage(AttributeTargets.Class, Inherited = false)]
public sealed class ComExposedClassAttribute<T> : Attribute, IComExposedDetails
where T : IComExposedClass
{
public unsafe ComWrappers.ComInterfaceEntry* GetComInterfaceEntries(out int count) => T.GetComInterfaceEntries(out count);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ private bool LookUpVTableInfo(RuntimeTypeHandle handle, out IIUnknownCacheStrate
qiHResult = 0;
if (!CacheStrategy.TryGetTableInfo(handle, out result))
{
IUnknownDerivedDetails? details = InterfaceDetailsStrategy.GetIUnknownDerivedDetails(handle);
IIUnknownDerivedDetails? details = InterfaceDetailsStrategy.GetIUnknownDerivedDetails(handle);
if (details is null)
{
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ internal sealed unsafe class DefaultCaching : IIUnknownCacheStrategy
// [TODO] Implement some smart/thread-safe caching
private readonly Dictionary<RuntimeTypeHandle, IIUnknownCacheStrategy.TableInfo> _cache = new();

IIUnknownCacheStrategy.TableInfo IIUnknownCacheStrategy.ConstructTableInfo(RuntimeTypeHandle handle, IUnknownDerivedDetails details, void* ptr)
IIUnknownCacheStrategy.TableInfo IIUnknownCacheStrategy.ConstructTableInfo(RuntimeTypeHandle handle, IIUnknownDerivedDetails details, void* ptr)
{
var obj = (void***)ptr;
return new IIUnknownCacheStrategy.TableInfo()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,14 @@ internal sealed class DefaultIUnknownInterfaceDetailsStrategy : IIUnknownInterfa
{
public static readonly IIUnknownInterfaceDetailsStrategy Instance = new DefaultIUnknownInterfaceDetailsStrategy();

public IUnknownDerivedDetails? GetIUnknownDerivedDetails(RuntimeTypeHandle type)
public IComExposedDetails? GetComExposedTypeDetails(RuntimeTypeHandle type)
{
return IUnknownDerivedDetails.GetFromAttribute(type);
return IComExposedDetails.GetFromAttribute(type);
}

public IIUnknownDerivedDetails? GetIUnknownDerivedDetails(RuntimeTypeHandle type)
{
return IIUnknownDerivedDetails.GetFromAttribute(type);
}
}
}
Loading