Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions base/sort.jl
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,7 @@ end

# For AbstractVector{Bool}, counting sort is always best.
# This is an implementation of counting sort specialized for Bools.
# Accepts unused buffer to avoid method ambiguity.
# Accepts unused scratch space to avoid method ambiguity.
function sort!(v::AbstractVector{Bool}, lo::Integer, hi::Integer, ::AdaptiveSortAlg, o::Ordering,
t::Union{AbstractVector{Bool}, Nothing}=nothing)
first = lt(o, false, true) ? false : lt(o, true, false) ? true : return v
Expand Down Expand Up @@ -856,15 +856,15 @@ function sort!(v::AbstractVector{T}, lo::Integer, hi::Integer, ::AdaptiveSortAlg
end

len = lenm1 + 1
if t !== nothing && checkbounds(Bool, t, lo:hi) # Fully preallocated and aligned buffer
if t !== nothing && checkbounds(Bool, t, lo:hi) # Fully preallocated and aligned scratch space
u2 = radix_sort!(u, lo, hi, bits, reinterpret(U, t))
uint_unmap!(v, u2, lo, hi, o, u_min)
elseif t !== nothing && (applicable(resize!, t, len) || length(t) >= len) # Viable buffer
elseif t !== nothing && (applicable(resize!, t, len) || length(t) >= len) # Viable scratch space
length(t) >= len || resize!(t, len)
t1 = axes(t, 1) isa OneTo ? t : view(t, firstindex(t):lastindex(t))
u2 = radix_sort!(view(u, lo:hi), 1, len, bits, reinterpret(U, t1))
uint_unmap!(view(v, lo:hi), u2, 1, len, o, u_min)
else # No viable buffer
else # No viable scratch space
u2 = radix_sort!(u, lo, hi, bits, similar(u))
uint_unmap!(v, u2, lo, hi, o, u_min)
end
Expand Down Expand Up @@ -930,8 +930,8 @@ function sort!(v::AbstractVector{T};
by=identity,
rev::Union{Bool,Nothing}=nothing,
order::Ordering=Forward,
buffer::Union{AbstractVector{T}, Nothing}=nothing) where T
sort!(v, alg, ord(lt,by,rev,order), buffer)
scratch::Union{AbstractVector{T}, Nothing}=nothing) where T
sort!(v, alg, ord(lt,by,rev,order), scratch)
end

# sort! for vectors of few unique integers
Expand Down Expand Up @@ -1070,7 +1070,7 @@ function partialsortperm!(ix::AbstractVector{<:Integer}, v::AbstractVector,
order::Ordering=Forward,
initialized::Bool=false)
if axes(ix,1) != axes(v,1)
throw(ArgumentError("The index vector is used as a buffer and must have the " *
throw(ArgumentError("The index vector is used as scratch space and must have the " *
"same length/indices as the source vector, $(axes(ix,1)) != $(axes(v,1))"))
end
if !initialized
Expand Down Expand Up @@ -1137,7 +1137,7 @@ function sortperm(A::AbstractArray;
by=identity,
rev::Union{Bool,Nothing}=nothing,
order::Ordering=Forward,
buffer::Union{AbstractVector{<:Integer}, Nothing}=nothing,
scratch::Union{AbstractVector{<:Integer}, Nothing}=nothing,
dims...) #to optionally specify dims argument
ordr = ord(lt,by,rev,order)
if ordr === Forward && isa(A,Vector) && eltype(A)<:Integer
Expand All @@ -1152,7 +1152,7 @@ function sortperm(A::AbstractArray;
end
end
ix = copymutable(LinearIndices(A))
sort!(ix; alg, order = Perm(ordr, vec(A)), buffer, dims...)
sort!(ix; alg, order = Perm(ordr, vec(A)), scratch, dims...)
end


Expand Down Expand Up @@ -1198,15 +1198,15 @@ function sortperm!(ix::AbstractArray{T}, A::AbstractArray;
rev::Union{Bool,Nothing}=nothing,
order::Ordering=Forward,
initialized::Bool=false,
buffer::Union{AbstractVector{T}, Nothing}=nothing,
scratch::Union{AbstractVector{T}, Nothing}=nothing,
dims...) where T <: Integer #to optionally specify dims argument
(typeof(A) <: AbstractVector) == (:dims in keys(dims)) && throw(ArgumentError("Dims argument incorrect for type $(typeof(A))"))
axes(ix) == axes(A) || throw(ArgumentError("index array must have the same size/axes as the source array, $(axes(ix)) != $(axes(A))"))

if !initialized
ix .= LinearIndices(A)
end
sort!(ix; alg, order = Perm(ord(lt, by, rev, order), vec(A)), buffer, dims...)
sort!(ix; alg, order = Perm(ord(lt, by, rev, order), vec(A)), scratch, dims...)
end

# sortperm for vectors of few unique integers
Expand Down Expand Up @@ -1271,19 +1271,19 @@ function sort(A::AbstractArray{T};
by=identity,
rev::Union{Bool,Nothing}=nothing,
order::Ordering=Forward,
buffer::Union{AbstractVector{T}, Nothing}=similar(A, size(A, dims))) where T
scratch::Union{AbstractVector{T}, Nothing}=similar(A, size(A, dims))) where T
dim = dims
order = ord(lt,by,rev,order)
n = length(axes(A, dim))
if dim != 1
pdims = (dim, setdiff(1:ndims(A), dim)...) # put the selected dimension first
Ap = permutedims(A, pdims)
Av = vec(Ap)
sort_chunks!(Av, n, alg, order, buffer)
sort_chunks!(Av, n, alg, order, scratch)
permutedims(Ap, invperm(pdims))
else
Av = A[:]
sort_chunks!(Av, n, alg, order, buffer)
sort_chunks!(Av, n, alg, order, scratch)
reshape(Av, axes(A))
end
end
Expand Down Expand Up @@ -1332,21 +1332,21 @@ function sort!(A::AbstractArray{T};
by=identity,
rev::Union{Bool,Nothing}=nothing,
order::Ordering=Forward,
buffer::Union{AbstractVector{T}, Nothing}=similar(A, size(A, dims))) where T
_sort!(A, Val(dims), alg, ord(lt, by, rev, order), buffer)
scratch::Union{AbstractVector{T}, Nothing}=similar(A, size(A, dims))) where T
_sort!(A, Val(dims), alg, ord(lt, by, rev, order), scratch)
end
function _sort!(A::AbstractArray{T}, ::Val{K},
alg::Algorithm,
order::Ordering,
buffer::Union{AbstractVector{T}, Nothing}) where {K,T}
scratch::Union{AbstractVector{T}, Nothing}) where {K,T}
nd = ndims(A)

1 <= K <= nd || throw(ArgumentError("dimension out of range"))

remdims = ntuple(i -> i == K ? 1 : axes(A, i), nd)
for idx in CartesianIndices(remdims)
Av = view(A, ntuple(i -> i == K ? Colon() : idx[i], nd)...)
sort!(Av, alg, order, buffer)
sort!(Av, alg, order, scratch)
end
A
end
Expand Down
16 changes: 8 additions & 8 deletions test/sorting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -740,19 +740,19 @@ end
end

# This testset is at the end of the file because it is slow
@testset "sort(x; buffer)" begin
@testset "sort(x; scratch)" begin
for n in [1,10,100,1000]
v = rand(n)
buffer = [0.0]
@test sort(v) == sort(v; buffer)
@test sort!(copy(v)) == sort!(copy(v); buffer)
@test sortperm(v) == sortperm(v; buffer=[4])
@test sortperm!(Vector{Int}(undef, n), v) == sortperm!(Vector{Int}(undef, n), v; buffer=[4])
scratch = [0.0]
@test sort(v) == sort(v; scratch)
@test sort!(copy(v)) == sort!(copy(v); scratch)
@test sortperm(v) == sortperm(v; scratch=[4])
@test sortperm!(Vector{Int}(undef, n), v) == sortperm!(Vector{Int}(undef, n), v; scratch=[4])

n > 100 && continue
M = rand(n, n)
@test sort(M; dims=2) == sort(M; dims=2, buffer)
@test sort!(copy(M); dims=1) == sort!(copy(M); dims=1, buffer)
@test sort(M; dims=2) == sort(M; dims=2, scratch)
@test sort!(copy(M); dims=1) == sort!(copy(M); dims=1, scratch)
end
end

Expand Down