11using System ;
22using System . Collections . Generic ;
3+ using Barracuda ;
34using MLAgents . InferenceBrain ;
45
56namespace MLAgents . Sensor
@@ -15,7 +16,7 @@ public class WriteAdapter
1516 TensorProxy m_Proxy ;
1617 int m_Batch ;
1718
18- int [ ] m_Shape ;
19+ TensorShape m_TensorShape ;
1920
2021 /// <summary>
2122 /// Set the adapter to write to an IList at the given channelOffset.
@@ -29,23 +30,30 @@ public void SetTarget(IList<float> data, int[] shape, int offset)
2930 m_Offset = offset ;
3031 m_Proxy = null ;
3132 m_Batch = 0 ;
32- m_Shape = shape ;
33+
34+ if ( shape . Length == 1 )
35+ {
36+ m_TensorShape = new TensorShape ( m_Batch , shape [ 0 ] ) ;
37+ }
38+ else
39+ {
40+ m_TensorShape = new TensorShape ( m_Batch , shape [ 0 ] , shape [ 1 ] , shape [ 2 ] ) ;
41+ }
3342 }
3443
3544 /// <summary>
3645 /// Set the adapter to write to a TensorProxy at the given batch and channel offset.
3746 /// </summary>
3847 /// <param name="tensorProxy">Tensor proxy that will be writtent to.</param>
39- /// <param name="shape">Shape of the observations to be written.</param>
4048 /// <param name="batchIndex">Batch index in the tensor proxy (i.e. the index of the Agent)</param>
4149 /// <param name="channelOffset">Offset from the start of the channel to write to.</param>
42- public void SetTarget ( TensorProxy tensorProxy , int [ ] shape , int batchIndex , int channelOffset )
50+ public void SetTarget ( TensorProxy tensorProxy , int batchIndex , int channelOffset )
4351 {
4452 m_Proxy = tensorProxy ;
4553 m_Batch = batchIndex ;
4654 m_Offset = channelOffset ;
4755 m_Data = null ;
48- m_Shape = shape ;
56+ m_TensorShape = m_Proxy . data . shape ;
4957 }
5058
5159 /// <summary>
@@ -56,7 +64,6 @@ public float this[int index]
5664 {
5765 set
5866 {
59- // TODO check shape is 1D?
6067 if ( m_Data != null )
6168 {
6269 m_Data [ index + m_Offset ] = value ;
@@ -80,26 +87,21 @@ public float this[int index]
8087 {
8188 if ( m_Data != null )
8289 {
83- var height = m_Shape [ 0 ] ;
84- var width = m_Shape [ 1 ] ;
85- var channels = m_Shape [ 2 ] ;
86-
87- if ( h < 0 || h >= height )
90+ if ( h < 0 || h >= m_TensorShape . height )
8891 {
89- throw new IndexOutOfRangeException ( $ "height value { h } must be in range [0, { height - 1 } ]") ;
92+ throw new IndexOutOfRangeException ( $ "height value { h } must be in range [0, { m_TensorShape . height - 1 } ]") ;
9093 }
91- if ( w < 0 || w >= width )
94+ if ( w < 0 || w >= m_TensorShape . width )
9295 {
93- throw new IndexOutOfRangeException ( $ "width value { w } must be in range [0, { width - 1 } ]") ;
96+ throw new IndexOutOfRangeException ( $ "width value { w } must be in range [0, { m_TensorShape . width - 1 } ]") ;
9497 }
95- if ( ch < 0 || ch >= channels )
98+ if ( ch < 0 || ch >= m_TensorShape . channels )
9699 {
97- throw new IndexOutOfRangeException ( $ "channel value { ch } must be in range [0, { channels - 1 } ]") ;
100+ throw new IndexOutOfRangeException ( $ "channel value { ch } must be in range [0, { m_TensorShape . channels - 1 } ]") ;
98101 }
99102
100- // Math copied from TensorShape.Index(). Note that m_Batch should always be 0
101- var index = m_Batch * height * width * channels + h * width * channels + w * channels + ch ;
102- m_Data [ index + m_Offset ] = value ;
103+ var index = m_TensorShape . Index ( m_Batch , h , w , ch + m_Offset ) ;
104+ m_Data [ index ] = value ;
103105 }
104106 else
105107 {
0 commit comments