diff --git a/scripts/rsl_rl/play.py b/scripts/rsl_rl/play.py index e3813484..dde86d4d 100644 --- a/scripts/rsl_rl/play.py +++ b/scripts/rsl_rl/play.py @@ -95,10 +95,10 @@ def main(): # export policy to onnx/jit export_model_dir = os.path.join(os.path.dirname(resume_path), "exported") export_policy_as_jit( - ppo_runner.alg.actor_critic, ppo_runner.obs_normalizer, path=export_model_dir, filename="policy.pt" + ppo_runner.alg.policy, ppo_runner.obs_normalizer, path=export_model_dir, filename="policy.pt" ) export_policy_as_onnx( - ppo_runner.alg.actor_critic, normalizer=ppo_runner.obs_normalizer, path=export_model_dir, filename="policy.onnx" + ppo_runner.alg.policy, normalizer=ppo_runner.obs_normalizer, path=export_model_dir, filename="policy.onnx" ) # reset environment