-
Notifications
You must be signed in to change notification settings - Fork 4.4k
[Renaming] SetActionMask -> SetDiscreteActionMask + added the virtual method CollectDiscreteActionMasks #3525
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
2ec1146
7a892a1
394e8e7
dd651ac
8b17be5
f5b8111
c9b5d4d
c3b0445
9c5b5f3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,7 +11,7 @@ namespace MLAgents | |
| /// left side of the board). This class represents the set of masked actions and provides | ||
| /// the utilities for setting and retrieving them. | ||
| /// </summary> | ||
| public class ActionMasker | ||
| public class DiscreteActionMasker | ||
| { | ||
| /// When using discrete control, is the starting indices of the actions | ||
| /// when all the branches are concatenated with each other. | ||
|
|
@@ -21,47 +21,11 @@ public class ActionMasker | |
|
|
||
| readonly BrainParameters m_BrainParameters; | ||
|
|
||
| internal ActionMasker(BrainParameters brainParameters) | ||
| internal DiscreteActionMasker(BrainParameters brainParameters) | ||
| { | ||
| m_BrainParameters = brainParameters; | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Sets an action mask for discrete control agents. When used, the agent will not be | ||
| /// able to perform the actions passed as argument at the next decision. | ||
| /// The actionIndices correspond to the actions the agent will be unable to perform | ||
| /// on the branch 0. | ||
| /// </summary> | ||
| /// <param name="actionIndices">The indices of the masked actions on branch 0.</param> | ||
| public void SetActionMask(IEnumerable<int> actionIndices) | ||
| { | ||
| SetActionMask(0, actionIndices); | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Sets an action mask for discrete control agents. When used, the agent will not be | ||
| /// able to perform the action passed as argument at the next decision for the specified | ||
| /// action branch. The actionIndex correspond to the action the agent will be unable | ||
| /// to perform. | ||
| /// </summary> | ||
| /// <param name="branch">The branch for which the actions will be masked.</param> | ||
| /// <param name="actionIndex">The index of the masked action.</param> | ||
| public void SetActionMask(int branch, int actionIndex) | ||
| { | ||
| SetActionMask(branch, new[] { actionIndex }); | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Sets an action mask for discrete control agents. When used, the agent will not be | ||
| /// able to perform the action passed as argument at the next decision. The actionIndex | ||
| /// correspond to the action the agent will be unable to perform on the branch 0. | ||
| /// </summary> | ||
| /// <param name="actionIndex">The index of the masked action on branch 0</param> | ||
| public void SetActionMask(int actionIndex) | ||
| { | ||
| SetActionMask(0, new[] { actionIndex }); | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Modifies an action mask for discrete control agents. When used, the agent will not be | ||
| /// able to perform the actions passed as argument at the next decision for the specified | ||
|
|
@@ -70,7 +34,7 @@ public void SetActionMask(int actionIndex) | |
| /// </summary> | ||
| /// <param name="branch">The branch for which the actions will be masked</param> | ||
| /// <param name="actionIndices">The indices of the masked actions</param> | ||
| public void SetActionMask(int branch, IEnumerable<int> actionIndices) | ||
| public void SetDiscreteActionMask(int branch, IEnumerable<int> actionIndices) | ||
| { | ||
| // If the branch does not exist, raise an error | ||
| if (branch >= m_BrainParameters.vectorActionSize.Length) | ||
|
|
@@ -110,7 +74,7 @@ public void SetActionMask(int branch, IEnumerable<int> actionIndices) | |
| /// </summary> | ||
| /// <returns>A mask for the agent. A boolean array of length equal to the total number of | ||
| /// actions.</returns> | ||
| internal bool[] GetMask() | ||
| internal bool[] GetDiscreteActionMask() | ||
|
||
| { | ||
| if (m_CurrentMask != null) | ||
| { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -60,7 +60,7 @@ internal struct AgentAction | |
| /// environment. Observations are determined by the cameras attached | ||
| /// to the agent in addition to the vector observations implemented by the | ||
| /// user in <see cref="Agent.CollectObservations(VectorSensor)"/> or | ||
| /// <see cref="Agent.CollectObservations(VectorSensor, ActionMasker)"/>. | ||
| /// <see cref="Agent.CollectObservations(VectorSensor, DiscreteActionMasker)"/>. | ||
|
||
| /// On the other hand, actions are determined by decisions produced by a Policy. | ||
| /// Currently, this class is expected to be extended to implement the desired agent behavior. | ||
| /// </summary> | ||
|
|
@@ -173,7 +173,7 @@ internal struct AgentParameters | |
| bool m_Initialized; | ||
|
|
||
| /// Keeps track of the actions that are masked at each step. | ||
| ActionMasker m_ActionMasker; | ||
| DiscreteActionMasker m_ActionMasker; | ||
|
|
||
| /// <summary> | ||
| /// Set of DemonstrationWriters that the Agent will write its step information to. | ||
|
|
@@ -408,7 +408,7 @@ public void RequestAction() | |
| void ResetData() | ||
| { | ||
| var param = m_PolicyFactory.brainParameters; | ||
| m_ActionMasker = new ActionMasker(param); | ||
| m_ActionMasker = new DiscreteActionMasker(param); | ||
| // If we haven't initialized vectorActions, initialize to 0. This should only | ||
| // happen during the creation of the Agent. In subsequent episodes, vectorAction | ||
| // should stay the previous action before the Done(), so that it is properly recorded. | ||
|
|
@@ -523,9 +523,16 @@ void SendInfoToBrain() | |
| UpdateSensors(); | ||
| using (TimerStack.Instance.Scoped("CollectObservations")) | ||
| { | ||
| CollectObservations(collectObservationsSensor, m_ActionMasker); | ||
| CollectObservations(collectObservationsSensor); | ||
| } | ||
| m_Info.actionMasks = m_ActionMasker.GetMask(); | ||
| using (TimerStack.Instance.Scoped("CollectDiscreteActionMasks")) | ||
| { | ||
| if (m_PolicyFactory.brainParameters.vectorActionSpaceType == SpaceType.Discrete) | ||
| { | ||
| CollectDiscreteActionMasks(m_ActionMasker); | ||
| } | ||
| } | ||
| m_Info.actionMasks = m_ActionMasker.GetDiscreteActionMask(); | ||
|
|
||
| m_Info.reward = m_Reward; | ||
| m_Info.done = false; | ||
|
|
@@ -586,51 +593,22 @@ public virtual void CollectObservations(VectorSensor sensor) | |
| } | ||
|
|
||
| /// <summary> | ||
| /// Collects the vector observations of the agent alongside the masked actions. | ||
| /// The agent observation describes the current environment from the | ||
| /// perspective of the agent. | ||
| /// Collects the masks for discrete actions. | ||
| /// When using discrete actions, the agent will not perform the masked action. | ||
| /// </summary> | ||
| /// <param name="sensor"> | ||
| /// The vector observations for the agent. | ||
| /// </param> | ||
| /// <param name="actionMasker"> | ||
| /// The masked actions for the agent. | ||
| /// The action masker for the agent. | ||
| /// </param> | ||
| /// <remarks> | ||
| /// An agents observation is any environment information that helps | ||
| /// the Agent achieve its goal. For example, for a fighting Agent, its | ||
| /// observation could include distances to friends or enemies, or the | ||
| /// current level of ammunition at its disposal. | ||
| /// Recall that an Agent may attach vector or visual observations. | ||
| /// Vector observations are added by calling the provided helper methods | ||
| /// on the VectorSensor input: | ||
| /// - <see cref="VectorSensor.AddObservation(int)"/> | ||
| /// - <see cref="VectorSensor.AddObservation(float)"/> | ||
| /// - <see cref="VectorSensor.AddObservation(Vector3)"/> | ||
| /// - <see cref="VectorSensor.AddObservation(Vector2)"/> | ||
| /// - <see cref="VectorSensor.AddObservation(Quaternion)"/> | ||
| /// - <see cref="VectorSensor.AddObservation(bool)"/> | ||
| /// - <see cref="VectorSensor.AddObservation(IEnumerable{float})"/> | ||
| /// - <see cref="VectorSensor.AddOneHotObservation(int, int)"/> | ||
| /// Depending on your environment, any combination of these helpers can | ||
| /// be used. They just need to be used in the exact same order each time | ||
| /// this method is called and the resulting size of the vector observation | ||
| /// needs to match the vectorObservationSize attribute of the linked Brain. | ||
| /// Visual observations are implicitly added from the cameras attached to | ||
| /// the Agent. | ||
| /// When using Discrete Control, you can prevent the Agent from using a certain | ||
| /// action by masking it. You can call the following method on the ActionMasker | ||
|
||
| /// input : | ||
| /// - <see cref="ActionMasker.SetActionMask(int)"/> | ||
| /// - <see cref="ActionMasker.SetActionMask(int, int)"/> | ||
| /// - <see cref="ActionMasker.SetActionMask(int, IEnumerable{int})"/> | ||
| /// - <see cref="ActionMasker.SetActionMask(IEnumerable{int})"/> | ||
| /// The branch input is the index of the action, actionIndices are the indices of the | ||
| /// invalid options for that action. | ||
| /// - <see cref="DiscreteActionMasker.SetDiscreteActionMask(int, IEnumerable{int})"/> | ||
| /// The first argument is the branch of the action, the second is an Enumerable | ||
| /// of indices corresponding to the invalid options for that action. | ||
| /// </remarks> | ||
| public virtual void CollectObservations(VectorSensor sensor, ActionMasker actionMasker) | ||
| public virtual void CollectDiscreteActionMasks(DiscreteActionMasker actionMasker) | ||
| { | ||
| CollectObservations(sensor); | ||
| } | ||
|
|
||
| /// <summary> | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't the file name need to change too?