Ayush Jain*, Andrew Szot*, Joseph J. Lim at USC CLVR lab
[Paper website]
The structure of the repository:
- analysis: Scripts used for analysis figures and experiments.
- envs: the four subfolders in this folder contain the four environments.
- method: Implementation of all method and baseline details
- rlf: Reinforcement Learning Framework. General RL / PPO training code.
- scripts: Miscalaneous scripts. Contains script for generating the train / test action set splits.
- main.py: Entry point for running policy.
- embedder.py: Entry point for training embedder.
Log directories:
- data/trained_model/ENV-NAME_PREFIX/: Trained models.
- data/vids/ENV-NAME/: Evaluation videos.
- data/logs/ENV-NAME/PREFIX/: Tensorboard summary.
- Python 3.7
- MuJoCo 2.0
All the python package requirements are in requirements.txt. If you are using conda, you can use the following command with Python 3.7.3:
conda create -n [your_name] python=3.7
source activate [your_name]
pip install -r requirements.txt
The experiment flow for each environment is similar. The steps are always the same as follows:
- 
Generate train and test action splits: python gen_action_sets.py --env-name $ENV_NAME
- 
Generate Action Datasets for the environment: python embedder.py --env-name $PLAY_ENV_NAME --save-dataset
- 
Train action embedder model: python embedder.py --env-name $PLAY_ENV_NAME --save-emb-model-file $EMB_FILE_NAME --train-embeddings
- 
Generate embedding files: python main.py --env-name $ENV_NAME --play-env-name $PLAY_ENV_NAME --load-emb-model-file $EMB_MODEL_NAME --save-embeddings-file $EMB_FILE_NAME --prefix main
- 
Train policy with saved embeddings: python main.py --env-name $ENV_NAME --load-embeddings-file $EMB_FILE_NAME
Note:
(1) $EMB_MODEL_NAME must be $EMB_FILE_NAME-htvae-500.m if your model is trained for at least 500 epochs (specified by --emb-epochs).
(2) Use --n-trajectories 64 and --emb-epochs 500 for faster data generation and embedder training.
Below are the example commands used for each environment and method approach.
$ENV_NAME = 'CreateLevelPush-v0' or 'CreateLevelNavigate-v0' or 'CreateLevelObstacle-v0'.
$PLAY_ENV_NAME = 'StateCreateGameN1PlayNew-v0' (state-based) or 'CreateGamePlay-v0' (video-based).
$EMB_FILE_NAME = 'create_st' (state-based) or create_im (video-based)
(1) Train policy directly with:
python main.py --env-name CreateLevelPush-v0 --prefix main.
python main.py --env-name CreateLevelNavigate-v0 --prefix main.
python main.py --env-name CreateLevelObstacle-v0 --prefix main.
OR
(2) For full procedure, follow these commands:
- Generate Splits: python gen_action_sets.py --env-name CreateLevelPush-v0
- Generate Data: python embedder.py --env-name StateCreateGameN1PlayNew-v0 --save-dataset
- Train Action Embedder: python embedder.py --env-name StateCreateGameN1PlayNew-v0 --save-emb-model-file create_st --train-embeddings
- Generate embedding files: python main.py --env-name CreateLevelPush-v0 --play-env-name StateCreateGameN1PlayNew-v0 --load-emb-model-file create_st-htvae-5000.m --save-embeddings-file create_st --prefix main
- Train policy with saved embeddings: python main.py --env-name CreateLevelPush-v0 --load-embeddings-file create_st --prefix main
There is no data generation or embedding learning to recommender system
$ENV_NAME = 'RecoEnv-v0'
(1) Train policy directly with:
python main.py --env-name RecoEnv-v0 --prefix main
OR
(2) For full procedure, follow these commands:
- Generate Splits: python gen_action_sets.py --env-name RecoEnv-v0
- Policy: python main.py --env-name RecoEnv-v0 --prefix main
$ENV_NAME = 'StackEnv-v0'
$PLAY_ENV_NAME = 'BlockPlayImg-v0'
$EMB_FILE_NAME = 'stack_im'
(1) Train policy directly with:
python main.py --env-name StackEnv-v0 --prefix main
OR
(2) For full procedure, follow these commands:
- Generate Splits: python gen_action_sets.py --env-name StackEnv-v0
- Generate Data: python embedder.py --env-name BlockPlayImg-v0 --save-dataset
- Train Action Embedder: python embedder.py --env-name BlockPlayImg-v0 --save-emb-model-file stack_im --train-embeddings
- Generate embedding files: python main.py --env-name StackEnv-v0 --play-env-name BlockPlayImg-v0 --load-emb-model-file stack_im-htvae-5000.m --save-embeddings-file stack_im --prefix main
- Train policy with saved embeddings: python main.py --env-name StackEnv-v0 --load-embeddings-file stack_im --prefix main
$ENV_NAME = 'MiniGrid-LavaCrossingS9N1-v0'
$PLAY_ENV_NAME = 'MiniGrid-Empty-Random-80x80-v0'
$EMB_FILE_NAME = 'gw_onehot_new'
(1) Train policy directly with:
python main.py --env-name MiniGrid-LavaCrossingS9N1-v0 --prefix main
OR
(2) For full procedure, follow these commands:
- Generate Splits: python gen_action_sets.py --env-name MiniGrid-LavaCrossingS9N1-v0
- Generate Data: python embedder.py --env-name MiniGrid-Empty-Random-80x80-v0 --save-dataset
- Train Action Embedder: python embedder.py --env-name MiniGrid-Empty-Random-80x80-v0 --save-emb-model-file gw_onehot_new --train-embeddings
- Generate embedding files: python main.py --env-name MiniGrid-LavaCrossingS9N1-v0 --play-env-name MiniGrid-Empty-Random-80x80-v0 --load-emb-model-file gw_onehot_new-htvae-5000.m --save-embeddings-file gw_onehot_new --prefix main
- Train policy with saved embeddings: python main.py --env-name MiniGrid-LavaCrossingS9N1-v0 --load-embeddings-file gw_onehot_new --prefix main
To run the baselines for any environment, add the following to the main command:
Baselines
- Nearest-Neighbor (NN): --nearest-neighbor --fixed-action-set --action-random-sample False --prefix NN
- Distance-based Policy Architecture (Dist): --distance-based --prefix dist
- Non-hierarchical embeddings (VAE): --load-embeddings-file $FILE --prefix vae, where $FILE storing these embeddings is environment-dependent:- CREATE: create_fc_st_vae
- Shape Stacking: stack_vae
- Grid World: gw_onehot_vae
 
- CREATE: 
Ablations
- Fixed Action Space (FX): --fixed-action-set --action-random-sample False --prefix FX
- Random-Sampling without clustering (RS): --sample-clusters False --prefix RS
- No-entropy (NE): --entropy-coef 0. --prefix NE
Other embedding data formats
- CREATE: Video-based embeddings: --load-embeddings-file create_fc_im --o-dim 128 --z-dim 128 --prefix im
- Grid World: (x,y) coordinate state-based embeddings: --load-embeddings-file gw_st --prefix st
Ground-truth embeddings
for CREATE and Grid World: --gt-embs --prefix GT
For running the three analysis scripts simply run
- analysis/analysis_dist.py.
- analysis/analysis_emb.py.
- analysis/analysis_ratio.py
- PPO code is based on the Pytorch implementation of PPO by Ilya Kostrikov
- The Grid world environment is from https://github.com/maximecb/gym-minigrid
- The recommender systems environment is from https://github.com/criteo-research/reco-gym