Skip to content

Commit ec3fad4

Browse files
bors[bot]oxinabox
andauthored
Merge #780
780: Move a bunch of no_grad to ChainRules r=oxinabox a=oxinabox this is the partner to JuliaDiff/ChainRules.jl#252 It will fail til that is merged and tagged What is left is: - Types (because JuliaDiff/ChainRulesCore.jl#213) (e.g. `Colon`, `OneTo` `Channel`) - Things to which the derivative is `Zero()` not `DoesNotExist()` (e.g. `one`, `ones`, `zero`, `zeros`) - Things that felt too magic: e.g. `Base.eval` Should I bump patch version and tag a release? Co-authored-by: Lyndon White <[email protected]> Co-authored-by: Lyndon White <[email protected]>
2 parents 4ea7ad7 + a2026e7 commit ec3fad4

File tree

9 files changed

+29
-23
lines changed

9 files changed

+29
-23
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Zygote"
22
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
3-
version = "0.5.5"
3+
version = "0.5.6"
44

55
[deps]
66
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
@@ -27,7 +27,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
2727
[compat]
2828
AbstractFFTs = "0.5"
2929
ArrayLayouts = "0.1, 0.2, 0.3, 0.4"
30-
ChainRules = "0.7.0"
30+
ChainRules = "0.7.16"
3131
DiffRules = "1.0"
3232
FillArrays = "0.8, 0.9"
3333
ForwardDiff = "0"

src/compiler/chainrules.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ Convert `x` from the differentials types ChainRules uses to the format Zygote us
4040
"""
4141
@inline wrap_chainrules_output(x) = unthunk(x) # For now we are just not going to deal with thunks
4242
@inline wrap_chainrules_output(x::Tuple) = map(wrap_chainrules_output, x)
43+
# Zygote convention: even if many AbstractZero partials (i.e. multi-input function), make just 1 nothing.
44+
@inline wrap_chainrules_output(x::Tuple{Vararg{ChainRules.AbstractZero}}) = nothing
4345
@inline wrap_chainrules_output(x::ChainRules.AbstractZero) = nothing
4446
for T_outer in (:Tuple, :NamedTuple)
4547
# we create separate methods rather than using a `Union` + an `if` so that we avoid a

src/lib/array.jl

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,7 @@ using Distributed: pmap
88
@adjoint Array(xs::AbstractArray) = Array(xs), ȳ -> (ȳ,)
99
@adjoint Array(xs::Array) = Array(xs), ȳ -> (ȳ,)
1010

11-
@nograd size, length, eachindex, Base.OneTo, axes, Colon(), findfirst, findlast, findall, ones, zeros, one, zero, any, all
12-
@nograd randn, randexp, randn!, randexp!
13-
@static if VERSION > v"1.3"
14-
@nograd Random.default_rng
15-
end
16-
17-
@adjoint Base.rand(rng::AbstractRNG, ::Type{T}, dims...) where {T<:Number} =
18-
rand(rng, T, dims...), _ -> nothing
11+
@nograd ones, zeros, Base.OneTo, Colon(), one, zero
1912

2013
@adjoint Base.vect(xs...) = Base.vect(xs...), Δ ->...,)
2114

src/lib/base.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
@nograd readline, Base.gc_num, Base.time_ns, Base.print, Base.println, Base.show,
2-
Core.show, Core.print, Core.println, string, repr, Threads.nthreads, Threads.threadid
3-
41
# Gradient of AD stacks
52

63
grad_mut(::AbstractVector) = []
@@ -47,11 +44,9 @@ end
4744
end
4845
end
4946

50-
@nograd haskey
51-
5247
# Channels
5348

54-
@nograd Channel, schedule
49+
@nograd Channel
5550

5651
grad_mut(ch::Channel) = Channel(ch.sz_max)
5752

src/lib/broadcast.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@ using Base.Broadcast
1616
using Base.Broadcast: Broadcasted, AbstractArrayStyle, broadcasted, materialize
1717
using NNlib
1818

19-
@nograd Broadcast.combine_styles, Broadcast.result_style
20-
2119
# There's a saying that debugging code is about twice as hard as writing it in
2220
# the first place. So if you're as clever as you can be when writing code, how
2321
# will you ever debug it?

src/lib/lib.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,7 @@ function accum(x::RefValue, y::RefValue)
2626
end
2727

2828
# Core functions
29-
30-
@nograd Core.apply_type, Core.typeof, nfields, fieldtype, Core.TypeVar, Core.UnionAll,
31-
(==), (===), (<=), (>=), (<), (>), isempty, supertype, Base.typename,
32-
eps, Meta.parse, Base.eval, sleep, isassigned
29+
@nograd eps, Base.eval, Core.TypeVar, Core.UnionAll
3330

3431
@adjoint deepcopy(x) = deepcopy(x), ȳ -> (ȳ,)
3532

src/lib/number.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11

2-
@nograd floor, ceil, trunc, round, hash, div
2+
@nograd floor, ceil, trunc, round, div
33

44
@adjoint Base.literal_pow(::typeof(^), x::Number, ::Val{p}) where {p} =
55
Base.literal_pow(^,x,Val(p)),

test/chainrules.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,22 @@ using Zygote, Test, ChainRules
124124
@test mimo_pullback_hitcount[] == 1
125125
end
126126

127+
@testset "all AbstractZero partials" begin
128+
# while ChainRules always has a partial for every input, Zygote combined them all
129+
# to a single `nothing` if they are all zero-like.
130+
131+
not_diff_eg(x, i) = [10, 20][i]
132+
function ChainRules.rrule(::typeof(not_diff_eg), x, i)
133+
function not_diff_eg_pullback(Δ)
134+
return ChainRules.NO_FIELDS, ChainRules.Zero(), ChainRules.DoesNotExist()
135+
end
136+
return not_diff_eg(x, i), not_diff_eg_pullback
137+
end
138+
139+
_, pb = Zygote.pullback(not_diff_eg, 10.4, 2)
140+
@test pb(1.2) === nothing
141+
end
142+
127143
@testset "nested AD hitting identity(::Tuple) pullback" begin
128144
# This is is a particularly fiddly case.
129145
# Its kind of a simplified version of `sin'''(0.5)` but different in some places.

test/gradcheck.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1538,9 +1538,14 @@ end
15381538
end
15391539

15401540
@testset "@nograd" begin
1541+
@test gradient(x->eachindex([10,20,30])[1], 11) == (nothing,)
1542+
1543+
#These are defined in ChainRules, we test them here to check we are handling them right
15411544
@test gradient(x -> findfirst(ismissing, x), [1, missing]) == (nothing,)
15421545
@test gradient(x -> findlast(ismissing, x), [1, missing]) == (nothing,)
15431546
@test gradient(x -> findall(ismissing, x)[1], [1, missing]) == (nothing,)
1547+
1548+
15441549
@test gradient(x -> Zygote.ignore(() -> x*x), 1) == (nothing,)
15451550
@test gradient(x -> Zygote.@ignore(x*x), 1) == (nothing,)
15461551
@test gradient(1) do x

0 commit comments

Comments
 (0)