Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ArrayInterface"
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
version = "7.4.11"
version = "7.5.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
24 changes: 17 additions & 7 deletions ext/ArrayInterfaceBandedMatricesExt.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,26 @@
module ArrayInterfaceBandedMatricesExt


if isdefined(Base, :get_extension)
using ArrayInterface
using ArrayInterface: BandedMatrixIndex
using BandedMatrices
using LinearAlgebra
else
using ..ArrayInterface
using ..ArrayInterface: BandedMatrixIndex
using ..BandedMatrices
using ..LinearAlgebra
end

const TransOrAdjBandedMatrix = Union{
Adjoint{T, <:BandedMatrix{T}},
Transpose{T, <:BandedMatrix{T}},
} where {T}

const AllBandedMatrix = Union{
BandedMatrix{T},
TransOrAdjBandedMatrix{T},
} where {T}

Base.firstindex(i::BandedMatrixIndex) = 1
Base.lastindex(i::BandedMatrixIndex) = i.count
Expand Down Expand Up @@ -45,24 +55,24 @@ end

function BandedMatrixIndex(rowsize, colsize, lowerbandwidth, upperbandwidth, isrow)
upperbandwidth > -lowerbandwidth || throw(ErrorException("Invalid Bandwidths"))
bandinds = upperbandwidth:-1:-lowerbandwidth
bandinds = upperbandwidth:-1:(-lowerbandwidth)
bandsizes = [_bandsize(band, rowsize, colsize) for band in bandinds]
BandedMatrixIndex(sum(bandsizes), rowsize, colsize, bandinds, bandsizes, isrow)
end

function ArrayInterface.findstructralnz(x::BandedMatrices.BandedMatrix)
function ArrayInterface.findstructralnz(x::AllBandedMatrix)
l, u = BandedMatrices.bandwidths(x)
rowsize, colsize = Base.size(x)
rowind = BandedMatrixIndex(rowsize, colsize, l, u, true)
colind = BandedMatrixIndex(rowsize, colsize, l, u, false)
return (rowind, colind)
end

ArrayInterface.has_sparsestruct(::Type{<:BandedMatrices.BandedMatrix}) = true
ArrayInterface.isstructured(::Type{<:BandedMatrices.BandedMatrix}) = true
ArrayInterface.fast_matrix_colors(::Type{<:BandedMatrices.BandedMatrix}) = true
ArrayInterface.has_sparsestruct(::Type{<:AllBandedMatrix}) = true
ArrayInterface.isstructured(::Type{<:AllBandedMatrix}) = true
ArrayInterface.fast_matrix_colors(::Type{<:AllBandedMatrix}) = true

function ArrayInterface.matrix_colors(A::BandedMatrices.BandedMatrix)
function ArrayInterface.matrix_colors(A::AllBandedMatrix)
l, u = BandedMatrices.bandwidths(A)
width = u + l + 1
return ArrayInterface._cycle(1:width, Base.size(A, 2))
Expand Down
53 changes: 42 additions & 11 deletions test/bandedmatrices.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,50 @@

using ArrayInterface
using BandedMatrices
using Test

B=BandedMatrix(Ones(5,5), (-1,2))
B[band(1)].=[1,2,3,4]
B[band(2)].=[5,6,7]
function checkequal(idx1::ArrayInterface.BandedMatrixIndex,
idx2::ArrayInterface.BandedMatrixIndex)
return idx1.rowsize == idx2.rowsize && idx1.colsize == idx2.colsize &&
idx1.bandinds == idx2.bandinds && idx1.bandsizes == idx2.bandsizes &&
idx1.isrow == idx2.isrow && idx1.count == idx2.count
end

B = BandedMatrix(Ones(5, 5), (-1, 2))
B[band(1)] .= [1, 2, 3, 4]
B[band(2)] .= [5, 6, 7]
@test ArrayInterface.has_sparsestruct(B)
rowind,colind=ArrayInterface.findstructralnz(B)
@test [B[rowind[i],colind[i]] for i in 1:length(rowind)]==[5,6,7,1,2,3,4]
B=BandedMatrix(Ones(4,6), (-1,2))
B[band(1)].=[1,2,3,4]
B[band(2)].=[5,6,7,8]
rowind,colind=ArrayInterface.findstructralnz(B)
@test [B[rowind[i],colind[i]] for i in 1:length(rowind)]==[5,6,7,8,1,2,3,4]
rowind, colind = ArrayInterface.findstructralnz(B)
@test [B[rowind[i], colind[i]] for i in 1:length(rowind)] == [5, 6, 7, 1, 2, 3, 4]
B = BandedMatrix(Ones(4, 6), (-1, 2))
B[band(1)] .= [1, 2, 3, 4]
B[band(2)] .= [5, 6, 7, 8]
rowind, colind = ArrayInterface.findstructralnz(B)
@test [B[rowind[i], colind[i]] for i in 1:length(rowind)] == [5, 6, 7, 8, 1, 2, 3, 4]
@test ArrayInterface.isstructured(typeof(B))
@test ArrayInterface.fast_matrix_colors(typeof(B))

for op in (adjoint, transpose)
B = BandedMatrix(Ones(5, 5), (-1, 2))
B[band(1)] .= [1, 2, 3, 4]
B[band(2)] .= [5, 6, 7]
B′ = op(B)
@test ArrayInterface.has_sparsestruct(B′)
rowind′, colind′ = ArrayInterface.findstructralnz(B′)
rowind′′, colind′′ = ArrayInterface.findstructralnz(BandedMatrix(B′))
@test checkequal(rowind′, rowind′′)
@test checkequal(colind′, colind′′)

B = BandedMatrix(Ones(4, 6), (-1, 2))
B[band(1)] .= [1, 2, 3, 4]
B[band(2)] .= [5, 6, 7, 8]
B′ = op(B)
rowind′, colind′ = ArrayInterface.findstructralnz(B′)
rowind′′, colind′′ = ArrayInterface.findstructralnz(BandedMatrix(B′))
@test checkequal(rowind′, rowind′′)
@test checkequal(colind′, colind′′)

@test ArrayInterface.isstructured(typeof(B′))
@test ArrayInterface.fast_matrix_colors(typeof(B′))

@test ArrayInterface.matrix_colors(B′) == ArrayInterface.matrix_colors(BandedMatrix(B′))
end