5
5
import warnings
6
6
import torch
7
7
8
+ from .rotation_conversions import _axis_angle_rotation
9
+
8
10
9
11
class Transform3d :
10
12
"""
@@ -103,12 +105,35 @@ class Transform3d:
103
105
s1_params -= lr * s1_params.grad
104
106
t_params -= lr * t_params.grad
105
107
s2_params -= lr * s2_params.grad
108
+
109
+ CONVENTIONS
110
+ We adopt a right-hand coordinate system, meaning that rotation about an axis
111
+ with a positive angle results in a counter clockwise rotation.
112
+
113
+ This class assumes that transformations are applied on inputs which
114
+ are row vectors. The internal representation of the Nx4x4 transformation
115
+ matrix is of the form:
116
+
117
+ .. code-block:: python
118
+
119
+ M = [
120
+ [Rxx, Ryx, Rzx, 0],
121
+ [Rxy, Ryy, Rzy, 0],
122
+ [Rxz, Ryz, Rzz, 0],
123
+ [Tx, Ty, Tz, 1],
124
+ ]
125
+
126
+ To apply the transformation to points which are row vectors, the M matrix
127
+ can be pre multiplied by the points:
128
+
129
+ .. code-block:: python
130
+
131
+ points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point
132
+ transformed_points = points * M
133
+
106
134
"""
107
135
108
136
def __init__ (self , dtype = torch .float32 , device = "cpu" ):
109
- """
110
- This class assumes a row major ordering for all matrices.
111
- """
112
137
self ._matrix = torch .eye (4 , dtype = dtype , device = device ).view (1 , 4 , 4 )
113
138
self ._transforms = [] # store transforms to compose
114
139
self ._lu = None
@@ -493,9 +518,12 @@ def __init__(
493
518
Create a new Transform3d representing 3D rotation about an axis
494
519
by an angle.
495
520
521
+ Assuming a right-hand coordinate system, positive rotation angles result
522
+ in a counter clockwise rotation.
523
+
496
524
Args:
497
525
angle:
498
- - A torch tensor of shape (N, 1 )
526
+ - A torch tensor of shape (N,)
499
527
- A python scalar
500
528
- A torch scalar
501
529
axis:
@@ -509,21 +537,11 @@ def __init__(
509
537
raise ValueError (msg % axis )
510
538
angle = _handle_angle_input (angle , dtype , device , "RotateAxisAngle" )
511
539
angle = (angle / 180.0 * math .pi ) if degrees else angle
512
- N = angle .shape [0 ]
513
-
514
- cos = torch .cos (angle )
515
- sin = torch .sin (angle )
516
- one = torch .ones_like (angle )
517
- zero = torch .zeros_like (angle )
518
-
519
- if axis == "X" :
520
- R_flat = (one , zero , zero , zero , cos , - sin , zero , sin , cos )
521
- if axis == "Y" :
522
- R_flat = (cos , zero , sin , zero , one , zero , - sin , zero , cos )
523
- if axis == "Z" :
524
- R_flat = (cos , - sin , zero , sin , cos , zero , zero , zero , one )
525
-
526
- R = torch .stack (R_flat , - 1 ).reshape ((N , 3 , 3 ))
540
+ # We assume the points on which this transformation will be applied
541
+ # are row vectors. The rotation matrix returned from _axis_angle_rotation
542
+ # is for transforming column vectors. Therefore we transpose this matrix.
543
+ # R will always be of shape (N, 3, 3)
544
+ R = _axis_angle_rotation (axis , angle ).transpose (1 , 2 )
527
545
super ().__init__ (device = device , R = R )
528
546
529
547
@@ -606,19 +624,16 @@ def _handle_input(
606
624
def _handle_angle_input (x , dtype , device : str , name : str ):
607
625
"""
608
626
Helper function for building a rotation function using angles.
609
- The output is always of shape (N, 1 ).
627
+ The output is always of shape (N,).
610
628
611
629
The input can be one of:
612
- - Torch tensor (N, 1) or (N )
630
+ - Torch tensor of shape (N, )
613
631
- Python scalar
614
632
- Torch scalar
615
633
"""
616
- # If x is actually a tensor of shape (N, 1) then just return it
617
- if torch .is_tensor (x ) and x .dim () == 2 :
618
- if x .shape [1 ] != 1 :
619
- msg = "Expected tensor of shape (N, 1); got %r (in %s)"
620
- raise ValueError (msg % (x .shape , name ))
621
- return x
634
+ if torch .is_tensor (x ) and x .dim () > 1 :
635
+ msg = "Expected tensor of shape (N,); got %r (in %s)"
636
+ raise ValueError (msg % (x .shape , name ))
622
637
else :
623
638
return _handle_coord (x , dtype , device )
624
639
0 commit comments