Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
5d09fcb
Add Built In Feature Extraction
glvov-bdai Oct 8, 2024
f7ef078
Update environments.rst
glvov-bdai Oct 8, 2024
12dadd1
Update environments.rst
glvov-bdai Oct 8, 2024
d907803
Update environments.rst
glvov-bdai Oct 8, 2024
bf2238a
convert obs to class
glvov-bdai Oct 9, 2024
11abbcf
Merge branch 'isaac-sim:main' into feature/preprocess_observation_upd…
glvov-bdai Oct 9, 2024
502b882
Update observations.py
garylvov Oct 10, 2024
ab42141
add vision transformer
garylvov Oct 11, 2024
df2b1ff
simplify MLP
garylvov Oct 11, 2024
a490961
change model
garylvov Oct 11, 2024
2343e35
Update pyproject.toml
glvov-bdai Oct 11, 2024
f92957e
Merge branch 'isaac-sim:main' into feature/preprocess_observation_upd…
glvov-bdai Oct 11, 2024
72a9fbf
update
glvov-bdai Oct 12, 2024
9447165
Merge branch 'isaac-sim:main' into feature/preprocess_observation_upd…
glvov-bdai Oct 12, 2024
5336ad6
Update feature_extractor.py
glvov-bdai Oct 12, 2024
3744dde
update env names
glvov-bdai Oct 12, 2024
953716b
Merge branch 'feature/preprocess_observation_updated' of https://gith…
glvov-bdai Oct 12, 2024
abb8743
consistent import ordering
glvov-bdai Oct 12, 2024
1e81c87
Update observations.py
glvov-bdai Oct 15, 2024
c4cbb95
Update cartpole_camera_env_cfg.py
glvov-bdai Oct 15, 2024
a06033e
formatting
glvov-bdai Oct 21, 2024
4c43aca
format but actually good this time
glvov-bdai Oct 21, 2024
4f61f88
fix to inherit from obs group
glvov-bdai Oct 21, 2024
ecb65ce
Merge branch 'isaac-sim:main' into feature/preprocess_observation_upd…
glvov-bdai Oct 22, 2024
27451f9
Merge branch 'main' into feature/preprocess_observation_updated
Oct 24, 2024
8829417
Merge branch 'main' into feature/preprocess_observation_updated
Oct 24, 2024
01bda48
Merge branch 'main' into feature/preprocess_observation_updated
glvov-bdai Oct 24, 2024
e9de97f
Merge branch 'main' into feature/preprocess_observation_updated
glvov-bdai Oct 28, 2024
cf79302
formatting
glvov-bdai Oct 28, 2024
ca775fa
Add changelog about cartpole renaming
glvov-bdai Oct 28, 2024
05ac736
Update source/extensions/omni.isaac.lab/omni/isaac/lab/envs/mdp/obser…
glvov-bdai Oct 28, 2024
c1f412f
address James' comments
glvov-bdai Oct 28, 2024
4bb7993
Merge branch 'main' into feature/preprocess_observation_updated
glvov-bdai Oct 28, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/source/overview/environments.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ Classic environments that are based on IsaacGymEnvs implementation of MuJoCo-sty
| | | and perceptive inputs |
| | |cartpole-depth-link| | |
| | | |
| | |cartpole-resnet-link| | |
| | | |
| | |cartpole-rgb-direct-link| | |
| | | |
| | |cartpole-depth-direct-link|| |
Expand All @@ -71,6 +73,7 @@ Classic environments that are based on IsaacGymEnvs implementation of MuJoCo-sty
.. |cartpole-link| replace:: `Isaac-Cartpole-v0 <https://github.com/isaac-sim/IsaacLab/blob/main/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/classic/cartpole/cartpole_env_cfg.py>`__
.. |cartpole-rgb-link| replace:: `Isaac-Cartpole-RGB-Camera-v0 <https://github.com/isaac-sim/IsaacLab/blob/main/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/classic/cartpole/cartpole_camera_env_cfg.py>`__
.. |cartpole-depth-link| replace:: `Isaac-Cartpole-Depth-Camera-v0 <https://github.com/isaac-sim/IsaacLab/blob/main/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/classic/cartpole/cartpole_camera_env_cfg.py>`__
.. |cartpole-resnet-link| replace:: `Isaac-Cartpole-ResNet18-Camera-v0 <https://github.com/isaac-sim/IsaacLab/blob/main/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/classic/cartpole/cartpole_camera_env_cfg.py>`__

.. |humanoid-direct-link| replace:: `Isaac-Humanoid-Direct-v0 <https://github.com/isaac-sim/IsaacLab/blob/main/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/direct/humanoid/humanoid_env.py>`__
.. |ant-direct-link| replace:: `Isaac-Ant-Direct-v0 <https://github.com/isaac-sim/IsaacLab/blob/main/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/direct/ant/ant_env.py>`__
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ known_third_party = [
"warp",
"carb",
"Semantics",
"torchvision"
]
# Imports from this repository
known_first_party = "omni.isaac.lab"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,15 @@
import omni.isaac.lab.utils.math as math_utils
from omni.isaac.lab.assets import Articulation, RigidObject
from omni.isaac.lab.managers import SceneEntityCfg
from omni.isaac.lab.managers.manager_base import ManagerTermBase
from omni.isaac.lab.managers.manager_term_cfg import ObservationTermCfg
from omni.isaac.lab.sensors import Camera, RayCaster, RayCasterCamera, TiledCamera

if TYPE_CHECKING:
from omni.isaac.lab.envs import ManagerBasedEnv, ManagerBasedRLEnv

from torchvision import models

"""
Root state.
"""
Expand Down Expand Up @@ -231,6 +235,75 @@ def image(
return images.clone()


class image_features(ManagerTermBase):
"""Extracted image features with a frozen encoder from images of a specific datatype from the camera sensor.

Calls :meth:`image` to get the images, then performs inference. On initialization,
for a model zoo different from the default, define model_zoo_cfg: A dictionary with string keys and callable values.
Should include "model", (mapped to a callable with no arguments to return the model), "preprocess" (mapped to
a callable which consumes the images and returns the preprocessed images),
and "inference" (mapped to a callable that provided the model, and the preproccessed images, returns the features.)
"""

def __init__(
self,
cfg: ObservationTermCfg,
env: ManagerBasedEnv,
model_zoo_cfg: dict | None = None,
initialize_all: bool = False,
):
super().__init__(cfg, env)
if model_zoo_cfg is None:
self.model_zoo_cfg = {
"ResNet18": {
"model": lambda: models.resnet18(pretrained=True).eval().to("cuda:0"),
"preprocess": lambda img: (
img.permute(0, 3, 1, 2) # Convert [batch, height, width, 3] -> [batch, 3, height, width]
# Normalize in the format expected by pytorch; https://pytorch.org/hub/pytorch_vision_resnet/
- torch.tensor([0.485, 0.456, 0.406], device=img.device).view(1, 3, 1, 1)
) / torch.tensor([0.229, 0.224, 0.225], device=img.device).view(1, 3, 1, 1),
"inference": lambda model, images: model(images),
},
}
self.reset_model(initialize_all=initialize_all)

# The following is named reset_model instead of reset as otherwise, it's called at the end of every episode
def reset_model(self, model_name: str | None = None, initialize_all: bool = False):
if model_name is None:
print("[WARNING]: No model name supplied, emptying entire model zoo.")
self.model_zoo = {}
elif model_name is not None:
self.model_zoo[model_name] = self.model_zoo_cfg[model_name]["model"]()
if initialize_all:
for model_name, model_callables in self.model_zoo_cfg.items():
self.model_zoo[model_name] = model_callables["model"]()

def __call__(
self,
env: ManagerBasedEnv,
sensor_cfg: SceneEntityCfg = SceneEntityCfg("tiled_camera"),
data_type: str = "rgb",
convert_perspective_to_orthogonal: bool = False,
model_name: str = "ResNet18",
) -> torch.Tensor:
if model_name not in self.model_zoo:
print(f"[INFO]: Adding {model_name} to the model zoo")
self.model_zoo[model_name] = self.model_zoo_cfg[model_name]["model"]()

images = image(
env=env,
sensor_cfg=sensor_cfg,
data_type=data_type,
convert_perspective_to_orthogonal=convert_perspective_to_orthogonal,
normalize=True, # want this for training stability
)

proc_images = self.model_zoo_cfg[model_name]["preprocess"](images)
features = self.model_zoo_cfg[model_name]["inference"](self.model_zoo[model_name], proc_images)

return features


"""
Actions.
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import gymnasium as gym

from . import agents
from .cartpole_camera_env_cfg import CartpoleDepthCameraEnvCfg, CartpoleRGBCameraEnvCfg
from .cartpole_camera_env_cfg import CartpoleDepthCameraEnvCfg, CartpoleResNet18CameraEnv, CartpoleRGBCameraEnvCfg
from .cartpole_env_cfg import CartpoleEnvCfg

##
Expand Down Expand Up @@ -49,3 +49,13 @@
"rl_games_cfg_entry_point": f"{agents.__name__}:rl_games_camera_ppo_cfg.yaml",
},
)

gym.register(
id="Isaac-Cartpole-ResNet18-Camera-v0",
entry_point="omni.isaac.lab.envs:ManagerBasedRLEnv",
disable_env_checker=True,
kwargs={
"env_cfg_entry_point": CartpoleResNet18CameraEnv,
"rl_games_cfg_entry_point": f"{agents.__name__}:rl_games_feature_ppo_cfg.yaml",
},
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
params:
seed: 42

# environment wrapper clipping
env:
# added to the wrapper
clip_observations: 5.0
# can make custom wrapper?
clip_actions: 1.0

algo:
name: a2c_continuous

model:
name: continuous_a2c_logstd

# doesn't have this fine grained control but made it close
network:
name: actor_critic
separate: False
space:
continuous:
mu_activation: None
sigma_activation: None

mu_init:
name: default
sigma_init:
name: const_initializer
val: 0
fixed_sigma: True
mlp:
units: [32, 64, 64]
activation: elu
d2rl: False

initializer:
name: default
regularizer:
name: None

load_checkpoint: False # flag which sets whether to load the checkpoint
load_path: '' # path to the checkpoint to load

config:
name: cartpole_features
env_name: rlgpu
device: 'cuda:0'
device_name: 'cuda:0'
multi_gpu: False
ppo: True
mixed_precision: False
normalize_input: True
normalize_value: True
value_bootstraop: True
num_actors: -1 # configured from the script (based on num_envs)
reward_shaper:
scale_value: 1.0
normalize_advantage: True
gamma: 0.99
tau : 0.95
learning_rate: 3e-4
lr_schedule: adaptive
kl_threshold: 0.008
score_to_win: 20000
max_epochs: 5000
save_best_after: 50
save_frequency: 25
grad_norm: 1.0
entropy_coef: 0.0
truncate_grads: True
e_clip: 0.2
horizon_length: 16
minibatch_size: 2048
mini_epochs: 8
critic_coef: 4
clip_value: True
seq_length: 4
bounds_loss_coef: 0.0001
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,18 @@ class DepthCameraPolicyCfg(RGBObservationsCfg.RGBCameraPolicyCfg):
policy: ObsGroup = DepthCameraPolicyCfg()


@configclass
class ResNet18ObservationCfg:
@configclass
class FeaturesCameraPolicyCfg(RGBObservationsCfg.RGBCameraPolicyCfg):
image = ObsTerm(
func=mdp.image_features,
params={"sensor_cfg": SceneEntityCfg("tiled_camera"), "data_type": "rgb", "model_name": "ResNet18"},
)

policy: ObsGroup = FeaturesCameraPolicyCfg()


##
# Environment configuration
##
Expand All @@ -107,3 +119,8 @@ class CartpoleDepthCameraEnvCfg(CartpoleEnvCfg):

scene: CartpoleSceneCfg = CartpoleDepthCameraSceneCfg(num_envs=1024, env_spacing=20)
observations: DepthObservationsCfg = DepthObservationsCfg()


@configclass
class CartpoleResNet18CameraEnv(CartpoleRGBCameraEnvCfg):
observations: ResNet18ObservationCfg = ResNet18ObservationCfg()