Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 29 additions & 14 deletions src/Compilers/CSharp/Portable/CodeGen/CodeGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ internal sealed partial class CodeGenerator
// There are scenarios where rvalues need to be passed to ref/in parameters
// in such cases the values must be spilled into temps and retained for the entirety of
// the most encompassing expression.
// If a ref to the temp could escape, it is retained for the most encompassing block.
private ArrayBuilder<LocalDefinition> _expressionTemps;
private ArrayBuilder<LocalDefinition> _blockTemps;
private bool _tempRefsMightEscape;

// not 0 when in a protected region with a handler.
private int _tryNestingLevel;
Expand Down Expand Up @@ -504,35 +507,47 @@ private TextSpan EmitSequencePoint(SyntaxTree syntaxTree, TextSpan span)

private void AddExpressionTemp(LocalDefinition temp)
{
// in some cases like stack locals, there is no slot allocated.
if (temp == null)
if (_tempRefsMightEscape)
{
return;
AddBlockTemp(temp);
}
else
{
AddTemp(ref _expressionTemps, temp);
}
}

private void ReleaseExpressionTemps() => ReleaseTemps(_expressionTemps);

ArrayBuilder<LocalDefinition> exprTemps = _expressionTemps;
if (exprTemps == null)
private void AddBlockTemp(LocalDefinition temp) => AddTemp(ref _blockTemps, temp);

private void ReleaseBlockTemps() => ReleaseTemps(_blockTemps);

private static void AddTemp(ref ArrayBuilder<LocalDefinition> temps, LocalDefinition temp)
{
if (temp == null)
{
exprTemps = ArrayBuilder<LocalDefinition>.GetInstance();
_expressionTemps = exprTemps;
return;
}

Debug.Assert(!exprTemps.Contains(temp));
exprTemps.Add(temp);
temps ??= ArrayBuilder<LocalDefinition>.GetInstance();

Debug.Assert(!temps.Contains(temp));
temps.Add(temp);
}

private void ReleaseExpressionTemps()
private void ReleaseTemps(ArrayBuilder<LocalDefinition> temps)
{
if (_expressionTemps?.Count > 0)
if (temps?.Count > 0)
{
// release in reverse order to keep same temps on top of the temp stack if possible
for (int i = _expressionTemps.Count - 1; i >= 0; i--)
for (int i = temps.Count - 1; i >= 0; i--)
{
var temp = _expressionTemps[i];
var temp = temps[i];
FreeTemp(temp);
}

_expressionTemps.Clear();
temps.Clear();
}
}
}
Expand Down
158 changes: 158 additions & 0 deletions src/Compilers/CSharp/Portable/CodeGen/CodeGenerator_RefSafety.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Collections.Immutable;
using System.Diagnostics;
using Microsoft.CodeAnalysis.CSharp.Symbols;

namespace Microsoft.CodeAnalysis.CSharp.CodeGen;

internal partial class CodeGenerator
{
private static bool MightEscapeTemporaryRefs(BoundCall node, bool used, AddressKind? receiverAddressKind)
{
return MightEscapeTemporaryRefs(
used: used,
returnType: node.Type,
returnRefKind: node.Method.RefKind,
receiverType: !node.Method.RequiresInstanceReceiver ? null : node.ReceiverOpt?.Type,
receiverScope: node.Method.TryGetThisParameter(out var thisParameter) ? thisParameter?.EffectiveScope : null,
receiverAddressKind: receiverAddressKind,
isReceiverReadOnly: node.Method.IsEffectivelyReadOnly,
parameters: node.Method.Parameters,
arguments: node.Arguments,
argsToParamsOpt: node.ArgsToParamsOpt,
expanded: node.Expanded);
}

private static bool MightEscapeTemporaryRefs(BoundObjectCreationExpression node, bool used)
{
return MightEscapeTemporaryRefs(
used: used,
returnType: node.Type,
returnRefKind: RefKind.None,
receiverType: null,
receiverScope: null,
receiverAddressKind: null,
isReceiverReadOnly: false,
parameters: node.Constructor.Parameters,
arguments: node.Arguments,
argsToParamsOpt: node.ArgsToParamsOpt,
expanded: node.Expanded);
}

private static bool MightEscapeTemporaryRefs(BoundFunctionPointerInvocation node, bool used)
{
FunctionPointerMethodSymbol method = node.FunctionPointer.Signature;
return MightEscapeTemporaryRefs(
used: used,
returnType: node.Type,
returnRefKind: method.RefKind,
receiverType: null,
receiverScope: null,
receiverAddressKind: null,
isReceiverReadOnly: false,
parameters: method.Parameters,
arguments: node.Arguments,
argsToParamsOpt: default,
expanded: false);
}

private static bool MightEscapeTemporaryRefs(
bool used,
TypeSymbol returnType,
RefKind returnRefKind,
TypeSymbol? receiverType,
ScopedKind? receiverScope,
AddressKind? receiverAddressKind,
bool isReceiverReadOnly,
ImmutableArray<ParameterSymbol> parameters,
ImmutableArray<BoundExpression> arguments,
ImmutableArray<int> argsToParamsOpt,
bool expanded)
{
Debug.Assert(receiverAddressKind is null || receiverType is not null);

int writableRefs = 0;
int readonlyRefs = 0;

if (used && (returnRefKind != RefKind.None || returnType.IsRefLikeOrAllowsRefLikeType()))
{
writableRefs++;
}

if (receiverType is not null)
{
receiverScope ??= ScopedKind.None;
if (receiverAddressKind is { } a && !IsAnyReadOnly(a) && receiverScope == ScopedKind.None)
{
writableRefs++;
}
else if (receiverType.IsRefLikeOrAllowsRefLikeType() && receiverScope != ScopedKind.ScopedValue)
{
if (isReceiverReadOnly || receiverType.IsReadOnly)
{
readonlyRefs++;
}
else
{
writableRefs++;
}
}
else if (receiverAddressKind != null && receiverScope == ScopedKind.None)
{
readonlyRefs++;
}
}

if (shouldReturnTrue(writableRefs, readonlyRefs))
{
return true;
}

for (var arg = 0; arg < arguments.Length; arg++)
{
var parameter = Binder.GetCorrespondingParameter(
arg,
parameters,
argsToParamsOpt,
expanded);

if (parameter is not null)
{
if (parameter.RefKind.IsWritableReference() && parameter.EffectiveScope == ScopedKind.None)
{
writableRefs++;
}
else if (parameter.Type.IsRefLikeOrAllowsRefLikeType() && parameter.EffectiveScope != ScopedKind.ScopedValue)
{
if (parameter.Type.IsReadOnly || !parameter.RefKind.IsWritableReference())
{
readonlyRefs++;
}
else
{
writableRefs++;
}
}
else if (parameter.RefKind != RefKind.None && parameter.EffectiveScope == ScopedKind.None)
{
readonlyRefs++;
}
}

if (shouldReturnTrue(writableRefs, readonlyRefs))
{
return true;
}
}

return false;

static bool shouldReturnTrue(int writableRefs, int readonlyRefs)
{
return writableRefs > 0 && (writableRefs + readonlyRefs) > 1;
}
}
}
Loading