Skip to content

Commit 0eefa4b

Browse files
committed
Fix Contains within SQL Server aggregate functions
Fixes #32374
1 parent 17fa62f commit 0eefa4b

8 files changed

+386
-7
lines changed

src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1423,7 +1423,14 @@ private StructuralTypeReferenceExpression BindComplexProperty(
14231423
}
14241424
}
14251425

1426-
private bool TryTranslateAggregateMethodCall(
1426+
/// <summary>
1427+
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
1428+
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
1429+
/// any release. You should only use it directly in your code with extreme caution and knowing that
1430+
/// doing so can result in application failures when updating to a new Entity Framework Core release.
1431+
/// </summary>
1432+
[EntityFrameworkInternal]
1433+
protected virtual bool TryTranslateAggregateMethodCall(
14271434
MethodCallExpression methodCallExpression,
14281435
[NotNullWhen(true)] out SqlExpression? translation)
14291436
{

src/EFCore.SqlServer/Query/Internal/SqlServerQueryCompilationContext.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,16 @@ public override bool IsBuffering
3939
=> base.IsBuffering
4040
|| (QuerySplittingBehavior == EntityFrameworkCore.QuerySplittingBehavior.SplitQuery
4141
&& !_multipleActiveResultSetsEnabled);
42+
43+
/// <summary>
44+
/// Tracks whether translation is currently within the argument of an aggregate method (e.g. MAX, COUNT); SQL Server does not
45+
/// allow subqueries and aggregates in that context.
46+
/// </summary>
47+
/// <remarks>
48+
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
49+
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
50+
/// any release. You should only use it directly in your code with extreme caution and knowing that
51+
/// doing so can result in application failures when updating to a new Entity Framework Core release.
52+
/// </remarks>
53+
public bool InAggregateFunction { get; set; }
4254
}

src/EFCore.SqlServer/Query/Internal/SqlServerQueryableMethodTranslatingExpressionVisitor.cs

Lines changed: 163 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Internal;
1818
/// </summary>
1919
public class SqlServerQueryableMethodTranslatingExpressionVisitor : RelationalQueryableMethodTranslatingExpressionVisitor
2020
{
21-
private readonly QueryCompilationContext _queryCompilationContext;
21+
private readonly SqlServerQueryCompilationContext _queryCompilationContext;
2222
private readonly IRelationalTypeMappingSource _typeMappingSource;
2323
private readonly ISqlExpressionFactory _sqlExpressionFactory;
2424
private readonly int _sqlServerCompatibilityLevel;
@@ -34,7 +34,7 @@ public class SqlServerQueryableMethodTranslatingExpressionVisitor : RelationalQu
3434
public SqlServerQueryableMethodTranslatingExpressionVisitor(
3535
QueryableMethodTranslatingExpressionVisitorDependencies dependencies,
3636
RelationalQueryableMethodTranslatingExpressionVisitorDependencies relationalDependencies,
37-
QueryCompilationContext queryCompilationContext,
37+
SqlServerQueryCompilationContext queryCompilationContext,
3838
ISqlServerSingletonOptions sqlServerSingletonOptions)
3939
: base(dependencies, relationalDependencies, queryCompilationContext)
4040
{
@@ -121,6 +121,103 @@ protected override Expression VisitExtension(Expression extensionExpression)
121121
return base.VisitExtension(extensionExpression);
122122
}
123123

124+
#region Aggregate functions
125+
126+
// We override these for SQL Server to add tracking whether we're inside an aggregate function context, since SQL Server doesn't
127+
// support subqueries (or aggregates) within them.
128+
129+
/// <summary>
130+
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
131+
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
132+
/// any release. You should only use it directly in your code with extreme caution and knowing that
133+
/// doing so can result in application failures when updating to a new Entity Framework Core release.
134+
/// </summary>
135+
protected override ShapedQueryExpression? TranslateAverage(ShapedQueryExpression source, LambdaExpression? selector, Type resultType)
136+
{
137+
var previousInAggregateFunction = _queryCompilationContext.InAggregateFunction;
138+
_queryCompilationContext.InAggregateFunction = true;
139+
var result = base.TranslateAverage(source, selector, resultType);
140+
_queryCompilationContext.InAggregateFunction = previousInAggregateFunction;
141+
return result;
142+
}
143+
144+
/// <summary>
145+
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
146+
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
147+
/// any release. You should only use it directly in your code with extreme caution and knowing that
148+
/// doing so can result in application failures when updating to a new Entity Framework Core release.
149+
/// </summary>
150+
protected override ShapedQueryExpression? TranslateSum(ShapedQueryExpression source, LambdaExpression? selector, Type resultType)
151+
{
152+
var previousInAggregateFunction = _queryCompilationContext.InAggregateFunction;
153+
_queryCompilationContext.InAggregateFunction = true;
154+
var result = base.TranslateSum(source, selector, resultType);
155+
_queryCompilationContext.InAggregateFunction = previousInAggregateFunction;
156+
return result;
157+
}
158+
159+
/// <summary>
160+
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
161+
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
162+
/// any release. You should only use it directly in your code with extreme caution and knowing that
163+
/// doing so can result in application failures when updating to a new Entity Framework Core release.
164+
/// </summary>
165+
protected override ShapedQueryExpression? TranslateCount(ShapedQueryExpression source, LambdaExpression? predicate)
166+
{
167+
var previousInAggregateFunction = _queryCompilationContext.InAggregateFunction;
168+
_queryCompilationContext.InAggregateFunction = true;
169+
var result = base.TranslateCount(source, predicate);
170+
_queryCompilationContext.InAggregateFunction = previousInAggregateFunction;
171+
return result;
172+
}
173+
174+
/// <summary>
175+
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
176+
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
177+
/// any release. You should only use it directly in your code with extreme caution and knowing that
178+
/// doing so can result in application failures when updating to a new Entity Framework Core release.
179+
/// </summary>
180+
protected override ShapedQueryExpression? TranslateLongCount(ShapedQueryExpression source, LambdaExpression? predicate)
181+
{
182+
var previousInAggregateFunction = _queryCompilationContext.InAggregateFunction;
183+
_queryCompilationContext.InAggregateFunction = true;
184+
var result = base.TranslateLongCount(source, predicate);
185+
_queryCompilationContext.InAggregateFunction = previousInAggregateFunction;
186+
return result;
187+
}
188+
189+
/// <summary>
190+
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
191+
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
192+
/// any release. You should only use it directly in your code with extreme caution and knowing that
193+
/// doing so can result in application failures when updating to a new Entity Framework Core release.
194+
/// </summary>
195+
protected override ShapedQueryExpression? TranslateMax(ShapedQueryExpression source, LambdaExpression? selector, Type resultType)
196+
{
197+
var previousInAggregateFunction = _queryCompilationContext.InAggregateFunction;
198+
_queryCompilationContext.InAggregateFunction = true;
199+
var result = base.TranslateMax(source, selector, resultType);
200+
_queryCompilationContext.InAggregateFunction = previousInAggregateFunction;
201+
return result;
202+
}
203+
204+
/// <summary>
205+
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
206+
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
207+
/// any release. You should only use it directly in your code with extreme caution and knowing that
208+
/// doing so can result in application failures when updating to a new Entity Framework Core release.
209+
/// </summary>
210+
protected override ShapedQueryExpression? TranslateMin(ShapedQueryExpression source, LambdaExpression? selector, Type resultType)
211+
{
212+
var previousInAggregateFunction = _queryCompilationContext.InAggregateFunction;
213+
_queryCompilationContext.InAggregateFunction = true;
214+
var result = base.TranslateMin(source, selector, resultType);
215+
_queryCompilationContext.InAggregateFunction = previousInAggregateFunction;
216+
return result;
217+
}
218+
219+
#endregion Aggregate functions
220+
124221
/// <summary>
125222
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
126223
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
@@ -315,6 +412,47 @@ static IEnumerable<INavigation> GetAllNavigationsInHierarchy(IEntityType entityT
315412
.SelectMany(t => t.GetDeclaredNavigations());
316413
}
317414

415+
/// <summary>
416+
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
417+
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
418+
/// any release. You should only use it directly in your code with extreme caution and knowing that
419+
/// doing so can result in application failures when updating to a new Entity Framework Core release.
420+
/// </summary>
421+
protected override ShapedQueryExpression? TranslateContains(ShapedQueryExpression source, Expression item)
422+
{
423+
var translatedSource = base.TranslateContains(source, item);
424+
425+
// SQL Server does not support subqueries inside aggregate functions (e.g. COUNT(SELECT * FROM OPENJSON(@p)...)).
426+
// As a result, we track whether we're within an aggregate function; if we are, and we see the regular Contains translation
427+
// (which uses IN with an OPENJSON subquery - incompatible), we transform it to the old-style IN+constants translation (as if a
428+
// low SQL Server compatibility level were defined)
429+
if (_queryCompilationContext.InAggregateFunction
430+
&& translatedSource is not null
431+
&& TryGetProjection(translatedSource, out var projection)
432+
&& projection is InExpression
433+
{
434+
Item: var translatedItem,
435+
Subquery:
436+
{
437+
Tables: [SqlServerOpenJsonExpression { Arguments: [SqlParameterExpression parameter] } openJsonExpression],
438+
GroupBy: [],
439+
Having: null,
440+
IsDistinct: false,
441+
Limit: null,
442+
Offset: null,
443+
Orderings: [],
444+
Projection: [{ Expression: ColumnExpression { Name: "value", Table: var projectionColumnTable } }]
445+
}
446+
}
447+
&& projectionColumnTable == openJsonExpression)
448+
{
449+
var newInExpression = _sqlExpressionFactory.In(translatedItem, parameter);
450+
return source.UpdateQueryExpression(_sqlExpressionFactory.Select(newInExpression));
451+
}
452+
453+
return translatedSource;
454+
}
455+
318456
/// <summary>
319457
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
320458
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
@@ -504,6 +642,29 @@ protected override bool IsValidSelectExpressionForExecuteUpdate(
504642
return false;
505643
}
506644

645+
private bool TryGetProjection(ShapedQueryExpression shapedQueryExpression, [NotNullWhen(true)] out SqlExpression? projection)
646+
{
647+
var shaperExpression = shapedQueryExpression.ShaperExpression;
648+
// No need to check ConvertChecked since this is convert node which we may have added during projection
649+
if (shaperExpression is UnaryExpression { NodeType: ExpressionType.Convert } unaryExpression
650+
&& unaryExpression.Operand.Type.IsNullableType()
651+
&& unaryExpression.Operand.Type.UnwrapNullableType() == unaryExpression.Type)
652+
{
653+
shaperExpression = unaryExpression.Operand;
654+
}
655+
656+
if (shapedQueryExpression.QueryExpression is SelectExpression selectExpression
657+
&& shaperExpression is ProjectionBindingExpression projectionBindingExpression
658+
&& selectExpression.GetProjection(projectionBindingExpression) is SqlExpression sqlExpression)
659+
{
660+
projection = sqlExpression;
661+
return true;
662+
}
663+
664+
projection = null;
665+
return false;
666+
}
667+
507668
private sealed class TemporalAnnotationApplyingExpressionVisitor : ExpressionVisitor
508669
{
509670
private readonly Func<TableExpression, TableExpressionBase> _annotationApplyingFunc;

src/EFCore.SqlServer/Query/Internal/SqlServerQueryableMethodTranslatingExpressionVisitorFactory.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,5 +49,5 @@ public SqlServerQueryableMethodTranslatingExpressionVisitorFactory(
4949
/// </summary>
5050
public virtual QueryableMethodTranslatingExpressionVisitor Create(QueryCompilationContext queryCompilationContext)
5151
=> new SqlServerQueryableMethodTranslatingExpressionVisitor(
52-
Dependencies, RelationalDependencies, queryCompilationContext, _sqlServerSingletonOptions);
52+
Dependencies, RelationalDependencies, (SqlServerQueryCompilationContext)queryCompilationContext, _sqlServerSingletonOptions);
5353
}

src/EFCore.SqlServer/Query/Internal/SqlServerSqlTranslatingExpressionVisitor.cs

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Internal;
1717
/// </summary>
1818
public class SqlServerSqlTranslatingExpressionVisitor : RelationalSqlTranslatingExpressionVisitor
1919
{
20-
private readonly QueryCompilationContext _queryCompilationContext;
20+
private readonly SqlServerQueryCompilationContext _queryCompilationContext;
2121
private readonly ISqlExpressionFactory _sqlExpressionFactory;
2222

2323
private static readonly HashSet<string> DateTimeDataTypes
@@ -73,7 +73,7 @@ private static readonly MethodInfo StringContainsMethodInfo
7373
/// </summary>
7474
public SqlServerSqlTranslatingExpressionVisitor(
7575
RelationalSqlTranslatingExpressionVisitorDependencies dependencies,
76-
QueryCompilationContext queryCompilationContext,
76+
SqlServerQueryCompilationContext queryCompilationContext,
7777
QueryableMethodTranslatingExpressionVisitor queryableMethodTranslatingExpressionVisitor)
7878
: base(dependencies, queryCompilationContext, queryableMethodTranslatingExpressionVisitor)
7979
{
@@ -432,6 +432,26 @@ private static string EscapeLikePattern(string pattern)
432432
return builder.ToString();
433433
}
434434

435+
/// <summary>
436+
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
437+
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
438+
/// any release. You should only use it directly in your code with extreme caution and knowing that
439+
/// doing so can result in application failures when updating to a new Entity Framework Core release.
440+
/// </summary>
441+
protected override bool TryTranslateAggregateMethodCall(
442+
MethodCallExpression methodCallExpression,
443+
[NotNullWhen(true)] out SqlExpression? translation)
444+
{
445+
var previousInAggregateFunction = _queryCompilationContext.InAggregateFunction;
446+
_queryCompilationContext.InAggregateFunction = true;
447+
448+
var result = base.TryTranslateAggregateMethodCall(methodCallExpression, out translation);
449+
450+
_queryCompilationContext.InAggregateFunction = previousInAggregateFunction;
451+
452+
return result;
453+
}
454+
435455
private Expression TranslateByteArrayElementAccess(Expression array, Expression index, Type resultType)
436456
{
437457
var visitedArray = Visit(array);

src/EFCore.SqlServer/Query/Internal/SqlServerSqlTranslatingExpressionVisitorFactory.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,6 @@ public virtual RelationalSqlTranslatingExpressionVisitor Create(
3939
QueryableMethodTranslatingExpressionVisitor queryableMethodTranslatingExpressionVisitor)
4040
=> new SqlServerSqlTranslatingExpressionVisitor(
4141
Dependencies,
42-
queryCompilationContext,
42+
(SqlServerQueryCompilationContext)queryCompilationContext,
4343
queryableMethodTranslatingExpressionVisitor);
4444
}

test/EFCore.Specification.Tests/Query/NorthwindAggregateOperatorsQueryTestBase.cs

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1923,4 +1923,89 @@ public virtual Task Not_Any_false(bool async)
19231923
=> AssertQuery(
19241924
async,
19251925
ss => ss.Set<Customer>().Where(c => !c.Orders.Any(o => false)).Select(c => c.CustomerID));
1926+
1927+
[ConditionalTheory] // #32374
1928+
[MemberData(nameof(IsAsyncData))]
1929+
public virtual Task Contains_inside_aggregate_function_with_GroupBy(bool async)
1930+
{
1931+
var cities = new[] { "London", "Berlin" };
1932+
1933+
return AssertQuery(
1934+
async,
1935+
ss => ss.Set<Customer>()
1936+
.GroupBy(c => c.Country)
1937+
.Select(g => g.Count(c => cities.Contains(c.City))));
1938+
}
1939+
1940+
[ConditionalTheory] // #32374
1941+
[MemberData(nameof(IsAsyncData))]
1942+
public virtual Task Contains_inside_Average_without_GroupBy(bool async)
1943+
{
1944+
var cities = new[] { "London", "Berlin" };
1945+
1946+
return AssertAverage(
1947+
async,
1948+
ss => ss.Set<Customer>(),
1949+
selector: c => cities.Contains(c.City) ? 1 : 0);
1950+
}
1951+
1952+
[ConditionalTheory] // #32374
1953+
[MemberData(nameof(IsAsyncData))]
1954+
public virtual Task Contains_inside_Sum_without_GroupBy(bool async)
1955+
{
1956+
var cities = new[] { "London", "Berlin" };
1957+
1958+
return AssertSum(
1959+
async,
1960+
ss => ss.Set<Customer>(),
1961+
selector: c => cities.Contains(c.City) ? 1 : 0);
1962+
}
1963+
1964+
[ConditionalTheory] // #32374
1965+
[MemberData(nameof(IsAsyncData))]
1966+
public virtual Task Contains_inside_Count_without_GroupBy(bool async)
1967+
{
1968+
var cities = new[] { "London", "Berlin" };
1969+
1970+
return AssertCount(
1971+
async,
1972+
ss => ss.Set<Customer>(),
1973+
predicate: c => cities.Contains(c.City));
1974+
}
1975+
1976+
[ConditionalTheory] // #32374
1977+
[MemberData(nameof(IsAsyncData))]
1978+
public virtual Task Contains_inside_LongCount_without_GroupBy(bool async)
1979+
{
1980+
var cities = new[] { "London", "Berlin" };
1981+
1982+
return AssertLongCount(
1983+
async,
1984+
ss => ss.Set<Customer>(),
1985+
predicate: c => cities.Contains(c.City));
1986+
}
1987+
1988+
[ConditionalTheory] // #32374
1989+
[MemberData(nameof(IsAsyncData))]
1990+
public virtual Task Contains_inside_Max_without_GroupBy(bool async)
1991+
{
1992+
var cities = new[] { "London", "Berlin" };
1993+
1994+
return AssertMax(
1995+
async,
1996+
ss => ss.Set<Customer>(),
1997+
selector: c => cities.Contains(c.City) ? 1 : 0);
1998+
}
1999+
2000+
[ConditionalTheory] // #32374
2001+
[MemberData(nameof(IsAsyncData))]
2002+
public virtual Task Contains_inside_Min_without_GroupBy(bool async)
2003+
{
2004+
var cities = new[] { "London", "Berlin" };
2005+
2006+
return AssertMin(
2007+
async,
2008+
ss => ss.Set<Customer>(),
2009+
selector: c => cities.Contains(c.City) ? 1 : 0);
2010+
}
19262011
}

0 commit comments

Comments
 (0)