Skip to content

Commit 88f3270

Browse files
authored
support constructors with non_differentiable (#243)
* support constructors with non_differentiable * Test at=non_differentiable on constructors * Make demo ADs not try to work on constructors * also test negative
1 parent 6ed6806 commit 88f3270

File tree

4 files changed

+36
-7
lines changed

4 files changed

+36
-7
lines changed

src/rule_definition_tools.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ macro non_differentiable(sig_expr)
294294
primal_name, orig_args = Iterators.peel(sig_expr.args)
295295

296296
constrained_args = _constrain_and_name.(orig_args, :Any)
297-
primal_sig_parts = [:(::typeof($primal_name)), constrained_args...]
297+
primal_sig_parts = [:(::Core.Typeof($primal_name)), constrained_args...]
298298

299299
unconstrained_args = _unconstrain.(constrained_args)
300300

test/demos/forwarddiffzero.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ using Test
1111
# Define the AD
1212

1313
# Note that we never directly define Dual Number Arithmetic on Dual numbers
14-
# instead it is automatically defined from the `frules`
14+
# instead it is automatically defined from the `frules`
1515
struct Dual <: Real
1616
primal::Float64
1717
partial::Float64
@@ -30,7 +30,8 @@ Base.to_power_type(x::Dual) = x
3030
function define_dual_overload(sig)
3131
sig = Base.unwrap_unionall(sig) # Not really handling most UnionAlls
3232
opT, argTs = Iterators.peel(sig.parameters)
33-
fieldcount(opT) == 0 || return # not handling functors
33+
opT isa Type{<:Type} && return # not handling constructors
34+
fieldcount(opT) == 0 || return # not handling functors
3435
all(Float64 <: argT for argT in argTs) || return # only handling purely Float64 ops.
3536

3637
N = length(sig.parameters) - 1 # skip the op
@@ -65,7 +66,7 @@ function ChainRulesCore.frule((_, Δx, Δy), ::typeof(*), x::Number, y::Number)
6566
end
6667

6768
# Manual refresh needed as new rule added in same file as AD after the `on_new_rule` call
68-
refresh_rules();
69+
refresh_rules();
6970

7071
@testset "ForwardDiffZero" begin
7172
foo(x) = x + x

test/demos/reversediffzero.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ Base.to_power_type(x::Tracked) = x
5959
function define_tracked_overload(sig)
6060
sig = Base.unwrap_unionall(sig) # not really handling most UnionAll
6161
opT, argTs = Iterators.peel(sig.parameters)
62-
fieldcount(opT) == 0 || return # not handling functors
62+
opT isa Type{<:Type} && return # not handling constructors
63+
fieldcount(opT) == 0 || return # not handling functors
6364
all(Float64 <: argT for argT in argTs) || return # only handling purely Float64 ops.
6465

6566
N = length(sig.parameters) - 1 # skip the op
@@ -116,7 +117,7 @@ function ChainRulesCore.rrule(::typeof(*), x::Number, y::Number)
116117
end
117118

118119
# Manual refresh needed as new rule added in same file as AD after the `on_new_rule` call
119-
refresh_rules();
120+
refresh_rules();
120121

121122
@testset "ReversedDiffZero" begin
122123
foo(x) = x + x

test/rule_definition_tools.jl

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,14 @@ macro test_macro_throws(err_expr, expr)
2121
end
2222
end
2323

24+
# struct need to be defined outside of tests for julia 1.0 compat
25+
struct NonDiffExample
26+
x
27+
end
28+
29+
struct NonDiffCounterExample
30+
x
31+
end
2432

2533
@testset "rule_definition_tools.jl" begin
2634
@testset "@non_differentiable" begin
@@ -98,6 +106,25 @@ end
98106
end
99107
end
100108

109+
@testset "Constructors" begin
110+
@non_differentiable NonDiffExample(::Any)
111+
112+
@test isequal(
113+
frule((Zero(), 1.2), NonDiffExample, 2.0),
114+
(NonDiffExample(2.0), DoesNotExist())
115+
)
116+
117+
res, pullback = rrule(NonDiffExample, 2.0)
118+
@test res == NonDiffExample(2.0)
119+
@test pullback(1.2) == (NO_FIELDS, DoesNotExist())
120+
121+
# https://github.com/JuliaDiff/ChainRulesCore.jl/issues/213
122+
# problem was that `@nondiff Foo(x)` was also defining rules for other types.
123+
# make sure that isn't happenning
124+
@test frule((Zero(), 1.2), NonDiffCounterExample, 2.0) === nothing
125+
@test rrule(NonDiffCounterExample, 2.0) === nothing
126+
end
127+
101128
@testset "Not supported (Yet)" begin
102129
# Varargs are not supported
103130
@test_macro_throws ErrorException @non_differentiable vararg1(xs...)
@@ -115,7 +142,7 @@ end
115142
@testset "@scalar_rule with multiple output" begin
116143
simo(x) = (x, 2x)
117144
@scalar_rule(simo(x), 1f0, 2f0)
118-
145+
119146
y, simo_pb = rrule(simo, π)
120147

121148
@test simo_pb((10f0, 20f0)) == (NO_FIELDS, 50f0)

0 commit comments

Comments
 (0)