Skip to content

Commit 037f3fb

Browse files
committed
Reductions preserve the AxisArray wrapper (fixes #55)
1 parent 65263db commit 037f3fb

File tree

2 files changed

+93
-0
lines changed

2 files changed

+93
-0
lines changed

src/core.jl

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,55 @@ Base.similar{S}(A::AxisArray, ::Type{S}, ax1::Axis, axs::Axis...) = similar(A, S
281281
end
282282
end
283283

284+
# These methods allow us to preserve the AxisArray under reductions
285+
# Note that we only extend the following two methods, and then have it
286+
# dispatch to package-local `reduced_indices` and `reduced_indices0`
287+
# methods. This avoids a whole slew of ambiguities.
288+
Base.reduced_indices(A::AxisArray, region) = reduced_indices(axes(A), region)
289+
Base.reduced_indices0(A::AxisArray, region) = reduced_indices0(axes(A), region)
290+
291+
reduced_indices{N}(axs::Tuple{Vararg{Axis,N}}, ::Tuple{}) = axs
292+
reduced_indices0{N}(axs::Tuple{Vararg{Axis,N}}, ::Tuple{}) = axs
293+
reduced_indices{N}(axs::Tuple{Vararg{Axis,N}}, region::Integer) =
294+
reduced_indices(axs, (region,))
295+
reduced_indices0{N}(axs::Tuple{Vararg{Axis,N}}, region::Integer) =
296+
reduced_indices0(axs, (region,))
297+
298+
reduced_indices{N}(axs::Tuple{Vararg{Axis,N}}, region::Dims) =
299+
ntuple(d->dregion ? reduced_axis(axs[d]) : axs[d], Val{N})
300+
reduced_indices0{N}(axs::Tuple{Vararg{Axis,N}}, region::Dims) =
301+
ntuple(d->dregion ? reduced_axis0(axs[d]) : axs[d], Val{N})
302+
303+
@inline reduced_indices{Ax<:Axis}(axs::Tuple{Vararg{Axis}}, region::Type{Ax}) =
304+
_reduced_indices(reduced_axis, (), region, axs...)
305+
@inline reduced_indices(axs::Tuple{Vararg{Axis}}, region::Axis) =
306+
_reduced_indices(reduced_axis, (), region, axs...)
307+
@inline reduced_indices0{Ax<:Axis}(axs::Tuple{Vararg{Axis}}, region::Type{Ax}) =
308+
_reduced_indices(reduced_axis0, (), region, axs...)
309+
@inline reduced_indices0(axs::Tuple{Vararg{Axis}}, region::Axis) =
310+
_reduced_indices(reduced_axis0, (), region, axs...)
311+
312+
reduced_indices(axs::Tuple{Vararg{Axis}}, region::Tuple{Vararg{DataType}}) =
313+
reduced_indices(reduced_indices(axs, region[1]), tail(region))
314+
reduced_indices(axs::Tuple{Vararg{Axis}}, region::Tuple{Vararg{Axis}}) =
315+
reduced_indices(reduced_indices(axs, region[1]), tail(region))
316+
reduced_indices0(axs::Tuple{Vararg{Axis}}, region::Tuple{Vararg{DataType}}) =
317+
reduced_indices0(reduced_indices0(axs, region[1]), tail(region))
318+
reduced_indices0(axs::Tuple{Vararg{Axis}}, region::Tuple{Vararg{Axis}}) =
319+
reduced_indices0(reduced_indices0(axs, region[1]), tail(region))
320+
321+
@inline _reduced_indices{name}(f, out, chosen::Type{Axis{name}}, ax::Axis{name}, axs...) =
322+
_reduced_indices(f, (out..., f(ax)), chosen, axs...)
323+
@inline _reduced_indices{name}(f, out, chosen::Axis{name}, ax::Axis{name}, axs...) =
324+
_reduced_indices(f, (out..., f(ax)), chosen, axs...)
325+
@inline _reduced_indices(f, out, chosen, ax::Axis, axs...) =
326+
_reduced_indices(f, (out..., ax), chosen, axs...)
327+
_reduced_indices(f, out, chosen) = out
328+
329+
reduced_axis(ax) = ax(oftype(ax.val, Base.OneTo(1)))
330+
reduced_axis0(ax) = ax(oftype(ax.val, length(ax.val) == 0 ? Base.OneTo(0) : Base.OneTo(1)))
331+
332+
284333
function Base.permutedims(A::AxisArray, perm)
285334
p = permutation(perm, axisnames(A))
286335
AxisArray(permutedims(A.data, p), axes(A)[[p...]])

test/core.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,3 +192,47 @@ A[0] = 12
192192
A = AxisArray(OffsetArrays.OffsetArray(rand(4,5), -1:2, 5:9), :x, :y)
193193
@test indices(A) == (-1:2, 5:9)
194194
@test linearindices(A) == 1:20
195+
196+
# Reductions (issue #55)
197+
A = AxisArray(collect(reshape(1:15,3,5)), :y, :x)
198+
B = AxisArray(collect(reshape(1:15,3,5)), Axis{:y}(0.1:0.1:0.3), Axis{:x}(10:10:50))
199+
for C in (A, B)
200+
for op in (sum, minimum) # together, cover both reduced_indices and reduced_indices0
201+
axv = axisvalues(C)
202+
# C1 = @inferred(sum(C, 1))
203+
C1 = op(C, 1)
204+
@test typeof(C1) == typeof(C)
205+
@test axisnames(C1) == (:y,:x)
206+
@test axisvalues(C1) === (oftype(axv[1], Base.OneTo(1)), axv[2])
207+
C2 = op(C, 2)
208+
@test typeof(C2) == typeof(C)
209+
@test axisnames(C2) == (:y,:x)
210+
@test axisvalues(C2) === (axv[1], oftype(axv[2], Base.OneTo(1)))
211+
# C12 = @inferred(sum(C, (1,2)))
212+
C12 = op(C, (1,2))
213+
@test typeof(C12) == typeof(C)
214+
@test axisnames(C12) == (:y,:x)
215+
@test axisvalues(C12) === (oftype(axv[1], Base.OneTo(1)), oftype(axv[2], Base.OneTo(1)))
216+
if op == sum
217+
@test C1 == [6 15 24 33 42]
218+
@test C2 == reshape([35,40,45], 3, 1)
219+
@test C12 == reshape([120], 1, 1)
220+
else
221+
@test C1 == [1 4 7 10 13]
222+
@test C2 == reshape([1,2,3], 3, 1)
223+
@test C12 == reshape([1], 1, 1)
224+
end
225+
C1t = @inferred(op(C, Axis{:y}))
226+
@test C1t == C1
227+
C2t = @inferred(op(C, Axis{:x}))
228+
@test C2t == C2
229+
C12t = @inferred(op(C, (Axis{:y},Axis{:x})))
230+
@test C12t == C12
231+
C1t = @inferred(op(C, Axis{:y}()))
232+
@test C1t == C1
233+
C2t = @inferred(op(C, Axis{:x}()))
234+
@test C2t == C2
235+
C12t = @inferred(op(C, (Axis{:y}(),Axis{:x}())))
236+
@test C12t == C12
237+
end
238+
end

0 commit comments

Comments
 (0)