6
6
import torch .nn as nn
7
7
from common_testing import TestCaseMixin , get_random_cuda_device
8
8
from pytorch3d .renderer import (
9
+ AlphaCompositor ,
9
10
BlendParams ,
10
11
HardGouraudShader ,
11
12
Materials ,
12
13
MeshRasterizer ,
13
14
MeshRenderer ,
14
15
PointLights ,
16
+ PointsRasterizationSettings ,
17
+ PointsRasterizer ,
18
+ PointsRenderer ,
15
19
RasterizationSettings ,
16
20
SoftPhongShader ,
17
21
TexturesVertex ,
18
22
)
19
23
from pytorch3d .renderer .cameras import FoVPerspectiveCameras , look_at_view_transform
20
- from pytorch3d .structures . meshes import Meshes
24
+ from pytorch3d .structures import Meshes , Pointclouds
21
25
from pytorch3d .utils .ico_sphere import ico_sphere
22
26
23
27
27
31
print ("GPUs: %s" % ", " .join (GPU_LIST ))
28
32
29
33
30
- class TestRenderMultiGPU (TestCaseMixin , unittest .TestCase ):
34
+ class TestRenderMeshesMultiGPU (TestCaseMixin , unittest .TestCase ):
31
35
def _check_mesh_renderer_props_on_device (self , renderer , device ):
32
36
"""
33
37
Helper function to check that all the properties of the mesh
@@ -99,7 +103,7 @@ def test_mesh_renderer_to(self):
99
103
# This also tests that background_color is correctly moved to
100
104
# the new device
101
105
device2 = torch .device ("cuda:0" )
102
- renderer .to (device2 )
106
+ renderer = renderer .to (device2 )
103
107
mesh = mesh .to (device2 )
104
108
self ._check_mesh_renderer_props_on_device (renderer , device2 )
105
109
output_images = renderer (mesh )
@@ -137,7 +141,7 @@ def init_render(self):
137
141
138
142
def forward (self , verts , texs ):
139
143
batch_size = verts .size (0 )
140
- self .renderer .to (verts .device )
144
+ self .renderer = self . renderer .to (verts .device )
141
145
tex = TexturesVertex (verts_features = texs )
142
146
faces = self .faces .expand (batch_size , - 1 , - 1 ).to (verts .device )
143
147
mesh = Meshes (verts , faces , tex ).to (verts .device )
@@ -157,3 +161,53 @@ def forward(self, verts, texs):
157
161
# Test a few iterations
158
162
for _ in range (100 ):
159
163
model (verts , texs )
164
+
165
+
166
+ class TestRenderPointssMultiGPU (TestCaseMixin , unittest .TestCase ):
167
+ def _check_points_renderer_props_on_device (self , renderer , device ):
168
+ """
169
+ Helper function to check that all the properties have
170
+ been moved to the correct device.
171
+ """
172
+ # Cameras
173
+ self .assertEqual (renderer .rasterizer .cameras .device , device )
174
+ self .assertEqual (renderer .rasterizer .cameras .R .device , device )
175
+ self .assertEqual (renderer .rasterizer .cameras .T .device , device )
176
+
177
+ def test_points_renderer_to (self ):
178
+ """
179
+ Test moving all the tensors in the points renderer to a new device.
180
+ """
181
+
182
+ device1 = torch .device ("cpu" )
183
+
184
+ R , T = look_at_view_transform (1500 , 0.0 , 0.0 )
185
+
186
+ raster_settings = PointsRasterizationSettings (
187
+ image_size = 256 , radius = 0.001 , points_per_pixel = 1
188
+ )
189
+ cameras = FoVPerspectiveCameras (
190
+ device = device1 , R = R , T = T , aspect_ratio = 1.0 , fov = 60.0 , zfar = 100
191
+ )
192
+ rasterizer = PointsRasterizer (cameras = cameras , raster_settings = raster_settings )
193
+
194
+ renderer = PointsRenderer (rasterizer = rasterizer , compositor = AlphaCompositor ())
195
+
196
+ mesh = ico_sphere (2 , device1 )
197
+ verts_padded = mesh .verts_padded ()
198
+ pointclouds = Pointclouds (
199
+ points = verts_padded , features = torch .randn_like (verts_padded )
200
+ )
201
+ self ._check_points_renderer_props_on_device (renderer , device1 )
202
+
203
+ # Test rendering on cpu
204
+ output_images = renderer (pointclouds )
205
+ self .assertEqual (output_images .device , device1 )
206
+
207
+ # Move renderer and pointclouds to another device and re render
208
+ device2 = torch .device ("cuda:0" )
209
+ renderer = renderer .to (device2 )
210
+ pointclouds = pointclouds .to (device2 )
211
+ self ._check_points_renderer_props_on_device (renderer , device2 )
212
+ output_images = renderer (pointclouds )
213
+ self .assertEqual (output_images .device , device2 )
0 commit comments