Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 17 additions & 13 deletions examples/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,6 @@
from time import strftime
from copy import deepcopy

from stable_baselines.common.vec_env import DummyVecEnv, SubprocVecEnv
from stable_baselines import PPO2

import ray
from ray import tune
from ray.tune import run_experiments
from ray.tune.registry import register_env
try:
from ray.rllib.agents.agent import get_agent_class
except ImportError:
from ray.rllib.agents.registry import get_agent_class

from flow.core.util import ensure_dir
from flow.utils.registry import env_constructor
from flow.utils.rllib import FlowParamsEncoder, get_flow_params
Expand Down Expand Up @@ -94,6 +82,9 @@ def run_model_stablebaseline(flow_params,
stable_baselines.*
the trained model
"""
from stable_baselines.common.vec_env import DummyVecEnv, SubprocVecEnv
from stable_baselines import PPO2

if num_cpus == 1:
constructor = env_constructor(params=flow_params, version=0)()
# The algorithms require a vectorized environment to run
Expand Down Expand Up @@ -139,6 +130,13 @@ def setup_exps_rllib(flow_params,
dict
training configuration parameters
"""
from ray import tune
from ray.tune.registry import register_env
try:
from ray.rllib.agents.agent import get_agent_class
except ImportError:
from ray.rllib.agents.registry import get_agent_class

horizon = flow_params['env'].horizon

alg_run = "PPO"
Expand Down Expand Up @@ -181,6 +179,9 @@ def setup_exps_rllib(flow_params,

def train_rllib(submodule, flags):
"""Train policies using the PPO algorithm in RLlib."""
import ray
from ray.tune import run_experiments

flow_params = submodule.flow_params
n_cpus = submodule.N_CPUS
n_rollouts = submodule.N_ROLLOUTS
Expand Down Expand Up @@ -216,7 +217,7 @@ def train_h_baselines(flow_params, args, multiagent):
"""Train policies using SAC and TD3 with h-baselines."""
from hbaselines.algorithms import OffPolicyRLAlgorithm
from hbaselines.utils.train import parse_options, get_hyperparameters
from hbaselines.envs.mixed_autonomy.envs import FlowEnv
from hbaselines.envs.mixed_autonomy import FlowEnv

flow_params = deepcopy(flow_params)

Expand Down Expand Up @@ -317,6 +318,9 @@ def train_h_baselines(flow_params, args, multiagent):

def train_stable_baselines(submodule, flags):
"""Train policies using the PPO algorithm in stable-baselines."""
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines import PPO2

flow_params = submodule.flow_params
# Path to the saved files
exp_tag = flow_params['exp_tag']
Expand Down