7
7
8
8
from pytorch3d .structures .textures import Textures
9
9
10
-
11
- def _clip_barycentric_coordinates (bary ) -> torch .Tensor :
12
- """
13
- Args:
14
- bary: barycentric coordinates of shape (...., 3) where `...` represents
15
- an arbitrary number of dimensions
16
-
17
- Returns:
18
- bary: All barycentric coordinate values clipped to the range [0, 1]
19
- and renormalized. The output is the same shape as the input.
20
- """
21
- if bary .shape [- 1 ] != 3 :
22
- msg = "Expected barycentric coords to have last dim = 3; got %r"
23
- raise ValueError (msg % bary .shape )
24
- clipped = bary .clamp (min = 0 , max = 1 )
25
- clipped_sum = torch .clamp (clipped .sum (dim = - 1 , keepdim = True ), min = 1e-5 )
26
- clipped = clipped / clipped_sum
27
- return clipped
28
-
29
-
30
- def interpolate_face_attributes (
31
- fragments , face_attributes : torch .Tensor , bary_clip : bool = False
32
- ) -> torch .Tensor :
33
- """
34
- Interpolate arbitrary face attributes using the barycentric coordinates
35
- for each pixel in the rasterized output.
36
-
37
- Args:
38
- fragments:
39
- The outputs of rasterization. From this we use
40
-
41
- - pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices
42
- of the faces (in the packed representation) which
43
- overlap each pixel in the image.
44
- - barycentric_coords: FloatTensor of shape (N, H, W, K, 3) specifying
45
- the barycentric coordianates of each pixel
46
- relative to the faces (in the packed
47
- representation) which overlap the pixel.
48
- face_attributes: packed attributes of shape (total_faces, 3, D),
49
- specifying the value of the attribute for each
50
- vertex in the face.
51
- bary_clip: Bool to indicate if barycentric_coords should be clipped
52
- before being used for interpolation.
53
-
54
- Returns:
55
- pixel_vals: tensor of shape (N, H, W, K, D) giving the interpolated
56
- value of the face attribute for each pixel.
57
- """
58
- pix_to_face = fragments .pix_to_face
59
- barycentric_coords = fragments .bary_coords
60
- F , FV , D = face_attributes .shape
61
- if FV != 3 :
62
- raise ValueError ("Faces can only have three vertices; got %r" % FV )
63
- N , H , W , K , _ = barycentric_coords .shape
64
- if pix_to_face .shape != (N , H , W , K ):
65
- msg = "pix_to_face must have shape (batch_size, H, W, K); got %r"
66
- raise ValueError (msg % pix_to_face .shape )
67
- if bary_clip :
68
- barycentric_coords = _clip_barycentric_coordinates (barycentric_coords )
69
-
70
- # Replace empty pixels in pix_to_face with 0 in order to interpolate.
71
- mask = pix_to_face == - 1
72
- pix_to_face = pix_to_face .clone ()
73
- pix_to_face [mask ] = 0
74
- idx = pix_to_face .view (N * H * W * K , 1 , 1 ).expand (N * H * W * K , 3 , D )
75
- pixel_face_vals = face_attributes .gather (0 , idx ).view (N , H , W , K , 3 , D )
76
- pixel_vals = (barycentric_coords [..., None ] * pixel_face_vals ).sum (dim = - 2 )
77
- pixel_vals [mask ] = 0 # Replace masked values in output.
78
- return pixel_vals
10
+ from .utils import interpolate_face_attributes
79
11
80
12
81
13
def interpolate_texture_map (fragments , meshes ) -> torch .Tensor :
@@ -97,8 +29,8 @@ def interpolate_texture_map(fragments, meshes) -> torch.Tensor:
97
29
relative to the faces (in the packed
98
30
representation) which overlap the pixel.
99
31
meshes: Meshes representing a batch of meshes. It is expected that
100
- meshes has a textures attribute which is an instance of the
101
- Textures class.
32
+ meshes has a textures attribute which is an instance of the
33
+ Textures class.
102
34
103
35
Returns:
104
36
texels: tensor of shape (N, H, W, K, C) giving the interpolated
@@ -114,7 +46,9 @@ def interpolate_texture_map(fragments, meshes) -> torch.Tensor:
114
46
texture_maps = meshes .textures .maps_padded ()
115
47
116
48
# pixel_uvs: (N, H, W, K, 2)
117
- pixel_uvs = interpolate_face_attributes (fragments , faces_verts_uvs )
49
+ pixel_uvs = interpolate_face_attributes (
50
+ fragments .pix_to_face , fragments .bary_coords , faces_verts_uvs
51
+ )
118
52
119
53
N , H_out , W_out , K = fragments .pix_to_face .shape
120
54
N , H_in , W_in , C = texture_maps .shape # 3 for RGB
@@ -178,5 +112,7 @@ def interpolate_vertex_colors(fragments, meshes) -> torch.Tensor:
178
112
vertex_textures = vertex_textures [meshes .verts_padded_to_packed_idx (), :]
179
113
faces_packed = meshes .faces_packed ()
180
114
faces_textures = vertex_textures [faces_packed ] # (F, 3, C)
181
- texels = interpolate_face_attributes (fragments , faces_textures )
115
+ texels = interpolate_face_attributes (
116
+ fragments .pix_to_face , fragments .bary_coords , faces_textures
117
+ )
182
118
return texels
0 commit comments