@@ -50,23 +50,58 @@ function rrule_test(f, ȳ, (x, x̄)::Tuple{Any, Any}; rtol=1e-9, atol=1e-9, fdm
5050 test_accumulation (Zero (), dx, ȳ, x̄_ad)
5151end
5252
53+ function _make_fdm_call (fdm, f, ȳ, xs, ignores)
54+ sig = Expr (:tuple )
55+ call = Expr (:call , f)
56+ newxs = Any[]
57+ arginds = Int[]
58+ i = 1
59+ for (x, ignore) in zip (xs, ignores)
60+ if ignore
61+ push! (call. args, x)
62+ else
63+ push! (call. args, Symbol (:x , i))
64+ push! (sig. args, Symbol (:x , i))
65+ push! (newxs, x)
66+ push! (arginds, i)
67+ end
68+ i += 1
69+ end
70+ fdexpr = :(j′vp ($ fdm, $ sig -> $ call, $ ȳ, $ (newxs... )))
71+ fd = eval (fdexpr)
72+ fd isa Tuple || (fd = (fd,))
73+ args = Any[nothing for _ in 1 : length (xs)]
74+ for (dx, ind) in zip (fd, arginds)
75+ args[ind] = dx
76+ end
77+ return (args... ,)
78+ end
79+
5380function rrule_test (f, ȳ, xx̄s:: Tuple{Any, Any} ...; rtol= 1e-9 , atol= 1e-9 , fdm= _fdm, kwargs... )
5481 # Check correctness of evaluation.
5582 xs, x̄s = collect (zip (xx̄s... ))
56- Ω, Δx_rules = ChainRules . rrule (f, xs... )
57- @test f (xs... ) == Ω
83+ y, rules = rrule (f, xs... )
84+ @test f (xs... ) == y
5885
5986 # Correctness testing via finite differencing.
60- Δxs_ad = map (Δx_rule-> Δx_rule (ȳ), Δx_rules)
61- Δxs_fd = j′vp (fdm, f, ȳ, xs... )
62- for (Δx_ad, Δx_fd) in zip (Δxs_ad, Δxs_fd)
63- @test isapprox (Δx_ad, Δx_fd; rtol= rtol, atol= atol, kwargs... )
87+ x̄s_ad = map (rules) do rule
88+ rule isa DNERule ? DNE () : rule (ȳ)
89+ end
90+ x̄s_fd = _make_fdm_call (fdm, f, ȳ, xs, x̄s .== nothing )
91+ for (x̄_ad, x̄_fd) in zip (x̄s_ad, x̄s_fd)
92+ if x̄_fd === nothing
93+ # The way we've structured the above, this tests that the rule is a DNERule
94+ @test x̄_ad isa DNE
95+ else
96+ @test isapprox (x̄_ad, x̄_fd; rtol= rtol, atol= atol, kwargs... )
97+ end
6498 end
6599
66100 # Assuming the above to be correct, check that other ChainRules mechanisms are correct.
67- for (x̄, Δx_rule, Δx_ad) in zip (x̄s, Δx_rules, Δxs_ad)
68- test_accumulation (x̄, Δx_rule, ȳ, Δx_ad)
69- test_accumulation (Zero (), Δx_rule, ȳ, Δx_ad)
101+ for (x̄, rule, x̄_ad) in zip (x̄s, rules, x̄s_ad)
102+ x̄ === nothing && continue
103+ test_accumulation (x̄, rule, ȳ, x̄_ad)
104+ test_accumulation (Zero (), rule, ȳ, x̄_ad)
70105 end
71106end
72107
0 commit comments