Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
name = "BoundaryValueDiffEq"
uuid = "764a87c0-6b3e-53db-9096-fe964310641d"
version = "5.1.0"
version = "5.2.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Expand All @@ -18,7 +19,6 @@ SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

Expand All @@ -32,6 +32,7 @@ BoundaryValueDiffEqODEInterfaceExt = "ODEInterface"
ADTypes = "0.2"
Adapt = "3"
ArrayInterface = "7"
BandedMatrices = "1"
ConcreteStructs = "0.2"
DiffEqBase = "6.94.2"
ForwardDiff = "0.10"
Expand Down
9 changes: 4 additions & 5 deletions src/BoundaryValueDiffEq.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
module BoundaryValueDiffEq

using Adapt, LinearAlgebra, PreallocationTools, Reexport, Setfield, SparseArrays, SciMLBase,
Static, RecursiveArrayTools, ForwardDiff
using Adapt, BandedMatrices, ForwardDiff, LinearAlgebra, PreallocationTools,
RecursiveArrayTools, Reexport, Setfield, SparseArrays
@reexport using ADTypes, DiffEqBase, NonlinearSolve, SparseDiffTools, SciMLBase

import ADTypes: AbstractADType
import ArrayInterface: matrix_colors, parameterless_type, undefmatrix
import ArrayInterface: matrix_colors, parameterless_type, undefmatrix, fast_scalar_indexing
import ConcreteStructs: @concrete
import DiffEqBase: solve
import ForwardDiff: pickchunksize
import RecursiveArrayTools: ArrayPartition, DiffEqArray
import SciMLBase: AbstractDiffEqInterpolation, StandardBVProblem, __solve
import RecursiveArrayTools: ArrayPartition
import SciMLBase: AbstractDiffEqInterpolation, StandardBVProblem, __solve, _unwrap_val
import SparseDiffTools: AbstractSparseADType
import TruncatedStacktraces: @truncate_stacktrace
import UnPack: @unpack
Expand Down
48 changes: 31 additions & 17 deletions src/solve/mirk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
fᵢ₂_cache
defect
new_stages
resid_size
kwargs
end

Expand Down Expand Up @@ -64,8 +65,13 @@ function SciMLBase.__init(prob::BVProblem, alg::AbstractMIRK; dt = 0.0,
bcresid_prototype, resid₁_size = __get_bcresid_prototype(prob.problem_type, prob, X)

residual = if iip
vcat([__alloc_diffcache(bcresid_prototype)],
__alloc_diffcache.(copy.(@view(y₀[2:end]))))
if prob.problem_type isa TwoPointBVProblem
vcat([__alloc_diffcache(__vec(bcresid_prototype))],
__alloc_diffcache.(copy.(@view(y₀[2:end]))))
else
vcat([__alloc_diffcache(bcresid_prototype)],
__alloc_diffcache.(copy.(@view(y₀[2:end]))))
end
else
nothing
end
Expand All @@ -74,6 +80,7 @@ function SciMLBase.__init(prob::BVProblem, alg::AbstractMIRK; dt = 0.0,
new_stages = [similar(X, ifelse(adaptive, M, 0)) for _ in 1:n]

# Transform the functions to handle non-vector inputs
bcresid_prototype = __vec(bcresid_prototype)
f, bc = if X isa AbstractVector
prob.f, prob.f.bc
elseif iip
Expand All @@ -92,7 +99,6 @@ function SciMLBase.__init(prob::BVProblem, alg::AbstractMIRK; dt = 0.0,
end
(__vecbc_a!, __vecbc_b!)
end
bcresid_prototype = vec(bcresid_prototype)
vecf!, vecbc!
else
vecf(u, p, t) = vec(prob.f(reshape(u, size(X)), p, t))
Expand All @@ -103,14 +109,13 @@ function SciMLBase.__init(prob::BVProblem, alg::AbstractMIRK; dt = 0.0,
__vecbc_b(ub, p) = vec(prob.f.bc[2](reshape(ub, size(X)), p))
(__vecbc_a, __vecbc_b)
end
bcresid_prototype = vec(bcresid_prototype)
vecf, vecbc
end

return MIRKCache{iip, T}(alg_order(alg), stage, M, size(X), f, bc, prob,
prob.problem_type, prob.p, alg, TU, ITU, bcresid_prototype, mesh, mesh_dt,
k_discrete, k_interp, y, y₀, residual, fᵢ_cache, fᵢ₂_cache, defect, new_stages,
(; defect_threshold, MxNsub, abstol, dt, adaptive, kwargs...))
resid₁_size, (; defect_threshold, MxNsub, abstol, dt, adaptive, kwargs...))
end

"""
Expand Down Expand Up @@ -224,13 +229,21 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y::AbstractVector) where {
end

loss = if iip
function loss_internal!(resid::AbstractVector, u::AbstractVector, p = cache.p)
@views function loss_internal!(resid::AbstractVector,
u::AbstractVector,
p = cache.p)
y_ = recursive_unflatten!(cache.y, u)
resids = [get_tmp(r, u) for r in cache.residual]
eval_bc_residual!(resids[1], cache.problem_type, cache.bc, y_, p, cache.mesh)
resid_bc = if cache.problem_type isa TwoPointBVProblem
(resids[1][1:prod(cache.resid_size[1])],
resids[1][(prod(cache.resid_size[1]) + 1):end])
else
resids[1]
end
eval_bc_residual!(resid_bc, cache.problem_type, cache.bc, y_, p, cache.mesh)
Φ!(resids[2:end], cache, y_, u, p)
if cache.problem_type isa TwoPointBVProblem
recursive_flatten_twopoint!(resid, resids)
recursive_flatten_twopoint!(resid, resids, cache.resid_size)
else
recursive_flatten!(resid, resids)
end
Expand All @@ -242,7 +255,7 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y::AbstractVector) where {
resid_bc = eval_bc_residual(cache.problem_type, cache.bc, y_, p, cache.mesh)
resid_co = Φ(cache, y_, u, p)
if cache.problem_type isa TwoPointBVProblem
return vcat(resid_bc.x[1], mapreduce(vec, vcat, resid_co), resid_bc.x[2])
return vcat(resid_bc[1], mapreduce(vec, vcat, resid_co), resid_bc[2])
else
return vcat(resid_bc, mapreduce(vec, vcat, resid_co))
end
Expand All @@ -268,7 +281,7 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc, loss_collocati

sd_collocation = if jac_alg.nonbc_diffmode isa AbstractSparseADType
PrecomputedJacobianColorvec(__generate_sparse_jacobian_prototype(cache,
cache.problem_type, y, cache.M, N))
cache.problem_type, y, y, cache.M, N))
else
NoSparsityDetection()
end
Expand Down Expand Up @@ -299,19 +312,20 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc, loss_collocati
return NonlinearProblem(NonlinearFunction{iip}(loss; jac, jac_prototype), y, cache.p)
end

function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc, loss_collocation, loss,
::TwoPointBVProblem) where {iip}
function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc, loss_collocation,
loss, ::TwoPointBVProblem) where {iip}
@unpack nlsolve, jac_alg = cache.alg
N = length(cache.mesh)

resid = ArrayPartition(cache.bcresid_prototype, similar(y, cache.M * (N - 1)))
resid = vcat(cache.bcresid_prototype[1:prod(cache.resid_size[1])],
similar(y, cache.M * (N - 1)),
cache.bcresid_prototype[(prod(cache.resid_size[1]) + 1):end])

# TODO: We can splitup the computation here as well similar to the Multiple Shooting
# TODO: code. That way for the BC part the actual jacobian computation is even cheaper
# TODO: Remember to not reorder if we end up using that implementation
sd = if jac_alg.diffmode isa AbstractSparseADType
PrecomputedJacobianColorvec(__generate_sparse_jacobian_prototype(cache,
cache.problem_type, resid.x[1], cache.M, N))
cache.problem_type, @view(cache.bcresid_prototype[1:prod(cache.resid_size[1])]),
@view(cache.bcresid_prototype[(prod(cache.resid_size[1]) + 1):end]), cache.M,
N))
else
NoSparsityDetection()
end
Expand Down
Loading