Skip to content

Commit 08ee676

Browse files
authored
Visit arguments in QueryableMethodNormalizingExpressionVisitor after converting List.Contains (#32219)
Fixes #32215 Fixes #32218
1 parent 338b76a commit 08ee676

File tree

7 files changed

+106
-10
lines changed

7 files changed

+106
-10
lines changed

src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
288288
// Server), we need to fall back to the previous IN translation.
289289
if (method.IsGenericMethod
290290
&& method.GetGenericMethodDefinition() == QueryableMethods.Contains
291-
&& methodCallExpression.Arguments[0] is ParameterQueryRootExpression parameterSource
291+
&& UnwrapAsQueryable(methodCallExpression.Arguments[0]) is ParameterQueryRootExpression parameterSource
292292
&& TranslateExpression(methodCallExpression.Arguments[1]) is SqlExpression item
293293
&& _sqlTranslator.Visit(parameterSource.ParameterExpression) is SqlParameterExpression sqlParameterExpression)
294294
{
@@ -300,6 +300,12 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
300300
.UpdateResultCardinality(ResultCardinality.Single);
301301
return shapedQueryExpression;
302302
}
303+
304+
static Expression UnwrapAsQueryable(Expression expression)
305+
=> expression is MethodCallExpression { Method: { IsGenericMethod: true } method } methodCall
306+
&& method.GetGenericMethodDefinition() == QueryableMethods.AsQueryable
307+
? methodCall.Arguments[0]
308+
: expression;
303309
}
304310

305311
return translated;

src/EFCore/Query/Internal/QueryableMethodNormalizingExpressionVisitor.cs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -435,12 +435,13 @@ private Expression TryConvertListContainsToQueryableContains(MethodCallExpressio
435435

436436
var sourceType = methodCallExpression.Method.DeclaringType!.GetGenericArguments()[0];
437437

438-
return Expression.Call(
439-
QueryableMethods.Contains.MakeGenericMethod(sourceType),
438+
return VisitMethodCall(
440439
Expression.Call(
441-
QueryableMethods.AsQueryable.MakeGenericMethod(sourceType),
442-
methodCallExpression.Object!),
443-
methodCallExpression.Arguments[0]);
440+
QueryableMethods.Contains.MakeGenericMethod(sourceType),
441+
Expression.Call(
442+
QueryableMethods.AsQueryable.MakeGenericMethod(sourceType),
443+
methodCallExpression.Object!),
444+
methodCallExpression.Arguments[0]));
444445
}
445446

446447
private static bool CanConvertEnumerableToQueryable(Type enumerableType, Type queryableType)

test/EFCore.Specification.Tests/Query/PrimitiveCollectionsQueryTestBase.cs

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -807,7 +807,7 @@ public virtual Task Project_primitive_collections_element(bool async)
807807
},
808808
assertOrder: true);
809809

810-
[ConditionalTheory] // #32208
810+
[ConditionalTheory] // #32208, #32215
811811
[MemberData(nameof(IsAsyncData))]
812812
public virtual Task Nested_contains_with_Lists_and_no_inferred_type_mapping(bool async)
813813
{
@@ -821,6 +821,20 @@ public virtual Task Nested_contains_with_Lists_and_no_inferred_type_mapping(bool
821821
ss => ss.Set<PrimitiveCollectionsEntity>().Where(e => strings.Contains(ints.Contains(e.Int) ? "one" : "two")));
822822
}
823823

824+
[ConditionalTheory] // #32208, #32215
825+
[MemberData(nameof(IsAsyncData))]
826+
public virtual Task Nested_contains_with_arrays_and_no_inferred_type_mapping(bool async)
827+
{
828+
var ints = new[] { 1, 2, 3 };
829+
var strings = new[] { "one", "two", "three" };
830+
831+
// Note that in this query, the outer Contains really has no type mapping, neither for its source (collection parameter), nor
832+
// for its item (the conditional expression returns constants). The default type mapping must be applied.
833+
return AssertQuery(
834+
async,
835+
ss => ss.Set<PrimitiveCollectionsEntity>().Where(e => strings.Contains(ints.Contains(e.Int) ? "one" : "two")));
836+
}
837+
824838
public abstract class PrimitiveCollectionsQueryFixtureBase : SharedStoreFixtureBase<PrimitiveCollectionsContext>, IQueryFixtureBase
825839
{
826840
private PrimitiveArrayData? _expectedData;

test/EFCore.SqlServer.FunctionalTests/Query/PrimitiveCollectionsQueryOldSqlServerTest.cs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -625,6 +625,21 @@ END IN (N'one', N'two', N'three')
625625
""");
626626
}
627627

628+
public override async Task Nested_contains_with_arrays_and_no_inferred_type_mapping(bool async)
629+
{
630+
await base.Nested_contains_with_arrays_and_no_inferred_type_mapping(async);
631+
632+
AssertSql(
633+
"""
634+
SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[String], [p].[Strings]
635+
FROM [PrimitiveCollectionsEntity] AS [p]
636+
WHERE CASE
637+
WHEN [p].[Int] IN (1, 2, 3) THEN N'one'
638+
ELSE N'two'
639+
END IN (N'one', N'two', N'three')
640+
""");
641+
}
642+
628643
[ConditionalFact]
629644
public virtual void Check_all_tests_overridden()
630645
=> TestHelpers.AssertAllMethodsOverridden(GetType());

test/EFCore.SqlServer.FunctionalTests/Query/PrimitiveCollectionsQuerySqlServerTest.cs

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1233,12 +1233,40 @@ public override async Task Nested_contains_with_Lists_and_no_inferred_type_mappi
12331233

12341234
AssertSql(
12351235
"""
1236+
@__ints_1='[1,2,3]' (Size = 4000)
12361237
@__strings_0='["one","two","three"]' (Size = 4000)
12371238
12381239
SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[String], [p].[Strings]
12391240
FROM [PrimitiveCollectionsEntity] AS [p]
12401241
WHERE CASE
1241-
WHEN [p].[Int] IN (1, 2, 3) THEN N'one'
1242+
WHEN [p].[Int] IN (
1243+
SELECT [i].[value]
1244+
FROM OPENJSON(@__ints_1) WITH ([value] int '$') AS [i]
1245+
) THEN N'one'
1246+
ELSE N'two'
1247+
END IN (
1248+
SELECT [s].[value]
1249+
FROM OPENJSON(@__strings_0) WITH ([value] nvarchar(max) '$') AS [s]
1250+
)
1251+
""");
1252+
}
1253+
1254+
public override async Task Nested_contains_with_arrays_and_no_inferred_type_mapping(bool async)
1255+
{
1256+
await base.Nested_contains_with_arrays_and_no_inferred_type_mapping(async);
1257+
1258+
AssertSql(
1259+
"""
1260+
@__ints_1='[1,2,3]' (Size = 4000)
1261+
@__strings_0='["one","two","three"]' (Size = 4000)
1262+
1263+
SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[String], [p].[Strings]
1264+
FROM [PrimitiveCollectionsEntity] AS [p]
1265+
WHERE CASE
1266+
WHEN [p].[Int] IN (
1267+
SELECT [i].[value]
1268+
FROM OPENJSON(@__ints_1) WITH ([value] int '$') AS [i]
1269+
) THEN N'one'
12421270
ELSE N'two'
12431271
END IN (
12441272
SELECT [s].[value]

test/EFCore.SqlServer.FunctionalTests/Query/QueryBugsTest.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3982,13 +3982,17 @@ public virtual async Task Nested_contains_with_enum()
39823982

39833983
AssertSql(
39843984
"""
3985+
@__todoTypes_1='[0]' (Size = 4000)
39853986
@__key_2='5f221fb9-66f4-442a-92c9-d97ed5989cc7'
39863987
@__keys_0='["0a47bcb7-a1cb-4345-8944-c58f82d6aac7","5f221fb9-66f4-442a-92c9-d97ed5989cc7"]' (Size = 4000)
39873988
39883989
SELECT [t].[Id], [t].[Type]
39893990
FROM [Todos] AS [t]
39903991
WHERE CASE
3991-
WHEN [t].[Type] = 0 THEN @__key_2
3992+
WHEN [t].[Type] IN (
3993+
SELECT [t0].[value]
3994+
FROM OPENJSON(@__todoTypes_1) WITH ([value] int '$') AS [t0]
3995+
) THEN @__key_2
39923996
ELSE @__key_2
39933997
END IN (
39943998
SELECT [k].[value]

test/EFCore.Sqlite.FunctionalTests/Query/PrimitiveCollectionsQuerySqliteTest.cs

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1115,12 +1115,16 @@ public override async Task Nested_contains_with_Lists_and_no_inferred_type_mappi
11151115

11161116
AssertSql(
11171117
"""
1118+
@__ints_1='[1,2,3]' (Size = 7)
11181119
@__strings_0='["one","two","three"]' (Size = 21)
11191120
11201121
SELECT "p"."Id", "p"."Bool", "p"."Bools", "p"."DateTime", "p"."DateTimes", "p"."Enum", "p"."Enums", "p"."Int", "p"."Ints", "p"."NullableInt", "p"."NullableInts", "p"."NullableString", "p"."NullableStrings", "p"."String", "p"."Strings"
11211122
FROM "PrimitiveCollectionsEntity" AS "p"
11221123
WHERE CASE
1123-
WHEN "p"."Int" IN (1, 2, 3) THEN 'one'
1124+
WHEN "p"."Int" IN (
1125+
SELECT "i"."value"
1126+
FROM json_each(@__ints_1) AS "i"
1127+
) THEN 'one'
11241128
ELSE 'two'
11251129
END IN (
11261130
SELECT "s"."value"
@@ -1129,6 +1133,30 @@ FROM json_each(@__strings_0) AS "s"
11291133
""");
11301134
}
11311135

1136+
public override async Task Nested_contains_with_arrays_and_no_inferred_type_mapping(bool async)
1137+
{
1138+
await base.Nested_contains_with_arrays_and_no_inferred_type_mapping(async);
1139+
1140+
AssertSql(
1141+
"""
1142+
@__ints_1='[1,2,3]' (Size = 7)
1143+
@__strings_0='["one","two","three"]' (Size = 21)
1144+
1145+
SELECT "p"."Id", "p"."Bool", "p"."Bools", "p"."DateTime", "p"."DateTimes", "p"."Enum", "p"."Enums", "p"."Int", "p"."Ints", "p"."NullableInt", "p"."NullableInts", "p"."NullableString", "p"."NullableStrings", "p"."String", "p"."Strings"
1146+
FROM "PrimitiveCollectionsEntity" AS "p"
1147+
WHERE CASE
1148+
WHEN "p"."Int" IN (
1149+
SELECT "i"."value"
1150+
FROM json_each(@__ints_1) AS "i"
1151+
) THEN 'one'
1152+
ELSE 'two'
1153+
END IN (
1154+
SELECT "s"."value"
1155+
FROM json_each(@__strings_0) AS "s"
1156+
)
1157+
""");
1158+
}
1159+
11321160
[ConditionalFact]
11331161
public virtual void Check_all_tests_overridden()
11341162
=> TestHelpers.AssertAllMethodsOverridden(GetType());

0 commit comments

Comments
 (0)