Skip to content

Commit 8301163

Browse files
nikhilaravifacebook-github-bot
authored andcommitted
transforms 3d convention fix
Summary: Fixed the rotation matrices generated by the RotateAxisAngle class and updated the tests. Added documentation for Transforms3d to clarify the conventions. Reviewed By: gkioxari Differential Revision: D19912903 fbshipit-source-id: c64926ce4e1381b145811557c32b73663d6d92d1
1 parent bdc2bb5 commit 8301163

File tree

4 files changed

+203
-104
lines changed

4 files changed

+203
-104
lines changed

pytorch3d/transforms/rotation_conversions.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,32 @@
55
import torch
66

77

8+
"""
9+
The transformation matrices returned from the functions in this file assume
10+
the points on which the transformation will be applied are column vectors.
11+
i.e. the R matrix is structured as
12+
13+
R = [
14+
[Rxx, Rxy, Rxz],
15+
[Ryx, Ryy, Ryz],
16+
[Rzx, Rzy, Rzz],
17+
] # (3, 3)
18+
19+
This matrix can be applied to column vectors by post multiplication
20+
by the points e.g.
21+
22+
points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point
23+
transformed_points = R * points
24+
25+
To apply the same matrix to points which are row vectors, the R matrix
26+
can be transposed and pre multiplied by the points:
27+
28+
e.g.
29+
points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point
30+
transformed_points = points * R.transpose(1, 0)
31+
"""
32+
33+
834
def quaternion_to_matrix(quaternions):
935
"""
1036
Convert rotations given as quaternions to rotation matrices.
@@ -80,7 +106,7 @@ def matrix_to_quaternion(matrix):
80106
return torch.stack((o0, o1, o2, o3), -1)
81107

82108

83-
def _primary_matrix(axis: str, angle):
109+
def _axis_angle_rotation(axis: str, angle):
84110
"""
85111
Return the rotation matrices for one of the rotations about an axis
86112
of which Euler angles describe, for each value of the angle given.
@@ -92,17 +118,20 @@ def _primary_matrix(axis: str, angle):
92118
Returns:
93119
Rotation matrices as tensor of shape (..., 3, 3).
94120
"""
121+
95122
cos = torch.cos(angle)
96123
sin = torch.sin(angle)
97124
one = torch.ones_like(angle)
98125
zero = torch.zeros_like(angle)
126+
99127
if axis == "X":
100-
o = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
128+
R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
101129
if axis == "Y":
102-
o = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
130+
R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
103131
if axis == "Z":
104-
o = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
105-
return torch.stack(o, -1).reshape(angle.shape + (3, 3))
132+
R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
133+
134+
return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))
106135

107136

108137
def euler_angles_to_matrix(euler_angles, convention: str):
@@ -126,7 +155,9 @@ def euler_angles_to_matrix(euler_angles, convention: str):
126155
for letter in convention:
127156
if letter not in ("X", "Y", "Z"):
128157
raise ValueError(f"Invalid letter {letter} in convention string.")
129-
matrices = map(_primary_matrix, convention, torch.unbind(euler_angles, -1))
158+
matrices = map(
159+
_axis_angle_rotation, convention, torch.unbind(euler_angles, -1)
160+
)
130161
return functools.reduce(torch.matmul, matrices)
131162

132163

pytorch3d/transforms/transform3d.py

Lines changed: 42 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import warnings
66
import torch
77

8+
from .rotation_conversions import _axis_angle_rotation
9+
810

911
class Transform3d:
1012
"""
@@ -103,12 +105,35 @@ class Transform3d:
103105
s1_params -= lr * s1_params.grad
104106
t_params -= lr * t_params.grad
105107
s2_params -= lr * s2_params.grad
108+
109+
CONVENTIONS
110+
We adopt a right-hand coordinate system, meaning that rotation about an axis
111+
with a positive angle results in a counter clockwise rotation.
112+
113+
This class assumes that transformations are applied on inputs which
114+
are row vectors. The internal representation of the Nx4x4 transformation
115+
matrix is of the form:
116+
117+
.. code-block:: python
118+
119+
M = [
120+
[Rxx, Ryx, Rzx, 0],
121+
[Rxy, Ryy, Rzy, 0],
122+
[Rxz, Ryz, Rzz, 0],
123+
[Tx, Ty, Tz, 1],
124+
]
125+
126+
To apply the transformation to points which are row vectors, the M matrix
127+
can be pre multiplied by the points:
128+
129+
.. code-block:: python
130+
131+
points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point
132+
transformed_points = points * M
133+
106134
"""
107135

108136
def __init__(self, dtype=torch.float32, device="cpu"):
109-
"""
110-
This class assumes a row major ordering for all matrices.
111-
"""
112137
self._matrix = torch.eye(4, dtype=dtype, device=device).view(1, 4, 4)
113138
self._transforms = [] # store transforms to compose
114139
self._lu = None
@@ -493,9 +518,12 @@ def __init__(
493518
Create a new Transform3d representing 3D rotation about an axis
494519
by an angle.
495520
521+
Assuming a right-hand coordinate system, positive rotation angles result
522+
in a counter clockwise rotation.
523+
496524
Args:
497525
angle:
498-
- A torch tensor of shape (N, 1)
526+
- A torch tensor of shape (N,)
499527
- A python scalar
500528
- A torch scalar
501529
axis:
@@ -509,21 +537,11 @@ def __init__(
509537
raise ValueError(msg % axis)
510538
angle = _handle_angle_input(angle, dtype, device, "RotateAxisAngle")
511539
angle = (angle / 180.0 * math.pi) if degrees else angle
512-
N = angle.shape[0]
513-
514-
cos = torch.cos(angle)
515-
sin = torch.sin(angle)
516-
one = torch.ones_like(angle)
517-
zero = torch.zeros_like(angle)
518-
519-
if axis == "X":
520-
R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
521-
if axis == "Y":
522-
R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
523-
if axis == "Z":
524-
R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
525-
526-
R = torch.stack(R_flat, -1).reshape((N, 3, 3))
540+
# We assume the points on which this transformation will be applied
541+
# are row vectors. The rotation matrix returned from _axis_angle_rotation
542+
# is for transforming column vectors. Therefore we transpose this matrix.
543+
# R will always be of shape (N, 3, 3)
544+
R = _axis_angle_rotation(axis, angle).transpose(1, 2)
527545
super().__init__(device=device, R=R)
528546

529547

@@ -606,19 +624,16 @@ def _handle_input(
606624
def _handle_angle_input(x, dtype, device: str, name: str):
607625
"""
608626
Helper function for building a rotation function using angles.
609-
The output is always of shape (N, 1).
627+
The output is always of shape (N,).
610628
611629
The input can be one of:
612-
- Torch tensor (N, 1) or (N)
630+
- Torch tensor of shape (N,)
613631
- Python scalar
614632
- Torch scalar
615633
"""
616-
# If x is actually a tensor of shape (N, 1) then just return it
617-
if torch.is_tensor(x) and x.dim() == 2:
618-
if x.shape[1] != 1:
619-
msg = "Expected tensor of shape (N, 1); got %r (in %s)"
620-
raise ValueError(msg % (x.shape, name))
621-
return x
634+
if torch.is_tensor(x) and x.dim() > 1:
635+
msg = "Expected tensor of shape (N,); got %r (in %s)"
636+
raise ValueError(msg % (x.shape, name))
622637
else:
623638
return _handle_coord(x, dtype, device)
624639

tests/test_rotation_conversions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch
99

1010
from pytorch3d.transforms.rotation_conversions import (
11+
_axis_angle_rotation,
1112
euler_angles_to_matrix,
1213
matrix_to_euler_angles,
1314
matrix_to_quaternion,
@@ -118,7 +119,6 @@ def test_from_euler(self):
118119
def test_to_euler(self):
119120
"""mtx -> euler -> mtx"""
120121
data = random_rotations(13, dtype=torch.float64)
121-
122122
for convention in self._all_euler_angle_conventions():
123123
euler_angles = matrix_to_euler_angles(data, convention)
124124
mdata = euler_angles_to_matrix(euler_angles, convention)

0 commit comments

Comments
 (0)