Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion scripts/reinforcement_learning/rsl_rl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
env = gym.wrappers.RecordVideo(env, **video_kwargs)

# wrap around environment for rsl-rl
env = RslRlVecEnvWrapper(env)
env = RslRlVecEnvWrapper(env, clip_actions=agent_cfg.clip_actions)

# create runner from rsl-rl
runner = OnPolicyRunner(env, agent_cfg.to_dict(), log_dir=log_dir, device=agent_cfg.device)
Expand Down
3 changes: 3 additions & 0 deletions source/isaaclab_rl/isaaclab_rl/rsl_rl/rl_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ class RslRlOnPolicyRunnerCfg:
algorithm: RslRlPpoAlgorithmCfg = MISSING
"""The algorithm configuration."""

clip_actions: float | None = None
"""The clipping value for actions. If ``None``, then no clipping is done."""

##
# Checkpointing parameters
##
Expand Down
17 changes: 14 additions & 3 deletions source/isaaclab_rl/isaaclab_rl/rsl_rl/vecenv_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,15 @@ class RslRlVecEnvWrapper(VecEnv):
https://github.com/leggedrobotics/rsl_rl/blob/master/rsl_rl/env/vec_env.py
"""

def __init__(self, env: ManagerBasedRLEnv | DirectRLEnv):
def __init__(self, env: ManagerBasedRLEnv | DirectRLEnv, clip_actions: float | None = None):
"""Initializes the wrapper.

Note:
The wrapper calls :meth:`reset` at the start since the RSL-RL runner does not call reset.

Args:
env: The environment to wrap around.

clip_actions: The clipping value for actions. If ``None``, then no clipping is done.
Raises:
ValueError: When the environment is not an instance of :class:`ManagerBasedRLEnv` or :class:`DirectRLEnv`.
"""
Expand All @@ -50,10 +50,14 @@ def __init__(self, env: ManagerBasedRLEnv | DirectRLEnv):
)
# initialize the wrapper
self.env = env
self.clip_actions = clip_actions

# store information required by wrapper
self.num_envs = self.unwrapped.num_envs
self.device = self.unwrapped.device
self.max_episode_length = self.unwrapped.max_episode_length

# obtain dimensions of the environment
if hasattr(self.unwrapped, "action_manager"):
self.num_actions = self.unwrapped.action_manager.total_action_dim
else:
Expand All @@ -72,6 +76,7 @@ def __init__(self, env: ManagerBasedRLEnv | DirectRLEnv):
self.num_privileged_obs = gym.spaces.flatdim(self.unwrapped.single_observation_space["critic"])
else:
self.num_privileged_obs = 0

# reset at the start since the RSL-RL runner does not call reset
self.env.reset()

Expand Down Expand Up @@ -105,7 +110,10 @@ def observation_space(self) -> gym.Space:
@property
def action_space(self) -> gym.Space:
"""Returns the :attr:`Env` :attr:`action_space`."""
return self.env.action_space
if self.clip_actions is None:
return self.env.action_space
else:
return gym.spaces.Box(low=-self.clip_actions, high=self.clip_actions, shape=(self.num_actions,))

@classmethod
def class_name(cls) -> str:
Expand Down Expand Up @@ -160,6 +168,9 @@ def reset(self) -> tuple[torch.Tensor, dict]: # noqa: D102
return obs_dict["policy"], {"observations": obs_dict}

def step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict]:
# clip actions
if self.clip_actions is not None:
actions = torch.clamp(actions, -self.clip_actions, self.clip_actions)
# record step information
obs_dict, rew, terminated, truncated, extras = self.env.step(actions)
# compute dones for compatibility with RSL-RL
Expand Down