Skip to content

Commit fa87c7c

Browse files
authored
[Refactor] CSL (#535)
* add UT * fix lint * update * Update test_oriented_reppoints.py * fix UT * init * Update angle_branch_retina_head.py * add ut * Update test_angle_coder.py * add UT * Update test_angle_branch_retina_head.py * fix * fix
1 parent d311e33 commit fa87c7c

File tree

9 files changed

+823
-612
lines changed

9 files changed

+823
-612
lines changed

configs/csl/rotated_retinanet_obb_csl_gaussian_r50_fpn_fp16_1x_dota_le90.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,15 @@
44
angle_version = 'le90'
55
model = dict(
66
bbox_head=dict(
7-
type='CSLRRetinaHead',
7+
type='AngleBranchRetinaHead',
88
angle_coder=dict(
99
type='CSLCoder',
1010
angle_version=angle_version,
1111
omega=4,
1212
window='gaussian',
1313
radius=3),
14-
loss_cls=dict(
15-
type='FocalLoss',
16-
use_sigmoid=True,
14+
loss_angle=dict(
15+
type='mmdet.SmoothFocalLoss',
1716
gamma=2.0,
1817
alpha=0.25,
19-
loss_weight=1.0),
20-
loss_bbox=dict(type='L1Loss', loss_weight=1.0),
21-
loss_angle=dict(
22-
type='SmoothFocalLoss', gamma=2.0, alpha=0.25, loss_weight=0.8)))
18+
loss_weight=0.8)))
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
_base_ = './oriented_rcnn_r50_fpn_1x_dota_le90.py'
22

3-
optim_wrapper = dict(type='AmpOptimWrapper')
3+
optim_wrapper = dict(type='AmpOptimWrapper', loss_scale='dynamic')
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
_base_ = ['./rotated_retinanet_obb_r50_fpn_1x_dota_le90.py']
22

3-
fp16 = dict(loss_scale='dynamic')
3+
optim_wrapper = dict(type='AmpOptimWrapper', loss_scale='dynamic')

mmrotate/core/bbox/coder/angle_coder.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import torch
55
from mmdet.models.task_modules.coders.base_bbox_coder import BaseBBoxCoder
6+
from torch import Tensor
67

78
from mmrotate.registry import TASK_UTILS
89

@@ -35,44 +36,44 @@ def __init__(self, angle_version, omega=1, window='gaussian', radius=6):
3536
self.omega = omega
3637
self.window = window
3738
self.radius = radius
38-
self.coding_len = int(self.angle_range // omega)
39+
self.encoded_size = int(self.angle_range // omega)
3940

40-
def encode(self, angle_targets):
41+
def encode(self, angle_targets: Tensor) -> Tensor:
4142
"""Circular Smooth Label Encoder.
4243
4344
Args:
4445
angle_targets (Tensor): Angle offset for each scale level
4546
Has shape (num_anchors * H * W, 1)
4647
4748
Returns:
48-
list[Tensor]: The csl encoding of angle offset for each
49-
scale level. Has shape (num_anchors * H * W, coding_len)
49+
Tensor: The csl encoding of angle offset for each scale
50+
level. Has shape (num_anchors * H * W, encoded_size)
5051
"""
5152

5253
# radius to degree
5354
angle_targets_deg = angle_targets * (180 / math.pi)
5455
# empty label
5556
smooth_label = torch.zeros_like(angle_targets).repeat(
56-
1, self.coding_len)
57+
1, self.encoded_size)
5758
angle_targets_deg = (angle_targets_deg +
5859
self.angle_offset) / self.omega
5960
# Float to Int
6061
angle_targets_long = angle_targets_deg.long()
6162

6263
if self.window == 'pulse':
63-
radius_range = angle_targets_long % self.coding_len
64+
radius_range = angle_targets_long % self.encoded_size
6465
smooth_value = 1.0
6566
elif self.window == 'rect':
6667
base_radius_range = torch.arange(
6768
-self.radius, self.radius, device=angle_targets_long.device)
6869
radius_range = (base_radius_range +
69-
angle_targets_long) % self.coding_len
70+
angle_targets_long) % self.encoded_size
7071
smooth_value = 1.0
7172
elif self.window == 'triangle':
7273
base_radius_range = torch.arange(
7374
-self.radius, self.radius, device=angle_targets_long.device)
7475
radius_range = (base_radius_range +
75-
angle_targets_long) % self.coding_len
76+
angle_targets_long) % self.encoded_size
7677
smooth_value = 1.0 - torch.abs(
7778
(1 / self.radius) * base_radius_range)
7879

@@ -83,7 +84,7 @@ def encode(self, angle_targets):
8384
device=angle_targets_long.device)
8485

8586
radius_range = (base_radius_range +
86-
angle_targets_long) % self.coding_len
87+
angle_targets_long) % self.encoded_size
8788
smooth_value = torch.exp(-torch.pow(base_radius_range, 2) /
8889
(2 * self.radius**2))
8990

@@ -96,19 +97,24 @@ def encode(self, angle_targets):
9697

9798
return smooth_label.scatter(1, radius_range, smooth_value)
9899

99-
def decode(self, angle_preds):
100+
def decode(self, angle_preds: Tensor, keepdim: bool = False) -> Tensor:
100101
"""Circular Smooth Label Decoder.
101102
102103
Args:
103-
angle_preds (Tensor): The csl encoding of angle offset
104-
for each scale level.
105-
Has shape (num_anchors * H * W, coding_len)
104+
angle_preds (Tensor): The csl encoding of angle offset for each
105+
scale level. Has shape (num_anchors * H * W, encoded_size) or
106+
(B, num_anchors * H * W, encoded_size)
107+
keepdim (bool): Whether the output tensor has dim retained or not.
108+
106109
107110
Returns:
108-
list[Tensor]: Angle offset for each scale level.
109-
Has shape (num_anchors * H * W, 1)
111+
Tensor: Angle offset for each scale level. When keepdim is true,
112+
return (num_anchors * H * W, 1) or (B, num_anchors * H * W, 1),
113+
otherwise (num_anchors * H * W,) or (B, num_anchors * H * W)
110114
"""
111-
angle_cls_inds = torch.argmax(angle_preds, dim=1)
115+
if angle_preds.shape[0] == 0:
116+
return angle_preds.new_zeros((0))
117+
angle_cls_inds = torch.argmax(angle_preds, dim=-1, keepdim=keepdim)
112118
angle_pred = ((angle_cls_inds + 0.5) *
113119
self.omega) % self.angle_range - self.angle_offset
114120
return angle_pred * (math.pi / 180)

mmrotate/models/dense_heads/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
from .angle_branch_retina_head import AngleBranchRetinaHead
23
from .cfa_head import CFAHead
34
from .csl_rotated_fcos_head import CSLRFCOSHead
4-
from .csl_rotated_retina_head import CSLRRetinaHead
55
from .kfiou_odm_refine_head import KFIoUODMRefineHead
66
from .kfiou_rotate_retina_head import KFIoURRetinaHead
77
from .kfiou_rotate_retina_refine_head import KFIoURRetinaRefineHead
@@ -23,7 +23,7 @@
2323
'RotatedRetinaHead', 'RotatedRPNHead', 'OrientedRPNHead',
2424
'RotatedRetinaRefineHead', 'ODMRefineHead', 'KFIoURRetinaHead',
2525
'KFIoURRetinaRefineHead', 'KFIoUODMRefineHead', 'RotatedRepPointsHead',
26-
'SAMRepPointsHead', 'CSLRRetinaHead', 'RotatedATSSHead',
26+
'SAMRepPointsHead', 'AngleBranchRetinaHead', 'RotatedATSSHead',
2727
'RotatedAnchorFreeHead', 'RotatedFCOSHead', 'CSLRFCOSHead',
2828
'OrientedRepPointsHead', 'R3Head', 'R3RefineHead', 'S2AHead',
2929
'S2ARefineHead', 'CFAHead'

0 commit comments

Comments
 (0)