diff --git a/src/Ardalis.Specification/Evaluators/OrderEvaluator.cs b/src/Ardalis.Specification/Evaluators/OrderEvaluator.cs index 94d39683..07fb097a 100644 --- a/src/Ardalis.Specification/Evaluators/OrderEvaluator.cs +++ b/src/Ardalis.Specification/Evaluators/OrderEvaluator.cs @@ -9,39 +9,49 @@ private OrderEvaluator() { } public IQueryable GetQuery(IQueryable query, ISpecification specification) where T : class { - if (specification.OrderExpressions != null) + if (specification is Specification spec) { - if (specification.OrderExpressions.Count(x => x.OrderType == OrderTypeEnum.OrderBy - || x.OrderType == OrderTypeEnum.OrderByDescending) > 1) + if (spec.OneOrManyOrderExpressions.IsEmpty) return query; + if (spec.OneOrManyOrderExpressions.SingleOrDefault is { } orderExpression) { - throw new DuplicateOrderChainException(); + return orderExpression.OrderType switch + { + OrderTypeEnum.OrderBy => query.OrderBy(orderExpression.KeySelector), + OrderTypeEnum.OrderByDescending => query.OrderByDescending(orderExpression.KeySelector), + _ => query + }; } + } - IOrderedQueryable? orderedQuery = null; - foreach (var orderExpression in specification.OrderExpressions) + IOrderedQueryable? orderedQuery = null; + var chainCount = 0; + foreach (var orderExpression in specification.OrderExpressions) + { + if (orderExpression.OrderType == OrderTypeEnum.OrderBy) { - if (orderExpression.OrderType == OrderTypeEnum.OrderBy) - { - orderedQuery = query.OrderBy(orderExpression.KeySelector); - } - else if (orderExpression.OrderType == OrderTypeEnum.OrderByDescending) - { - orderedQuery = query.OrderByDescending(orderExpression.KeySelector); - } - else if (orderExpression.OrderType == OrderTypeEnum.ThenBy) - { - orderedQuery = orderedQuery!.ThenBy(orderExpression.KeySelector); - } - else if (orderExpression.OrderType == OrderTypeEnum.ThenByDescending) - { - orderedQuery = orderedQuery!.ThenByDescending(orderExpression.KeySelector); - } + chainCount++; + if (chainCount == 2) throw new DuplicateOrderChainException(); + orderedQuery = query.OrderBy(orderExpression.KeySelector); } - - if (orderedQuery != null) + else if (orderExpression.OrderType == OrderTypeEnum.OrderByDescending) { - query = orderedQuery; + chainCount++; + if (chainCount == 2) throw new DuplicateOrderChainException(); + orderedQuery = query.OrderByDescending(orderExpression.KeySelector); } + else if (orderExpression.OrderType == OrderTypeEnum.ThenBy) + { + orderedQuery = orderedQuery!.ThenBy(orderExpression.KeySelector); + } + else if (orderExpression.OrderType == OrderTypeEnum.ThenByDescending) + { + orderedQuery = orderedQuery!.ThenByDescending(orderExpression.KeySelector); + } + } + + if (orderedQuery is not null) + { + query = orderedQuery; } return query; @@ -49,39 +59,49 @@ public IQueryable GetQuery(IQueryable query, ISpecification specific public IEnumerable Evaluate(IEnumerable query, ISpecification specification) { - if (specification.OrderExpressions != null) + if (specification is Specification spec) { - if (specification.OrderExpressions.Count(x => x.OrderType == OrderTypeEnum.OrderBy - || x.OrderType == OrderTypeEnum.OrderByDescending) > 1) + if (spec.OneOrManyOrderExpressions.IsEmpty) return query; + if (spec.OneOrManyOrderExpressions.SingleOrDefault is { } orderExpression) { - throw new DuplicateOrderChainException(); + return orderExpression.OrderType switch + { + OrderTypeEnum.OrderBy => query.OrderBy(orderExpression.KeySelectorFunc), + OrderTypeEnum.OrderByDescending => query.OrderByDescending(orderExpression.KeySelectorFunc), + _ => query + }; } + } - IOrderedEnumerable? orderedQuery = null; - foreach (var orderExpression in specification.OrderExpressions) + IOrderedEnumerable? orderedQuery = null; + var chainCount = 0; + foreach (var orderExpression in specification.OrderExpressions) + { + if (orderExpression.OrderType == OrderTypeEnum.OrderBy) { - if (orderExpression.OrderType == OrderTypeEnum.OrderBy) - { - orderedQuery = query.OrderBy(orderExpression.KeySelectorFunc); - } - else if (orderExpression.OrderType == OrderTypeEnum.OrderByDescending) - { - orderedQuery = query.OrderByDescending(orderExpression.KeySelectorFunc); - } - else if (orderExpression.OrderType == OrderTypeEnum.ThenBy) - { - orderedQuery = orderedQuery!.ThenBy(orderExpression.KeySelectorFunc); - } - else if (orderExpression.OrderType == OrderTypeEnum.ThenByDescending) - { - orderedQuery = orderedQuery!.ThenByDescending(orderExpression.KeySelectorFunc); - } + chainCount++; + if (chainCount == 2) throw new DuplicateOrderChainException(); + orderedQuery = query.OrderBy(orderExpression.KeySelectorFunc); } - - if (orderedQuery != null) + else if (orderExpression.OrderType == OrderTypeEnum.OrderByDescending) { - query = orderedQuery; + chainCount++; + if (chainCount == 2) throw new DuplicateOrderChainException(); + orderedQuery = query.OrderByDescending(orderExpression.KeySelectorFunc); } + else if (orderExpression.OrderType == OrderTypeEnum.ThenBy) + { + orderedQuery = orderedQuery!.ThenBy(orderExpression.KeySelectorFunc); + } + else if (orderExpression.OrderType == OrderTypeEnum.ThenByDescending) + { + orderedQuery = orderedQuery!.ThenByDescending(orderExpression.KeySelectorFunc); + } + } + + if (orderedQuery is not null) + { + query = orderedQuery; } return query; diff --git a/src/Ardalis.Specification/Specification.cs b/src/Ardalis.Specification/Specification.cs index 798edc7f..abdde5da 100644 --- a/src/Ardalis.Specification/Specification.cs +++ b/src/Ardalis.Specification/Specification.cs @@ -27,7 +27,6 @@ public class Specification : Specification, ISpecification : ISpecification { private const int DEFAULT_CAPACITY_SEARCH = 2; - private const int DEFAULT_CAPACITY_ORDER = 2; private const int DEFAULT_CAPACITY_INCLUDE = 2; private const int DEFAULT_CAPACITY_INCLUDESTRING = 1; @@ -43,7 +42,7 @@ public class Specification : ISpecification // This will be reconsidered for version 10 where we may store the whole state as a single array of structs. private OneOrMany> _whereExpressions = new(); private List>? _searchExpressions; - private List>? _orderExpressions; + private OneOrMany> _orderExpressions = new(); private List? _includeExpressions; private List? _includeStrings; private Dictionary? _items; @@ -94,7 +93,7 @@ public class Specification : ISpecification // Specs are not intended to be thread-safe, so we don't need to worry about thread-safety here. internal void Add(WhereExpressionInfo whereExpression) => _whereExpressions.Add(whereExpression); - internal void Add(OrderExpressionInfo orderExpression) => (_orderExpressions ??= new(DEFAULT_CAPACITY_ORDER)).Add(orderExpression); + internal void Add(OrderExpressionInfo orderExpression) => _orderExpressions.Add(orderExpression); internal void Add(IncludeExpressionInfo includeExpression) => (_includeExpressions ??= new(DEFAULT_CAPACITY_INCLUDE)).Add(includeExpression); internal void Add(string includeString) => (_includeStrings ??= new(DEFAULT_CAPACITY_INCLUDESTRING)).Add(includeString); internal void Add(SearchExpressionInfo searchExpression) @@ -130,7 +129,7 @@ internal void Add(SearchExpressionInfo searchExpression) public IEnumerable> SearchCriterias => _searchExpressions ?? Enumerable.Empty>(); /// - public IEnumerable> OrderExpressions => _orderExpressions ?? Enumerable.Empty>(); + public IEnumerable> OrderExpressions => _orderExpressions.Values; /// public IEnumerable IncludeExpressions => _includeExpressions ?? Enumerable.Empty(); @@ -142,6 +141,7 @@ internal void Add(SearchExpressionInfo searchExpression) public IEnumerable QueryTags => _queryTags.Values; internal OneOrMany> OneOrManyWhereExpressions => _whereExpressions; + internal OneOrMany> OneOrManyOrderExpressions => _orderExpressions; internal OneOrMany OneOrManyQueryTags => _queryTags; /// @@ -189,9 +189,9 @@ void ISpecification.CopyTo(Specification otherSpec) otherSpec._includeStrings = _includeStrings.ToList(); } - if (_orderExpressions is not null) + if (!_orderExpressions.IsEmpty) { - otherSpec._orderExpressions = _orderExpressions.ToList(); + otherSpec._orderExpressions = _orderExpressions.Clone(); } if (_searchExpressions is not null) diff --git a/tests/Ardalis.Specification.Tests/Evaluators/OrderEvaluatorTests.cs b/tests/Ardalis.Specification.Tests/Evaluators/OrderEvaluatorTests.cs index 554af795..7f2b8bb7 100644 --- a/tests/Ardalis.Specification.Tests/Evaluators/OrderEvaluatorTests.cs +++ b/tests/Ardalis.Specification.Tests/Evaluators/OrderEvaluatorTests.cs @@ -7,7 +7,7 @@ public class OrderEvaluatorTests public record Customer(int Id, string? Name = null); [Fact] - public void ThrowsDuplicateOrderChainException_GivenMultipleOrderChains() + public void ThrowsDuplicateOrderChainException_GivenMultipleOrderByChains() { List input = [new(3), new(1), new(2), new(5), new(4)]; List expected = [new(1), new(2), new(3), new(4), new(5)]; @@ -24,6 +24,24 @@ public void ThrowsDuplicateOrderChainException_GivenMultipleOrderChains() sut2.Should().Throw(); } + [Fact] + public void ThrowsDuplicateOrderChainException_GivenMultipleOrderByDescendingChains() + { + List input = [new(3), new(1), new(2), new(5), new(4)]; + List expected = [new(1), new(2), new(3), new(4), new(5)]; + + var spec = new Specification(); + spec.Query + .OrderByDescending(x => x.Id) + .OrderByDescending(x => x.Name); + + var sut1 = new Action(() => _evaluator.Evaluate(input, spec)); + var sut2 = new Action(() => _evaluator.GetQuery(input.AsQueryable(), spec)); + + sut1.Should().Throw(); + sut2.Should().Throw(); + } + [Fact] public void OrdersItemsAscending_GivenOrderBy() { @@ -50,6 +68,19 @@ public void OrdersItemsDescending_GivenOrderByDescending() Assert(spec, input, expected); } + [Fact] + public void DoesNothing_GivenInvalidRootChain() + { + List input = [new(3), new(1), new(2), new(5), new(4)]; + List expected = [new(3), new(1), new(2), new(5), new(4)]; + + var spec = new Specification(); + var expr = new OrderExpressionInfo(x => x.Id, OrderTypeEnum.ThenBy); + spec.Add(expr); + + Assert(spec, input, expected); + } + [Fact] public void OrdersItems_GivenOrderByThenBy() {