-
Notifications
You must be signed in to change notification settings - Fork 27
Open
Description
ref JuliaDiff/ChainRulesTestUtils.jl#258
FiniteDifferences.jl/src/to_vec.jl
Lines 36 to 57 in 5c2979e
| # Fallback method for `to_vec`. Won't always do what you wanted, but should be fine a decent | |
| # chunk of the time. | |
| function to_vec(x::T) where {T} | |
| Base.isstructtype(T) || throw(error("Expected a struct type")) | |
| isempty(fieldnames(T)) && return (Bool[], _ -> x) # Singleton types | |
| val_vecs_and_backs = map(name -> to_vec(getfield(x, name)), fieldnames(T)) | |
| vals = first.(val_vecs_and_backs) | |
| backs = last.(val_vecs_and_backs) | |
| v, vals_from_vec = to_vec(vals) | |
| function structtype_from_vec(v::Vector{<:Real}) | |
| val_vecs = vals_from_vec(v) | |
| values = map((b, v) -> b(v), backs, val_vecs) | |
| try | |
| T(values...) | |
| catch MethodError | |
| return _force_construct(T, values...) | |
| end | |
| end | |
| return v, structtype_from_vec | |
| end |
Metadata
Metadata
Assignees
Labels
No labels