Skip to content

Commit 0cb446f

Browse files
committed
Add some projectors
1 parent 55d64f1 commit 0cb446f

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

src/rulesets/Base/fastmath_able.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ let
1111
## sin
1212
function rrule(::typeof(sin), x::CommutativeMulNumber)
1313
sinx, cosx = sincos(x)
14-
sin_pullback(Δy) = (NoTangent(), cosx' * Δy)
14+
project_x = ProjectTo(x)
15+
sin_pullback(Δy) = (NoTangent(), project_x(cosx' * Δy))
1516
return (sinx, sin_pullback)
1617
end
1718

@@ -23,7 +24,8 @@ let
2324
## cos
2425
function rrule(::typeof(cos), x::CommutativeMulNumber)
2526
sinx, cosx = sincos(x)
26-
cos_pullback(Δy) = (NoTangent(), -sinx' * Δy)
27+
project_x = ProjectTo(x)
28+
cos_pullback(Δy) = (NoTangent(), -project_x(sinx' * Δy))
2729
return (cosx, cos_pullback)
2830
end
2931

@@ -61,7 +63,7 @@ let
6163
project_x = ProjectTo(x)
6264
function inv_pullback(ΔΩ)
6365
Ω′ = conj(Ω)
64-
return NoTangent(), project_x(Ω′ * -ΔΩ * Ω′)
66+
return NoTangent(), -project_x(Ω′ * ΔΩ * Ω′)
6567
end
6668
return Ω, inv_pullback
6769
end

0 commit comments

Comments
 (0)