Skip to content

Commit ee81d99

Browse files
author
Chris Elion
authored
[MLA-16] add filter mask to ray perception (#3111)
* add filter mask to ray perception * use LayerMask type
1 parent dfe9c11 commit ee81d99

File tree

4 files changed

+26
-8
lines changed

4 files changed

+26
-8
lines changed

UnitySDK/Assets/ML-Agents/Editor/Tests/RayPerceptionTests.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,11 @@ public void TestPerception2D()
3030
tags);
3131
Assert.IsTrue(result.Count == angles.Length * (tags.Length + 2));
3232
}
33+
34+
[Test]
35+
public void TestConstants()
36+
{
37+
Assert.AreEqual(Physics.DefaultRaycastLayers, Physics2D.DefaultRaycastLayers);
38+
}
3339
}
3440
}

UnitySDK/Assets/ML-Agents/Scripts/Sensor/RayPerceptionSensor.cs

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ public enum CastType
2525
float m_CastRadius;
2626
CastType m_CastType;
2727
Transform m_Transform;
28+
int m_LayerMask;
2829

2930
/// <summary>
3031
/// Debug information for the raycast hits. This is used by the RayPerceptionSensorComponent.
@@ -67,7 +68,8 @@ public DebugDisplayInfo debugDisplayInfo
6768
}
6869

6970
public RayPerceptionSensor(string name, float rayDistance, List<string> detectableObjects, float[] angles,
70-
Transform transform, float startOffset, float endOffset, float castRadius, CastType castType)
71+
Transform transform, float startOffset, float endOffset, float castRadius, CastType castType,
72+
int rayLayerMask)
7173
{
7274
var numObservations = (detectableObjects.Count + 2) * angles.Length;
7375
m_Shape = new[] { numObservations };
@@ -84,6 +86,7 @@ public RayPerceptionSensor(string name, float rayDistance, List<string> detectab
8486
m_EndOffset = endOffset;
8587
m_CastRadius = castRadius;
8688
m_CastType = castType;
89+
m_LayerMask = rayLayerMask;
8790

8891
if (Application.isEditor)
8992
{
@@ -97,7 +100,8 @@ public int Write(WriteAdapter adapter)
97100
{
98101
PerceiveStatic(
99102
m_RayDistance, m_Angles, m_DetectableObjects, m_StartOffset, m_EndOffset,
100-
m_CastRadius, m_Transform, m_CastType, m_Observations, false, m_DebugDisplayInfo
103+
m_CastRadius, m_Transform, m_CastType, m_Observations, false, m_LayerMask,
104+
m_DebugDisplayInfo
101105
);
102106
adapter.AddRange(m_Observations);
103107
}
@@ -164,6 +168,7 @@ public static void PerceiveStatic(float rayLength,
164168
float startOffset, float endOffset, float castRadius,
165169
Transform transform, CastType castType, float[] perceptionBuffer,
166170
bool legacyHitFractionBehavior = false,
171+
int layerMask = Physics.DefaultRaycastLayers,
167172
DebugDisplayInfo debugInfo = null)
168173
{
169174
Array.Clear(perceptionBuffer, 0, perceptionBuffer.Length);
@@ -221,11 +226,13 @@ public static void PerceiveStatic(float rayLength,
221226
RaycastHit rayHit;
222227
if (castRadius > 0f)
223228
{
224-
castHit = Physics.SphereCast(startPositionWorld, castRadius, rayDirection, out rayHit, rayLength);
229+
castHit = Physics.SphereCast(startPositionWorld, castRadius, rayDirection, out rayHit,
230+
rayLength, layerMask);
225231
}
226232
else
227233
{
228-
castHit = Physics.Raycast(startPositionWorld, rayDirection, out rayHit, rayLength);
234+
castHit = Physics.Raycast(startPositionWorld, rayDirection, out rayHit,
235+
rayLength, layerMask);
229236
}
230237

231238
hitFraction = castHit ? rayHit.distance / rayLength : 1.0f;
@@ -236,11 +243,12 @@ public static void PerceiveStatic(float rayLength,
236243
RaycastHit2D rayHit;
237244
if (castRadius > 0f)
238245
{
239-
rayHit = Physics2D.CircleCast(startPositionWorld, castRadius, rayDirection, rayLength);
246+
rayHit = Physics2D.CircleCast(startPositionWorld, castRadius, rayDirection,
247+
rayLength, layerMask);
240248
}
241249
else
242250
{
243-
rayHit = Physics2D.Raycast(startPositionWorld, rayDirection, rayLength);
251+
rayHit = Physics2D.Raycast(startPositionWorld, rayDirection, rayLength, layerMask);
244252
}
245253

246254
castHit = rayHit;

UnitySDK/Assets/ML-Agents/Scripts/Sensor/RayPerceptionSensorComponentBase.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ public abstract class RayPerceptionSensorComponentBase : SensorComponent
2727
[Tooltip("Length of the rays to cast.")]
2828
public float rayLength = 20f;
2929

30+
[Tooltip("Controls which layers the rays can hit.")]
31+
public LayerMask rayLayerMask = Physics.DefaultRaycastLayers;
32+
3033
[Range(1, 50)]
3134
[Tooltip("Whether to stack previous observations. Using 1 means no previous observations.")]
3235
public int observationStacks = 1;
@@ -57,7 +60,8 @@ public override ISensor CreateSensor()
5760
{
5861
var rayAngles = GetRayAngles(raysPerDirection, maxRayDegrees);
5962
m_RaySensor = new RayPerceptionSensor(sensorName, rayLength, detectableTags, rayAngles,
60-
transform, GetStartVerticalOffset(), GetEndVerticalOffset(), sphereCastRadius, GetCastType()
63+
transform, GetStartVerticalOffset(), GetEndVerticalOffset(), sphereCastRadius, GetCastType(),
64+
rayLayerMask
6165
);
6266

6367
if (observationStacks != 1)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
m_EditorVersion: 2017.4.32f1
1+
m_EditorVersion: 2017.4.33f1

0 commit comments

Comments
 (0)