Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@ public void TestWritesToTensor()
valueType = TensorProxy.TensorType.FloatingPoint,
data = new Tensor(2, 3)
};
var shape = new[] { 3 };
writer.SetTarget(t, shape, 0, 0);

writer.SetTarget(t, 0, 0);
Assert.AreEqual(0f, t.data[0, 0]);
writer[0] = 1f;
Assert.AreEqual(1f, t.data[0, 0]);

writer.SetTarget(t, shape, 1, 1);
writer.SetTarget(t, 1, 1);
writer[0] = 2f;
writer[1] = 3f;
// [0, 0] shouldn't change
Expand All @@ -69,7 +69,7 @@ public void TestWritesToTensor()
data = new Tensor(2, 3)
};

writer.SetTarget(t, shape, 1, 1);
writer.SetTarget(t, 1, 1);
writer.AddRange(new [] {-1f, -2f});
Assert.AreEqual(0f, t.data[0, 0]);
Assert.AreEqual(0f, t.data[0, 1]);
Expand All @@ -91,11 +91,11 @@ public void TestWritesToTensor3D()

var shape = new[] { 2, 2, 3 };

writer.SetTarget(t, shape, 0, 0);
writer.SetTarget(t, 0, 0);
writer[1, 0, 1] = 1f;
Assert.AreEqual(1f, t.data[0, 1, 0, 1]);

writer.SetTarget(t, shape, 0, 1);
writer.SetTarget(t, 0, 1);
writer[1, 0, 0] = 2f;
Assert.AreEqual(2f, t.data[0, 1, 0, 1]);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,7 @@ public void Generate(TensorProxy tensorProxy, int batchSize, IEnumerable<Agent>
foreach (var sensorIndex in m_SensorIndices)
{
var sensor = agent.sensors[sensorIndex];
var shape = sensor.GetObservationShape();
m_WriteAdapter.SetTarget(tensorProxy, shape, agentIndex, tensorOffset);
m_WriteAdapter.SetTarget(tensorProxy, agentIndex, tensorOffset);
var numWritten = sensor.Write(m_WriteAdapter);
tensorOffset += numWritten;
}
Expand Down Expand Up @@ -355,7 +354,7 @@ public void Generate(TensorProxy tensorProxy, int batchSize, IEnumerable<Agent>
foreach (var agent in agents)
{
var sensor = agent.sensors[m_SensorIndex];
m_WriteAdapter.SetTarget(tensorProxy, sensor.GetObservationShape(), agentIndex, 0);
m_WriteAdapter.SetTarget(tensorProxy, agentIndex, 0);
sensor.Write(m_WriteAdapter);
agentIndex++;
}
Expand Down
40 changes: 21 additions & 19 deletions UnitySDK/Assets/ML-Agents/Scripts/Sensor/WriteAdapter.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
using Barracuda;
using MLAgents.InferenceBrain;

namespace MLAgents.Sensor
Expand All @@ -15,7 +16,7 @@ public class WriteAdapter
TensorProxy m_Proxy;
int m_Batch;

int[] m_Shape;
TensorShape m_TensorShape;

/// <summary>
/// Set the adapter to write to an IList at the given channelOffset.
Expand All @@ -29,23 +30,30 @@ public void SetTarget(IList<float> data, int[] shape, int offset)
m_Offset = offset;
m_Proxy = null;
m_Batch = 0;
m_Shape = shape;

if (shape.Length == 1)
{
m_TensorShape = new TensorShape(m_Batch, shape[0]);
}
else
{
m_TensorShape = new TensorShape(m_Batch, shape[0], shape[1], shape[2]);
}
}

/// <summary>
/// Set the adapter to write to a TensorProxy at the given batch and channel offset.
/// </summary>
/// <param name="tensorProxy">Tensor proxy that will be writtent to.</param>
/// <param name="shape">Shape of the observations to be written.</param>
/// <param name="batchIndex">Batch index in the tensor proxy (i.e. the index of the Agent)</param>
/// <param name="channelOffset">Offset from the start of the channel to write to.</param>
public void SetTarget(TensorProxy tensorProxy, int[] shape, int batchIndex, int channelOffset)
public void SetTarget(TensorProxy tensorProxy, int batchIndex, int channelOffset)
{
m_Proxy = tensorProxy;
m_Batch = batchIndex;
m_Offset = channelOffset;
m_Data = null;
m_Shape = shape;
m_TensorShape = m_Proxy.data.shape;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't used in the tensor case, but just set it for consistency.

}

/// <summary>
Expand All @@ -56,7 +64,6 @@ public float this[int index]
{
set
{
// TODO check shape is 1D?
if (m_Data != null)
{
m_Data[index + m_Offset] = value;
Expand All @@ -80,26 +87,21 @@ public float this[int index]
{
if (m_Data != null)
{
var height = m_Shape[0];
var width = m_Shape[1];
var channels = m_Shape[2];

if (h < 0 || h >= height)
if (h < 0 || h >= m_TensorShape.height)
{
throw new IndexOutOfRangeException($"height value {h} must be in range [0, {height-1}]");
throw new IndexOutOfRangeException($"height value {h} must be in range [0, {m_TensorShape.height-1}]");
}
if (w < 0 || w >= width)
if (w < 0 || w >= m_TensorShape.width)
{
throw new IndexOutOfRangeException($"width value {w} must be in range [0, {width-1}]");
throw new IndexOutOfRangeException($"width value {w} must be in range [0, {m_TensorShape.width-1}]");
}
if (ch < 0 || ch >= channels)
if (ch < 0 || ch >= m_TensorShape.channels)
{
throw new IndexOutOfRangeException($"channel value {ch} must be in range [0, {channels-1}]");
throw new IndexOutOfRangeException($"channel value {ch} must be in range [0, {m_TensorShape.channels-1}]");
}

// Math copied from TensorShape.Index(). Note that m_Batch should always be 0
var index = m_Batch * height * width * channels + h * width * channels + w * channels + ch;
m_Data[index + m_Offset] = value;
var index = m_TensorShape.Index(m_Batch, h, w, ch + m_Offset);
m_Data[index] = value;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could compute the index for both TensorProxy and float array cases, and use m_Proxy.data[index] = value; but I don't see a strong reason for doing that. Same for the vector case.

}
else
{
Expand Down