From e0ad7c622efd08eac459d9c67f95e7df9044d81d Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Wed, 1 Sep 2021 06:01:20 +0200 Subject: [PATCH 1/3] rrule for fill! --- src/rulesets/Base/array.jl | 12 ++++++++++++ test/rulesets/Base/array.jl | 5 +++++ 2 files changed, 17 insertions(+) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index eb05753d0..28e4fb959 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -342,3 +342,15 @@ function rrule(::typeof(fill), x::Any, dims...) fill_pullback(Ȳ) = (NoTangent(), project(sum(Ȳ)), nots...) return fill(x, dims...), fill_pullback end + + +##### +##### `fill!` +##### + +function rrule(::typeof(fill!), A, x) + project = x isa Union{Number, AbstractArray{<:Number}} ? ProjectTo(x) : identity + fill!_pullback(Ȳ) = (NoTangent(), ZeroTangent(), project(sum(Ȳ))) + return fill!(A, x), fill!_pullback +end + diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index f601a0bb8..2cac0e503 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -198,3 +198,8 @@ end test_rrule(fill, 55 + 0.5im, 5) test_rrule(fill, 3.3, (3, 3, 3)) end + +@testset "fill!" begin + test_rrule(fill!, rand(2), 1) + test_rrule(fill!, rand(2) + im*rand(2), 1) +end From 7c5ef0d22e0fec9ba0dbe3c3dc8199bb7ceff2e4 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Wed, 1 Sep 2021 06:03:03 +0200 Subject: [PATCH 2/3] cleanup --- src/rulesets/Base/array.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 28e4fb959..607a6d39a 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -343,7 +343,6 @@ function rrule(::typeof(fill), x::Any, dims...) return fill(x, dims...), fill_pullback end - ##### ##### `fill!` ##### From 20c65bbb975dcedea1001a6dece360fe336e150f Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Wed, 1 Sep 2021 06:04:28 +0200 Subject: [PATCH 3/3] bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index ee1a2a877..3c1473b07 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.11.3" +version = "1.12.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"