-
-
Notifications
You must be signed in to change notification settings - Fork 160
Closed
JuliaDiff/ChainRules.jl
#758Description
DiffEqFlux.jl/src/DiffEqFlux.jl
Lines 24 to 40 in e32422d
| # ForwardDiff integration | |
| ZygoteRules.@adjoint function ForwardDiff.Dual{T}(x, ẋ::Tuple) where T | |
| @assert length(ẋ) == 1 | |
| ForwardDiff.Dual{T}(x, ẋ), ḋ -> (ḋ.partials[1], (ḋ.value,)) | |
| end | |
| ZygoteRules.@adjoint ZygoteRules.literal_getproperty(d::ForwardDiff.Dual{T}, ::Val{:partials}) where T = | |
| d.partials, ṗ -> (ForwardDiff.Dual{T}(ṗ[1], 0),) | |
| ZygoteRules.@adjoint ZygoteRules.literal_getproperty(d::ForwardDiff.Dual{T}, ::Val{:value}) where T = | |
| d.value, ẋ -> (ForwardDiff.Dual{T}(0, ẋ),) | |
| ZygoteRules.@adjoint ZygoteRules.literal_getproperty(A::Tridiagonal, ::Val{:dl}) = A.dl,y -> Tridiagonal(dl,zeros(length(d)),zeros(length(du)),) | |
| ZygoteRules.@adjoint ZygoteRules.literal_getproperty(A::Tridiagonal, ::Val{:d}) = A.d,y -> Tridiagonal(zeros(length(dl)),d,zeros(length(du)),) | |
| ZygoteRules.@adjoint ZygoteRules.literal_getproperty(A::Tridiagonal, ::Val{:du}) = A.dl,y -> Tridiagonal(zeros(length(dl)),zeros(length(d),du),) | |
| ZygoteRules.@adjoint Tridiagonal(dl, d, du) = Tridiagonal(dl, d, du), p̄ -> (diag(p̄[2:end,1:end-1]),diag(p̄),diag(p̄[1:end-1,2:end])) |
Related to SciML/SciMLSensitivity.jl#582
Metadata
Metadata
Assignees
Labels
No labels