6666# #### map
6767# ####
6868
69- # `unzip_map` can use `StructArrays.components(StructArray(Iterators.map(f, args...)))`,
70- # will be useful for the gradient of `map` etc.
71-
72-
7369"""
74- unzip_map(f, args...)
70+ unzip_map(f, args...)
7571
7672For a function `f` which returns a tuple, this is `== unzip(map(f, args...))`,
7773but performed using `StructArrays` for efficiency.
@@ -86,40 +82,36 @@ function unzip_map(f::F, args...) where {F}
8682end
8783
8884unzip_map (f:: F , args:: Tuple... ) where {F} = unzip (map (f, args... ))
85+ # unzip_map(f::F, args::NamedTuple...) where {F} = unzip(map(f, args...))
8986
9087unzip_map (f:: F , args:: AbstractGPUArray... ) where {F} = unzip (map (f, args... ))
9188
89+ """
90+ unzip_map_reversed(f, args...)
91+
92+ For a pure function `f` which returns a tuple, this is `== unzip(map(f, args...))`.
93+ But the order of evaluation is should be the reverse.
94+ Does NOT handle `zip`-like behaviour.
95+ """
9296function unzip_map_reversed (f:: F , args... ) where {F}
9397 T = Broadcast. combine_eltypes (f, args)
9498 if isconcretetype (T)
9599 T <: Tuple || throw (ArgumentError (""" unzip_map_reversed(f, args) only works on functions returning a tuple,
96100 but f = $(sprint (show, f)) returns type T = $T """ ))
97101 end
98102 len1 = length (first (args))
99- if all (a -> length (a)== len1, args)
100- rev_args = map (Iterators. reverse, args)
101- outs = StructArrays. components (StructArray (Iterators. map (f, rev_args... )))
102- else
103- len = minimum (length, args)
104- rev_args = map (a -> Iterators. reverse (@view a[begin : begin + len- 1 ]), args)
105- outs = StructArrays. components (StructArray (Iterators. map (f, rev_args... )))
106- end
107- return map (reverse!!, outs)
103+ all (a -> length (a)== len1, args) || error (" unzip_map_reversed does not handle zip-like behaviour." )
104+ return map (reverse!!, unzip_map (f, map (_safereverse, args)... ))
108105end
109106
107+ # This avoids MethodError: no method matching iterate(::Base.Iterators.Reverse{Tangent{Tuple{Float64, Float64}, Tuple{Float64, Float64}}}) on 1.6
108+ _safereverse (x) = VERSION > v " 1.7" ? Iterators. reverse (x) : reverse (x)
109+
110110function unzip_map_reversed (f:: F , args:: Tuple... ) where {F}
111- len = minimum (length, args)
112- rev_args = map (a -> reverse (a[1 : len]), args)
113- # vlen = Val(len)
114- # rev_args = map(args) do a
115- # reverse(ntuple(i -> a[i], vlen)) # does not infer better
116- # end
117- return map (reverse, unzip (map (f, rev_args... )))
111+ len1 = length (first (args))
112+ all (a -> length (a)== len1, args) || error (" unzip_map_reversed does not handle zip-like behaviour." )
113+ return map (reverse, unzip (map (f, map (reverse, args)... )))
118114end
119- # function unzip_map_reversed(f::F, args::Tuple{Vararg{Any, N}}...) where {F,N}
120- # rev_args = map(reverse, args)
121- # return map(reverse, unzip(map(f, rev_args...)))
122- # end
123115
124116"""
125117 reverse!!(x)
@@ -135,10 +127,11 @@ function reverse!!(x::AbstractArray)
135127 end
136128end
137129reverse!! (x:: AbstractArray{<:AbstractZero} ) = x
130+ reverse!! (x) = reverse (x)
138131
139- frule ((_, xdot), :: typeof (reverse!!), x:: AbstractArray ) = reverse!! (x), reverse!! (xdot)
132+ frule ((_, xdot), :: typeof (reverse!!), x) = reverse!! (x), reverse!! (xdot)
140133
141- function rrule (:: typeof (reverse!!), x:: AbstractArray )
134+ function rrule (:: typeof (reverse!!), x)
142135 reverse!!_back (dy) = (NoTangent (), reverse (unthunk (dy)))
143136 return reverse!! (x), reverse!!_back
144137end
@@ -181,10 +174,16 @@ end
181174 Expr (:tuple , each... )
182175end
183176
184- unzip (xs:: AbstractArray{Tuple{T}} ) where {T} = (reinterpret (T, xs),) # best case, no copy
177+ function unzip (xs:: AbstractArray{Tuple{T}} ) where {T}
178+ if isbitstype (T)
179+ (reinterpret (T, xs),) # best case, no copy
180+ else
181+ (map (only, xs),)
182+ end
183+ end
185184
186185@generated function unzip (xs:: AbstractArray{Ts} ) where {Ts<: Tuple }
187- each = if count (! Base. issingletontype, Ts. parameters) < 2
186+ each = if count (! Base. issingletontype, Ts. parameters) < 2 && all (isbitstype, Ts . parameters)
188187 # good case, no copy of data, some trivial arrays
189188 [Base. issingletontype (T) ? :(similar (xs, $ T)) : :(reinterpret ($ T, xs)) for T in Ts. parameters]
190189 else
0 commit comments