-
Notifications
You must be signed in to change notification settings - Fork 4.4k
Return deterministic actions #5597
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 12 commits
7e7c3e2
824f54b
3e1a60a
2918be6
f1d0965
646498e
78fd1c8
6e43451
4b6808f
a507c7d
5b059fd
c8eb7a9
283ed15
0267d3d
0431025
98da4b1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -389,6 +389,7 @@ def test_exportable_settings(use_defaults): | |
| init_entcoef: 0.5 | ||
| reward_signal_steps_per_update: 10.0 | ||
| network_settings: | ||
| deterministic: true | ||
| normalize: false | ||
| hidden_units: 256 | ||
| num_layers: 3 | ||
|
|
@@ -541,6 +542,7 @@ def test_default_settings(): | |
| test1_settings = run_options.behaviors["test1"] | ||
| assert test1_settings.max_steps == 2 | ||
| assert test1_settings.network_settings.hidden_units == 2000 | ||
| assert not test1_settings.network_settings.deterministic | ||
|
||
| assert test1_settings.network_settings.num_layers == 1000 | ||
| # Change the overridden fields back, and check if the rest are equal. | ||
| test1_settings.max_steps = 1 | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,10 +11,10 @@ | |
| from mlagents_envs.base_env import ActionSpec | ||
|
|
||
|
|
||
| def create_action_model(inp_size, act_size): | ||
| def create_action_model(inp_size, act_size, deterministic=False): | ||
| mask = torch.ones([1, act_size * 2]) | ||
| action_spec = ActionSpec(act_size, tuple(act_size for _ in range(act_size))) | ||
| action_model = ActionModel(inp_size, action_spec) | ||
| action_model = ActionModel(inp_size, action_spec, deterministic=deterministic) | ||
| return action_model, mask | ||
|
|
||
|
|
||
|
|
@@ -43,6 +43,31 @@ def test_sample_action(): | |
| assert _disc.shape == (1, 1) | ||
|
|
||
|
|
||
| def test_deterministic_sample_action(): | ||
| inp_size = 4 | ||
| act_size = 2 | ||
| action_model, masks = create_action_model(inp_size, act_size, deterministic=True) | ||
| sample_inp = torch.ones((1, inp_size)) | ||
| dists = action_model._get_dists(sample_inp, masks=masks) | ||
| agent_action1 = action_model._sample_action(dists) | ||
| agent_action2 = action_model._sample_action(dists) | ||
| agent_action3 = action_model._sample_action(dists) | ||
| assert torch.equal(agent_action1.continuous_tensor, agent_action2.continuous_tensor) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. some tests on discrete actions would be great! |
||
| assert torch.equal(agent_action1.continuous_tensor, agent_action3.continuous_tensor) | ||
| action_model, masks = create_action_model(inp_size, act_size, deterministic=False) | ||
| sample_inp = torch.ones((1, inp_size)) | ||
| dists = action_model._get_dists(sample_inp, masks=masks) | ||
| agent_action1 = action_model._sample_action(dists) | ||
| agent_action2 = action_model._sample_action(dists) | ||
| agent_action3 = action_model._sample_action(dists) | ||
| assert not torch.equal( | ||
| agent_action1.continuous_tensor, agent_action2.continuous_tensor | ||
| ) | ||
| assert not torch.equal( | ||
| agent_action1.continuous_tensor, agent_action3.continuous_tensor | ||
| ) | ||
|
|
||
|
|
||
| def test_get_probs_and_entropy(): | ||
| inp_size = 4 | ||
| act_size = 2 | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -32,6 +32,7 @@ def __init__( | |
| action_spec: ActionSpec, | ||
| conditional_sigma: bool = False, | ||
| tanh_squash: bool = False, | ||
| deterministic: bool = False, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please update the docstring |
||
| ): | ||
| """ | ||
| A torch module that represents the action space of a policy. The ActionModel may contain | ||
|
|
@@ -66,22 +67,31 @@ def __init__( | |
| # During training, clipping is done in TorchPolicy, but we need to clip before ONNX | ||
| # export as well. | ||
| self._clip_action_on_export = not tanh_squash | ||
| self._deterministic = deterministic | ||
|
|
||
| def _sample_action(self, dists: DistInstances) -> AgentAction: | ||
| """ | ||
| Samples actions from a DistInstances tuple | ||
| :params dists: The DistInstances tuple | ||
| :return: An AgentAction corresponding to the actions sampled from the DistInstances | ||
| """ | ||
|
|
||
| continuous_action: Optional[torch.Tensor] = None | ||
| discrete_action: Optional[List[torch.Tensor]] = None | ||
| # This checks None because mypy complains otherwise | ||
| if dists.continuous is not None: | ||
| continuous_action = dists.continuous.sample() | ||
| if self._deterministic: | ||
| continuous_action = dists.continuous.deterministic_sample() | ||
| else: | ||
| continuous_action = dists.continuous.sample() | ||
| if dists.discrete is not None: | ||
| discrete_action = [] | ||
| for discrete_dist in dists.discrete: | ||
| discrete_action.append(discrete_dist.sample()) | ||
| if self._deterministic: | ||
| for discrete_dist in dists.discrete: | ||
| discrete_action.append(discrete_dist.deterministic_sample()) | ||
| else: | ||
| for discrete_dist in dists.discrete: | ||
| discrete_action.append(discrete_dist.sample()) | ||
| return AgentAction(continuous_action, discrete_action) | ||
|
|
||
| def _get_dists(self, inputs: torch.Tensor, masks: torch.Tensor) -> DistInstances: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can we use == True just for readability