@@ -96,7 +96,7 @@ def forward_single(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
9696 the channels number is num_anchors * num_classes.
9797 - bbox_pred (Tensor): Box energies / deltas for a single scale
9898 level, the channels number is num_anchors * 5.
99- - angle_cls (Tensor): Angle for a single scale level the channels
99+ - angle_pred (Tensor): Angle for a single scale level the channels
100100 number is num_anchors * coding_len.
101101 """
102102 cls_feat = x
@@ -107,8 +107,8 @@ def forward_single(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
107107 reg_feat = reg_conv (reg_feat )
108108 cls_score = self .retina_cls (cls_feat )
109109 bbox_pred = self .retina_reg (reg_feat )
110- angle_cls = self .retina_angle_cls (reg_feat )
111- return cls_score , bbox_pred , angle_cls
110+ angle_pred = self .retina_angle_cls (reg_feat )
111+ return cls_score , bbox_pred , angle_pred
112112
113113 def loss_by_feat_single (self , cls_score : Tensor , bbox_pred : Tensor ,
114114 angle_pred : Tensor , anchors : Tensor ,
@@ -145,6 +145,11 @@ def loss_by_feat_single(self, cls_score: Tensor, bbox_pred: Tensor,
145145 Returns:
146146 tuple: loss components.
147147 """
148+ # Equivalent substitution of ``@force_fp32()``
149+ cls_score = cls_score .float ()
150+ bbox_pred = bbox_pred .float ()
151+ angle_pred = angle_pred .float ()
152+
148153 # classification loss
149154 labels = labels .reshape (- 1 )
150155 label_weights = label_weights .reshape (- 1 )
@@ -171,7 +176,6 @@ def loss_by_feat_single(self, cls_score: Tensor, bbox_pred: Tensor,
171176 bbox_pred = get_box_tensor (bbox_pred )
172177 loss_bbox = self .loss_bbox (
173178 bbox_pred , bbox_targets , bbox_weights , avg_factor = avg_factor )
174-
175179 angle_pred = angle_pred .permute (0 , 2 , 3 ,
176180 1 ).reshape (- 1 , self .coding_len )
177181 angle_targets = angle_targets .reshape (- 1 , self .coding_len )
@@ -549,7 +553,6 @@ def _predict_by_feat_single(self,
549553 mlvl_valid_priors = []
550554 mlvl_scores = []
551555 mlvl_labels = []
552- mlvl_angle_preds = []
553556 if with_score_factors :
554557 mlvl_score_factors = []
555558 else :
@@ -558,6 +561,11 @@ def _predict_by_feat_single(self,
558561 enumerate (zip (cls_score_list , bbox_pred_list , angle_pred_list ,
559562 score_factor_list , mlvl_priors )):
560563
564+ # Equivalent substitution of ``@force_fp32()``
565+ cls_score = cls_score .float ()
566+ bbox_pred = bbox_pred .float ()
567+ angle_pred = angle_pred .float ()
568+
561569 assert cls_score .size ()[- 2 :] == bbox_pred .size ()[- 2 :]
562570
563571 dim = self .bbox_coder .encode_size
@@ -596,27 +604,28 @@ def _predict_by_feat_single(self,
596604 if with_score_factors :
597605 score_factor = score_factor [keep_idxs ]
598606
607+ # Angle decoder
608+ angle_pred = self .angle_coder .decode (angle_pred )
609+
610+ if self .use_encoded_angle :
611+ bbox_pred [..., - 1 ] = angle_pred
612+ bbox_pred = self .bbox_coder .decode (
613+ priors , bbox_pred , max_shape = img_shape )
614+ else :
615+ bbox_pred = self .bbox_coder .decode (
616+ priors , bbox_pred , max_shape = img_shape )
617+ bbox_pred [..., - 1 ] = angle_pred
618+
599619 mlvl_bbox_preds .append (bbox_pred )
600620 mlvl_valid_priors .append (priors )
601621 mlvl_scores .append (scores )
602622 mlvl_labels .append (labels )
603- mlvl_angle_preds .append (angle_pred )
604623
605624 if with_score_factors :
606625 mlvl_score_factors .append (score_factor )
607626
608- bbox_pred = torch . cat (mlvl_bbox_preds )
627+ bboxes = cat_boxes (mlvl_bbox_preds )
609628 priors = cat_boxes (mlvl_valid_priors )
610- angle_pred = torch .cat (mlvl_angle_preds )
611-
612- if self .use_encoded_angle :
613- bbox_pred [..., - 1 ] = angle_pred
614- bboxes = self .bbox_coder .decode (
615- priors , bbox_pred , max_shape = img_shape )
616- else :
617- bboxes = self .bbox_coder .decode (
618- priors , bbox_pred , max_shape = img_shape )
619- bboxes [..., - 1 ] = angle_pred
620629
621630 results = InstanceData ()
622631 results .bboxes = bboxes
0 commit comments