Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions ml-agents/mlagents/trainers/ppo/policy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging
import numpy as np
from typing import Any, Dict
import tensorflow as tf

from mlagents.envs.timers import timed
from mlagents.trainers import BrainInfo, ActionInfo
Expand Down Expand Up @@ -190,15 +192,24 @@ def update(self, mini_batch, num_sequences):
run_out = self._execute_model(feed_dict, self.update_dict)
return run_out

def get_value_estimates(self, brain_info, idx):
def get_value_estimates(
self, brain_info: BrainInfo, idx: int, done: bool
) -> Dict[str, float]:
"""
Generates value estimates for bootstrapping.
:param brain_info: BrainInfo to be used for bootstrapping.
:param idx: Index in BrainInfo of agent.
:param done: Whether or not this is the last element of the episode, in which case we want the value estimate to be 0.
:return: The value estimate dictionary with key being the name of the reward signal and the value the
corresponding value estimate.
"""
feed_dict = {self.model.batch_size: 1, self.model.sequence_length: 1}
if done:
return {k: 0.0 for k in self.model.value_heads.keys()}

feed_dict: Dict[tf.Tensor, Any] = {
self.model.batch_size: 1,
self.model.sequence_length: 1,
}
for i in range(len(brain_info.visual_observations)):
feed_dict[self.model.visual_in[i]] = [
brain_info.visual_observations[i][idx]
Expand All @@ -214,7 +225,8 @@ def get_value_estimates(self, brain_info, idx):
idx
].reshape([-1, len(self.model.act_size)])
value_estimates = self.sess.run(self.model.value_heads, feed_dict)
return value_estimates

return {k: float(v) for k, v in value_estimates.items()}

def get_action(self, brain_info: BrainInfo) -> ActionInfo:
"""
Expand Down
11 changes: 7 additions & 4 deletions ml-agents/mlagents/trainers/ppo/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,9 +346,12 @@ def process_experiences(
else:
bootstrapping_info = info
idx = l
value_next = self.policy.get_value_estimates(bootstrapping_info, idx)
if info.local_done[l] and not info.max_reached[l]:
value_next["extrinsic"] = 0.0
value_next = self.policy.get_value_estimates(
bootstrapping_info,
idx,
info.local_done[l] and not info.max_reached[l],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to determine this just from bootstrapping_info and idx within get_value_estimates()? I couldn't quite convince myself when I was looking at it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I'm not sure - it seems like bootstrapping_info becomes the previous info if the not condition is met. Seems like we'll run into an issue if both of those conditions are met, since bootstrapping_info will be sth different and we can't figure out if info.local_done is True.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, that's more or less what I thought. Sounds good as it is!

)

tmp_advantages = []
tmp_returns = []
for name in self.policy.reward_signals:
Expand Down Expand Up @@ -507,7 +510,7 @@ def get_gae(rewards, value_estimates, value_next=0.0, gamma=0.99, lambd=0.95):
:param lambd: GAE weighing factor.
:return: list of advantage estimates for time-steps t to T.
"""
value_estimates = np.asarray(value_estimates.tolist() + [value_next])
value_estimates = np.append(value_estimates, value_next)
delta_t = rewards + gamma * value_estimates[1:] - value_estimates[:-1]
advantage = discount_rewards(r=delta_t, gamma=gamma * lambd)
return advantage
31 changes: 31 additions & 0 deletions ml-agents/mlagents/trainers/tests/test_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,37 @@ def test_ppo_policy_evaluate(mock_communicator, mock_launcher, dummy_config):
env.close()


@mock.patch("mlagents.envs.UnityEnvironment.executable_launcher")
@mock.patch("mlagents.envs.UnityEnvironment.get_communicator")
def test_ppo_get_value_estimates(mock_communicator, mock_launcher, dummy_config):
tf.reset_default_graph()
mock_communicator.return_value = MockCommunicator(
discrete_action=False, visual_inputs=0
)
env = UnityEnvironment(" ")
brain_infos = env.reset()
brain_info = brain_infos[env.brain_names[0]]

trainer_parameters = dummy_config
model_path = env.brain_names[0]
trainer_parameters["model_path"] = model_path
trainer_parameters["keep_checkpoints"] = 3
policy = PPOPolicy(
0, env.brains[env.brain_names[0]], trainer_parameters, False, False
)
run_out = policy.get_value_estimates(brain_info, 0, done=False)
for key, val in run_out.items():
assert type(key) is str
assert type(val) is float

run_out = policy.get_value_estimates(brain_info, 0, done=True)
for key, val in run_out.items():
assert type(key) is str
assert val == 0.0

env.close()


@mock.patch("mlagents.envs.UnityEnvironment.executable_launcher")
@mock.patch("mlagents.envs.UnityEnvironment.get_communicator")
def test_ppo_model_cc_vector(mock_communicator, mock_launcher):
Expand Down