Skip to content

Commit b9194a5

Browse files
author
Chris Elion
authored
use TensorShape for index calc (#3171)
* use tensorshape for index calc * docstring * dont need shape anymore
1 parent 90659b9 commit b9194a5

File tree

3 files changed

+29
-28
lines changed

3 files changed

+29
-28
lines changed

UnitySDK/Assets/ML-Agents/Editor/Tests/Sensor/WriterAdapterTests.cs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,13 @@ public void TestWritesToTensor()
4848
valueType = TensorProxy.TensorType.FloatingPoint,
4949
data = new Tensor(2, 3)
5050
};
51-
var shape = new[] { 3 };
52-
writer.SetTarget(t, shape, 0, 0);
51+
52+
writer.SetTarget(t, 0, 0);
5353
Assert.AreEqual(0f, t.data[0, 0]);
5454
writer[0] = 1f;
5555
Assert.AreEqual(1f, t.data[0, 0]);
5656

57-
writer.SetTarget(t, shape, 1, 1);
57+
writer.SetTarget(t, 1, 1);
5858
writer[0] = 2f;
5959
writer[1] = 3f;
6060
// [0, 0] shouldn't change
@@ -69,7 +69,7 @@ public void TestWritesToTensor()
6969
data = new Tensor(2, 3)
7070
};
7171

72-
writer.SetTarget(t, shape, 1, 1);
72+
writer.SetTarget(t, 1, 1);
7373
writer.AddRange(new [] {-1f, -2f});
7474
Assert.AreEqual(0f, t.data[0, 0]);
7575
Assert.AreEqual(0f, t.data[0, 1]);
@@ -91,11 +91,11 @@ public void TestWritesToTensor3D()
9191

9292
var shape = new[] { 2, 2, 3 };
9393

94-
writer.SetTarget(t, shape, 0, 0);
94+
writer.SetTarget(t, 0, 0);
9595
writer[1, 0, 1] = 1f;
9696
Assert.AreEqual(1f, t.data[0, 1, 0, 1]);
9797

98-
writer.SetTarget(t, shape, 0, 1);
98+
writer.SetTarget(t, 0, 1);
9999
writer[1, 0, 0] = 2f;
100100
Assert.AreEqual(2f, t.data[0, 1, 0, 1]);
101101
}

UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/GeneratorImpl.cs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,7 @@ public void Generate(TensorProxy tensorProxy, int batchSize, IEnumerable<Agent>
106106
foreach (var sensorIndex in m_SensorIndices)
107107
{
108108
var sensor = agent.sensors[sensorIndex];
109-
var shape = sensor.GetObservationShape();
110-
m_WriteAdapter.SetTarget(tensorProxy, shape, agentIndex, tensorOffset);
109+
m_WriteAdapter.SetTarget(tensorProxy, agentIndex, tensorOffset);
111110
var numWritten = sensor.Write(m_WriteAdapter);
112111
tensorOffset += numWritten;
113112
}
@@ -355,7 +354,7 @@ public void Generate(TensorProxy tensorProxy, int batchSize, IEnumerable<Agent>
355354
foreach (var agent in agents)
356355
{
357356
var sensor = agent.sensors[m_SensorIndex];
358-
m_WriteAdapter.SetTarget(tensorProxy, sensor.GetObservationShape(), agentIndex, 0);
357+
m_WriteAdapter.SetTarget(tensorProxy, agentIndex, 0);
359358
sensor.Write(m_WriteAdapter);
360359
agentIndex++;
361360
}

UnitySDK/Assets/ML-Agents/Scripts/Sensor/WriteAdapter.cs

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System;
22
using System.Collections.Generic;
3+
using Barracuda;
34
using MLAgents.InferenceBrain;
45

56
namespace MLAgents.Sensor
@@ -15,7 +16,7 @@ public class WriteAdapter
1516
TensorProxy m_Proxy;
1617
int m_Batch;
1718

18-
int[] m_Shape;
19+
TensorShape m_TensorShape;
1920

2021
/// <summary>
2122
/// Set the adapter to write to an IList at the given channelOffset.
@@ -29,23 +30,30 @@ public void SetTarget(IList<float> data, int[] shape, int offset)
2930
m_Offset = offset;
3031
m_Proxy = null;
3132
m_Batch = 0;
32-
m_Shape = shape;
33+
34+
if (shape.Length == 1)
35+
{
36+
m_TensorShape = new TensorShape(m_Batch, shape[0]);
37+
}
38+
else
39+
{
40+
m_TensorShape = new TensorShape(m_Batch, shape[0], shape[1], shape[2]);
41+
}
3342
}
3443

3544
/// <summary>
3645
/// Set the adapter to write to a TensorProxy at the given batch and channel offset.
3746
/// </summary>
3847
/// <param name="tensorProxy">Tensor proxy that will be writtent to.</param>
39-
/// <param name="shape">Shape of the observations to be written.</param>
4048
/// <param name="batchIndex">Batch index in the tensor proxy (i.e. the index of the Agent)</param>
4149
/// <param name="channelOffset">Offset from the start of the channel to write to.</param>
42-
public void SetTarget(TensorProxy tensorProxy, int[] shape, int batchIndex, int channelOffset)
50+
public void SetTarget(TensorProxy tensorProxy, int batchIndex, int channelOffset)
4351
{
4452
m_Proxy = tensorProxy;
4553
m_Batch = batchIndex;
4654
m_Offset = channelOffset;
4755
m_Data = null;
48-
m_Shape = shape;
56+
m_TensorShape = m_Proxy.data.shape;
4957
}
5058

5159
/// <summary>
@@ -56,7 +64,6 @@ public float this[int index]
5664
{
5765
set
5866
{
59-
// TODO check shape is 1D?
6067
if (m_Data != null)
6168
{
6269
m_Data[index + m_Offset] = value;
@@ -80,26 +87,21 @@ public float this[int index]
8087
{
8188
if (m_Data != null)
8289
{
83-
var height = m_Shape[0];
84-
var width = m_Shape[1];
85-
var channels = m_Shape[2];
86-
87-
if (h < 0 || h >= height)
90+
if (h < 0 || h >= m_TensorShape.height)
8891
{
89-
throw new IndexOutOfRangeException($"height value {h} must be in range [0, {height-1}]");
92+
throw new IndexOutOfRangeException($"height value {h} must be in range [0, {m_TensorShape.height-1}]");
9093
}
91-
if (w < 0 || w >= width)
94+
if (w < 0 || w >= m_TensorShape.width)
9295
{
93-
throw new IndexOutOfRangeException($"width value {w} must be in range [0, {width-1}]");
96+
throw new IndexOutOfRangeException($"width value {w} must be in range [0, {m_TensorShape.width-1}]");
9497
}
95-
if (ch < 0 || ch >= channels)
98+
if (ch < 0 || ch >= m_TensorShape.channels)
9699
{
97-
throw new IndexOutOfRangeException($"channel value {ch} must be in range [0, {channels-1}]");
100+
throw new IndexOutOfRangeException($"channel value {ch} must be in range [0, {m_TensorShape.channels-1}]");
98101
}
99102

100-
// Math copied from TensorShape.Index(). Note that m_Batch should always be 0
101-
var index = m_Batch * height * width * channels + h * width * channels + w * channels + ch;
102-
m_Data[index + m_Offset] = value;
103+
var index = m_TensorShape.Index(m_Batch, h, w, ch + m_Offset);
104+
m_Data[index] = value;
103105
}
104106
else
105107
{

0 commit comments

Comments
 (0)