6
6
import torch
7
7
from omegaconf import DictConfig
8
8
from pytorch3d .common .linear_with_repeat import LinearWithRepeat
9
+ from pytorch3d .implicitron .models .renderer .base import ImplicitronRayBundle
9
10
from pytorch3d .implicitron .third_party import hyperlayers , pytorch_prototyping
10
11
from pytorch3d .implicitron .tools .config import Configurable , registry , run_auto_creation
11
- from pytorch3d .renderer import ray_bundle_to_ray_points , RayBundle
12
+ from pytorch3d .renderer import ray_bundle_to_ray_points
12
13
from pytorch3d .renderer .cameras import CamerasBase
13
14
from pytorch3d .renderer .implicit import HarmonicEmbedding
14
15
@@ -68,15 +69,15 @@ def __post_init__(self):
68
69
69
70
def forward (
70
71
self ,
71
- ray_bundle : RayBundle ,
72
+ ray_bundle : ImplicitronRayBundle ,
72
73
fun_viewpool = None ,
73
74
camera : Optional [CamerasBase ] = None ,
74
75
global_code = None ,
75
76
** kwargs ,
76
77
):
77
78
"""
78
79
Args:
79
- ray_bundle: A RayBundle object containing the following variables:
80
+ ray_bundle: An ImplicitronRayBundle object containing the following variables:
80
81
origins: A tensor of shape `(minibatch, ..., 3)` denoting the
81
82
origins of the sampling rays in world coords.
82
83
directions: A tensor of shape `(minibatch, ..., 3)`
@@ -96,10 +97,11 @@ def forward(
96
97
"""
97
98
# We first convert the ray parametrizations to world
98
99
# coordinates with `ray_bundle_to_ray_points`.
100
+ # pyre-ignore[6]
99
101
rays_points_world = ray_bundle_to_ray_points (ray_bundle )
100
102
101
103
embeds = create_embeddings_for_implicit_function (
102
- xyz_world = ray_bundle_to_ray_points ( ray_bundle ) ,
104
+ xyz_world = rays_points_world ,
103
105
# pyre-fixme[6]: Expected `Optional[typing.Callable[..., typing.Any]]`
104
106
# for 2nd param but got `Union[torch.Tensor, torch.nn.Module]`.
105
107
xyz_embedding_function = self ._harmonic_embedding ,
@@ -175,15 +177,15 @@ def _get_colors(self, features: torch.Tensor, rays_directions: torch.Tensor):
175
177
def forward (
176
178
self ,
177
179
raymarch_features : torch .Tensor ,
178
- ray_bundle : RayBundle ,
180
+ ray_bundle : ImplicitronRayBundle ,
179
181
camera : Optional [CamerasBase ] = None ,
180
182
** kwargs ,
181
183
):
182
184
"""
183
185
Args:
184
186
raymarch_features: Features from the raymarching network of shape
185
187
`(minibatch, ..., self.in_features)`
186
- ray_bundle: A RayBundle object containing the following variables:
188
+ ray_bundle: An ImplicitronRayBundle object containing the following variables:
187
189
origins: A tensor of shape `(minibatch, ..., 3)` denoting the
188
190
origins of the sampling rays in world coords.
189
191
directions: A tensor of shape `(minibatch, ..., 3)`
@@ -297,7 +299,7 @@ def _run_hypernet(self, global_code: torch.Tensor) -> Tuple[SRNRaymarchFunction]
297
299
298
300
def forward (
299
301
self ,
300
- ray_bundle : RayBundle ,
302
+ ray_bundle : ImplicitronRayBundle ,
301
303
fun_viewpool = None ,
302
304
camera : Optional [CamerasBase ] = None ,
303
305
global_code = None ,
@@ -350,7 +352,7 @@ def raymarch_function_tweak_args(cls, type, args: DictConfig) -> None:
350
352
def forward (
351
353
self ,
352
354
* ,
353
- ray_bundle : RayBundle ,
355
+ ray_bundle : ImplicitronRayBundle ,
354
356
fun_viewpool = None ,
355
357
camera : Optional [CamerasBase ] = None ,
356
358
global_code = None ,
@@ -410,7 +412,7 @@ def hypernet_tweak_args(cls, type, args: DictConfig) -> None:
410
412
def forward (
411
413
self ,
412
414
* ,
413
- ray_bundle : RayBundle ,
415
+ ray_bundle : ImplicitronRayBundle ,
414
416
fun_viewpool = None ,
415
417
camera : Optional [CamerasBase ] = None ,
416
418
global_code = None ,
0 commit comments