Skip to content

Commit 6e79796

Browse files
LilithHafnerLilith Hafner
andauthored
Fix-ups for sorting workspace/buffer (#45330) (#45570)
* Fix and test sort!(OffsetArray(rand(200), -10)) * Convert to 1-based indexing rather than generalize to arbitrary indexing * avoid overhead of views where reasonable * style * handle edge cases better, making the workspace function unhelpful. Also minor style changes and fixups from #45596 and local review. * move comments in tests for discoverability Co-authored-by: Lilith Hafner <[email protected]>
1 parent fa2f304 commit 6e79796

File tree

2 files changed

+38
-28
lines changed

2 files changed

+38
-28
lines changed

base/sort.jl

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ module Sort
55
import ..@__MODULE__, ..parentmodule
66
const Base = parentmodule(@__MODULE__)
77
using .Base.Order
8-
using .Base: copymutable, LinearIndices, length, (:), iterate, elsize,
8+
using .Base: copymutable, LinearIndices, length, (:), iterate, OneTo,
99
eachindex, axes, first, last, similar, zip, OrdinalRange, firstindex, lastindex,
1010
AbstractVector, @inbounds, AbstractRange, @eval, @inline, Vector, @noinline,
1111
AbstractMatrix, AbstractUnitRange, isless, identity, eltype, >, <, <=, >=, |, +, -, *, !,
@@ -605,7 +605,10 @@ function sort!(v::AbstractVector{T}, lo::Integer, hi::Integer, a::MergeSortAlg,
605605
hi-lo <= SMALL_THRESHOLD && return sort!(v, lo, hi, SMALL_ALGORITHM, o)
606606

607607
m = midpoint(lo, hi)
608-
t = workspace(v, t0, m-lo+1)
608+
609+
t = t0 === nothing ? similar(v, m-lo+1) : t0
610+
length(t) < m-lo+1 && resize!(t, m-lo+1)
611+
Base.require_one_based_indexing(t)
609612

610613
sort!(v, lo, m, a, o, t)
611614
sort!(v, m+1, hi, a, o, t)
@@ -683,7 +686,7 @@ function radix_sort!(v::AbstractVector{U}, lo::Integer, hi::Integer, bits::Unsig
683686
t::AbstractVector{U}, chunk_size=radix_chunk_size_heuristic(lo, hi, bits)) where U <: Unsigned
684687
# bits is unsigned for performance reasons.
685688
mask = UInt(1) << chunk_size - 1
686-
counts = Vector{UInt}(undef, mask+2)
689+
counts = Vector{Int}(undef, mask+2)
687690

688691
@inbounds for shift in 0:chunk_size:bits-1
689692

@@ -732,6 +735,7 @@ end
732735

733736
# For AbstractVector{Bool}, counting sort is always best.
734737
# This is an implementation of counting sort specialized for Bools.
738+
# Accepts unused workspace to avoid method ambiguity.
735739
function sort!(v::AbstractVector{B}, lo::Integer, hi::Integer, a::AdaptiveSort, o::Ordering,
736740
t::Union{AbstractVector{B}, Nothing}=nothing) where {B <: Bool}
737741
first = lt(o, false, true) ? false : lt(o, true, false) ? true : return v
@@ -746,10 +750,6 @@ function sort!(v::AbstractVector{B}, lo::Integer, hi::Integer, a::AdaptiveSort,
746750
v
747751
end
748752

749-
workspace(v::AbstractVector, ::Nothing, len::Integer) = similar(v, len)
750-
function workspace(v::AbstractVector{T}, t::AbstractVector{T}, len::Integer) where T
751-
length(t) < len ? resize!(t, len) : t
752-
end
753753
maybe_unsigned(x::Integer) = x # this is necessary to avoid calling unsigned on BigInt
754754
maybe_unsigned(x::BitSigned) = unsigned(x)
755755
function _extrema(v::AbstractVector, lo::Integer, hi::Integer, o::Ordering)
@@ -856,8 +856,18 @@ function sort!(v::AbstractVector{T}, lo::Integer, hi::Integer, a::AdaptiveSort,
856856
u[i] -= u_min
857857
end
858858

859-
u2 = radix_sort!(u, lo, hi, bits, reinterpret(U, workspace(v, t, hi)))
860-
uint_unmap!(v, u2, lo, hi, o, u_min)
859+
if t !== nothing && checkbounds(Bool, t, lo:hi) # Fully preallocated and aligned workspace
860+
u2 = radix_sort!(u, lo, hi, bits, reinterpret(U, t))
861+
uint_unmap!(v, u2, lo, hi, o, u_min)
862+
elseif t !== nothing && (applicable(resize!, t) || length(t) >= hi-lo+1) # Viable workspace
863+
length(t) >= hi-lo+1 || resize!(t, hi-lo+1)
864+
t1 = axes(t, 1) isa OneTo ? t : view(t, firstindex(t):lastindex(t))
865+
u2 = radix_sort!(view(u, lo:hi), 1, hi-lo+1, bits, reinterpret(U, t1))
866+
uint_unmap!(view(v, lo:hi), u2, 1, hi-lo+1, o, u_min)
867+
else # No viable workspace
868+
u2 = radix_sort!(u, lo, hi, bits, similar(u))
869+
uint_unmap!(v, u2, lo, hi, o, u_min)
870+
end
861871
end
862872

863873
## generic sorting methods ##
@@ -1113,7 +1123,7 @@ function sortperm(v::AbstractVector;
11131123
by=identity,
11141124
rev::Union{Bool,Nothing}=nothing,
11151125
order::Ordering=Forward,
1116-
workspace::Union{AbstractVector, Nothing}=nothing)
1126+
workspace::Union{AbstractVector{<:Integer}, Nothing}=nothing)
11171127
ordr = ord(lt,by,rev,order)
11181128
if ordr === Forward && isa(v,Vector) && eltype(v)<:Integer
11191129
n = length(v)
@@ -1235,7 +1245,7 @@ function sort(A::AbstractArray{T};
12351245
by=identity,
12361246
rev::Union{Bool,Nothing}=nothing,
12371247
order::Ordering=Forward,
1238-
workspace::Union{AbstractVector{T}, Nothing}=similar(A, 0)) where T
1248+
workspace::Union{AbstractVector{T}, Nothing}=similar(A, size(A, dims))) where T
12391249
dim = dims
12401250
order = ord(lt,by,rev,order)
12411251
n = length(axes(A, dim))
@@ -1296,7 +1306,7 @@ function sort!(A::AbstractArray{T};
12961306
by=identity,
12971307
rev::Union{Bool,Nothing}=nothing,
12981308
order::Ordering=Forward,
1299-
workspace::Union{AbstractVector{T}, Nothing}=nothing) where T
1309+
workspace::Union{AbstractVector{T}, Nothing}=similar(A, size(A, dims))) where T
13001310
ordr = ord(lt, by, rev, order)
13011311
nd = ndims(A)
13021312
k = dims
@@ -1523,8 +1533,8 @@ issignleft(o::ForwardOrdering, x::Floats) = lt(o, x, zero(x))
15231533
issignleft(o::ReverseOrdering, x::Floats) = lt(o, x, -zero(x))
15241534
issignleft(o::Perm, i::Integer) = issignleft(o.order, o.data[i])
15251535

1526-
function fpsort!(v::AbstractVector, a::Algorithm, o::Ordering,
1527-
t::Union{AbstractVector, Nothing}=nothing)
1536+
function fpsort!(v::AbstractVector{T}, a::Algorithm, o::Ordering,
1537+
t::Union{AbstractVector{T}, Nothing}=nothing) where T
15281538
# fpsort!'s optimizations speed up comparisons, of which there are O(nlogn).
15291539
# The overhead is O(n). For n < 10, it's not worth it.
15301540
length(v) < 10 && return sort!(v, firstindex(v), lastindex(v), SMALL_ALGORITHM, o, t)
@@ -1550,8 +1560,8 @@ function sort!(v::FPSortable, a::Algorithm, o::DirectOrdering,
15501560
t::Union{FPSortable, Nothing}=nothing)
15511561
fpsort!(v, a, o, t)
15521562
end
1553-
function sort!(v::AbstractVector{<:Union{Signed, Unsigned}}, a::Algorithm,
1554-
o::Perm{<:DirectOrdering,<:FPSortable}, t::Union{AbstractVector, Nothing}=nothing)
1563+
function sort!(v::AbstractVector{T}, a::Algorithm, o::Perm{<:DirectOrdering,<:FPSortable},
1564+
t::Union{AbstractVector{T}, Nothing}=nothing) where T <: Union{Signed, Unsigned}
15551565
fpsort!(v, a, o, t)
15561566
end
15571567

test/sorting.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -513,6 +513,16 @@ end
513513
@test issorted(a)
514514
end
515515

516+
@testset "sort!(::OffsetVector)" begin
517+
for length in vcat(0:5, [10, 300, 500, 1000])
518+
for offset in [-100000, -10, -1, 0, 1, 17, 1729]
519+
x = OffsetVector(rand(length), offset)
520+
sort!(x)
521+
@test issorted(x)
522+
end
523+
end
524+
end
525+
516526
@testset "sort!(::OffsetMatrix; dims)" begin
517527
x = OffsetMatrix(rand(5,5), 5, -5)
518528
sort!(x; dims=1)
@@ -654,17 +664,6 @@ end
654664
end
655665
end
656666

657-
@testset "workspace()" begin
658-
for v in [[1, 2, 3], [0.0]]
659-
for t0 in vcat([nothing], [similar(v,i) for i in 1:5]), len in 0:5
660-
t = Base.Sort.workspace(v, t0, len)
661-
@test eltype(t) == eltype(v)
662-
@test length(t) >= len
663-
@test firstindex(t) == 1
664-
end
665-
end
666-
end
667-
668667
@testset "sort(x; workspace=w) " begin
669668
for n in [1,10,100,1000]
670669
v = rand(n)
@@ -681,7 +680,7 @@ end
681680
end
682681
end
683682

684-
683+
# This testset is at the end of the file because it is slow.
685684
@testset "searchsorted" begin
686685
numTypes = [ Int8, Int16, Int32, Int64, Int128,
687686
UInt8, UInt16, UInt32, UInt64, UInt128,
@@ -842,5 +841,6 @@ end
842841
end
843842
end
844843
end
844+
# The "searchsorted" testset is at the end of the file because it is slow.
845845

846846
end

0 commit comments

Comments
 (0)