@@ -65,17 +65,18 @@ def generate_compressed_data(in_array: np.ndarray) -> bytes:
6565 image_arr = (in_array * 255 ).astype (np .uint8 )
6666 bytes_out = bytes ()
6767
68- num_channels = in_array .shape [2 ]
68+ num_channels = in_array .shape [0 ]
6969 num_images = (num_channels + 2 ) // 3
7070 # Split the input image into batches of 3 channels.
7171 for i in range (num_images ):
72- sub_image = image_arr [..., 3 * i : 3 * i + 3 ]
72+ sub_image = image_arr [3 * i : 3 * i + 3 , ... ]
7373 if (i == num_images - 1 ) and (num_channels % 3 ) != 0 :
7474 # Pad zeros
7575 zero_shape = list (in_array .shape )
76- zero_shape [2 ] = 3 - (num_channels % 3 )
76+ zero_shape [0 ] = 3 - (num_channels % 3 )
7777 z = np .zeros (zero_shape , dtype = np .uint8 )
78- sub_image = np .concatenate ([sub_image , z ], axis = 2 )
78+ sub_image = np .concatenate ([sub_image , z ], axis = 0 )
79+ sub_image = np .moveaxis (sub_image , 0 , - 1 )
7980 im = Image .fromarray (sub_image , "RGB" )
8081 byteIO = io .BytesIO ()
8182 im .save (byteIO , format = "PNG" )
@@ -92,7 +93,7 @@ def generate_compressed_proto_obs(
9293 obs_proto .compression_type = PNG
9394 if grayscale :
9495 # grayscale flag is only used for old API without mapping
95- expected_shape = [in_array .shape [0 ], in_array .shape [1 ], 1 ]
96+ expected_shape = [1 , in_array .shape [1 ], in_array .shape [2 ] ]
9697 obs_proto .shape .extend (expected_shape )
9798 else :
9899 obs_proto .shape .extend (in_array .shape )
@@ -109,9 +110,9 @@ def generate_compressed_proto_obs_with_mapping(
109110 if mapping is not None :
110111 obs_proto .compressed_channel_mapping .extend (mapping )
111112 expected_shape = [
112- in_array .shape [0 ],
113- in_array .shape [1 ],
114113 len ({m for m in mapping if m >= 0 }),
114+ in_array .shape [1 ],
115+ in_array .shape [2 ],
115116 ]
116117 obs_proto .shape .extend (expected_shape )
117118 else :
@@ -233,10 +234,10 @@ def proto_from_steps_and_action(
233234
234235
235236def test_process_pixels ():
236- in_array = np .random .rand (128 , 64 , 3 )
237+ in_array = np .random .rand (3 , 128 , 64 )
237238 byte_arr = generate_compressed_data (in_array )
238239 out_array = process_pixels (byte_arr , 3 )
239- assert out_array .shape == (128 , 64 , 3 )
240+ assert out_array .shape == (3 , 128 , 64 )
240241 assert np .sum (in_array - out_array ) / np .prod (in_array .shape ) < 0.01
241242 assert np .allclose (in_array , out_array , atol = 0.01 )
242243
@@ -245,21 +246,21 @@ def test_process_pixels_multi_png():
245246 height = 128
246247 width = 64
247248 num_channels = 7
248- in_array = np .random .rand (height , width , num_channels )
249+ in_array = np .random .rand (num_channels , height , width )
249250 byte_arr = generate_compressed_data (in_array )
250251 out_array = process_pixels (byte_arr , num_channels )
251- assert out_array .shape == (height , width , num_channels )
252+ assert out_array .shape == (num_channels , height , width )
252253 assert np .sum (in_array - out_array ) / np .prod (in_array .shape ) < 0.01
253254 assert np .allclose (in_array , out_array , atol = 0.01 )
254255
255256
256257def test_process_pixels_gray ():
257- in_array = np .random .rand (128 , 64 , 3 )
258+ in_array = np .random .rand (3 , 128 , 64 )
258259 byte_arr = generate_compressed_data (in_array )
259260 out_array = process_pixels (byte_arr , 1 )
260- assert out_array .shape == (128 , 64 , 1 )
261- assert np .mean (in_array .mean (axis = 2 , keepdims = True ) - out_array ) < 0.01
262- assert np .allclose (in_array .mean (axis = 2 , keepdims = True ), out_array , atol = 0.01 )
261+ assert out_array .shape == (1 , 128 , 64 )
262+ assert np .mean (in_array .mean (axis = 0 , keepdims = True ) - out_array ) < 0.01
263+ assert np .allclose (in_array .mean (axis = 0 , keepdims = True ), out_array , atol = 0.01 )
263264
264265
265266def test_vector_observation ():
@@ -276,7 +277,7 @@ def test_vector_observation():
276277
277278
278279def test_process_visual_observation ():
279- shape = (128 , 64 , 3 )
280+ shape = (3 , 128 , 64 )
280281 in_array_1 = np .random .rand (* shape )
281282 proto_obs_1 = generate_compressed_proto_obs (in_array_1 )
282283 in_array_2 = np .random .rand (* shape )
@@ -292,51 +293,51 @@ def test_process_visual_observation():
292293 ap_list = [ap1 , ap2 ]
293294 obs_spec = create_observation_specs_with_shapes ([shape ])[0 ]
294295 arr = _process_maybe_compressed_observation (0 , obs_spec , ap_list )
295- assert list (arr .shape ) == [2 , 128 , 64 , 3 ]
296+ assert list (arr .shape ) == [2 , 3 , 128 , 64 ]
296297 assert np .allclose (arr [0 , :, :, :], in_array_1 , atol = 0.01 )
297298 assert np .allclose (arr [1 , :, :, :], in_array_2 , atol = 0.01 )
298299
299300
300301def test_process_visual_observation_grayscale ():
301- in_array_1 = np .random .rand (128 , 64 , 3 )
302+ in_array_1 = np .random .rand (3 , 128 , 64 )
302303 proto_obs_1 = generate_compressed_proto_obs (in_array_1 , grayscale = True )
303- expected_out_array_1 = np .mean (in_array_1 , axis = 2 , keepdims = True )
304- in_array_2 = np .random .rand (128 , 64 , 3 )
304+ expected_out_array_1 = np .mean (in_array_1 , axis = 0 , keepdims = True )
305+ in_array_2 = np .random .rand (3 , 128 , 64 )
305306 in_array_2_mapping = [0 , 0 , 0 ]
306307 proto_obs_2 = generate_compressed_proto_obs_with_mapping (
307308 in_array_2 , in_array_2_mapping
308309 )
309- expected_out_array_2 = np .mean (in_array_2 , axis = 2 , keepdims = True )
310+ expected_out_array_2 = np .mean (in_array_2 , axis = 0 , keepdims = True )
310311
311312 ap1 = AgentInfoProto ()
312313 ap1 .observations .extend ([proto_obs_1 ])
313314 ap2 = AgentInfoProto ()
314315 ap2 .observations .extend ([proto_obs_2 ])
315316 ap_list = [ap1 , ap2 ]
316- shape = (128 , 64 , 1 )
317+ shape = (1 , 128 , 64 )
317318 obs_spec = create_observation_specs_with_shapes ([shape ])[0 ]
318319 arr = _process_maybe_compressed_observation (0 , obs_spec , ap_list )
319- assert list (arr .shape ) == [2 , 128 , 64 , 1 ]
320+ assert list (arr .shape ) == [2 , 1 , 128 , 64 ]
320321 assert np .allclose (arr [0 , :, :, :], expected_out_array_1 , atol = 0.01 )
321322 assert np .allclose (arr [1 , :, :, :], expected_out_array_2 , atol = 0.01 )
322323
323324
324325def test_process_visual_observation_padded_channels ():
325- in_array_1 = np .random .rand (128 , 64 , 12 )
326+ in_array_1 = np .random .rand (12 , 128 , 64 )
326327 in_array_1_mapping = [0 , 1 , 2 , 3 , - 1 , - 1 , 4 , 5 , 6 , 7 , - 1 , - 1 ]
327328 proto_obs_1 = generate_compressed_proto_obs_with_mapping (
328329 in_array_1 , in_array_1_mapping
329330 )
330- expected_out_array_1 = np .take (in_array_1 , [0 , 1 , 2 , 3 , 6 , 7 , 8 , 9 ], axis = 2 )
331+ expected_out_array_1 = np .take (in_array_1 , [0 , 1 , 2 , 3 , 6 , 7 , 8 , 9 ], axis = 0 )
331332
332333 ap1 = AgentInfoProto ()
333334 ap1 .observations .extend ([proto_obs_1 ])
334335 ap_list = [ap1 ]
335- shape = (128 , 64 , 8 )
336+ shape = (8 , 128 , 64 )
336337 obs_spec = create_observation_specs_with_shapes ([shape ])[0 ]
337338
338339 arr = _process_maybe_compressed_observation (0 , obs_spec , ap_list )
339- assert list (arr .shape ) == [1 , 128 , 64 , 8 ]
340+ assert list (arr .shape ) == [1 , 8 , 128 , 64 ]
340341 assert np .allclose (arr [0 , :, :, :], expected_out_array_1 , atol = 0.01 )
341342
342343
0 commit comments