Skip to content
17 changes: 17 additions & 0 deletions src/rulesets/LinearAlgebra/structured.jl
Original file line number Diff line number Diff line change
Expand Up @@ -268,3 +268,20 @@ function rrule(::typeof(logdet), X::Union{Diagonal, AbstractTriangular})
end
return y, logdet_pullback
end

#####
##### Tridiagonal
#####

function rrule(::Type{Tridiagonal}, dl, d, du)
y = Tridiagonal(dl, d, du)
@views function ∇Tridiagonal(∂y)
return (
NoTangent(),
diag(∂y[2:end, 1:(end - 1)]),
diag(∂y),
diag(∂y[1:(end - 1), 2:end]),
)
end
return y, ∇Tridiagonal
end
5 changes: 5 additions & 0 deletions test/rulesets/LinearAlgebra/structured.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,4 +161,9 @@
end
end
end

@testset "Tridiagonal" begin
test_rrule(Tridiagonal, [1.0, 4.0], [2.0, 3.0, 4.0], [5.0, 3.0])
@test pb(10 * res) == (NoTangent(), [10, 40], [20, 30, 40], [50, 30])
end
end