Skip to content

Commit a237bf2

Browse files
committed
fix
1 parent 8135b86 commit a237bf2

File tree

4 files changed

+30
-19
lines changed

4 files changed

+30
-19
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
_base_ = './oriented_rcnn_r50_fpn_1x_dota_le90.py'
22

3-
optim_wrapper = dict(type='AmpOptimWrapper')
3+
optim_wrapper = dict(type='AmpOptimWrapper', loss_scale='dynamic')
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
_base_ = ['./rotated_retinanet_obb_r50_fpn_1x_dota_le90.py']
22

3-
fp16 = dict(loss_scale='dynamic')
3+
optim_wrapper = dict(type='AmpOptimWrapper', loss_scale='dynamic')

mmrotate/core/bbox/coder/angle_coder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,8 @@ def decode(self, angle_preds: Tensor) -> Tensor:
109109
Tensor: Angle offset for each scale level.
110110
Has shape (num_anchors * H * W, 1)
111111
"""
112+
if angle_preds.shape[0] == 0:
113+
return angle_preds.new_zeros((0))
112114
angle_cls_inds = torch.argmax(angle_preds, dim=1)
113115
angle_pred = ((angle_cls_inds + 0.5) *
114116
self.omega) % self.angle_range - self.angle_offset

mmrotate/models/dense_heads/angle_branch_retina_head.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def forward_single(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
9696
the channels number is num_anchors * num_classes.
9797
- bbox_pred (Tensor): Box energies / deltas for a single scale
9898
level, the channels number is num_anchors * 5.
99-
- angle_cls (Tensor): Angle for a single scale level the channels
99+
- angle_pred (Tensor): Angle for a single scale level the channels
100100
number is num_anchors * coding_len.
101101
"""
102102
cls_feat = x
@@ -107,8 +107,8 @@ def forward_single(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
107107
reg_feat = reg_conv(reg_feat)
108108
cls_score = self.retina_cls(cls_feat)
109109
bbox_pred = self.retina_reg(reg_feat)
110-
angle_cls = self.retina_angle_cls(reg_feat)
111-
return cls_score, bbox_pred, angle_cls
110+
angle_pred = self.retina_angle_cls(reg_feat)
111+
return cls_score, bbox_pred, angle_pred
112112

113113
def loss_by_feat_single(self, cls_score: Tensor, bbox_pred: Tensor,
114114
angle_pred: Tensor, anchors: Tensor,
@@ -145,6 +145,11 @@ def loss_by_feat_single(self, cls_score: Tensor, bbox_pred: Tensor,
145145
Returns:
146146
tuple: loss components.
147147
"""
148+
# Equivalent substitution of ``@force_fp32()``
149+
cls_score = cls_score.float()
150+
bbox_pred = bbox_pred.float()
151+
angle_pred = angle_pred.float()
152+
148153
# classification loss
149154
labels = labels.reshape(-1)
150155
label_weights = label_weights.reshape(-1)
@@ -171,7 +176,6 @@ def loss_by_feat_single(self, cls_score: Tensor, bbox_pred: Tensor,
171176
bbox_pred = get_box_tensor(bbox_pred)
172177
loss_bbox = self.loss_bbox(
173178
bbox_pred, bbox_targets, bbox_weights, avg_factor=avg_factor)
174-
175179
angle_pred = angle_pred.permute(0, 2, 3,
176180
1).reshape(-1, self.coding_len)
177181
angle_targets = angle_targets.reshape(-1, self.coding_len)
@@ -549,7 +553,6 @@ def _predict_by_feat_single(self,
549553
mlvl_valid_priors = []
550554
mlvl_scores = []
551555
mlvl_labels = []
552-
mlvl_angle_preds = []
553556
if with_score_factors:
554557
mlvl_score_factors = []
555558
else:
@@ -558,6 +561,11 @@ def _predict_by_feat_single(self,
558561
enumerate(zip(cls_score_list, bbox_pred_list, angle_pred_list,
559562
score_factor_list, mlvl_priors)):
560563

564+
# Equivalent substitution of ``@force_fp32()``
565+
cls_score = cls_score.float()
566+
bbox_pred = bbox_pred.float()
567+
angle_pred = angle_pred.float()
568+
561569
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
562570

563571
dim = self.bbox_coder.encode_size
@@ -596,27 +604,28 @@ def _predict_by_feat_single(self,
596604
if with_score_factors:
597605
score_factor = score_factor[keep_idxs]
598606

607+
# Angle decoder
608+
angle_pred = self.angle_coder.decode(angle_pred)
609+
610+
if self.use_encoded_angle:
611+
bbox_pred[..., -1] = angle_pred
612+
bbox_pred = self.bbox_coder.decode(
613+
priors, bbox_pred, max_shape=img_shape)
614+
else:
615+
bbox_pred = self.bbox_coder.decode(
616+
priors, bbox_pred, max_shape=img_shape)
617+
bbox_pred[..., -1] = angle_pred
618+
599619
mlvl_bbox_preds.append(bbox_pred)
600620
mlvl_valid_priors.append(priors)
601621
mlvl_scores.append(scores)
602622
mlvl_labels.append(labels)
603-
mlvl_angle_preds.append(angle_pred)
604623

605624
if with_score_factors:
606625
mlvl_score_factors.append(score_factor)
607626

608-
bbox_pred = torch.cat(mlvl_bbox_preds)
627+
bboxes = cat_boxes(mlvl_bbox_preds)
609628
priors = cat_boxes(mlvl_valid_priors)
610-
angle_pred = torch.cat(mlvl_angle_preds)
611-
612-
if self.use_encoded_angle:
613-
bbox_pred[..., -1] = angle_pred
614-
bboxes = self.bbox_coder.decode(
615-
priors, bbox_pred, max_shape=img_shape)
616-
else:
617-
bboxes = self.bbox_coder.decode(
618-
priors, bbox_pred, max_shape=img_shape)
619-
bboxes[..., -1] = angle_pred
620629

621630
results = InstanceData()
622631
results.bboxes = bboxes

0 commit comments

Comments
 (0)