33
44import torch
55from mmdet .models .task_modules .coders .base_bbox_coder import BaseBBoxCoder
6+ from torch import Tensor
67
78from mmrotate .registry import TASK_UTILS
89
@@ -35,44 +36,44 @@ def __init__(self, angle_version, omega=1, window='gaussian', radius=6):
3536 self .omega = omega
3637 self .window = window
3738 self .radius = radius
38- self .coding_len = int (self .angle_range // omega )
39+ self .encoded_size = int (self .angle_range // omega )
3940
40- def encode (self , angle_targets ) :
41+ def encode (self , angle_targets : Tensor ) -> Tensor :
4142 """Circular Smooth Label Encoder.
4243
4344 Args:
4445 angle_targets (Tensor): Angle offset for each scale level
4546 Has shape (num_anchors * H * W, 1)
4647
4748 Returns:
48- list[ Tensor] : The csl encoding of angle offset for each
49- scale level. Has shape (num_anchors * H * W, coding_len )
49+ Tensor: The csl encoding of angle offset for each scale
50+ level. Has shape (num_anchors * H * W, encoded_size )
5051 """
5152
5253 # radius to degree
5354 angle_targets_deg = angle_targets * (180 / math .pi )
5455 # empty label
5556 smooth_label = torch .zeros_like (angle_targets ).repeat (
56- 1 , self .coding_len )
57+ 1 , self .encoded_size )
5758 angle_targets_deg = (angle_targets_deg +
5859 self .angle_offset ) / self .omega
5960 # Float to Int
6061 angle_targets_long = angle_targets_deg .long ()
6162
6263 if self .window == 'pulse' :
63- radius_range = angle_targets_long % self .coding_len
64+ radius_range = angle_targets_long % self .encoded_size
6465 smooth_value = 1.0
6566 elif self .window == 'rect' :
6667 base_radius_range = torch .arange (
6768 - self .radius , self .radius , device = angle_targets_long .device )
6869 radius_range = (base_radius_range +
69- angle_targets_long ) % self .coding_len
70+ angle_targets_long ) % self .encoded_size
7071 smooth_value = 1.0
7172 elif self .window == 'triangle' :
7273 base_radius_range = torch .arange (
7374 - self .radius , self .radius , device = angle_targets_long .device )
7475 radius_range = (base_radius_range +
75- angle_targets_long ) % self .coding_len
76+ angle_targets_long ) % self .encoded_size
7677 smooth_value = 1.0 - torch .abs (
7778 (1 / self .radius ) * base_radius_range )
7879
@@ -83,7 +84,7 @@ def encode(self, angle_targets):
8384 device = angle_targets_long .device )
8485
8586 radius_range = (base_radius_range +
86- angle_targets_long ) % self .coding_len
87+ angle_targets_long ) % self .encoded_size
8788 smooth_value = torch .exp (- torch .pow (base_radius_range , 2 ) /
8889 (2 * self .radius ** 2 ))
8990
@@ -96,19 +97,24 @@ def encode(self, angle_targets):
9697
9798 return smooth_label .scatter (1 , radius_range , smooth_value )
9899
99- def decode (self , angle_preds ) :
100+ def decode (self , angle_preds : Tensor , keepdim : bool = False ) -> Tensor :
100101 """Circular Smooth Label Decoder.
101102
102103 Args:
103- angle_preds (Tensor): The csl encoding of angle offset
104- for each scale level.
105- Has shape (num_anchors * H * W, coding_len)
104+ angle_preds (Tensor): The csl encoding of angle offset for each
105+ scale level. Has shape (num_anchors * H * W, encoded_size) or
106+ (B, num_anchors * H * W, encoded_size)
107+ keepdim (bool): Whether the output tensor has dim retained or not.
108+
106109
107110 Returns:
108- list[Tensor]: Angle offset for each scale level.
109- Has shape (num_anchors * H * W, 1)
111+ Tensor: Angle offset for each scale level. When keepdim is true,
112+ return (num_anchors * H * W, 1) or (B, num_anchors * H * W, 1),
113+ otherwise (num_anchors * H * W,) or (B, num_anchors * H * W)
110114 """
111- angle_cls_inds = torch .argmax (angle_preds , dim = 1 )
115+ if angle_preds .shape [0 ] == 0 :
116+ return angle_preds .new_zeros ((0 ))
117+ angle_cls_inds = torch .argmax (angle_preds , dim = - 1 , keepdim = keepdim )
112118 angle_pred = ((angle_cls_inds + 0.5 ) *
113119 self .omega ) % self .angle_range - self .angle_offset
114120 return angle_pred * (math .pi / 180 )
0 commit comments