Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions REQUIRE
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
julia 0.5
IntervalSets
IterTools
RangeArrays
Compat 0.19
5 changes: 4 additions & 1 deletion src/AxisArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ __precompile__()
module AxisArrays

using Base: tail
import Base.Iterators: repeated
using RangeArrays, IntervalSets
using IterTools
using Compat

export AxisArray, Axis, axisnames, axisvalues, axisdim, axes, atindex, atvalue
export AxisArray, Axis, axisnames, axisvalues, axisdim, axes, atindex, atvalue, flatten

# From IntervalSets:
export ClosedInterval, ..
Expand All @@ -17,6 +19,7 @@ include("intervals.jl")
include("search.jl")
include("indexing.jl")
include("sortedvector.jl")
include("categoricalvector.jl")
include("combine.jl")

end
85 changes: 85 additions & 0 deletions src/categoricalvector.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@

export CategoricalVector

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use a different name for this, as CategoricalVector is already used in CategoricalArrays, which replaced PooledDataArray in DataTables (and soon in DataFrames). Why not CategoricalAxis?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not an Axis, and it mirrors the SortedVector type.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, IIUC its only purpose is to treat it as a categorical axis, isn't it? Ideas about other possible names? It would be too bad to have conflicts when loading both AxisArrays and DataTables/DataFrames.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could choose not to export it?

There aren't conflicts when you use both packages unless you also use CategoricalVector.

The nomenclature used within AxisArrays is Categorical, which is how CategoricalVector came up.

How about DiscreteVector?


"""
A CategoricalVector is an AbstractVector which is treated as a categorical axis regardless

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing backticks around types in this docstring.

of the element type. Duplicate values are not allowed but are not filtered out.
A CategoricalVector axis can be indexed with an ClosedInterval, with a value, or with a
vector of values. Use of a CategoricalVector{Tuple} axis allows indexing similar to the
hierarchical index of the Python Pandas package or the R data.table package.
In general, indexing into a CategoricalVector will be much slower than the corresponding
SortedVector or another sorted axis type, as linear search is required.
### Constructors
```julia
CategoricalVector(x::AbstractVector)
```
### Arguments
* `x::AbstractVector` : the wrapped vector
### Examples
```julia
v = CategoricalVector(collect([1; 8; 10:15]))
A = AxisArray(reshape(1:16, 8, 2), v, [:a, :b])
A[Axis{:row}(1), :]
A[Axis{:row}(10), :]
A[Axis{:row}([1, 10]), :]
## Hierarchical index example with three key levels
data = reshape(1.:40., 20, 2)
v = collect(zip([:a, :b, :c][rand(1:3,20)], [:x,:y][rand(1:2,20)], [:x,:y][rand(1:2,20)]))
A = AxisArray(data, CategoricalVector(v), [:a, :b])
A[:b, :]
A[[:a,:c], :]
A[(:a,:x), :]
A[(:a,:x,:x), :]
```
"""
immutable CategoricalVector{T, A<:AbstractVector{T}} <: AbstractVector{T}
data::A
end

function CategoricalVector(data::AbstractVector{T}) where T
CategoricalVector{T, typeof(data)}(data)
end

Base.getindex(v::CategoricalVector, idx::Int) = v.data[idx]
Base.getindex(v::CategoricalVector, idx::AbstractVector) = CategoricalVector(v.data[idx])

Base.length(v::CategoricalVector) = length(v.data)
Base.size(v::CategoricalVector) = size(v.data)
Base.size(v::CategoricalVector, i) = size(v.data, i)
Base.indices(v::CategoricalVector) = indices(v.data)

axistrait(::Type{CategoricalVector{T,A}}) where {T,A} = Categorical
checkaxis(::CategoricalVector) = nothing


## Add some special indexing for CategoricalVector{Tuple}'s to achieve something like
## Panda's hierarchical indexing

axisindexes{T<:Tuple,S,A}(ax::Axis{S,CategoricalVector{T,A}}, idx) = axisindexes(ax, (idx,))

function axisindexes{T<:Tuple,S,A}(ax::Axis{S,CategoricalVector{T,A}}, idx::Tuple)
collect(filter(ax_idx->_tuple_matches(ax.val[ax_idx], idx), indices(ax.val)...))
end

function _tuple_matches(element::Tuple, idx::Tuple)
length(idx) <= length(element) || return false

for (x, y) in zip(element, idx)
x == y || return false
end

return true
end

axisindexes{T<:Tuple,S,A}(ax::Axis{S,CategoricalVector{T,A}}, idx::AbstractArray) =
vcat([axisindexes(ax, i) for i in idx]...)
120 changes: 120 additions & 0 deletions src/combine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,123 @@ function Base.join{T,N,D,Ax}(As::AxisArray{T,N,D,Ax}...; fillvalue::T=zero(T),
return result

end #join

function _flatten_array_axes(array_name, array_axes...)
((array_name, (idx isa Tuple ? idx : (idx,))...) for idx in product((Ax.val for Ax in array_axes)...))
end

function _flatten_axes(array_names, array_axes)
collect(Iterators.flatten(map(array_names, array_axes) do tup_name, tup_array_axes
_flatten_array_axes(tup_name, tup_array_axes...)
end))
end

function _splitall{N}(::Type{Val{N}}, As...)
tuple((Base.IteratorsMD.split(A, Val{N}) for A in As)...)
end

function _reshapeall{N}(::Type{Val{N}}, As...)
tuple((reshape(A, Val{N}) for A in As)...)
end

function _check_common_axes(common_axis_tuple)
if !all(axisname(first(common_axis_tuple)) .=== axisname.(common_axis_tuple[2:end]))
throw(ArgumentError("Leading common axes must have the same name in each array"))
end

return nothing
end

function _flat_axis_eltype(LType, trailing_axes)
eltypes = map(trailing_axes) do array_trailing_axes
Tuple{LType, eltype.(array_trailing_axes)...}
end

return typejoin(eltypes...)
end

function flatten{N, NA}(::Type{Val{N}}, As::Vararg{AxisArray, NA})
flatten(Val{N}, ntuple(identity, Val{NA}), As...)
end

"""
flatten(As::AxisArray...) -> AxisArray
flatten(last_dim::Type{Val{N}}, As::AxisArray...) -> AxisArray
flatten(last_dim::Type{Val{N}}, labels::Tuple, As::AxisArray...) -> AxisArray

Concatenates AxisArrays with N equal leading axes into a single AxisArray.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Backquotes around type names (here and elsewhere).

All additional axes in any of the arrays are flattened into a single additional
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

collapsed

CategoricalVector{Tuple} axis.

### Arguments

* `::Type{Val{N}}`: the greatest common dimension to share between all input
arrays. The remaining axes are flattened. All N axes must be common
to each input array, at the same dimension. Values from 0 up to the
minimum number of dimensions across all input arrays are allowed.
* `labels::Tuple`: (optional) a label for each AxisArray in As which is used in the flat
axis
* `As::AxisArray...`: AxisArrays to be flattened together.
"""
@generated function flatten{N, AN, LType}(::Type{Val{N}}, labels::NTuple{AN, LType}, As::Vararg{AxisArray, AN})
if N < 0
throw(ArgumentError("flatten dimension N must be at least 0"))
end

if N > minimum(ndims.(As))
throw(ArgumentError(
"""
flatten dimension N must not be greater than the maximum number of dimensions
across all input arrays
"""
))
end

flat_dim = Val{N + 1}
flat_dim_int = Int(N) + 1

common_axes, trailing_axes = zip(_splitall(Val{N}, axisparams.(As)...)...)

foreach(_check_common_axes, zip(common_axes...))

new_common_axes = first(common_axes)
flat_axis_eltype = _flat_axis_eltype(LType, trailing_axes)
flat_axis_type = CategoricalVector{flat_axis_eltype, Vector{flat_axis_eltype}}

new_axes_type = Tuple{new_common_axes..., Axis{:flat, flat_axis_type}}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:collapsed?

new_eltype = Base.promote_eltype(As...)

quote
common_axes, trailing_axes = zip(_splitall(Val{N}, axes.(As)...)...)

for common_axis_tuple in zip(common_axes...)
if !isempty(common_axis_tuple)
for common_axis in common_axis_tuple[2:end]
if !all(axisvalues(common_axis) .== axisvalues(common_axis_tuple[1]))
throw(ArgumentError(
"""
Leading common axes must be identical across
all input arrays"""
))
end
end
end
end

array_data = cat($flat_dim, _reshapeall($flat_dim, As...)...)

axis_array_type = AxisArray{
$new_eltype,
$flat_dim_int,
Array{$new_eltype, $flat_dim_int},
$new_axes_type
}

new_axes = (
first(common_axes)...,
Axis{:flat, $flat_axis_type}($flat_axis_type(_flatten_axes(labels, trailing_axes))),
)

return axis_array_type(array_data, new_axes)
end
end
23 changes: 18 additions & 5 deletions src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,15 @@ end
axes(A::AbstractArray) = default_axes(A)
axes(A::AbstractArray, dim::Int) = default_axes(A)[dim]

"""
axisparams(::AxisArray) -> Vararg{::Type{Axis}}
axisparams(::Type{AxisArray}) -> Vararg{::Type{Axis}}

Returns the axis parameters for an AxisArray.
"""
axisparams{T,N,D,Ax}(::AxisArray{T,N,D,Ax}) = (Ax.parameters...)
axisparams{T,N,D,Ax}(::Type{AxisArray{T,N,D,Ax}}) = (Ax.parameters...)

### Axis traits ###
@compat abstract type AxisTrait end
immutable Dimensional <: AxisTrait end
Expand All @@ -516,6 +525,7 @@ immutable Unsupported <: AxisTrait end

"""
axistrait(ax::Axis) -> Type{<:AxisTrait}
axistrait{T}(::Type{T}) -> Type{<:AxisTrait}

Returns the indexing type of an `Axis`, any subtype of `AxisTrait`.
The default is `Unsupported`, meaning there is no special indexing behaviour for this axis
Expand All @@ -528,13 +538,16 @@ User-defined axis types can be added along with custom indexing behaviors by def
methods of this function. Here is the example of adding a custom Dimensional axis:

```julia
AxisArrays.axistrait(v::MyCustomAxis) = AxisArrays.Dimensional
AxisArrays.axistrait(::Type{MyCustomAxis}) = AxisArrays.Dimensional
```
"""
axistrait(::Any) = Unsupported
axistrait(ax::Axis) = axistrait(ax.val)
axistrait{T<:Union{Number, Dates.AbstractTime}}(::AbstractVector{T}) = Dimensional
axistrait{T<:Union{Symbol, AbstractString}}(::AbstractVector{T}) = Categorical
axistrait{T}(::T) = axistrait(T)
axistrait{T}(::Type{T}) = Unsupported
axistrait{name, T}(::Type{Axis{name, T}}) = axistrait(T)
axistrait{T<:AbstractVector}(::Type{T}) = _axistrait_el(eltype(T))
_axistrait_el{T<:Union{Number, Dates.AbstractTime}}(::Type{T}) = Dimensional
_axistrait_el{T<:Union{Symbol, AbstractString}}(::Type{T}) = Categorical
_axistrait_el{T}(::Type{T}) = Unsupported

checkaxis(ax::Axis) = checkaxis(ax.val)
checkaxis(ax) = checkaxis(axistrait(ax), ax)
Expand Down
17 changes: 16 additions & 1 deletion src/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,17 @@ end
ex = Expr(:tuple)
n = 0
for i=1:length(I)
if axistrait(I[i]) <: Categorical && i <= length(Ax.parameters)
if I[i] <: Axis
push!(ex.args, :(axisindexes(A.axes[$i], I[$i].val)))
else
push!(ex.args, :(axisindexes(A.axes[$i], I[$i])))
end
n += 1

continue
end

if I[i] <: Idx
push!(ex.args, :(I[$i]))
n += 1
Expand All @@ -273,7 +284,11 @@ end
end
n += length(I[i])
elseif i <= length(Ax.parameters)
push!(ex.args, :(axisindexes(A.axes[$i], I[$i])))
if I[i] <: Axis
push!(ex.args, :(axisindexes(A.axes[$i], I[$i].val)))
else
push!(ex.args, :(axisindexes(A.axes[$i], I[$i])))
end
n += 1
else
push!(ex.args, :(error("dimension ", $i, " does not have an axis to index")))
Expand Down
2 changes: 1 addition & 1 deletion src/sortedvector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ Base.size(v::SortedVector) = size(v.data)
Base.size(v::SortedVector, i) = size(v.data, i)
Base.indices(v::SortedVector) = indices(v.data)

axistrait(::SortedVector) = Dimensional
axistrait{T}(::Type{SortedVector{T}}) = Dimensional
checkaxis(::SortedVector) = nothing


Expand Down
21 changes: 21 additions & 0 deletions test/categoricalvector.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Test CategoricalVector with a hierarchical index (indexed using Tuples)
srand(1234)
data = reshape(1.:40., 20, 2)
v = collect(zip([:a, :b, :c][rand(1:3,20)], [:x,:y][rand(1:2,20)], [:x,:y][rand(1:2,20)]))
idx = sortperm(v)
A = AxisArray(data[idx,:], CategoricalVector(v[idx]), [:a, :b])
@test A[:b, :] == A[5:12, :]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume these reflect the actual random numbers produced due to srand(1234)? Would it be clearer to just hardcode v?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would; I was just following test/sortedvector.jl as much as possible. I can change it though :)

@test A[[:a,:c], :] == A[[1:4;13:end], :]
@test A[(:a,:y), :] == A[2:4, :]
@test A[(:c,:y,:y), :] == A[16:end, :]
@test AxisArrays.axistrait(axes(A)[1]) <: AxisArrays.Categorical

v = CategoricalVector(collect([1; 8; 10:15]))
@test AxisArrays.axistrait(axes(A)[1]) <: AxisArrays.Categorical
A = AxisArray(reshape(1:16, 8, 2), v, [:a, :b])
@test A[Axis{:row}(CategoricalVector([15]))] == AxisArray(reshape(A.data[8, :], 1, 2), CategoricalVector([15]), [:a, :b])
@test A[Axis{:row}(CategoricalVector([15])), 1] == AxisArray([A.data[8, 1]], CategoricalVector([15]))
@test AxisArrays.axistrait(axes(A)[1]) <: AxisArrays.Categorical

# TODO: maybe make this work? Would require removing or modifying Base.getindex(A::AxisArray, idxs::Idx...)
# @test A[CategoricalVector([15]), 1] == AxisArray([A.data[8, 1]], CategoricalVector([15]))
23 changes: 23 additions & 0 deletions test/combine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,26 @@ ABdata[3:6,3:6,:,2] = Bdata
@test join(A,B,method=:left) == AxisArray(ABdata[1:4, 1:4, :, :], A.axes...)
@test join(A,B,method=:right) == AxisArray(ABdata[3:6, 3:6, :, :], B.axes...)
@test join(A,B,method=:outer) == join(A,B)

# flatten
A1 = AxisArray(A1data, Axis{:X}(1:2), Axis{:Y}(1:2))
A2 = AxisArray(reshape(A2data, size(A2data)..., 1), Axis{:X}(1:2), Axis{:Y}(1:2), Axis{:Z}([:foo]))

@test @inferred(flatten(Val{2}, A1, A2)) == AxisArray(cat(3, A1data, A2data), Axis{:X}(1:2), Axis{:Y}(1:2), Axis{:flat}(CategoricalVector([(1,), (2, :foo)])))
@test @inferred(flatten(Val{2}, A1)) == AxisArray(reshape(A1, 2, 2, 1), Axis{:X}(1:2), Axis{:Y}(1:2), Axis{:flat}(CategoricalVector([(1,)])))
@test @inferred(flatten(Val{2}, A1)) == AxisArray(reshape(A1.data, size(A1)..., 1), axes(A1)..., Axis{:flat}(CategoricalVector([(1,)])))

@test @inferred(flatten(Val{2}, (:A1, :A2), A1, A2)) == AxisArray(cat(3, A1data, A2data), Axis{:X}(1:2), Axis{:Y}(1:2), Axis{:flat}(CategoricalVector([(:A1,), (:A2, :foo)])))
@test @inferred(flatten(Val{2}, (:foo,), A1)) == AxisArray(reshape(A1, 2, 2, 1), Axis{:X}(1:2), Axis{:Y}(1:2), Axis{:flat}(CategoricalVector([(:foo,)])))
@test @inferred(flatten(Val{2}, (:a,), A1)) == AxisArray(reshape(A1.data, size(A1)..., 1), axes(A1)..., Axis{:flat}(CategoricalVector([(:a,)])))

@test @inferred(flatten(Val{0}, A1)) == AxisArray(vec(A1data), Axis{:flat}(CategoricalVector(collect(IterTools.product((1,), axisvalues(A1)...)))))
@test @inferred(flatten(Val{1}, A1)) == AxisArray(A1data, Axis{:row}(1:2), Axis{:flat}(CategoricalVector(collect(IterTools.product((1,), axisvalues(A1)[2])))))

@test_throws ArgumentError flatten(Val{-1}, A1)
@test_throws ArgumentError flatten(Val{10}, A1)

A1ᵀ = transpose(A1)
@test_throws ArgumentError flatten(Val{-1}, A1, A1ᵀ)
@test_throws ArgumentError flatten(Val{1}, A1, A1ᵀ)
@test_throws ArgumentError flatten(Val{10}, A1, A1ᵀ)
10 changes: 8 additions & 2 deletions test/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,13 @@ A = @inferred(AxisArray(vals, Axis{:Timestamp}(dt-Dates.Hour(2):Dates.Hour(1):dt
@test A[dt, :].data == vals[3, :]

@test AxisArrays.axistrait(A.axes[1]) == AxisArrays.Dimensional
@test AxisArrays.axistrait(typeof(A.axes[1])) == AxisArrays.Dimensional
@test AxisArrays.axistrait(A.axes[1].val) == AxisArrays.Dimensional
@test AxisArrays.axistrait(typeof(A.axes[1].val)) == AxisArrays.Dimensional
@test AxisArrays.axistrait(A.axes[2]) == AxisArrays.Categorical
@test AxisArrays.axistrait(typeof(A.axes[2])) == AxisArrays.Categorical
@test AxisArrays.axistrait(A.axes[2].val) == AxisArrays.Categorical
@test AxisArrays.axistrait(typeof(A.axes[2].val)) == AxisArrays.Categorical

@test_throws ArgumentError AxisArrays.checkaxis(Axis{:x}(10:-1:1))
@test_throws ArgumentError AxisArrays.checkaxis(10:-1:1)
Expand Down Expand Up @@ -236,8 +240,10 @@ map!(*, A2, A, A)
# Reductions (issue #55)
A = AxisArray(collect(reshape(1:15,3,5)), :y, :x)
B = @inferred(AxisArray(collect(reshape(1:15,3,5)), Axis{:y}(0.1:0.1:0.3), Axis{:x}(10:10:50)))
for C in (A, B)
for op in (sum, minimum) # together, cover both reduced_indices and reduced_indices0
arrays = (A, B)
functions = (sum, minimum)
for C in arrays
for op in functions # together, cover both reduced_indices and reduced_indices0
axv = axisvalues(C)
C1 = @inferred(op(C, 1))
@test typeof(C1) == typeof(C)
Expand Down
Loading