Skip to content
53 changes: 21 additions & 32 deletions Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,46 +32,35 @@ public override void InitializeAgent()
{
}

public override void CollectObservations(VectorSensor sensor, ActionMasker actionMasker)
public override void CollectDiscreteActionMasks(DiscreteActionMasker actionMasker)
{
// There are no numeric observations to collect as this environment uses visual
// observations.

// Mask the necessary actions if selected by the user.
if (maskActions)
{
SetMask(actionMasker);
}
}

/// <summary>
/// Applies the mask for the agents action to disallow unnecessary actions.
/// </summary>
void SetMask(ActionMasker actionMasker)
{
// Prevents the agent from picking an action that would make it collide with a wall
var positionX = (int)transform.position.x;
var positionZ = (int)transform.position.z;
var maxPosition = (int)Academy.Instance.FloatProperties.GetPropertyWithDefault("gridSize", 5f) - 1;
// Prevents the agent from picking an action that would make it collide with a wall
var positionX = (int)transform.position.x;
var positionZ = (int)transform.position.z;
var maxPosition = (int)Academy.Instance.FloatProperties.GetPropertyWithDefault("gridSize", 5f) - 1;

if (positionX == 0)
{
actionMasker.SetActionMask(k_Left);
}
if (positionX == 0)
{
actionMasker.SetDiscreteActionMask(0, new int[]{ k_Left});
}

if (positionX == maxPosition)
{
actionMasker.SetActionMask(k_Right);
}
if (positionX == maxPosition)
{
actionMasker.SetDiscreteActionMask(0, new int[]{k_Right});
}

if (positionZ == 0)
{
actionMasker.SetActionMask(k_Down);
}
if (positionZ == 0)
{
actionMasker.SetDiscreteActionMask(0, new int[]{k_Down});
}

if (positionZ == maxPosition)
{
actionMasker.SetActionMask(k_Up);
if (positionZ == maxPosition)
{
actionMasker.SetDiscreteActionMask(0, new int[]{k_Up});
}
}
}

Expand Down
3 changes: 2 additions & 1 deletion com.unity.ml-agents/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.

## [Unreleased]
### Major Changes
- Agent.CollectObservations now takes a VectorSensor argument. It was also overloaded to optionally take an ActionMasker argument. (#3352, #3389)
- `Agent.CollectObservations` now takes a VectorSensor argument. (#3352, #3389)
- Added `Agent.CollectDiscreteActionMasks` virtual method with a `DiscreteActionMasker` argument to specify which discrete actions are unavailable to the Agent. (#3525)
- Beta support for ONNX export was added. If the `tf2onnx` python package is installed, models will be saved to `.onnx` as well as `.nn` format.
Note that Barracuda 0.6.0 or later is required to import the `.onnx` files properly
- Multi-GPU training and the `--multi-gpu` option has been removed temporarily. (#3345)
Expand Down
44 changes: 4 additions & 40 deletions com.unity.ml-agents/Runtime/ActionMasker.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace MLAgents
/// left side of the board). This class represents the set of masked actions and provides
/// the utilities for setting and retrieving them.
/// </summary>
public class ActionMasker
public class DiscreteActionMasker
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't the file name need to change too?

{
/// When using discrete control, is the starting indices of the actions
/// when all the branches are concatenated with each other.
Expand All @@ -21,47 +21,11 @@ public class ActionMasker

readonly BrainParameters m_BrainParameters;

internal ActionMasker(BrainParameters brainParameters)
internal DiscreteActionMasker(BrainParameters brainParameters)
{
m_BrainParameters = brainParameters;
}

/// <summary>
/// Sets an action mask for discrete control agents. When used, the agent will not be
/// able to perform the actions passed as argument at the next decision.
/// The actionIndices correspond to the actions the agent will be unable to perform
/// on the branch 0.
/// </summary>
/// <param name="actionIndices">The indices of the masked actions on branch 0.</param>
public void SetActionMask(IEnumerable<int> actionIndices)
{
SetActionMask(0, actionIndices);
}

/// <summary>
/// Sets an action mask for discrete control agents. When used, the agent will not be
/// able to perform the action passed as argument at the next decision for the specified
/// action branch. The actionIndex correspond to the action the agent will be unable
/// to perform.
/// </summary>
/// <param name="branch">The branch for which the actions will be masked.</param>
/// <param name="actionIndex">The index of the masked action.</param>
public void SetActionMask(int branch, int actionIndex)
{
SetActionMask(branch, new[] { actionIndex });
}

/// <summary>
/// Sets an action mask for discrete control agents. When used, the agent will not be
/// able to perform the action passed as argument at the next decision. The actionIndex
/// correspond to the action the agent will be unable to perform on the branch 0.
/// </summary>
/// <param name="actionIndex">The index of the masked action on branch 0</param>
public void SetActionMask(int actionIndex)
{
SetActionMask(0, new[] { actionIndex });
}

/// <summary>
/// Modifies an action mask for discrete control agents. When used, the agent will not be
/// able to perform the actions passed as argument at the next decision for the specified
Expand All @@ -70,7 +34,7 @@ public void SetActionMask(int actionIndex)
/// </summary>
/// <param name="branch">The branch for which the actions will be masked</param>
/// <param name="actionIndices">The indices of the masked actions</param>
public void SetActionMask(int branch, IEnumerable<int> actionIndices)
public void SetDiscreteActionMask(int branch, IEnumerable<int> actionIndices)
{
// If the branch does not exist, raise an error
if (branch >= m_BrainParameters.vectorActionSize.Length)
Expand Down Expand Up @@ -110,7 +74,7 @@ public void SetActionMask(int branch, IEnumerable<int> actionIndices)
/// </summary>
/// <returns>A mask for the agent. A boolean array of length equal to the total number of
/// actions.</returns>
internal bool[] GetMask()
internal bool[] GetDiscreteActionMask()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm OK leaving this as GetMask(). The "DiscreteAction" part in the class name seems like enough

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You think? Why is it different from SetMask()?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just think it's redundant. DiscreteActionMasker.GetMask() and DiscreteActionMasker.GetDiscreteActionMask() convey the same amount of info. (same for Set...)

{
if (m_CurrentMask != null)
{
Expand Down
60 changes: 19 additions & 41 deletions com.unity.ml-agents/Runtime/Agent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ internal struct AgentAction
/// environment. Observations are determined by the cameras attached
/// to the agent in addition to the vector observations implemented by the
/// user in <see cref="Agent.CollectObservations(VectorSensor)"/> or
/// <see cref="Agent.CollectObservations(VectorSensor, ActionMasker)"/>.
/// <see cref="Agent.CollectObservations(VectorSensor, DiscreteActionMasker)"/>.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This overload no longer exists, right?

/// On the other hand, actions are determined by decisions produced by a Policy.
/// Currently, this class is expected to be extended to implement the desired agent behavior.
/// </summary>
Expand Down Expand Up @@ -173,7 +173,7 @@ internal struct AgentParameters
bool m_Initialized;

/// Keeps track of the actions that are masked at each step.
ActionMasker m_ActionMasker;
DiscreteActionMasker m_ActionMasker;

/// <summary>
/// Set of DemonstrationWriters that the Agent will write its step information to.
Expand Down Expand Up @@ -408,7 +408,7 @@ public void RequestAction()
void ResetData()
{
var param = m_PolicyFactory.brainParameters;
m_ActionMasker = new ActionMasker(param);
m_ActionMasker = new DiscreteActionMasker(param);
// If we haven't initialized vectorActions, initialize to 0. This should only
// happen during the creation of the Agent. In subsequent episodes, vectorAction
// should stay the previous action before the Done(), so that it is properly recorded.
Expand Down Expand Up @@ -523,9 +523,16 @@ void SendInfoToBrain()
UpdateSensors();
using (TimerStack.Instance.Scoped("CollectObservations"))
{
CollectObservations(collectObservationsSensor, m_ActionMasker);
CollectObservations(collectObservationsSensor);
}
m_Info.actionMasks = m_ActionMasker.GetMask();
using (TimerStack.Instance.Scoped("CollectDiscreteActionMasks"))
{
if (m_PolicyFactory.brainParameters.vectorActionSpaceType == SpaceType.Discrete)
{
CollectDiscreteActionMasks(m_ActionMasker);
}
}
m_Info.actionMasks = m_ActionMasker.GetDiscreteActionMask();

m_Info.reward = m_Reward;
m_Info.done = false;
Expand Down Expand Up @@ -586,51 +593,22 @@ public virtual void CollectObservations(VectorSensor sensor)
}

/// <summary>
/// Collects the vector observations of the agent alongside the masked actions.
/// The agent observation describes the current environment from the
/// perspective of the agent.
/// Collects the masks for discrete actions.
/// When using discrete actions, the agent will not perform the masked action.
/// </summary>
/// <param name="sensor">
/// The vector observations for the agent.
/// </param>
/// <param name="actionMasker">
/// The masked actions for the agent.
/// The action masker for the agent.
/// </param>
/// <remarks>
/// An agents observation is any environment information that helps
/// the Agent achieve its goal. For example, for a fighting Agent, its
/// observation could include distances to friends or enemies, or the
/// current level of ammunition at its disposal.
/// Recall that an Agent may attach vector or visual observations.
/// Vector observations are added by calling the provided helper methods
/// on the VectorSensor input:
/// - <see cref="VectorSensor.AddObservation(int)"/>
/// - <see cref="VectorSensor.AddObservation(float)"/>
/// - <see cref="VectorSensor.AddObservation(Vector3)"/>
/// - <see cref="VectorSensor.AddObservation(Vector2)"/>
/// - <see cref="VectorSensor.AddObservation(Quaternion)"/>
/// - <see cref="VectorSensor.AddObservation(bool)"/>
/// - <see cref="VectorSensor.AddObservation(IEnumerable{float})"/>
/// - <see cref="VectorSensor.AddOneHotObservation(int, int)"/>
/// Depending on your environment, any combination of these helpers can
/// be used. They just need to be used in the exact same order each time
/// this method is called and the resulting size of the vector observation
/// needs to match the vectorObservationSize attribute of the linked Brain.
/// Visual observations are implicitly added from the cameras attached to
/// the Agent.
/// When using Discrete Control, you can prevent the Agent from using a certain
/// action by masking it. You can call the following method on the ActionMasker
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This still feels a bit redundant; how about something like

/// When using Discrete Control, you can prevent the Agent from using a certain
/// action by masking it with <see cref="DiscreteActionMasker.SetMask(int, IEnumerable{int})"/>

(Remove the "The first argument..." part, since the DiscreteActionMasker docs should cover that.)

/// input :
/// - <see cref="ActionMasker.SetActionMask(int)"/>
/// - <see cref="ActionMasker.SetActionMask(int, int)"/>
/// - <see cref="ActionMasker.SetActionMask(int, IEnumerable{int})"/>
/// - <see cref="ActionMasker.SetActionMask(IEnumerable{int})"/>
/// The branch input is the index of the action, actionIndices are the indices of the
/// invalid options for that action.
/// - <see cref="DiscreteActionMasker.SetDiscreteActionMask(int, IEnumerable{int})"/>
/// The first argument is the branch of the action, the second is an Enumerable
/// of indices corresponding to the invalid options for that action.
/// </remarks>
public virtual void CollectObservations(VectorSensor sensor, ActionMasker actionMasker)
public virtual void CollectDiscreteActionMasks(DiscreteActionMasker actionMasker)
{
CollectObservations(sensor);
}

/// <summary>
Expand Down
58 changes: 29 additions & 29 deletions com.unity.ml-agents/Tests/Editor/EditModeTestActionMasker.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ public class EditModeTestActionMasker
public void Contruction()
{
var bp = new BrainParameters();
var masker = new ActionMasker(bp);
var masker = new DiscreteActionMasker(bp);
Assert.IsNotNull(masker);
}

Expand All @@ -18,18 +18,18 @@ public void FailsWithContinuous()
var bp = new BrainParameters();
bp.vectorActionSpaceType = SpaceType.Continuous;
bp.vectorActionSize = new[] {4};
var masker = new ActionMasker(bp);
masker.SetActionMask(0, new[] {0});
Assert.Catch<UnityAgentsException>(() => masker.GetMask());
var masker = new DiscreteActionMasker(bp);
masker.SetDiscreteActionMask(0, new[] {0});
Assert.Catch<UnityAgentsException>(() => masker.GetDiscreteActionMask());
}

[Test]
public void NullMask()
{
var bp = new BrainParameters();
bp.vectorActionSpaceType = SpaceType.Discrete;
var masker = new ActionMasker(bp);
var mask = masker.GetMask();
var masker = new DiscreteActionMasker(bp);
var mask = masker.GetDiscreteActionMask();
Assert.IsNull(mask);
}

Expand All @@ -39,11 +39,11 @@ public void FirstBranchMask()
var bp = new BrainParameters();
bp.vectorActionSpaceType = SpaceType.Discrete;
bp.vectorActionSize = new[] {4, 5, 6};
var masker = new ActionMasker(bp);
var mask = masker.GetMask();
var masker = new DiscreteActionMasker(bp);
var mask = masker.GetDiscreteActionMask();
Assert.IsNull(mask);
masker.SetActionMask(0, new[] {1, 2, 3});
mask = masker.GetMask();
masker.SetDiscreteActionMask(0, new[] {1, 2, 3});
mask = masker.GetDiscreteActionMask();
Assert.IsFalse(mask[0]);
Assert.IsTrue(mask[1]);
Assert.IsTrue(mask[2]);
Expand All @@ -60,9 +60,9 @@ public void SecondBranchMask()
vectorActionSpaceType = SpaceType.Discrete,
vectorActionSize = new[] { 4, 5, 6 }
};
var masker = new ActionMasker(bp);
masker.SetActionMask(1, new[] {1, 2, 3});
var mask = masker.GetMask();
var masker = new DiscreteActionMasker(bp);
masker.SetDiscreteActionMask(1, new[] {1, 2, 3});
var mask = masker.GetDiscreteActionMask();
Assert.IsFalse(mask[0]);
Assert.IsFalse(mask[4]);
Assert.IsTrue(mask[5]);
Expand All @@ -80,10 +80,10 @@ public void MaskReset()
vectorActionSpaceType = SpaceType.Discrete,
vectorActionSize = new[] { 4, 5, 6 }
};
var masker = new ActionMasker(bp);
masker.SetActionMask(1, new[] {1, 2, 3});
var masker = new DiscreteActionMasker(bp);
masker.SetDiscreteActionMask(1, new[] {1, 2, 3});
masker.ResetMask();
var mask = masker.GetMask();
var mask = masker.GetDiscreteActionMask();
for (var i = 0; i < 15; i++)
{
Assert.IsFalse(mask[i]);
Expand All @@ -98,20 +98,20 @@ public void ThrowsError()
vectorActionSpaceType = SpaceType.Discrete,
vectorActionSize = new[] { 4, 5, 6 }
};
var masker = new ActionMasker(bp);
var masker = new DiscreteActionMasker(bp);

Assert.Catch<UnityAgentsException>(
() => masker.SetActionMask(0, new[] {5}));
() => masker.SetDiscreteActionMask(0, new[] {5}));
Assert.Catch<UnityAgentsException>(
() => masker.SetActionMask(1, new[] {5}));
masker.SetActionMask(2, new[] {5});
() => masker.SetDiscreteActionMask(1, new[] {5}));
masker.SetDiscreteActionMask(2, new[] {5});
Assert.Catch<UnityAgentsException>(
() => masker.SetActionMask(3, new[] {1}));
masker.GetMask();
() => masker.SetDiscreteActionMask(3, new[] {1}));
masker.GetDiscreteActionMask();
masker.ResetMask();
masker.SetActionMask(0, new[] {0, 1, 2, 3});
masker.SetDiscreteActionMask(0, new[] {0, 1, 2, 3});
Assert.Catch<UnityAgentsException>(
() => masker.GetMask());
() => masker.GetDiscreteActionMask());
}

[Test]
Expand All @@ -120,11 +120,11 @@ public void MultipleMaskEdit()
var bp = new BrainParameters();
bp.vectorActionSpaceType = SpaceType.Discrete;
bp.vectorActionSize = new[] {4, 5, 6};
var masker = new ActionMasker(bp);
masker.SetActionMask(0, new[] {0, 1});
masker.SetActionMask(0, new[] {3});
masker.SetActionMask(2, new[] {1});
var mask = masker.GetMask();
var masker = new DiscreteActionMasker(bp);
masker.SetDiscreteActionMask(0, new[] {0, 1});
masker.SetDiscreteActionMask(0, new[] {3});
masker.SetDiscreteActionMask(2, new[] {1});
var mask = masker.GetDiscreteActionMask();
for (var i = 0; i < 15; i++)
{
if ((i == 0) || (i == 1) || (i == 3) || (i == 10))
Expand Down
Loading