Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
42d475b
Initial plan
Copilot Oct 26, 2025
8218cff
Initial exploration of cross-platform intrinsics analyzer
Copilot Oct 26, 2025
0b15963
Add cross-platform intrinsics method support to analyzer
Copilot Oct 26, 2025
19e8619
Add tests and fix fixer for cross-platform intrinsics methods
Copilot Oct 26, 2025
8a7411e
Extend analyzer to cover AndNot, Negate, Round, Truncate, and Conditi…
Copilot Oct 27, 2025
80647b2
Add FusedMultiplyAdd cross-platform intrinsics support
Copilot Oct 27, 2025
2bb6ba8
Address review feedback for cross-platform intrinsics analyzer
Copilot Oct 27, 2025
789a4f6
First step: Add semantic fixes for x86 AndNot (parameter swap) and Ma…
Copilot Oct 27, 2025
c93e6de
Step 2: Add AdvSimd method support (AbsScalar, AddSaturate, BitwiseSe…
Copilot Oct 27, 2025
7eee078
Step 3: Add AdvSimd Round variants (RoundToNearest, RoundToNegativeIn…
Copilot Oct 27, 2025
ff6fb2b
Step 4: Add similar method coverage for x86/x64 and WASM platforms
Copilot Oct 27, 2025
2c9be9e
Step 5: Add Load and Store API support across all platforms
Copilot Oct 27, 2025
552d3ed
Implement fixer for Load and Store methods
Copilot Oct 27, 2025
4b0b04e
Fix failing tests by skipping MaxNative/MinNative tests (NET 9+ APIs …
Copilot Oct 27, 2025
f63c65d
Co-authored-by: stephentoub <[email protected]
Copilot Oct 27, 2025
d24289e
Address code review feedback: consolidate vector type checks, add IsV…
Copilot Oct 27, 2025
e2c288e
Add missing IsVectorType helper method implementation
Copilot Oct 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -75,5 +75,155 @@ protected override SyntaxNode ReplaceWithUnaryOperator(SyntaxNode currentNode, S

return generator.Parenthesize(replacementExpression);
}

protected override SyntaxNode ReplaceWithUnaryMethod(SyntaxNode currentNode, SyntaxGenerator generator, string methodName)
{
if (currentNode is not InvocationExpressionSyntax invocationExpression)
{
Debug.Fail($"Found unexpected node kind: {currentNode.RawKind}");
return currentNode;
}

SeparatedSyntaxList<ArgumentSyntax> arguments = invocationExpression.ArgumentList.Arguments;

if (arguments.Count != 1)
{
Debug.Fail($"Found unexpected number of arguments for unary method replacement: {arguments.Count}");
return currentNode;
}

// Determine the vector type name from the return type if available
var vectorTypeName = DetermineVectorTypeName(invocationExpression);

// Create the cross-platform method call: VectorXXX.MethodName(arg)
// The type parameter will be inferred from the argument
var vectorTypeIdentifier = generator.IdentifierName(vectorTypeName);
var replacementExpression = generator.InvocationExpression(
generator.MemberAccessExpression(vectorTypeIdentifier, methodName),
arguments[0].Expression);

return generator.Parenthesize(replacementExpression);
}

protected override SyntaxNode ReplaceWithBinaryMethod(SyntaxNode currentNode, SyntaxGenerator generator, string methodName)
{
if (currentNode is not InvocationExpressionSyntax invocationExpression)
{
Debug.Fail($"Found unexpected node kind: {currentNode.RawKind}");
return currentNode;
}

SeparatedSyntaxList<ArgumentSyntax> arguments = invocationExpression.ArgumentList.Arguments;

if (arguments.Count != 2)
{
Debug.Fail($"Found unexpected number of arguments for binary method replacement: {arguments.Count}");
return currentNode;
}

// Determine the vector type name from the return type if available
var vectorTypeName = DetermineVectorTypeName(invocationExpression);

// Create the cross-platform method call: VectorXXX.MethodName(arg1, arg2)
// The type parameter will be inferred from the arguments
var vectorTypeIdentifier = generator.IdentifierName(vectorTypeName);
var replacementExpression = generator.InvocationExpression(
generator.MemberAccessExpression(vectorTypeIdentifier, methodName),
arguments[0].Expression,
arguments[1].Expression);

return generator.Parenthesize(replacementExpression);
}

private static string DetermineVectorTypeName(SyntaxNode node)
{
// For method signatures like "Vector256<float> M(Vector256<float> x)",
// we need to find the return type of the method containing this invocation

// Walk up to find the method declaration
var current = node;
while (current != null)
{
if (current is MethodDeclarationSyntax methodDecl)
{
// Check the return type
var returnType = methodDecl.ReturnType;
if (returnType is GenericNameSyntax genericReturn &&
(genericReturn.Identifier.Text == "Vector64" ||
genericReturn.Identifier.Text == "Vector128" ||
genericReturn.Identifier.Text == "Vector256" ||
genericReturn.Identifier.Text == "Vector512"))
{
return genericReturn.Identifier.Text;
}
}
current = current.Parent;
}

// Also check the invocation itself for argument types
// This handles cases where the invocation is in an expression-bodied member
if (node is InvocationExpressionSyntax invocation)
{
foreach (var arg in invocation.ArgumentList.Arguments)
{
var vectorType = FindVectorTypeInExpression(arg.Expression);
if (vectorType != null)
{
return vectorType;
}
}
}

// Default to Vector128 if we can't determine it
return "Vector128";
}

private static string? FindVectorTypeInExpression(SyntaxNode node)
{
// Look for Vector types in the expression (could be identifiers or generic names)
foreach (var descendant in node.DescendantNodesAndSelf())
{
if (descendant is GenericNameSyntax genericName &&
(genericName.Identifier.Text == "Vector64" ||
genericName.Identifier.Text == "Vector128" ||
genericName.Identifier.Text == "Vector256" ||
genericName.Identifier.Text == "Vector512"))
{
return genericName.Identifier.Text;
}
}
return null;
}

protected override SyntaxNode ReplaceWithTernaryMethod(SyntaxNode currentNode, SyntaxGenerator generator, string methodName)
{
if (currentNode is not InvocationExpressionSyntax invocationExpression)
{
Debug.Fail($"Found unexpected node kind: {currentNode.RawKind}");
return currentNode;
}

SeparatedSyntaxList<ArgumentSyntax> arguments = invocationExpression.ArgumentList.Arguments;

if (arguments.Count != 3)
{
Debug.Fail($"Found unexpected number of arguments for ternary method replacement: {arguments.Count}");
return currentNode;
}

// Determine the vector type name from the return type if available
var vectorTypeName = DetermineVectorTypeName(invocationExpression);

// Create the cross-platform method call: VectorXXX.MethodName(arg1, arg2, arg3)
// The type parameter will be inferred from the arguments
var vectorTypeIdentifier = generator.IdentifierName(vectorTypeName);
var replacementExpression = generator.InvocationExpression(
generator.MemberAccessExpression(vectorTypeIdentifier, methodName),
arguments[0].Expression,
arguments[1].Expression,
arguments[2].Expression);

return generator.Parenthesize(replacementExpression);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ public partial class UseCrossPlatformIntrinsicsAnalyzer
{
public enum RuleKind
{
// These names match the underlying IL names for the cross-platform API that will be used in the fixer.
// These names match the underlying IL names or method names for the cross-platform API that will be used in the fixer.

op_Addition,
op_BitwiseAnd,
Expand All @@ -21,6 +21,19 @@ public enum RuleKind
op_UnaryNegation,
op_UnsignedRightShift,

// Named methods (not operators)
Abs,
AndNot,
Ceiling,
ConditionalSelect,
Floor,
Max,
Min,
Negate,
Round,
Sqrt,
Truncate,

Count,
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,20 @@ RuleKind.op_RightShift or
RuleKind.op_OnesComplement or
RuleKind.op_UnaryNegation => IsValidUnaryOperatorMethodInvocation(invocation),

RuleKind.Abs or
RuleKind.Ceiling or
RuleKind.Floor or
RuleKind.Negate or
RuleKind.Round or
RuleKind.Sqrt or
RuleKind.Truncate => IsValidUnaryMethodInvocation(invocation),

RuleKind.AndNot or
RuleKind.Max or
RuleKind.Min => IsValidBinaryMethodInvocation(invocation),

RuleKind.ConditionalSelect => IsValidTernaryMethodInvocation(invocation),

_ => false,
};

Expand All @@ -108,6 +122,27 @@ static bool IsValidUnaryOperatorMethodInvocation(IInvocationOperation invocation
return (invocation.Arguments.Length == 1) &&
SymbolEqualityComparer.Default.Equals(invocation.Type, invocation.Arguments[0].Parameter?.Type);
}

static bool IsValidUnaryMethodInvocation(IInvocationOperation invocation)
{
return (invocation.Arguments.Length == 1) &&
SymbolEqualityComparer.Default.Equals(invocation.Type, invocation.Arguments[0].Parameter?.Type);
}

static bool IsValidBinaryMethodInvocation(IInvocationOperation invocation)
{
return (invocation.Arguments.Length == 2) &&
SymbolEqualityComparer.Default.Equals(invocation.Type, invocation.Arguments[0].Parameter?.Type) &&
SymbolEqualityComparer.Default.Equals(invocation.Type, invocation.Arguments[1].Parameter?.Type);
}

static bool IsValidTernaryMethodInvocation(IInvocationOperation invocation)
{
return (invocation.Arguments.Length == 3) &&
SymbolEqualityComparer.Default.Equals(invocation.Type, invocation.Arguments[0].Parameter?.Type) &&
SymbolEqualityComparer.Default.Equals(invocation.Type, invocation.Arguments[1].Parameter?.Type) &&
SymbolEqualityComparer.Default.Equals(invocation.Type, invocation.Arguments[2].Parameter?.Type);
}
}

private void OnCompilationStart(CompilationStartAnalysisContext context)
Expand Down Expand Up @@ -277,6 +312,7 @@ private void OnCompilationStart(CompilationStartAnalysisContext context)
{
AddBinaryOperatorMethods(methodSymbols, "Add", x86Sse2TypeSymbol, RuleKind.op_Addition);
AddBinaryOperatorMethods(methodSymbols, "And", x86Sse2TypeSymbol, RuleKind.op_BitwiseAnd);
AddBinaryOperatorMethods(methodSymbols, "AndNot", x86Sse2TypeSymbol, RuleKind.AndNot);
AddBinaryOperatorMethods(methodSymbols, "Divide", x86Sse2TypeSymbol, RuleKind.op_Division);
AddBinaryOperatorMethods(methodSymbols, "Multiply", x86Sse2TypeSymbol, RuleKind.op_Multiply, [SpecialType.System_Double]);
AddBinaryOperatorMethods(methodSymbols, "MultiplyLow", x86Sse2TypeSymbol, RuleKind.op_Multiply);
Expand All @@ -294,6 +330,94 @@ private void OnCompilationStart(CompilationStartAnalysisContext context)
AddBinaryOperatorMethods(methodSymbols, "MultiplyLow", x86Sse41TypeSymbol, RuleKind.op_Multiply);
}

// Register named methods (not operators) that have cross-platform equivalents

if (compilation.TryGetOrCreateTypeByMetadataName(WellKnownTypeNames.SystemRuntimeIntrinsicsArmAdvSimd, out var armAdvSimdTypeSymbolForMethods))
{
AddUnaryOperatorMethods(methodSymbols, "Abs", armAdvSimdTypeSymbolForMethods, RuleKind.Abs);
AddBinaryOperatorMethods(methodSymbols, "AndNot", armAdvSimdTypeSymbolForMethods, RuleKind.AndNot);
AddBinaryOperatorMethods(methodSymbols, "Max", armAdvSimdTypeSymbolForMethods, RuleKind.Max);
AddBinaryOperatorMethods(methodSymbols, "Min", armAdvSimdTypeSymbolForMethods, RuleKind.Min);
AddUnaryOperatorMethods(methodSymbols, "Negate", armAdvSimdTypeSymbolForMethods, RuleKind.Negate);
}

if (compilation.TryGetOrCreateTypeByMetadataName(WellKnownTypeNames.SystemRuntimeIntrinsicsArmAdvSimdArm64, out var armAdvSimdArm64TypeSymbolForMethods))
{
AddUnaryOperatorMethods(methodSymbols, "Abs", armAdvSimdArm64TypeSymbolForMethods, RuleKind.Abs);
AddUnaryOperatorMethods(methodSymbols, "Negate", armAdvSimdArm64TypeSymbolForMethods, RuleKind.Negate);
}

if (compilation.TryGetOrCreateTypeByMetadataName(WellKnownTypeNames.SystemRuntimeIntrinsicsWasmPackedSimd, out var wasmPackedSimdTypeSymbolForMethods))
{
AddUnaryOperatorMethods(methodSymbols, "Abs", wasmPackedSimdTypeSymbolForMethods, RuleKind.Abs);
AddUnaryOperatorMethods(methodSymbols, "Ceiling", wasmPackedSimdTypeSymbolForMethods, RuleKind.Ceiling);
AddUnaryOperatorMethods(methodSymbols, "Floor", wasmPackedSimdTypeSymbolForMethods, RuleKind.Floor);
AddBinaryOperatorMethods(methodSymbols, "Max", wasmPackedSimdTypeSymbolForMethods, RuleKind.Max);
AddBinaryOperatorMethods(methodSymbols, "Min", wasmPackedSimdTypeSymbolForMethods, RuleKind.Min);
AddUnaryOperatorMethods(methodSymbols, "Negate", wasmPackedSimdTypeSymbolForMethods, RuleKind.Negate);
AddUnaryOperatorMethods(methodSymbols, "Sqrt", wasmPackedSimdTypeSymbolForMethods, RuleKind.Sqrt);
AddUnaryOperatorMethods(methodSymbols, "Truncate", wasmPackedSimdTypeSymbolForMethods, RuleKind.Truncate);
}

if (compilation.TryGetOrCreateTypeByMetadataName(WellKnownTypeNames.SystemRuntimeIntrinsicsX86Avx, out var x86AvxTypeSymbolForMethods))
{
AddBinaryOperatorMethods(methodSymbols, "AndNot", x86AvxTypeSymbolForMethods, RuleKind.AndNot);
AddUnaryOperatorMethods(methodSymbols, "Ceiling", x86AvxTypeSymbolForMethods, RuleKind.Ceiling);
AddUnaryOperatorMethods(methodSymbols, "Floor", x86AvxTypeSymbolForMethods, RuleKind.Floor);
AddBinaryOperatorMethods(methodSymbols, "Max", x86AvxTypeSymbolForMethods, RuleKind.Max);
AddBinaryOperatorMethods(methodSymbols, "Min", x86AvxTypeSymbolForMethods, RuleKind.Min);
AddUnaryOperatorMethods(methodSymbols, "RoundToNearestInteger", x86AvxTypeSymbolForMethods, RuleKind.Round);
AddUnaryOperatorMethods(methodSymbols, "RoundToNegativeInfinity", x86AvxTypeSymbolForMethods, RuleKind.Floor);
AddUnaryOperatorMethods(methodSymbols, "RoundToPositiveInfinity", x86AvxTypeSymbolForMethods, RuleKind.Ceiling);
AddUnaryOperatorMethods(methodSymbols, "RoundToZero", x86AvxTypeSymbolForMethods, RuleKind.Truncate);
AddUnaryOperatorMethods(methodSymbols, "Sqrt", x86AvxTypeSymbolForMethods, RuleKind.Sqrt);
}

if (compilation.TryGetOrCreateTypeByMetadataName(WellKnownTypeNames.SystemRuntimeIntrinsicsX86Avx2, out var x86Avx2TypeSymbolForMethods))
{
AddUnaryOperatorMethods(methodSymbols, "Abs", x86Avx2TypeSymbolForMethods, RuleKind.Abs);
}

if (compilation.TryGetOrCreateTypeByMetadataName(WellKnownTypeNames.SystemRuntimeIntrinsicsX86Avx512BW, out var x86Avx512BWTypeSymbolForMethods))
{
AddUnaryOperatorMethods(methodSymbols, "Abs", x86Avx512BWTypeSymbolForMethods, RuleKind.Abs);
}

if (compilation.TryGetOrCreateTypeByMetadataName(WellKnownTypeNames.SystemRuntimeIntrinsicsX86Avx512F, out var x86Avx512FTypeSymbolForMethods))
{
AddUnaryOperatorMethods(methodSymbols, "Abs", x86Avx512FTypeSymbolForMethods, RuleKind.Abs);
AddBinaryOperatorMethods(methodSymbols, "Max", x86Avx512FTypeSymbolForMethods, RuleKind.Max);
AddBinaryOperatorMethods(methodSymbols, "Min", x86Avx512FTypeSymbolForMethods, RuleKind.Min);
AddUnaryOperatorMethods(methodSymbols, "RoundToNearestInteger", x86Avx512FTypeSymbolForMethods, RuleKind.Round);
AddUnaryOperatorMethods(methodSymbols, "Sqrt", x86Avx512FTypeSymbolForMethods, RuleKind.Sqrt);
}

if (compilation.TryGetOrCreateTypeByMetadataName(WellKnownTypeNames.SystemRuntimeIntrinsicsX86Sse, out var x86SseTypeSymbolForMethods))
{
AddBinaryOperatorMethods(methodSymbols, "AndNot", x86SseTypeSymbolForMethods, RuleKind.AndNot);
AddBinaryOperatorMethods(methodSymbols, "Max", x86SseTypeSymbolForMethods, RuleKind.Max);
AddBinaryOperatorMethods(methodSymbols, "Min", x86SseTypeSymbolForMethods, RuleKind.Min);
AddUnaryOperatorMethods(methodSymbols, "Sqrt", x86SseTypeSymbolForMethods, RuleKind.Sqrt);
}

if (compilation.TryGetOrCreateTypeByMetadataName(WellKnownTypeNames.SystemRuntimeIntrinsicsX86Sse2, out var x86Sse2TypeSymbolForMethods))
{
AddUnaryOperatorMethods(methodSymbols, "Sqrt", x86Sse2TypeSymbolForMethods, RuleKind.Sqrt);
}

if (compilation.TryGetOrCreateTypeByMetadataName(WellKnownTypeNames.SystemRuntimeIntrinsicsX86Sse41, out var x86Sse41TypeSymbolForMethods))
{
AddUnaryOperatorMethods(methodSymbols, "RoundToNearestInteger", x86Sse41TypeSymbolForMethods, RuleKind.Round);
AddUnaryOperatorMethods(methodSymbols, "RoundToNegativeInfinity", x86Sse41TypeSymbolForMethods, RuleKind.Floor);
AddUnaryOperatorMethods(methodSymbols, "RoundToPositiveInfinity", x86Sse41TypeSymbolForMethods, RuleKind.Ceiling);
AddUnaryOperatorMethods(methodSymbols, "RoundToZero", x86Sse41TypeSymbolForMethods, RuleKind.Truncate);
}

if (compilation.TryGetOrCreateTypeByMetadataName(WellKnownTypeNames.SystemRuntimeIntrinsicsX86Ssse3, out var x86Ssse3TypeSymbolForMethods))
{
AddUnaryOperatorMethods(methodSymbols, "Abs", x86Ssse3TypeSymbolForMethods, RuleKind.Abs);
}

if (methodSymbols.Any())
{
context.RegisterOperationAction((context) => AnalyzeInvocation(context, methodSymbols), OperationKind.Invocation);
Expand Down
Loading
Loading