Skip to content

Remove @adjoint for adjoint/Adjoint in favor of ChainRules #1257

@mtfishman

Description

@mtfishman

Can we remove the @adjoint definitions for adjoint/Adjoint:

Zygote.jl/src/lib/array.jl

Lines 371 to 381 in 4777767

@adjoint function Base.adjoint(x)
back(Δ) =',)
back::NamedTuple{(:parent,)}) =.parent,)
return x', back
end
@adjoint function LinearAlgebra.Adjoint(x)
back(Δ) = (LinearAlgebra.Adjoint(Δ),)
back::NamedTuple{(:parent,)}) =.parent,)
return LinearAlgebra.Adjoint(x), back
end

in favor of the rrules in ChainRules (found here)?

I want to overload the rrule for adjoint for a custom type, but I found that it ignores my rrule in favor of Zygote's rules, for example:

julia> using ChainRulesCore

julia> using ZygoteRules

julia> using Zygote

julia> struct MyType end

julia> Base.adjoint(x::MyType) = x

julia> f(x) = (MyType()'; x)
f (generic function with 1 method)

julia> f'(1.0)
1.0

julia> ChainRulesCore.rrule(::typeof(adjoint), x::MyType) = error("ChainRulesCore.rrule")

julia> f'(1.0)
1.0

julia> ZygoteRules.@adjoint Base.adjoint(x::MyType) = error("ZygoteRules.@adjoint")

julia> f'(1.0)
ERROR: ZygoteRules.@adjoint
Stacktrace:
 [1] error(s::String)
   @ Base ./error.jl:33
 [2] adjoint
   @ ./REPL[10]:1 [inlined]
 [3] _pullback
   @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
 [4] _pullback
   @ ./REPL[6]:1 [inlined]
 [5] _pullback(ctx::Zygote.Context, f::typeof(f), args::Float64)
   @ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
 [6] _pullback(f::Function, args::Float64)
   @ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface.jl:34
 [7] pullback(f::Function, args::Float64)
   @ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface.jl:40
 [8] (::Zygote.var"#62#63"{typeof(f)})(x::Float64)
   @ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface.jl:82
 [9] top-level scope
   @ REPL[11]:1

The same is true for transpose.

Package information:

(@v1.7) pkg> st ChainRulesCore Zygote
      Status `~/.julia/environments/v1.7/Project.toml`
  [d360d2e6] ChainRulesCore v1.15.1
  [e88e6eb3] Zygote v0.6.41

julia> versioninfo()
Julia Version 1.7.3
Commit 742b9abb4d (2022-05-06 12:58 UTC)
Platform Info:
  OS: Linux (x86_64-pc-linux-gnu)
  CPU: Intel(R) Xeon(R) E-2176M  CPU @ 2.70GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-12.0.1 (ORCJIT, skylake)
Environment:
  JULIA_EDITOR = vim

Metadata

Metadata

Assignees

No one assigned

    Labels

    ChainRulesadjoint -> rrule, and further integration

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions