Skip to content

Commit 1934046

Browse files
nikhilaravifacebook-github-bot
authored andcommitted
Return self in the to method for the renderer classes
Summary: Add `return self` to the `to` function for the renderer classes. Reviewed By: bottler Differential Revision: D25534487 fbshipit-source-id: e8dbd35524f0bd40e835439e93184b5a1f1532ca
1 parent 831e64e commit 1934046

File tree

6 files changed

+77
-4
lines changed

6 files changed

+77
-4
lines changed

pytorch3d/renderer/mesh/rasterizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def __init__(self, cameras=None, raster_settings=None):
7979
def to(self, device):
8080
# Manually move to device cameras as it is not a subclass of nn.Module
8181
self.cameras = self.cameras.to(device)
82+
return self
8283

8384
def transform(self, meshes_world, **kwargs) -> torch.Tensor:
8485
"""

pytorch3d/renderer/mesh/renderer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def to(self, device):
3737
# Rasterizer and shader have submodules which are not of type nn.Module
3838
self.rasterizer.to(device)
3939
self.shader.to(device)
40+
return self
4041

4142
def forward(self, meshes_world, **kwargs) -> torch.Tensor:
4243
"""

pytorch3d/renderer/mesh/shader.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def to(self, device):
5555
self.cameras = self.cameras.to(device)
5656
self.materials = self.materials.to(device)
5757
self.lights = self.lights.to(device)
58+
return self
5859

5960
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
6061
cameras = kwargs.get("cameras", self.cameras)
@@ -109,6 +110,7 @@ def to(self, device):
109110
self.cameras = self.cameras.to(device)
110111
self.materials = self.materials.to(device)
111112
self.lights = self.lights.to(device)
113+
return self
112114

113115
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
114116
cameras = kwargs.get("cameras", self.cameras)
@@ -168,6 +170,7 @@ def to(self, device):
168170
self.cameras = self.cameras.to(device)
169171
self.materials = self.materials.to(device)
170172
self.lights = self.lights.to(device)
173+
return self
171174

172175
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
173176
cameras = kwargs.get("cameras", self.cameras)
@@ -226,6 +229,7 @@ def to(self, device):
226229
self.cameras = self.cameras.to(device)
227230
self.materials = self.materials.to(device)
228231
self.lights = self.lights.to(device)
232+
return self
229233

230234
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
231235
cameras = kwargs.get("cameras", self.cameras)
@@ -301,6 +305,7 @@ def to(self, device):
301305
self.cameras = self.cameras.to(device)
302306
self.materials = self.materials.to(device)
303307
self.lights = self.lights.to(device)
308+
return self
304309

305310
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
306311
cameras = kwargs.get("cameras", self.cameras)

pytorch3d/renderer/points/rasterizer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,11 @@ def transform(self, point_clouds, **kwargs) -> torch.Tensor:
9999
point_clouds = point_clouds.update_padded(pts_screen)
100100
return point_clouds
101101

102+
def to(self, device):
103+
# Manually move to device cameras as it is not a subclass of nn.Module
104+
self.cameras = self.cameras.to(device)
105+
return self
106+
102107
def forward(self, point_clouds, **kwargs) -> PointFragments:
103108
"""
104109
Args:

pytorch3d/renderer/points/renderer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,13 @@ def __init__(self, rasterizer, compositor):
3232
self.rasterizer = rasterizer
3333
self.compositor = compositor
3434

35+
def to(self, device):
36+
# Manually move to device rasterizer as the cameras
37+
# within the class are not of type nn.Module
38+
self.rasterizer = self.rasterizer.to(device)
39+
self.compositor = self.compositor.to(device)
40+
return self
41+
3542
def forward(self, point_clouds, **kwargs) -> torch.Tensor:
3643
fragments = self.rasterizer(point_clouds, **kwargs)
3744

tests/test_render_multigpu.py

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,22 @@
66
import torch.nn as nn
77
from common_testing import TestCaseMixin, get_random_cuda_device
88
from pytorch3d.renderer import (
9+
AlphaCompositor,
910
BlendParams,
1011
HardGouraudShader,
1112
Materials,
1213
MeshRasterizer,
1314
MeshRenderer,
1415
PointLights,
16+
PointsRasterizationSettings,
17+
PointsRasterizer,
18+
PointsRenderer,
1519
RasterizationSettings,
1620
SoftPhongShader,
1721
TexturesVertex,
1822
)
1923
from pytorch3d.renderer.cameras import FoVPerspectiveCameras, look_at_view_transform
20-
from pytorch3d.structures.meshes import Meshes
24+
from pytorch3d.structures import Meshes, Pointclouds
2125
from pytorch3d.utils.ico_sphere import ico_sphere
2226

2327

@@ -27,7 +31,7 @@
2731
print("GPUs: %s" % ", ".join(GPU_LIST))
2832

2933

30-
class TestRenderMultiGPU(TestCaseMixin, unittest.TestCase):
34+
class TestRenderMeshesMultiGPU(TestCaseMixin, unittest.TestCase):
3135
def _check_mesh_renderer_props_on_device(self, renderer, device):
3236
"""
3337
Helper function to check that all the properties of the mesh
@@ -99,7 +103,7 @@ def test_mesh_renderer_to(self):
99103
# This also tests that background_color is correctly moved to
100104
# the new device
101105
device2 = torch.device("cuda:0")
102-
renderer.to(device2)
106+
renderer = renderer.to(device2)
103107
mesh = mesh.to(device2)
104108
self._check_mesh_renderer_props_on_device(renderer, device2)
105109
output_images = renderer(mesh)
@@ -137,7 +141,7 @@ def init_render(self):
137141

138142
def forward(self, verts, texs):
139143
batch_size = verts.size(0)
140-
self.renderer.to(verts.device)
144+
self.renderer = self.renderer.to(verts.device)
141145
tex = TexturesVertex(verts_features=texs)
142146
faces = self.faces.expand(batch_size, -1, -1).to(verts.device)
143147
mesh = Meshes(verts, faces, tex).to(verts.device)
@@ -157,3 +161,53 @@ def forward(self, verts, texs):
157161
# Test a few iterations
158162
for _ in range(100):
159163
model(verts, texs)
164+
165+
166+
class TestRenderPointssMultiGPU(TestCaseMixin, unittest.TestCase):
167+
def _check_points_renderer_props_on_device(self, renderer, device):
168+
"""
169+
Helper function to check that all the properties have
170+
been moved to the correct device.
171+
"""
172+
# Cameras
173+
self.assertEqual(renderer.rasterizer.cameras.device, device)
174+
self.assertEqual(renderer.rasterizer.cameras.R.device, device)
175+
self.assertEqual(renderer.rasterizer.cameras.T.device, device)
176+
177+
def test_points_renderer_to(self):
178+
"""
179+
Test moving all the tensors in the points renderer to a new device.
180+
"""
181+
182+
device1 = torch.device("cpu")
183+
184+
R, T = look_at_view_transform(1500, 0.0, 0.0)
185+
186+
raster_settings = PointsRasterizationSettings(
187+
image_size=256, radius=0.001, points_per_pixel=1
188+
)
189+
cameras = FoVPerspectiveCameras(
190+
device=device1, R=R, T=T, aspect_ratio=1.0, fov=60.0, zfar=100
191+
)
192+
rasterizer = PointsRasterizer(cameras=cameras, raster_settings=raster_settings)
193+
194+
renderer = PointsRenderer(rasterizer=rasterizer, compositor=AlphaCompositor())
195+
196+
mesh = ico_sphere(2, device1)
197+
verts_padded = mesh.verts_padded()
198+
pointclouds = Pointclouds(
199+
points=verts_padded, features=torch.randn_like(verts_padded)
200+
)
201+
self._check_points_renderer_props_on_device(renderer, device1)
202+
203+
# Test rendering on cpu
204+
output_images = renderer(pointclouds)
205+
self.assertEqual(output_images.device, device1)
206+
207+
# Move renderer and pointclouds to another device and re render
208+
device2 = torch.device("cuda:0")
209+
renderer = renderer.to(device2)
210+
pointclouds = pointclouds.to(device2)
211+
self._check_points_renderer_props_on_device(renderer, device2)
212+
output_images = renderer(pointclouds)
213+
self.assertEqual(output_images.device, device2)

0 commit comments

Comments
 (0)