File tree Expand file tree Collapse file tree 1 file changed +8
-8
lines changed Expand file tree Collapse file tree 1 file changed +8
-8
lines changed Original file line number Diff line number Diff line change @@ -170,23 +170,23 @@ def test_batch_3d_2d_input_concat(sample_ds_3d, bsize):
170170
171171
172172@pytest .mark .parametrize ('bsize' , [5 , 10 ])
173- def test_batch_3d_squeeze_batch_dim ( sample_ds_3d , bsize ):
173+ def test_batch_1d_squeeze_batch_dim ( sample_ds_1d , bsize ):
174174 xbsize = 20
175175 bg = BatchGenerator (
176- sample_ds_3d ,
177- input_dims = {'time' : 1 , 'y' : bsize , ' x' : xbsize },
176+ sample_ds_1d ,
177+ input_dims = {'x' : xbsize },
178178 squeeze_batch_dim = False ,
179179 )
180180 for ds_batch in bg :
181- assert ds_batch ['x ' ].shape == [1 , bsize , xbsize ]
181+ assert list ( ds_batch ['foo ' ].shape ) == [1 , xbsize ]
182182
183183 bg2 = BatchGenerator (
184- sample_ds_3d ,
185- input_dims = {'time' : 1 , 'y' : bsize , ' x' : xbsize },
184+ sample_ds_1d ,
185+ input_dims = {'x' : xbsize },
186186 squeeze_batch_dim = True ,
187187 )
188- for ds_batch in bg :
189- assert ds_batch ['x ' ].shape == [bsize , xbsize ]
188+ for ds_batch in bg2 :
189+ assert list ( ds_batch ['foo ' ].shape ) == [xbsize ]
190190
191191
192192def test_preload_batch_false (sample_ds_1d ):
You can’t perform that action at this time.
0 commit comments