@@ -140,88 +140,122 @@ function Base.join{T,N,D,Ax}(As::AxisArray{T,N,D,Ax}...; fillvalue::T=zero(T),
140140
141141end # join
142142
143- function greatest_common_axis (As:: AxisArray... )
144- length (As) == 1 && return ndims (first (As))
143+ function _flatten_array_axes (array_name, array_axes... )
144+ ((array_name, (idx isa Tuple ? idx : (idx,)). .. ) for idx in product ((Ax. val for Ax in array_axes). .. ))
145+ end
145146
146- for (i, zip_axes) in enumerate (zip (axes .(As)... ))
147- if ! all (ax -> ax == zip_axes[1 ], zip_axes[2 : end ])
148- return i - 1
149- end
147+ function _flatten_axes (array_names, array_axes)
148+ collect (Iterators. flatten (map (array_names, array_axes) do tup_name, tup_array_axes
149+ _flatten_array_axes (tup_name, tup_array_axes... )
150+ end ))
151+ end
152+
153+ function _splitall {N} (:: Type{Val{N}} , As... )
154+ tuple ((Base. IteratorsMD. split (A, Val{N}) for A in As). .. )
155+ end
156+
157+ function _reshapeall {N} (:: Type{Val{N}} , As... )
158+ tuple ((reshape (A, Val{N}) for A in As). .. )
159+ end
160+
161+ function _check_common_axes (common_axis_tuple)
162+ if ! all (axisname (first (common_axis_tuple)) .=== axisname .(common_axis_tuple[2 : end ]))
163+ throw (ArgumentError (" Leading common axes must have the same name in each array" ))
150164 end
151165
152- return minimum ( map (ndims, As))
166+ return nothing
153167end
154168
155- function flatten_array_axes (array_name, array_axes)
156- map (zip (repeated (array_name), product (map (Ax-> Ax. val, array_axes)... ))) do tup
157- tup_name, tup_idx = tup
158- return (tup_name, tup_idx... )
169+ function _flat_axis_eltype (LType, trailing_axes)
170+ eltypes = map (trailing_axes) do array_trailing_axes
171+ Tuple{LType, eltype .(array_trailing_axes)... }
159172 end
173+
174+ return typejoin (eltypes... )
160175end
161176
162- function flatten_axes (array_names, array_axes )
163- collect ( chain ( map (flatten_array_axes, array_names, array_axes) ... ) )
177+ function flatten {N, NA} ( :: Type{Val{N}} , As :: Vararg{AxisArray, NA} )
178+ flatten (Val{N}, ntuple (identity, Val{NA}), As ... )
164179end
165180
166181"""
167182 flatten(As::AxisArray...) -> AxisArray
168- flatten(last_dim::Integer, As::AxisArray...) -> AxisArray
183+ flatten(last_dim::Type{Val{N}}, As::AxisArray...) -> AxisArray
184+ flatten(last_dim::Type{Val{N}}, labels::Tuple, As::AxisArray...) -> AxisArray
169185
170- Concatenates AxisArrays with equal leading axes into a single AxisArray.
186+ Concatenates AxisArrays with N equal leading axes into a single AxisArray.
171187All additional axes in any of the arrays are flattened into a single additional
172188CategoricalVector{Tuple} axis.
173189
174190### Arguments
175191
176- * `last_dim::Integer `: (optional) the greatest common dimension to share between all input
177- arrays. The remaining axes are flattened. If this argument is not
178- provided, the greatest common axis found among the input arrays is
179- used. All preceeding axes must also be common to each input array, at
180- the same dimension. Values from 0 up to one more than the minimum
181- number of dimensions across all input arrays are allowed.
192+ * `::Type{Val{N}} `: the greatest common dimension to share between all input
193+ arrays. The remaining axes are flattened. All N axes must be common
194+ to each input array, at the same dimension. Values from 0 up to the
195+ minimum number of dimensions across all input arrays are allowed.
196+ * `labels::Tuple`: (optional) a label for each AxisArray in As which is used in the flat
197+ axis
182198* `As::AxisArray...`: AxisArrays to be flattened together.
183199"""
184- function flatten (As:: AxisArray... ; kwargs... )
185- gca = greatest_common_axis (As... )
186-
187- return _flatten (gca, As... ; kwargs... )
188- end
189-
190- function flatten (last_dim:: Integer , As:: AxisArray... ; kwargs... )
191- last_dim >= 0 || throw (ArgumentError (" last_dim must be at least 0" ))
192-
193- if last_dim > minimum (map (ndims, As))
194- throw (ArgumentError (
195- " There must be at least $last_dim (last_dim) axes in each argument"
196- ))
200+ @generated function flatten {N, AN, LType} (:: Type{Val{N}} , labels:: NTuple{AN, LType} , As:: Vararg{AxisArray, AN} )
201+ if N < 0
202+ throw (ArgumentError (" flatten dimension N must be at least 0" ))
197203 end
198204
199- if last_dim > greatest_common_axis (As ... )
205+ if N > minimum ( ndims .(As) )
200206 throw (ArgumentError (
201- " The first $last_dim axes don't all match across all arguments"
207+ """
208+ flatten dimension N must not be greater than the maximum number of dimensions
209+ across all input arrays
210+ """
202211 ))
203212 end
204213
205- return _flatten (last_dim, As ... ; kwargs ... )
206- end
214+ flat_dim = Val{N + 1 }
215+ flat_dim_int = Int (N) + 1
207216
208- function _flatten (
209- last_dim:: Integer ,
210- As:: AxisArray... ;
211- array_names= 1 : length (As),
212- axis_name= nothing ,
213- )
214- common_axes = axes (As[1 ])[1 : last_dim]
215-
216- if axis_name === nothing
217- axis_name = _defaultdimname (last_dim + 1 )
218- elseif ! isa (axis_name, Symbol)
219- throw (ArgumentError (" axis_name must be a Symbol" ))
220- end
217+ common_axes, trailing_axes = zip (_splitall (Val{N}, axisparams .(As)... )... )
218+
219+ foreach (_check_common_axes, zip (common_axes... ))
220+
221+ new_common_axes = first (common_axes)
222+ flat_axis_eltype = _flat_axis_eltype (LType, trailing_axes)
223+ flat_axis_type = CategoricalVector{flat_axis_eltype, Vector{flat_axis_eltype}}
224+
225+ new_axes_type = Tuple{new_common_axes... , Axis{:flat , flat_axis_type}}
226+ new_eltype = Base. promote_eltype (As... )
221227
222- new_data = cat (last_dim + 1 , ( view (A . data, repeated (:, last_dim + 1 ) ... ) for A in As) . .. )
223- new_axis = flatten_axes (array_names, map (A -> axes (A)[last_dim + 1 : end ], As))
228+ quote
229+ common_axes, trailing_axes = zip ( _splitall (Val{N}, axes .( As)... ) ... )
224230
225- # TODO : Consider creating a SortedVector axis when all flattened axes are Dimensional
226- return AxisArray (new_data, common_axes... , CategoricalVector (new_axis))
231+ for common_axis_tuple in zip (common_axes... )
232+ if ! isempty (common_axis_tuple)
233+ for common_axis in common_axis_tuple[2 : end ]
234+ if ! all (axisvalues (common_axis) .== axisvalues (common_axis_tuple[1 ]))
235+ throw (ArgumentError (
236+ """
237+ Leading common axes must be identical across
238+ all input arrays"""
239+ ))
240+ end
241+ end
242+ end
243+ end
244+
245+ array_data = cat ($ flat_dim, _reshapeall ($ flat_dim, As... )... )
246+
247+ axis_array_type = AxisArray{
248+ $ new_eltype,
249+ $ flat_dim_int,
250+ Array{$ new_eltype, $ flat_dim_int},
251+ $ new_axes_type
252+ }
253+
254+ new_axes = (
255+ first (common_axes)... ,
256+ Axis {:flat, $flat_axis_type} ($ flat_axis_type (_flatten_axes (labels, trailing_axes))),
257+ )
258+
259+ return axis_array_type (array_data, new_axes)
260+ end
227261end
0 commit comments