@@ -281,6 +281,55 @@ Base.similar{S}(A::AxisArray, ::Type{S}, ax1::Axis, axs::Axis...) = similar(A, S
281281 end
282282end
283283
284+ # These methods allow us to preserve the AxisArray under reductions
285+ # Note that we only extend the following two methods, and then have it
286+ # dispatch to package-local `reduced_indices` and `reduced_indices0`
287+ # methods. This avoids a whole slew of ambiguities.
288+ Base. reduced_indices (A:: AxisArray , region) = reduced_indices (axes (A), region)
289+ Base. reduced_indices0 (A:: AxisArray , region) = reduced_indices0 (axes (A), region)
290+
291+ reduced_indices {N} (axs:: Tuple{Vararg{Axis,N}} , :: Tuple{} ) = axs
292+ reduced_indices0 {N} (axs:: Tuple{Vararg{Axis,N}} , :: Tuple{} ) = axs
293+ reduced_indices {N} (axs:: Tuple{Vararg{Axis,N}} , region:: Integer ) =
294+ reduced_indices (axs, (region,))
295+ reduced_indices0 {N} (axs:: Tuple{Vararg{Axis,N}} , region:: Integer ) =
296+ reduced_indices0 (axs, (region,))
297+
298+ reduced_indices {N} (axs:: Tuple{Vararg{Axis,N}} , region:: Dims ) =
299+ ntuple (d-> d∈ region ? reduced_axis (axs[d]) : axs[d], Val{N})
300+ reduced_indices0 {N} (axs:: Tuple{Vararg{Axis,N}} , region:: Dims ) =
301+ ntuple (d-> d∈ region ? reduced_axis0 (axs[d]) : axs[d], Val{N})
302+
303+ @inline reduced_indices {Ax<:Axis} (axs:: Tuple{Vararg{Axis}} , region:: Type{Ax} ) =
304+ _reduced_indices (reduced_axis, (), region, axs... )
305+ @inline reduced_indices (axs:: Tuple{Vararg{Axis}} , region:: Axis ) =
306+ _reduced_indices (reduced_axis, (), region, axs... )
307+ @inline reduced_indices0 {Ax<:Axis} (axs:: Tuple{Vararg{Axis}} , region:: Type{Ax} ) =
308+ _reduced_indices (reduced_axis0, (), region, axs... )
309+ @inline reduced_indices0 (axs:: Tuple{Vararg{Axis}} , region:: Axis ) =
310+ _reduced_indices (reduced_axis0, (), region, axs... )
311+
312+ reduced_indices (axs:: Tuple{Vararg{Axis}} , region:: Tuple{Vararg{DataType}} ) =
313+ reduced_indices (reduced_indices (axs, region[1 ]), tail (region))
314+ reduced_indices (axs:: Tuple{Vararg{Axis}} , region:: Tuple{Vararg{Axis}} ) =
315+ reduced_indices (reduced_indices (axs, region[1 ]), tail (region))
316+ reduced_indices0 (axs:: Tuple{Vararg{Axis}} , region:: Tuple{Vararg{DataType}} ) =
317+ reduced_indices0 (reduced_indices0 (axs, region[1 ]), tail (region))
318+ reduced_indices0 (axs:: Tuple{Vararg{Axis}} , region:: Tuple{Vararg{Axis}} ) =
319+ reduced_indices0 (reduced_indices0 (axs, region[1 ]), tail (region))
320+
321+ @inline _reduced_indices {name} (f, out, chosen:: Type{Axis{name}} , ax:: Axis{name} , axs... ) =
322+ _reduced_indices (f, (out... , f (ax)), chosen, axs... )
323+ @inline _reduced_indices {name} (f, out, chosen:: Axis{name} , ax:: Axis{name} , axs... ) =
324+ _reduced_indices (f, (out... , f (ax)), chosen, axs... )
325+ @inline _reduced_indices (f, out, chosen, ax:: Axis , axs... ) =
326+ _reduced_indices (f, (out... , ax), chosen, axs... )
327+ _reduced_indices (f, out, chosen) = out
328+
329+ reduced_axis (ax) = ax (oftype (ax. val, Base. OneTo (1 )))
330+ reduced_axis0 (ax) = ax (oftype (ax. val, length (ax. val) == 0 ? Base. OneTo (0 ) : Base. OneTo (1 )))
331+
332+
284333function Base. permutedims (A:: AxisArray , perm)
285334 p = permutation (perm, axisnames (A))
286335 AxisArray (permutedims (A. data, p), axes (A)[[p... ]])
0 commit comments