9
9
from typing import Any , Dict , Optional
10
10
11
11
import torch
12
+ from pytorch3d .implicitron .models .renderer .ray_sampler import ImplicitronRayBundle
12
13
from pytorch3d .implicitron .tools import metric_utils as utils
13
14
from pytorch3d .implicitron .tools .config import registry , ReplaceableBase
15
+ from pytorch3d .ops import packed_to_padded , padded_to_packed
14
16
from pytorch3d .renderer import utils as rend_utils
15
17
16
18
from .renderer .base import RendererOutput
@@ -60,7 +62,7 @@ def __post_init__(self) -> None:
60
62
def forward (
61
63
self ,
62
64
raymarched : RendererOutput ,
63
- xys : torch . Tensor ,
65
+ ray_bundle : ImplicitronRayBundle ,
64
66
image_rgb : Optional [torch .Tensor ] = None ,
65
67
depth_map : Optional [torch .Tensor ] = None ,
66
68
fg_probability : Optional [torch .Tensor ] = None ,
@@ -79,10 +81,8 @@ def forward(
79
81
names of the output metrics `metric_name_i` with their corresponding
80
82
values `metric_value_i` represented as 0-dimensional float tensors.
81
83
raymarched: Output of the renderer.
82
- xys: A tensor of shape `(B, ..., 2)` containing 2D image locations at which
83
- the predictions are defined. All ground truth inputs are sampled at
84
- these locations in order to extract values that correspond to the
85
- predictions.
84
+ ray_bundle: ImplicitronRayBundle object which was used to produce the raymarched
85
+ object
86
86
image_rgb: A tensor of shape `(B, H, W, 3)` containing ground truth rgb
87
87
values.
88
88
depth_map: A tensor of shape `(B, Hd, Wd, 1)` containing ground truth depth
@@ -141,7 +141,7 @@ class ViewMetrics(ViewMetricsBase):
141
141
def forward (
142
142
self ,
143
143
raymarched : RendererOutput ,
144
- xys : torch . Tensor ,
144
+ ray_bundle : ImplicitronRayBundle ,
145
145
image_rgb : Optional [torch .Tensor ] = None ,
146
146
depth_map : Optional [torch .Tensor ] = None ,
147
147
fg_probability : Optional [torch .Tensor ] = None ,
@@ -165,10 +165,8 @@ def forward(
165
165
input 3D coordinates used to compute the eikonal loss.
166
166
raymarched.aux["density_grid"]: A tensor of shape `(B, Hg, Wg, Dg, 1)`
167
167
containing a `Hg x Wg x Dg` voxel grid of density values.
168
- xys: A tensor of shape `(B, ..., 2)` containing 2D image locations at which
169
- the predictions are defined. All ground truth inputs are sampled at
170
- these locations in order to extract values that correspond to the
171
- predictions.
168
+ ray_bundle: ImplicitronRayBundle object which was used to produce the raymarched
169
+ object
172
170
image_rgb: A tensor of shape `(B, H, W, 3)` containing ground truth rgb
173
171
values.
174
172
depth_map: A tensor of shape `(B, Hd, Wd, 1)` containing ground truth depth
@@ -209,7 +207,7 @@ def forward(
209
207
"""
210
208
metrics = self ._calculate_stage (
211
209
raymarched ,
212
- xys ,
210
+ ray_bundle ,
213
211
image_rgb ,
214
212
depth_map ,
215
213
fg_probability ,
@@ -221,7 +219,7 @@ def forward(
221
219
metrics .update (
222
220
self (
223
221
raymarched .prev_stage ,
224
- xys ,
222
+ ray_bundle ,
225
223
image_rgb ,
226
224
depth_map ,
227
225
fg_probability ,
@@ -235,7 +233,7 @@ def forward(
235
233
def _calculate_stage (
236
234
self ,
237
235
raymarched : RendererOutput ,
238
- xys : torch . Tensor ,
236
+ ray_bundle : ImplicitronRayBundle ,
239
237
image_rgb : Optional [torch .Tensor ] = None ,
240
238
depth_map : Optional [torch .Tensor ] = None ,
241
239
fg_probability : Optional [torch .Tensor ] = None ,
@@ -253,6 +251,27 @@ def _calculate_stage(
253
251
_reshape_nongrid_var (x )
254
252
for x in [raymarched .features , raymarched .masks , raymarched .depths ]
255
253
]
254
+ xys = ray_bundle .xys
255
+
256
+ # If ray_bundle is packed than we can sample images in padded state to lower
257
+ # memory requirements. Instead of having one image for every element in
258
+ # ray_bundle we can than have one image per unique sampled camera.
259
+ if ray_bundle .is_packed ():
260
+ # pyre-ignore[6]
261
+ cumsum = torch .cumsum (ray_bundle .camera_counts , dim = 0 , dtype = torch .long )
262
+ first_idxs = torch .cat (
263
+ (
264
+ # pyre-ignore[16]
265
+ ray_bundle .camera_counts .new_zeros ((1 ,), dtype = torch .long ),
266
+ cumsum [:- 1 ],
267
+ )
268
+ )
269
+ # pyre-ignore[16]
270
+ num_inputs = int (ray_bundle .camera_counts .sum ())
271
+ # pyre-ignore[6]
272
+ max_size = int (torch .max (ray_bundle .camera_counts ))
273
+ xys = packed_to_padded (xys , first_idxs , max_size )
274
+
256
275
# reshape the sampling grid as well
257
276
# TODO: we can get rid of the singular dimension here and in _reshape_nongrid_var
258
277
# now that we use rend_utils.ndc_grid_sample
@@ -262,7 +281,20 @@ def _calculate_stage(
262
281
def sample (tensor , mode ):
263
282
if tensor is None :
264
283
return tensor
265
- return rend_utils .ndc_grid_sample (tensor , xys , mode = mode )
284
+ if ray_bundle .is_packed ():
285
+ # select images that corespond to sampled cameras if raybundle is packed
286
+ tensor = tensor [ray_bundle .camera_ids ]
287
+ result = rend_utils .ndc_grid_sample (tensor , xys , mode = mode )
288
+ if ray_bundle .is_packed ():
289
+ # Images after sampling are in a form [batch, 3, max_num_rays, 1],
290
+ # packed_to_padded combines first two dimensions so we need to swap 1st
291
+ # and 2nd dimension. the result is [n_rays_total_training, 1, 3, 1]
292
+ # (we use keepdim=True).
293
+ result = result .transpose (1 , 2 )
294
+ result = padded_to_packed (result , first_idxs , num_inputs )[:, None ]
295
+ result = result .transpose (1 , 2 )
296
+
297
+ return result
266
298
267
299
# eval all results in this size
268
300
image_rgb = sample (image_rgb , mode = "bilinear" )
0 commit comments