Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 16 additions & 41 deletions torchdyn/core/problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,40 +66,25 @@ def __init__(self, vector_field:Union[Callable, nn.Module], solver:Union[str, nn
self.vf.register_parameter('dummy_parameter', dummy_parameter)
self.vf_params = torch.cat([p.contiguous().flatten() for p in self.vf.parameters()])

# instantiates an underlying autograd.Function that overrides the backward pass with the intended version
# sensitivity algorithm
if self.sensalg == 'adjoint': # alias .apply as direct call to preserve consistency of call signature
self.autograd_function = _gather_odefunc_adjoint(self.vf, self.vf_params, solver, atol, rtol, interpolator,
solver_adjoint, atol_adjoint, rtol_adjoint, integral_loss,
problem_type='standard').apply
elif self.sensalg == 'interpolated_adjoint':
self.autograd_function = _gather_odefunc_interp_adjoint(self.vf, self.vf_params, solver, atol, rtol, interpolator,
solver_adjoint, atol_adjoint, rtol_adjoint, integral_loss,
problem_type='standard').apply


def _prep_odeint(self):
def _autograd_func(self):
"create autograd functions for backward pass"
self.vf_params = torch.cat([p.contiguous().flatten() for p in self.vf.parameters()])
if self.sensalg == 'adjoint': # alias .apply as direct call to preserve consistency of call signature
self.autograd_function = _gather_odefunc_adjoint(self.vf, self.vf_params, self.solver, self.atol, self.rtol, self.interpolator,
return _gather_odefunc_adjoint(self.vf, self.vf_params, self.solver, self.atol, self.rtol, self.interpolator,
self.solver_adjoint, self.atol_adjoint, self.rtol_adjoint, self.integral_loss,
problem_type='standard').apply
elif self.sensalg == 'interpolated_adjoint':
self.autograd_function = _gather_odefunc_interp_adjoint(self.vf, self.vf_params, self.solver, self.atol, self.rtol, self.interpolator,
return _gather_odefunc_interp_adjoint(self.vf, self.vf_params, self.solver, self.atol, self.rtol, self.interpolator,
self.solver_adjoint, self.atol_adjoint, self.rtol_adjoint, self.integral_loss,
problem_type='standard').apply

problem_type='standard').apply

def odeint(self, x:Tensor, t_span:Tensor, save_at:Tensor=(), args={}):
"Returns Tuple(`t_eval`, `solution`)"
self._prep_odeint()
if self.sensalg == 'autograd':
return odeint(self.vf, x, t_span, self.solver, self.atol, self.rtol, interpolator=self.interpolator,
save_at=save_at, args=args)

else:
return self.autograd_function(self.vf_params, x, t_span, save_at, args)
return self._autograd_func()(self.vf_params, x, t_span, save_at, args)

def forward(self, x:Tensor, t_span:Tensor, save_at:Tensor=(), args={}):
"For safety redirects to intended method `odeint`"
Expand Down Expand Up @@ -128,39 +113,29 @@ def __init__(self, vector_field:Callable, solver:str, sensitivity:str='autograd'
self.parallel_solver = solver
self.fine_steps, self.maxiter = fine_steps, maxiter

def _autograd_func(self):
"create autograd functions for backward pass"
self.vf_params = torch.cat([p.contiguous().flatten() for p in self.vf.parameters()])
if self.sensalg == 'adjoint': # alias .apply as direct call to preserve consistency of call signature
self.autograd_function = _gather_odefunc_adjoint(self.vf, self.vf_params, solver, 0, 0, None,
solver_adjoint, atol_adjoint, rtol_adjoint, integral_loss,
'multiple_shooting', fine_steps, maxiter).apply
return _gather_odefunc_adjoint(self.vf, self.vf_params, self.solver, 0, 0, None,
self.solver_adjoint, self.atol_adjoint, self.rtol_adjoint, self.integral_loss,
'multiple_shooting', self.fine_steps, self.maxiter).apply
elif self.sensalg == 'interpolated_adjoint':
self.autograd_function = _gather_odefunc_interp_adjoint(self.vf, self.vf_params, solver, 0, 0, None,
solver_adjoint, atol_adjoint, rtol_adjoint, integral_loss,
'multiple_shooting', fine_steps, maxiter).apply

return _gather_odefunc_interp_adjoint(self.vf, self.vf_params, self.solver, 0, 0, None,
self.solver_adjoint, self.atol_adjoint, self.rtol_adjoint, self.integral_loss,
'multiple_shooting', self.fine_steps, self.maxiter).apply
def odeint(self, x:Tensor, t_span:Tensor, B0:Tensor=None):
"Returns Tuple(`t_eval`, `solution`)"
self._prep_odeint()
if self.sensalg == 'autograd':
return odeint_mshooting(self.vf, x, t_span, self.parallel_solver, B0, self.fine_steps, self.maxiter)
else:
return self.autograd_function(self.vf_params, x, t_span, B0)
return self._autograd_func()(self.vf_params, x, t_span, B0)

def forward(self, x:Tensor, t_span:Tensor, B0:Tensor=None):
"For safety redirects to intended method `odeint`"
return self.odeint(x, t_span, B0)

def _prep_odeint(self):
"create autograd functions for backward pass"
self.vf_params = torch.cat([p.contiguous().flatten() for p in self.vf.parameters()])
if self.sensalg == 'adjoint': # alias .apply as direct call to preserve consistency of call signature
self.autograd_function = _gather_odefunc_adjoint(self.vf, self.vf_params, self.solver, 0, 0, None,
self.solver_adjoint, self.atol_adjoint, self.rtol_adjoint, self.integral_loss,
'multiple_shooting', self.fine_steps, self.maxiter).apply
elif self.sensalg == 'interpolated_adjoint':
self.autograd_function = _gather_odefunc_interp_adjoint(self.vf, self.vf_params, self.solver, 0, 0, None,
self.solver_adjoint, self.atol_adjoint, self.rtol_adjoint, self.integral_loss,
'multiple_shooting', self.fine_steps, self.maxiter).apply


class SDEProblem(nn.Module):
def __init__(self):
Expand Down
10 changes: 4 additions & 6 deletions torchdyn/numerics/sensitivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ def generic_odeint(problem_type, vf, x, t_span, solver, atol, rtol, interpolator
def _gather_odefunc_adjoint(vf, vf_params, solver, atol, rtol, interpolator, solver_adjoint,
atol_adjoint, rtol_adjoint, integral_loss, problem_type, maxiter=4, fine_steps=4):
"Prepares definition of autograd.Function for adjoint sensitivity analysis of the above `ODEProblem`"
global _ODEProblemFuncAdjoint
class _ODEProblemFuncAdjoint(Function):
class _ODEProblemFunc(Function):
@staticmethod
def forward(ctx, vf_params, x, t_span, B=None, save_at=()):
t_sol, sol = generic_odeint(problem_type, vf, x, t_span, solver, atol, rtol, interpolator, B,
Expand Down Expand Up @@ -98,15 +97,14 @@ def adjoint_dynamics(t, A):
λ_tspan = torch.stack([dLdt[0], dLdt[-1]])
return (μ, λ, λ_tspan, None, None, None)

return _ODEProblemFuncAdjoint
return _ODEProblemFunc


#TODO: introduce `t_span` grad as above
def _gather_odefunc_interp_adjoint(vf, vf_params, solver, atol, rtol, interpolator, solver_adjoint,
atol_adjoint, rtol_adjoint, integral_loss, problem_type, maxiter=4, fine_steps=4):
"Prepares definition of autograd.Function for interpolated adjoint sensitivity analysis of the above `ODEProblem`"
global _ODEProblemFuncInterpAdjoint
class _ODEProblemFuncInterpAdjoint(Function):
class _ODEProblemFunc(Function):
@staticmethod
def forward(ctx, vf_params, x, t_span, B=None, save_at=()):
t_sol, sol = generic_odeint(problem_type, vf, x, t_span, solver, atol, rtol, interpolator, B,
Expand Down Expand Up @@ -160,4 +158,4 @@ def adjoint_dynamics(t, A):
λ, μ = λ.reshape(λT.shape), μ.reshape(μT.shape)
return (μ, λ, None, None, None)

return _ODEProblemFuncInterpAdjoint
return _ODEProblemFunc