Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ public LinearCollectionElementMarshallingCodeContext(
_managedSpanIdentifier = managedSpanIdentifier;
_nativeSpanIdentifier = nativeSpanIdentifier;
ParentContext = parentContext;
Direction = ParentContext.Direction;
}

public override (TargetFramework framework, Version version) GetTargetFramework()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,22 @@ private IMarshallingGenerator CreateCustomNativeTypeMarshaller(TypePositionInfo
if (marshallerData.Shape.HasFlag(MarshallerShape.CallerAllocatedBuffer))
marshallingStrategy = new StatelessCallerAllocatedBufferMarshalling(marshallingStrategy, marshallerData.MarshallerType.Syntax, marshallerData.BufferElementType.Syntax, isLinearCollectionMarshalling: false);

if (marshallerData.Shape.HasFlag(MarshallerShape.Free))
FreeStrategy freeStrategy = GetFreeStrategy(info, context);

if (freeStrategy == FreeStrategy.FreeOriginal)
{
marshallingStrategy = new StatelessUnmanagedToManagedOwnershipTracking(marshallingStrategy);
}

if (freeStrategy != FreeStrategy.NoFree && marshallerData.Shape.HasFlag(MarshallerShape.Free))
{
marshallingStrategy = new StatelessFreeMarshalling(marshallingStrategy, marshallerData.MarshallerType.Syntax);
}

if (freeStrategy == FreeStrategy.FreeOriginal)
{
marshallingStrategy = new FreeOwnedOriginalValueMarshalling(marshallingStrategy);
}
}

IMarshallingGenerator marshallingGenerator = new CustomTypeMarshallingGenerator(marshallingStrategy, enableByValueContentsMarshalling: false);
Expand Down Expand Up @@ -311,19 +325,42 @@ private IMarshallingGenerator CreateNativeCollectionMarshaller(
marshallingStrategy = new StatefulCallerAllocatedBufferMarshalling(marshallingStrategy, marshallerTypeSyntax, bufferElementTypeSyntax);
}

FreeStrategy freeStrategy = GetFreeStrategy(info, context);
IElementsMarshallingCollectionSource collectionSource = new StatefulLinearCollectionSource();
IElementsMarshalling elementsMarshalling = CreateElementsMarshalling(marshallerData, elementInfo, elementMarshaller, unmanagedElementType, collectionSource);

marshallingStrategy = new StatefulLinearCollectionMarshalling(marshallingStrategy, marshallerData.Shape, numElementsExpression, elementsMarshalling);
if (freeStrategy == FreeStrategy.FreeOriginal)
{
marshallingStrategy = new StatelessUnmanagedToManagedOwnershipTracking(marshallingStrategy);
}

marshallingStrategy = new StatefulLinearCollectionMarshalling(marshallingStrategy, marshallerData.Shape, numElementsExpression, elementsMarshalling, freeStrategy != FreeStrategy.NoFree);

if (freeStrategy == FreeStrategy.FreeOriginal)
{
marshallingStrategy = new FreeOwnedOriginalValueMarshalling(marshallingStrategy);
}

if (marshallerData.Shape.HasFlag(MarshallerShape.Free))
{
marshallingStrategy = new StatefulFreeMarshalling(marshallingStrategy);
}
}
else
{
marshallingStrategy = new StatelessLinearCollectionSpaceAllocator(marshallerTypeSyntax, nativeType, marshallerData.Shape, numElementsExpression);

FreeStrategy freeStrategy = GetFreeStrategy(info, context);

IElementsMarshallingCollectionSource collectionSource = new StatelessLinearCollectionSource(marshallerTypeSyntax);
if (freeStrategy == FreeStrategy.FreeOriginal)
{
marshallingStrategy = new StatelessUnmanagedToManagedOwnershipTracking(marshallingStrategy);
}

IElementsMarshalling elementsMarshalling = CreateElementsMarshalling(marshallerData, elementInfo, elementMarshaller, unmanagedElementType, collectionSource);

marshallingStrategy = new StatelessLinearCollectionMarshalling(marshallingStrategy, elementsMarshalling, nativeType, marshallerData.Shape);
marshallingStrategy = new StatelessLinearCollectionMarshalling(marshallingStrategy, elementsMarshalling, nativeType, marshallerData.Shape, freeStrategy != FreeStrategy.NoFree);

if (marshallerData.Shape.HasFlag(MarshallerShape.CallerAllocatedBuffer))
{
Expand All @@ -334,8 +371,15 @@ private IMarshallingGenerator CreateNativeCollectionMarshaller(
marshallingStrategy = new StatelessCallerAllocatedBufferMarshalling(marshallingStrategy, marshallerTypeSyntax, bufferElementTypeSyntax, isLinearCollectionMarshalling: true);
}

if (marshallerData.Shape.HasFlag(MarshallerShape.Free))
if (freeStrategy != FreeStrategy.NoFree && marshallerData.Shape.HasFlag(MarshallerShape.Free))
{
marshallingStrategy = new StatelessFreeMarshalling(marshallingStrategy, marshallerTypeSyntax);
}

if (freeStrategy == FreeStrategy.FreeOriginal)
{
marshallingStrategy = new FreeOwnedOriginalValueMarshalling(marshallingStrategy);
}
}

IMarshallingGenerator marshallingGenerator = new CustomTypeMarshallingGenerator(
Expand All @@ -351,6 +395,48 @@ private IMarshallingGenerator CreateNativeCollectionMarshaller(
return marshallingGenerator;
}

private enum FreeStrategy
{
/// <summary>
/// Free the unmanaged value stored in the native identifier.
/// </summary>
FreeNative,
/// <summary>
/// Free the unmanaged value originally passed into the stub.
/// </summary>
FreeOriginal,
/// <summary>
/// Do not free the unmanaged value, we don't own it.
/// </summary>
NoFree
}

private static FreeStrategy GetFreeStrategy(TypePositionInfo info, StubCodeContext context)
{
// When marshalling from managed to unmanaged, we always own the value in the native identifier.
if (context.Direction == MarshalDirection.ManagedToUnmanaged)
{
return FreeStrategy.FreeNative;
}

// When we're in a case where we don't have state across stages, the parent stub context that can track the state
// will only call our Cleanup stage when we own the value in the native identifier.
if (!context.AdditionalTemporaryStateLivesAcrossStages)
{
return FreeStrategy.FreeNative;
}

// In an unmanaged-to-managed stub where a value is passed by 'ref',
// we own the original value once we replace it with the new value we're passing out to the caller.
if (info.RefKind == RefKind.Ref)
{
return FreeStrategy.FreeOriginal;
}

// In an unmanaged-to-managed stub, we don't take ownership of the value when it isn't passed by 'ref'.
return FreeStrategy.NoFree;
}

private static IElementsMarshalling CreateElementsMarshalling(CustomTypeMarshallerData marshallerData, TypePositionInfo elementInfo, IMarshallingGenerator elementMarshaller, TypeSyntax unmanagedElementType, IElementsMarshallingCollectionSource collectionSource)
{
IElementsMarshalling elementsMarshalling;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ internal interface IElementsMarshallingCollectionSource

internal interface IElementsMarshalling
{
StatementSyntax GenerateByValueOutMarshalStatement(TypePositionInfo info, StubCodeContext context);
StatementSyntax GenerateManagedToUnmanagedByValueOutMarshalStatement(TypePositionInfo info, StubCodeContext context);
StatementSyntax GenerateMarshalStatement(TypePositionInfo info, StubCodeContext context);
StatementSyntax GenerateByValueOutUnmarshalStatement(TypePositionInfo info, StubCodeContext context);
StatementSyntax GenerateManagedToUnmanagedByValueOutUnmarshalStatement(TypePositionInfo info, StubCodeContext context);
StatementSyntax GenerateUnmarshalStatement(TypePositionInfo info, StubCodeContext context);
StatementSyntax GenerateElementCleanupStatement(TypePositionInfo info, StubCodeContext context);
}
Expand All @@ -45,7 +45,7 @@ public BlittableElementsMarshalling(TypeSyntax managedElementType, TypeSyntax un
_collectionSource = collectionSource;
}

public StatementSyntax GenerateByValueOutMarshalStatement(TypePositionInfo info, StubCodeContext context)
public StatementSyntax GenerateManagedToUnmanagedByValueOutMarshalStatement(TypePositionInfo info, StubCodeContext context)
{
// If the parameter is marshalled by-value [Out], then we don't marshal the contents of the collection.
// We do clear the span, so that if the invoke target doesn't fill it, we aren't left with undefined content.
Expand Down Expand Up @@ -73,7 +73,7 @@ public StatementSyntax GenerateMarshalStatement(TypePositionInfo info, StubCodeC
Argument(destination)));
}

public StatementSyntax GenerateByValueOutUnmarshalStatement(TypePositionInfo info, StubCodeContext context)
public StatementSyntax GenerateManagedToUnmanagedByValueOutUnmarshalStatement(TypePositionInfo info, StubCodeContext context)
{
ExpressionSyntax source = CastToManagedIfNecessary(_collectionSource.GetUnmanagedValuesDestination(info, context));

Expand Down Expand Up @@ -175,7 +175,7 @@ public NonBlittableElementsMarshalling(
_collectionSource = collectionSource;
}

public StatementSyntax GenerateByValueOutMarshalStatement(TypePositionInfo info, StubCodeContext context)
public StatementSyntax GenerateManagedToUnmanagedByValueOutMarshalStatement(TypePositionInfo info, StubCodeContext context)
{
// If the parameter is marshalled by-value [Out], then we don't marshal the contents of the collection.
// We do clear the span, so that if the invoke target doesn't fill it, we aren't left with undefined content.
Expand Down Expand Up @@ -259,7 +259,7 @@ public StatementSyntax GenerateUnmarshalStatement(TypePositionInfo info, StubCod
StubCodeContext.Stage.Unmarshal));
}

public StatementSyntax GenerateByValueOutUnmarshalStatement(TypePositionInfo info, StubCodeContext context)
public StatementSyntax GenerateManagedToUnmanagedByValueOutUnmarshalStatement(TypePositionInfo info, StubCodeContext context)
{
// Use ManagedSource and NativeDestination spans for by-value marshalling since we're just marshalling back the contents,
// not the array itself.
Expand Down Expand Up @@ -356,7 +356,9 @@ public StatementSyntax GenerateElementCleanupStatement(TypePositionInfo info, St
VariableDeclarator(
Identifier(nativeSpanIdentifier))
.WithInitializer(EqualsValueClause(
_collectionSource.GetUnmanagedValuesDestination(info, context)))))),
context.Direction == MarshalDirection.ManagedToUnmanaged
? _collectionSource.GetUnmanagedValuesDestination(info, context)
: _collectionSource.GetUnmanagedValuesSource(info, context)))))),
contentsCleanupStatements);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ public IMarshallingGenerator Create(
return s_delegate;

case { MarshallingAttributeInfo: SafeHandleMarshallingInfo(_, bool isAbstract) }:
if (!context.AdditionalTemporaryStateLivesAcrossStages)
if (!context.AdditionalTemporaryStateLivesAcrossStages || context.Direction != MarshalDirection.ManagedToUnmanaged)
{
throw new MarshallingNotSupportedException(info, context);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -372,40 +372,36 @@ internal sealed class StatefulLinearCollectionMarshalling : ICustomTypeMarshalli
private readonly MarshallerShape _shape;
private readonly ExpressionSyntax _numElementsExpression;
private readonly IElementsMarshalling _elementsMarshalling;
private readonly bool _cleanupElements;

public StatefulLinearCollectionMarshalling(
ICustomTypeMarshallingStrategy innerMarshaller,
MarshallerShape shape,
ExpressionSyntax numElementsExpression,
IElementsMarshalling elementsMarshalling)
IElementsMarshalling elementsMarshalling,
bool cleanupElements)
{
_innerMarshaller = innerMarshaller;
_shape = shape;
_numElementsExpression = numElementsExpression;
_elementsMarshalling = elementsMarshalling;
_cleanupElements = cleanupElements;
}

public ManagedTypeInfo AsNativeType(TypePositionInfo info) => _innerMarshaller.AsNativeType(info);
public IEnumerable<StatementSyntax> GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context)
{
if (!_cleanupElements)
{
yield break;
}

StatementSyntax elementCleanup = _elementsMarshalling.GenerateElementCleanupStatement(info, context);

if (!elementCleanup.IsKind(SyntaxKind.EmptyStatement))
{
yield return elementCleanup;
}

if (!_shape.HasFlag(MarshallerShape.Free))
yield break;

string marshaller = StatefulValueMarshalling.GetMarshallerIdentifier(info, context);
// <marshaller>.Free();
yield return ExpressionStatement(
InvocationExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
IdentifierName(marshaller),
IdentifierName(ShapeMemberNames.Free)),
ArgumentList()));
}
public IEnumerable<StatementSyntax> GenerateGuaranteedUnmarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateGuaranteedUnmarshalStatements(info, context);

Expand All @@ -419,9 +415,9 @@ public IEnumerable<StatementSyntax> GenerateMarshalStatements(TypePositionInfo i
yield return statement;
}

if (!info.IsByRef && info.ByValueContentsMarshalKind == ByValueContentsMarshalKind.Out)
if (context.Direction == MarshalDirection.ManagedToUnmanaged && !info.IsByRef && info.ByValueContentsMarshalKind == ByValueContentsMarshalKind.Out)
{
yield return _elementsMarshalling.GenerateByValueOutMarshalStatement(info, context);
yield return _elementsMarshalling.GenerateManagedToUnmanagedByValueOutMarshalStatement(info, context);
yield break;
}

Expand All @@ -437,9 +433,10 @@ public IEnumerable<StatementSyntax> GenerateUnmarshalStatements(TypePositionInfo
{
string numElementsIdentifier = MarshallerHelpers.GetNumElementsIdentifier(info, context);

if (!info.IsByRef && info.ByValueContentsMarshalKind.HasFlag(ByValueContentsMarshalKind.Out))
if (context.Direction == MarshalDirection.ManagedToUnmanaged && !info.IsByRef && info.ByValueContentsMarshalKind.HasFlag(ByValueContentsMarshalKind.Out))
{
yield return _elementsMarshalling.GenerateByValueOutUnmarshalStatement(info, context);
yield return _elementsMarshalling.GenerateManagedToUnmanagedByValueOutUnmarshalStatement(info, context);
yield break;
}

if (!_shape.HasFlag(MarshallerShape.ToManaged))
Expand Down Expand Up @@ -469,4 +466,50 @@ public IEnumerable<StatementSyntax> GenerateUnmarshalStatements(TypePositionInfo

public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) => true;
}

/// <summary>
/// Marshaller that enables calling the Free method on a stateful marshaller.
/// </summary>
internal sealed class StatefulFreeMarshalling : ICustomTypeMarshallingStrategy
{
private readonly ICustomTypeMarshallingStrategy _innerMarshaller;

public StatefulFreeMarshalling(ICustomTypeMarshallingStrategy innerMarshaller)
{
_innerMarshaller = innerMarshaller;
}

public ManagedTypeInfo AsNativeType(TypePositionInfo info) => _innerMarshaller.AsNativeType(info);

public IEnumerable<StatementSyntax> GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context)
{
foreach (var statement in _innerMarshaller.GenerateCleanupStatements(info, context))
{
yield return statement;
}

string marshaller = StatefulValueMarshalling.GetMarshallerIdentifier(info, context);
// <marshaller>.Free();
yield return ExpressionStatement(
InvocationExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
IdentifierName(marshaller),
IdentifierName(ShapeMemberNames.Free)),
ArgumentList()));
}
public IEnumerable<StatementSyntax> GenerateGuaranteedUnmarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateGuaranteedUnmarshalStatements(info, context);

public IEnumerable<StatementSyntax> GenerateMarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateMarshalStatements(info, context);

public IEnumerable<StatementSyntax> GenerateNotifyForSuccessfulInvokeStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateNotifyForSuccessfulInvokeStatements(info, context);
public IEnumerable<StatementSyntax> GeneratePinnedMarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GeneratePinnedMarshalStatements(info, context);
public IEnumerable<StatementSyntax> GeneratePinStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GeneratePinStatements(info, context);
public IEnumerable<StatementSyntax> GenerateSetupStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateSetupStatements(info, context);

public IEnumerable<StatementSyntax> GenerateUnmarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateUnmarshalStatements(info, context);

public IEnumerable<StatementSyntax> GenerateUnmarshalCaptureStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateUnmarshalCaptureStatements(info, context);

public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.UsesNativeIdentifier(info, context);
}
}
Loading