Skip to content

Commit facca73

Browse files
committed
Fix Contains within SQL Server aggregate functions
Fixes #32374
1 parent 17fa62f commit facca73

File tree

10 files changed

+603
-7
lines changed

10 files changed

+603
-7
lines changed

src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1423,7 +1423,14 @@ private StructuralTypeReferenceExpression BindComplexProperty(
14231423
}
14241424
}
14251425

1426-
private bool TryTranslateAggregateMethodCall(
1426+
/// <summary>
1427+
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
1428+
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
1429+
/// any release. You should only use it directly in your code with extreme caution and knowing that
1430+
/// doing so can result in application failures when updating to a new Entity Framework Core release.
1431+
/// </summary>
1432+
[EntityFrameworkInternal]
1433+
protected virtual bool TryTranslateAggregateMethodCall(
14271434
MethodCallExpression methodCallExpression,
14281435
[NotNullWhen(true)] out SqlExpression? translation)
14291436
{

src/EFCore.SqlServer/Query/Internal/SqlServerQueryCompilationContext.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,16 @@ public override bool IsBuffering
3939
=> base.IsBuffering
4040
|| (QuerySplittingBehavior == EntityFrameworkCore.QuerySplittingBehavior.SplitQuery
4141
&& !_multipleActiveResultSetsEnabled);
42+
43+
/// <summary>
44+
/// Tracks whether translation is currently within the argument of an aggregate method (e.g. MAX, COUNT); SQL Server does not
45+
/// allow subqueries and aggregates in that context.
46+
/// </summary>
47+
/// <remarks>
48+
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
49+
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
50+
/// any release. You should only use it directly in your code with extreme caution and knowing that
51+
/// doing so can result in application failures when updating to a new Entity Framework Core release.
52+
/// </remarks>
53+
public virtual bool InAggregateFunction { get; set; }
4254
}

src/EFCore.SqlServer/Query/Internal/SqlServerQueryableMethodTranslatingExpressionVisitor.cs

Lines changed: 163 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Internal;
1818
/// </summary>
1919
public class SqlServerQueryableMethodTranslatingExpressionVisitor : RelationalQueryableMethodTranslatingExpressionVisitor
2020
{
21-
private readonly QueryCompilationContext _queryCompilationContext;
21+
private readonly SqlServerQueryCompilationContext _queryCompilationContext;
2222
private readonly IRelationalTypeMappingSource _typeMappingSource;
2323
private readonly ISqlExpressionFactory _sqlExpressionFactory;
2424
private readonly int _sqlServerCompatibilityLevel;
@@ -34,7 +34,7 @@ public class SqlServerQueryableMethodTranslatingExpressionVisitor : RelationalQu
3434
public SqlServerQueryableMethodTranslatingExpressionVisitor(
3535
QueryableMethodTranslatingExpressionVisitorDependencies dependencies,
3636
RelationalQueryableMethodTranslatingExpressionVisitorDependencies relationalDependencies,
37-
QueryCompilationContext queryCompilationContext,
37+
SqlServerQueryCompilationContext queryCompilationContext,
3838
ISqlServerSingletonOptions sqlServerSingletonOptions)
3939
: base(dependencies, relationalDependencies, queryCompilationContext)
4040
{
@@ -121,6 +121,103 @@ protected override Expression VisitExtension(Expression extensionExpression)
121121
return base.VisitExtension(extensionExpression);
122122
}
123123

124+
#region Aggregate functions
125+
126+
// We override these for SQL Server to add tracking whether we're inside an aggregate function context, since SQL Server doesn't
127+
// support subqueries (or aggregates) within them.
128+
129+
/// <summary>
130+
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
131+
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
132+
/// any release. You should only use it directly in your code with extreme caution and knowing that
133+
/// doing so can result in application failures when updating to a new Entity Framework Core release.
134+
/// </summary>
135+
protected override ShapedQueryExpression? TranslateAverage(ShapedQueryExpression source, LambdaExpression? selector, Type resultType)
136+
{
137+
var previousInAggregateFunction = _queryCompilationContext.InAggregateFunction;
138+
_queryCompilationContext.InAggregateFunction = true;
139+
var result = base.TranslateAverage(source, selector, resultType);
140+
_queryCompilationContext.InAggregateFunction = previousInAggregateFunction;
141+
return result;
142+
}
143+
144+
/// <summary>
145+
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
146+
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
147+
/// any release. You should only use it directly in your code with extreme caution and knowing that
148+
/// doing so can result in application failures when updating to a new Entity Framework Core release.
149+
/// </summary>
150+
protected override ShapedQueryExpression? TranslateSum(ShapedQueryExpression source, LambdaExpression? selector, Type resultType)
151+
{
152+
var previousInAggregateFunction = _queryCompilationContext.InAggregateFunction;
153+
_queryCompilationContext.InAggregateFunction = true;
154+
var result = base.TranslateSum(source, selector, resultType);
155+
_queryCompilationContext.InAggregateFunction = previousInAggregateFunction;
156+
return result;
157+
}
158+
159+
/// <summary>
160+
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
161+
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
162+
/// any release. You should only use it directly in your code with extreme caution and knowing that
163+
/// doing so can result in application failures when updating to a new Entity Framework Core release.
164+
/// </summary>
165+
protected override ShapedQueryExpression? TranslateCount(ShapedQueryExpression source, LambdaExpression? predicate)
166+
{
167+
var previousInAggregateFunction = _queryCompilationContext.InAggregateFunction;
168+
_queryCompilationContext.InAggregateFunction = true;
169+
var result = base.TranslateCount(source, predicate);
170+
_queryCompilationContext.InAggregateFunction = previousInAggregateFunction;
171+
return result;
172+
}
173+
174+
/// <summary>
175+
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
176+
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
177+
/// any release. You should only use it directly in your code with extreme caution and knowing that
178+
/// doing so can result in application failures when updating to a new Entity Framework Core release.
179+
/// </summary>
180+
protected override ShapedQueryExpression? TranslateLongCount(ShapedQueryExpression source, LambdaExpression? predicate)
181+
{
182+
var previousInAggregateFunction = _queryCompilationContext.InAggregateFunction;
183+
_queryCompilationContext.InAggregateFunction = true;
184+
var result = base.TranslateLongCount(source, predicate);
185+
_queryCompilationContext.InAggregateFunction = previousInAggregateFunction;
186+
return result;
187+
}
188+
189+
/// <summary>
190+
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
191+
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
192+
/// any release. You should only use it directly in your code with extreme caution and knowing that
193+
/// doing so can result in application failures when updating to a new Entity Framework Core release.
194+
/// </summary>
195+
protected override ShapedQueryExpression? TranslateMax(ShapedQueryExpression source, LambdaExpression? selector, Type resultType)
196+
{
197+
var previousInAggregateFunction = _queryCompilationContext.InAggregateFunction;
198+
_queryCompilationContext.InAggregateFunction = true;
199+
var result = base.TranslateMax(source, selector, resultType);
200+
_queryCompilationContext.InAggregateFunction = previousInAggregateFunction;
201+
return result;
202+
}
203+
204+
/// <summary>
205+
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
206+
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
207+
/// any release. You should only use it directly in your code with extreme caution and knowing that
208+
/// doing so can result in application failures when updating to a new Entity Framework Core release.
209+
/// </summary>
210+
protected override ShapedQueryExpression? TranslateMin(ShapedQueryExpression source, LambdaExpression? selector, Type resultType)
211+
{
212+
var previousInAggregateFunction = _queryCompilationContext.InAggregateFunction;
213+
_queryCompilationContext.InAggregateFunction = true;
214+
var result = base.TranslateMin(source, selector, resultType);
215+
_queryCompilationContext.InAggregateFunction = previousInAggregateFunction;
216+
return result;
217+
}
218+
219+
#endregion Aggregate functions
220+
124221
/// <summary>
125222
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
126223
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
@@ -315,6 +412,47 @@ static IEnumerable<INavigation> GetAllNavigationsInHierarchy(IEntityType entityT
315412
.SelectMany(t => t.GetDeclaredNavigations());
316413
}
317414

415+
/// <summary>
416+
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
417+
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
418+
/// any release. You should only use it directly in your code with extreme caution and knowing that
419+
/// doing so can result in application failures when updating to a new Entity Framework Core release.
420+
/// </summary>
421+
protected override ShapedQueryExpression? TranslateContains(ShapedQueryExpression source, Expression item)
422+
{
423+
var translatedSource = base.TranslateContains(source, item);
424+
425+
// SQL Server does not support subqueries inside aggregate functions (e.g. COUNT(SELECT * FROM OPENJSON(@p)...)).
426+
// As a result, we track whether we're within an aggregate function; if we are, and we see the regular Contains translation
427+
// (which uses IN with an OPENJSON subquery - incompatible), we transform it to the old-style IN+constants translation (as if a
428+
// low SQL Server compatibility level were defined)
429+
if (_queryCompilationContext.InAggregateFunction
430+
&& translatedSource is not null
431+
&& TryGetProjection(translatedSource, out var projection)
432+
&& projection is InExpression
433+
{
434+
Item: var translatedItem,
435+
Subquery:
436+
{
437+
Tables: [SqlServerOpenJsonExpression { Arguments: [SqlParameterExpression parameter] } openJsonExpression],
438+
GroupBy: [],
439+
Having: null,
440+
IsDistinct: false,
441+
Limit: null,
442+
Offset: null,
443+
Orderings: [],
444+
Projection: [{ Expression: ColumnExpression { Name: "value", Table: var projectionColumnTable } }]
445+
}
446+
}
447+
&& projectionColumnTable == openJsonExpression)
448+
{
449+
var newInExpression = _sqlExpressionFactory.In(translatedItem, parameter);
450+
return source.UpdateQueryExpression(_sqlExpressionFactory.Select(newInExpression));
451+
}
452+
453+
return translatedSource;
454+
}
455+
318456
/// <summary>
319457
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
320458
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
@@ -504,6 +642,29 @@ protected override bool IsValidSelectExpressionForExecuteUpdate(
504642
return false;
505643
}
506644

645+
private bool TryGetProjection(ShapedQueryExpression shapedQueryExpression, [NotNullWhen(true)] out SqlExpression? projection)
646+
{
647+
var shaperExpression = shapedQueryExpression.ShaperExpression;
648+
// No need to check ConvertChecked since this is convert node which we may have added during projection
649+
if (shaperExpression is UnaryExpression { NodeType: ExpressionType.Convert } unaryExpression
650+
&& unaryExpression.Operand.Type.IsNullableType()
651+
&& unaryExpression.Operand.Type.UnwrapNullableType() == unaryExpression.Type)
652+
{
653+
shaperExpression = unaryExpression.Operand;
654+
}
655+
656+
if (shapedQueryExpression.QueryExpression is SelectExpression selectExpression
657+
&& shaperExpression is ProjectionBindingExpression projectionBindingExpression
658+
&& selectExpression.GetProjection(projectionBindingExpression) is SqlExpression sqlExpression)
659+
{
660+
projection = sqlExpression;
661+
return true;
662+
}
663+
664+
projection = null;
665+
return false;
666+
}
667+
507668
private sealed class TemporalAnnotationApplyingExpressionVisitor : ExpressionVisitor
508669
{
509670
private readonly Func<TableExpression, TableExpressionBase> _annotationApplyingFunc;

src/EFCore.SqlServer/Query/Internal/SqlServerQueryableMethodTranslatingExpressionVisitorFactory.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,5 +49,5 @@ public SqlServerQueryableMethodTranslatingExpressionVisitorFactory(
4949
/// </summary>
5050
public virtual QueryableMethodTranslatingExpressionVisitor Create(QueryCompilationContext queryCompilationContext)
5151
=> new SqlServerQueryableMethodTranslatingExpressionVisitor(
52-
Dependencies, RelationalDependencies, queryCompilationContext, _sqlServerSingletonOptions);
52+
Dependencies, RelationalDependencies, (SqlServerQueryCompilationContext)queryCompilationContext, _sqlServerSingletonOptions);
5353
}

src/EFCore.SqlServer/Query/Internal/SqlServerSqlTranslatingExpressionVisitor.cs

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Internal;
1717
/// </summary>
1818
public class SqlServerSqlTranslatingExpressionVisitor : RelationalSqlTranslatingExpressionVisitor
1919
{
20-
private readonly QueryCompilationContext _queryCompilationContext;
20+
private readonly SqlServerQueryCompilationContext _queryCompilationContext;
2121
private readonly ISqlExpressionFactory _sqlExpressionFactory;
2222

2323
private static readonly HashSet<string> DateTimeDataTypes
@@ -73,7 +73,7 @@ private static readonly MethodInfo StringContainsMethodInfo
7373
/// </summary>
7474
public SqlServerSqlTranslatingExpressionVisitor(
7575
RelationalSqlTranslatingExpressionVisitorDependencies dependencies,
76-
QueryCompilationContext queryCompilationContext,
76+
SqlServerQueryCompilationContext queryCompilationContext,
7777
QueryableMethodTranslatingExpressionVisitor queryableMethodTranslatingExpressionVisitor)
7878
: base(dependencies, queryCompilationContext, queryableMethodTranslatingExpressionVisitor)
7979
{
@@ -432,6 +432,28 @@ private static string EscapeLikePattern(string pattern)
432432
return builder.ToString();
433433
}
434434

435+
/// <summary>
436+
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
437+
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
438+
/// any release. You should only use it directly in your code with extreme caution and knowing that
439+
/// doing so can result in application failures when updating to a new Entity Framework Core release.
440+
/// </summary>
441+
protected override bool TryTranslateAggregateMethodCall(
442+
MethodCallExpression methodCallExpression,
443+
[NotNullWhen(true)] out SqlExpression? translation)
444+
{
445+
var previousInAggregateFunction = _queryCompilationContext.InAggregateFunction;
446+
_queryCompilationContext.InAggregateFunction = true;
447+
448+
#pragma warning disable EF1001 // Internal EF Core API usage.
449+
var result = base.TryTranslateAggregateMethodCall(methodCallExpression, out translation);
450+
#pragma warning restore EF1001 // Internal EF Core API usage.
451+
452+
_queryCompilationContext.InAggregateFunction = previousInAggregateFunction;
453+
454+
return result;
455+
}
456+
435457
private Expression TranslateByteArrayElementAccess(Expression array, Expression index, Type resultType)
436458
{
437459
var visitedArray = Visit(array);

src/EFCore.SqlServer/Query/Internal/SqlServerSqlTranslatingExpressionVisitorFactory.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,6 @@ public virtual RelationalSqlTranslatingExpressionVisitor Create(
3939
QueryableMethodTranslatingExpressionVisitor queryableMethodTranslatingExpressionVisitor)
4040
=> new SqlServerSqlTranslatingExpressionVisitor(
4141
Dependencies,
42-
queryCompilationContext,
42+
(SqlServerQueryCompilationContext)queryCompilationContext,
4343
queryableMethodTranslatingExpressionVisitor);
4444
}

test/EFCore.Cosmos.FunctionalTests/Query/NorthwindAggregateOperatorsQueryCosmosTest.cs

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2249,6 +2249,86 @@ public override async Task Not_Any_false(bool async)
22492249
AssertSql();
22502250
}
22512251

2252+
public override async Task Contains_inside_aggregate_function_with_GroupBy(bool async)
2253+
{
2254+
// GroupBy. Issue #17313.
2255+
await AssertTranslationFailed(() => base.Contains_inside_aggregate_function_with_GroupBy(async));
2256+
2257+
AssertSql();
2258+
}
2259+
2260+
public override async Task Contains_inside_Average_without_GroupBy(bool async)
2261+
{
2262+
await base.Contains_inside_Average_without_GroupBy(async);
2263+
2264+
AssertSql(
2265+
"""
2266+
SELECT AVG((c["City"] IN ("London", "Berlin") ? 1.0 : 0.0)) AS c
2267+
FROM root c
2268+
WHERE (c["Discriminator"] = "Customer")
2269+
""");
2270+
}
2271+
2272+
public override async Task Contains_inside_Sum_without_GroupBy(bool async)
2273+
{
2274+
await base.Contains_inside_Sum_without_GroupBy(async);
2275+
2276+
AssertSql(
2277+
"""
2278+
SELECT SUM((c["City"] IN ("London", "Berlin") ? 1 : 0)) AS c
2279+
FROM root c
2280+
WHERE (c["Discriminator"] = "Customer")
2281+
""");
2282+
}
2283+
2284+
public override async Task Contains_inside_Count_without_GroupBy(bool async)
2285+
{
2286+
await base.Contains_inside_Count_without_GroupBy(async);
2287+
2288+
AssertSql(
2289+
"""
2290+
SELECT COUNT(1) AS c
2291+
FROM root c
2292+
WHERE ((c["Discriminator"] = "Customer") AND c["City"] IN ("London", "Berlin"))
2293+
""");
2294+
}
2295+
2296+
public override async Task Contains_inside_LongCount_without_GroupBy(bool async)
2297+
{
2298+
await base.Contains_inside_LongCount_without_GroupBy(async);
2299+
2300+
AssertSql(
2301+
"""
2302+
SELECT COUNT(1) AS c
2303+
FROM root c
2304+
WHERE ((c["Discriminator"] = "Customer") AND c["City"] IN ("London", "Berlin"))
2305+
""");
2306+
}
2307+
2308+
public override async Task Contains_inside_Max_without_GroupBy(bool async)
2309+
{
2310+
await base.Contains_inside_Max_without_GroupBy(async);
2311+
2312+
AssertSql(
2313+
"""
2314+
SELECT MAX((c["City"] IN ("London", "Berlin") ? 1 : 0)) AS c
2315+
FROM root c
2316+
WHERE (c["Discriminator"] = "Customer")
2317+
""");
2318+
}
2319+
2320+
public override async Task Contains_inside_Min_without_GroupBy(bool async)
2321+
{
2322+
await base.Contains_inside_Min_without_GroupBy(async);
2323+
2324+
AssertSql(
2325+
"""
2326+
SELECT MIN((c["City"] IN ("London", "Berlin") ? 1 : 0)) AS c
2327+
FROM root c
2328+
WHERE (c["Discriminator"] = "Customer")
2329+
""");
2330+
}
2331+
22522332
private void AssertSql(params string[] expected)
22532333
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);
22542334

0 commit comments

Comments
 (0)