diff --git a/torchdyn/core/problems.py b/torchdyn/core/problems.py index e4b1bee..a3e66aa 100644 --- a/torchdyn/core/problems.py +++ b/torchdyn/core/problems.py @@ -66,40 +66,25 @@ def __init__(self, vector_field:Union[Callable, nn.Module], solver:Union[str, nn self.vf.register_parameter('dummy_parameter', dummy_parameter) self.vf_params = torch.cat([p.contiguous().flatten() for p in self.vf.parameters()]) - # instantiates an underlying autograd.Function that overrides the backward pass with the intended version - # sensitivity algorithm - if self.sensalg == 'adjoint': # alias .apply as direct call to preserve consistency of call signature - self.autograd_function = _gather_odefunc_adjoint(self.vf, self.vf_params, solver, atol, rtol, interpolator, - solver_adjoint, atol_adjoint, rtol_adjoint, integral_loss, - problem_type='standard').apply - elif self.sensalg == 'interpolated_adjoint': - self.autograd_function = _gather_odefunc_interp_adjoint(self.vf, self.vf_params, solver, atol, rtol, interpolator, - solver_adjoint, atol_adjoint, rtol_adjoint, integral_loss, - problem_type='standard').apply - - - def _prep_odeint(self): + def _autograd_func(self): "create autograd functions for backward pass" self.vf_params = torch.cat([p.contiguous().flatten() for p in self.vf.parameters()]) if self.sensalg == 'adjoint': # alias .apply as direct call to preserve consistency of call signature - self.autograd_function = _gather_odefunc_adjoint(self.vf, self.vf_params, self.solver, self.atol, self.rtol, self.interpolator, + return _gather_odefunc_adjoint(self.vf, self.vf_params, self.solver, self.atol, self.rtol, self.interpolator, self.solver_adjoint, self.atol_adjoint, self.rtol_adjoint, self.integral_loss, problem_type='standard').apply elif self.sensalg == 'interpolated_adjoint': - self.autograd_function = _gather_odefunc_interp_adjoint(self.vf, self.vf_params, self.solver, self.atol, self.rtol, self.interpolator, + return _gather_odefunc_interp_adjoint(self.vf, self.vf_params, self.solver, self.atol, self.rtol, self.interpolator, self.solver_adjoint, self.atol_adjoint, self.rtol_adjoint, self.integral_loss, - problem_type='standard').apply - + problem_type='standard').apply def odeint(self, x:Tensor, t_span:Tensor, save_at:Tensor=(), args={}): "Returns Tuple(`t_eval`, `solution`)" - self._prep_odeint() if self.sensalg == 'autograd': return odeint(self.vf, x, t_span, self.solver, self.atol, self.rtol, interpolator=self.interpolator, save_at=save_at, args=args) - else: - return self.autograd_function(self.vf_params, x, t_span, save_at, args) + return self._autograd_func()(self.vf_params, x, t_span, save_at, args) def forward(self, x:Tensor, t_span:Tensor, save_at:Tensor=(), args={}): "For safety redirects to intended method `odeint`" @@ -128,39 +113,29 @@ def __init__(self, vector_field:Callable, solver:str, sensitivity:str='autograd' self.parallel_solver = solver self.fine_steps, self.maxiter = fine_steps, maxiter + def _autograd_func(self): + "create autograd functions for backward pass" + self.vf_params = torch.cat([p.contiguous().flatten() for p in self.vf.parameters()]) if self.sensalg == 'adjoint': # alias .apply as direct call to preserve consistency of call signature - self.autograd_function = _gather_odefunc_adjoint(self.vf, self.vf_params, solver, 0, 0, None, - solver_adjoint, atol_adjoint, rtol_adjoint, integral_loss, - 'multiple_shooting', fine_steps, maxiter).apply + return _gather_odefunc_adjoint(self.vf, self.vf_params, self.solver, 0, 0, None, + self.solver_adjoint, self.atol_adjoint, self.rtol_adjoint, self.integral_loss, + 'multiple_shooting', self.fine_steps, self.maxiter).apply elif self.sensalg == 'interpolated_adjoint': - self.autograd_function = _gather_odefunc_interp_adjoint(self.vf, self.vf_params, solver, 0, 0, None, - solver_adjoint, atol_adjoint, rtol_adjoint, integral_loss, - 'multiple_shooting', fine_steps, maxiter).apply - + return _gather_odefunc_interp_adjoint(self.vf, self.vf_params, self.solver, 0, 0, None, + self.solver_adjoint, self.atol_adjoint, self.rtol_adjoint, self.integral_loss, + 'multiple_shooting', self.fine_steps, self.maxiter).apply + def odeint(self, x:Tensor, t_span:Tensor, B0:Tensor=None): "Returns Tuple(`t_eval`, `solution`)" - self._prep_odeint() if self.sensalg == 'autograd': return odeint_mshooting(self.vf, x, t_span, self.parallel_solver, B0, self.fine_steps, self.maxiter) else: - return self.autograd_function(self.vf_params, x, t_span, B0) + return self._autograd_func()(self.vf_params, x, t_span, B0) def forward(self, x:Tensor, t_span:Tensor, B0:Tensor=None): "For safety redirects to intended method `odeint`" return self.odeint(x, t_span, B0) - def _prep_odeint(self): - "create autograd functions for backward pass" - self.vf_params = torch.cat([p.contiguous().flatten() for p in self.vf.parameters()]) - if self.sensalg == 'adjoint': # alias .apply as direct call to preserve consistency of call signature - self.autograd_function = _gather_odefunc_adjoint(self.vf, self.vf_params, self.solver, 0, 0, None, - self.solver_adjoint, self.atol_adjoint, self.rtol_adjoint, self.integral_loss, - 'multiple_shooting', self.fine_steps, self.maxiter).apply - elif self.sensalg == 'interpolated_adjoint': - self.autograd_function = _gather_odefunc_interp_adjoint(self.vf, self.vf_params, self.solver, 0, 0, None, - self.solver_adjoint, self.atol_adjoint, self.rtol_adjoint, self.integral_loss, - 'multiple_shooting', self.fine_steps, self.maxiter).apply - class SDEProblem(nn.Module): def __init__(self): diff --git a/torchdyn/numerics/sensitivity.py b/torchdyn/numerics/sensitivity.py index 54de164..503b21e 100644 --- a/torchdyn/numerics/sensitivity.py +++ b/torchdyn/numerics/sensitivity.py @@ -32,8 +32,7 @@ def generic_odeint(problem_type, vf, x, t_span, solver, atol, rtol, interpolator def _gather_odefunc_adjoint(vf, vf_params, solver, atol, rtol, interpolator, solver_adjoint, atol_adjoint, rtol_adjoint, integral_loss, problem_type, maxiter=4, fine_steps=4): "Prepares definition of autograd.Function for adjoint sensitivity analysis of the above `ODEProblem`" - global _ODEProblemFuncAdjoint - class _ODEProblemFuncAdjoint(Function): + class _ODEProblemFunc(Function): @staticmethod def forward(ctx, vf_params, x, t_span, B=None, save_at=()): t_sol, sol = generic_odeint(problem_type, vf, x, t_span, solver, atol, rtol, interpolator, B, @@ -98,15 +97,14 @@ def adjoint_dynamics(t, A): λ_tspan = torch.stack([dLdt[0], dLdt[-1]]) return (μ, λ, λ_tspan, None, None, None) - return _ODEProblemFuncAdjoint + return _ODEProblemFunc #TODO: introduce `t_span` grad as above def _gather_odefunc_interp_adjoint(vf, vf_params, solver, atol, rtol, interpolator, solver_adjoint, atol_adjoint, rtol_adjoint, integral_loss, problem_type, maxiter=4, fine_steps=4): "Prepares definition of autograd.Function for interpolated adjoint sensitivity analysis of the above `ODEProblem`" - global _ODEProblemFuncInterpAdjoint - class _ODEProblemFuncInterpAdjoint(Function): + class _ODEProblemFunc(Function): @staticmethod def forward(ctx, vf_params, x, t_span, B=None, save_at=()): t_sol, sol = generic_odeint(problem_type, vf, x, t_span, solver, atol, rtol, interpolator, B, @@ -160,4 +158,4 @@ def adjoint_dynamics(t, A): λ, μ = λ.reshape(λT.shape), μ.reshape(μT.shape) return (μ, λ, None, None, None) - return _ODEProblemFuncInterpAdjoint \ No newline at end of file + return _ODEProblemFunc \ No newline at end of file