diff --git a/src/DelegateDecompiler.Tests/NestedExpressionsTests.cs b/src/DelegateDecompiler.Tests/NestedExpressionsTests.cs index d19e5dd5..e786cea1 100644 --- a/src/DelegateDecompiler.Tests/NestedExpressionsTests.cs +++ b/src/DelegateDecompiler.Tests/NestedExpressionsTests.cs @@ -14,6 +14,21 @@ public class NestedExpressionsTests : DecompilerTestsBase int M1() => 0; static int M2() => 0; + readonly IQueryable fQref1 = Enumerable.Empty().AsQueryable(); + static IQueryable fQref2 = Enumerable.Empty().AsQueryable(); + [Decompile] + IQueryable pQref1 => Enumerable.Empty().AsQueryable(); + [Decompile] + static IQueryable pQref2 => Enumerable.Empty().AsQueryable(); + [Decompile] + IQueryable MQref1() => Enumerable.Empty().AsQueryable(); + [Decompile] + static IQueryable MQref2() => Enumerable.Empty().AsQueryable(); + [Decompile] + IQueryable ParamedMQref1(int floor) => Enumerable.Empty().AsQueryable().Where(x => x >= floor); + [Decompile] + static IQueryable ParamedMQref2(int floor) => Enumerable.Empty().AsQueryable().Where(x => x >= floor); + [Test] public void TestNestedExpression() { @@ -178,5 +193,107 @@ public void TestFuncWithStaticMethodClosure() ints => ints.SingleOrDefault(i => i == M2()) ); } + + [Test] + public void TestQueryableBoundAsVariable() + { + IQueryable query = Enumerable.Empty().AsQueryable().Where(i => i >= 0); + Test, IQueryable>>( + ints => query, + ints => query + ); + } + + [Test] + public void TestQueryableRefFromField() + { + Test, IQueryable>>( + ints => fQref1.Where(i => i >= 0), + ints => fQref1.Where(i => i >= 0) + ); + } + + [Test] + public void TestQueryableRefFromStaticField() + { + Test, IQueryable>>( + ints => fQref2.Where(i => i >= 0), + ints => fQref2.Where(i => i >= 0) + ); + } + + [Test] + public void TestQueryableRefFromProperty() + { + Test, IQueryable>>( + ints => Enumerable.Empty().AsQueryable().Where(i => i >= 0), + ints => pQref1.Where(i => i >= 0) + ); + } + + [Test] + public void TestQueryableRefFromStaticProperty() + { + Test, IQueryable>>( + ints => Enumerable.Empty().AsQueryable().Where(i => i >= 0), + ints => pQref2.Where(i => i >= 0) + ); + } + + [Test] + public void TestQueryableRefFromMethod() + { + Test, IQueryable>>( + ints => Enumerable.Empty().AsQueryable().Where(i => i >= 0), + ints => MQref1().Where(i => i >= 0) + ); + } + + [Test] + public void TestQueryableRefFromStaticMethod() + { + Test, IQueryable>>( + ints => Enumerable.Empty().AsQueryable().Where(i => i >= 0), + ints => MQref2().Where(i => i >= 0) + ); + } + + [Test] + public void TestQueryableRefFromMethodWithBoundParameters() + { + var floor = 10; + Test, IQueryable>>( + ints => Enumerable.Empty().AsQueryable().Where(x => x >= floor).Where(i => i >= 0), + ints => ParamedMQref1(floor).Where(i => i >= 0) + ); + } + + [Test] + public void TestQueryableRefFromStaticMethoWithBoundParameters() + { + var floor = 10; + Test, IQueryable>>( + ints => Enumerable.Empty().AsQueryable().Where(x => x >= floor).Where(i => i >= 0), + ints => ParamedMQref1(floor).Where(i => i >= 0) + ); + } + + [Test] + public void TestQueryableRefFromMethodWithUnboundParameters() + { + Test, int, IQueryable>>( + (ints, floor) => Enumerable.Empty().AsQueryable().Where(x => x >= floor).Where(i => i >= 0), + (ints, floor) => ParamedMQref1(floor).Where(i => i >= 0) + ); + } + + [Test] + public void TestQueryableRefFromStaticMethoWithUnboundParameters() + { + Test, int, IQueryable>>( + (ints, floor) => Enumerable.Empty().AsQueryable().Where(x => x >= floor).Where(i => i >= 0), + (ints, floor) => ParamedMQref1(floor).Where(i => i >= 0) + ); + } } } diff --git a/src/DelegateDecompiler/DecompileExtensions.cs b/src/DelegateDecompiler/DecompileExtensions.cs index 6b688d33..c7dacc89 100644 --- a/src/DelegateDecompiler/DecompileExtensions.cs +++ b/src/DelegateDecompiler/DecompileExtensions.cs @@ -35,7 +35,9 @@ public static LambdaExpression Decompile(this MethodInfo method) public static LambdaExpression Decompile(this MethodInfo method, Type declaringType) { - return Cache.GetOrAdd(Tuple.Create(declaringType, method), DecompileDelegate).Value; + return (LambdaExpression)Cache + .GetOrAdd(Tuple.Create(declaringType, method), DecompileDelegate) + .Value.Decompile(); } public static IQueryable Decompile(this IQueryable self)