diff --git a/flow_matching/solver/ode_solver.py b/flow_matching/solver/ode_solver.py index 2ee8a17..c924c05 100644 --- a/flow_matching/solver/ode_solver.py +++ b/flow_matching/solver/ode_solver.py @@ -181,7 +181,7 @@ def dynamics_func(t, states): sol, log_det = odeint( dynamics_func, y_init, - time_grid, + time_grid.to(x_1.device), method=method, options=ode_opts, atol=atol,