Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions com.unity.ml-agents/Runtime/Agent.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using UnityEngine;
using Barracuda;
using MLAgents.Sensors;
Expand Down Expand Up @@ -691,6 +692,17 @@ public virtual void CollectObservations(VectorSensor sensor)
{
}

/// <summary>
/// Returns a read-only view of the observations that were generated in
/// <see cref="CollectObservations(VectorSensor)"/>. This is mainly useful inside of a
/// <see cref="Heuristic(float[])"/> method to avoid recomputing the observations.
/// </summary>
/// <returns>A read-only view of the observations list.</returns>
public ReadOnlyCollection<float> GetObservations()
{
return collectObservationsSensor.GetObservations();
}

/// <summary>
/// Collects the masks for discrete actions.
/// When using discrete actions, the agent will not perform the masked action.
Expand Down
10 changes: 10 additions & 0 deletions com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Collections.Generic;
using System.Collections.ObjectModel;
using UnityEngine;

namespace MLAgents.Sensors
Expand Down Expand Up @@ -60,6 +61,15 @@ public int Write(WriteAdapter adapter)
return expectedObservations;
}

/// <summary>
/// Returns a read-only view of the observations that added.
/// </summary>
/// <returns>A read-only view of the observations list.</returns>
internal ReadOnlyCollection<float> GetObservations()
{
return m_Observations.AsReadOnly();
}

/// <inheritdoc/>
public void Update()
{
Expand Down
5 changes: 5 additions & 0 deletions com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ public override void OnEpisodeBegin()

public override void Heuristic(float[] actionsOut)
{
var obs = GetObservations();
actionsOut[0] = obs[0];
heuristicCalls++;
}
}
Expand Down Expand Up @@ -667,6 +669,9 @@ public void TestHeuristicPolicyStepsSensors()
Assert.AreEqual(numSteps, agent1.heuristicCalls);
Assert.AreEqual(numSteps, agent1.sensor1.numWriteCalls);
Assert.AreEqual(numSteps, agent1.sensor2.numCompressedCalls);

// Make sure the Heuristic method read the observation and set the action
Assert.AreEqual(agent1.collectObservationsCallsForEpisode, agent1.GetAction()[0]);
}
}

Expand Down