Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 1 addition & 18 deletions configs/gwd/rotated_retinanet_hbb_gwd_r50_fpn_1x_dota_oc.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,6 @@
_base_ = ['../rotated_retinanet/rotated_retinanet_hbb_r50_fpn_1x_dota_oc.py']
_base_ = '../rotated_retinanet/rotated_retinanet_hbb_r50_fpn_1x_dota_oc.py'

model = dict(
bbox_head=dict(
reg_decoded_bbox=True,
loss_bbox=dict(type='GDLoss', loss_type='gwd', loss_weight=5.0)))

img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='RResize', img_scale=(1024, 1024)),
dict(
type='RRandomFlip',
flip_ratio=[0.25, 0.25, 0.25],
direction=['horizontal', 'vertical', 'diagonal']),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
]
data = dict(train=dict(pipeline=train_pipeline))
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
_base_ = [
'../rotated_retinanet/rotated_retinanet_obb_r50_fpn_1x_dota_le135.py'
]
_base_ = '../rotated_retinanet/rotated_retinanet_obb_r50_fpn_1x_dota_le135.py'

model = dict(
bbox_head=dict(
Expand Down
21 changes: 1 addition & 20 deletions configs/kld/rotated_retinanet_hbb_kld_r50_fpn_1x_dota_oc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
_base_ = ['../rotated_retinanet/rotated_retinanet_hbb_r50_fpn_1x_dota_oc.py']
_base_ = '../rotated_retinanet/rotated_retinanet_hbb_r50_fpn_1x_dota_oc.py'

angle_version = 'oc'
model = dict(
bbox_head=dict(
reg_decoded_bbox=True,
Expand All @@ -11,21 +10,3 @@
fun='log1p',
tau=1,
loss_weight=1.0)))

img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='RResize', img_scale=(1024, 1024)),
dict(
type='RRandomFlip',
flip_ratio=[0.25, 0.25, 0.25],
direction=['horizontal', 'vertical', 'diagonal'],
version=angle_version),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
]
data = dict(train=dict(pipeline=train_pipeline))
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
_base_ = ['../rotated_retinanet/rotated_retinanet_hbb_r50_fpn_1x_dota_oc.py']
_base_ = '../rotated_retinanet/rotated_retinanet_hbb_r50_fpn_1x_dota_oc.py'

angle_version = 'oc'
model = dict(
bbox_head=dict(
reg_decoded_bbox=True,
Expand All @@ -12,21 +11,3 @@
tau=1,
sqrt=False,
loss_weight=5.5)))

img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='RResize', img_scale=(1024, 1024)),
dict(
type='RRandomFlip',
flip_ratio=[0.25, 0.25, 0.25],
direction=['horizontal', 'vertical', 'diagonal'],
version=angle_version),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
]
data = dict(train=dict(pipeline=train_pipeline))
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
_base_ = [
'../rotated_retinanet/rotated_retinanet_hbb_r50_fpn_6x_hrsc_rr_oc.py'
]
_base_ = '../rotated_retinanet/rotated_retinanet_hbb_r50_fpn_6x_hrsc_rr_oc.py'

model = dict(
bbox_head=dict(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
_base_ = [
'../rotated_retinanet/rotated_retinanet_obb_r50_fpn_1x_dota_le135.py'
]
_base_ = '../rotated_retinanet/rotated_retinanet_obb_r50_fpn_1x_dota_le135.py'

model = dict(
bbox_head=dict(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
_base_ = ['../rotated_retinanet/rotated_retinanet_obb_r50_fpn_1x_dota_le90.py']
_base_ = '../rotated_retinanet/rotated_retinanet_obb_r50_fpn_1x_dota_le90.py'

model = dict(
bbox_head=dict(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
_base_ = ['./rotated_retinanet_obb_kld_stable_r50_fpn_1x_dota_le90.py']
_base_ = './rotated_retinanet_obb_kld_stable_r50_fpn_1x_dota_le90.py'

optimizer = dict(
_delete_=True,
type='AdamW',
lr=0.0001,
betas=(0.9, 0.999),
weight_decay=0.05,
paramwise_cfg=dict(
custom_keys={
'absolute_pos_embed': dict(decay_mult=0.),
'relative_position_bias_table': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)
}))
optim_wrapper = dict(
optimizer=dict(
_delete_=True,
type='AdamW',
lr=0.0001,
betas=(0.9, 0.999),
weight_decay=0.05))
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
_base_ = ['../rotated_retinanet/rotated_retinanet_obb_r50_fpn_1x_dota_le90.py']
_base_ = '../rotated_retinanet/rotated_retinanet_obb_r50_fpn_1x_dota_le90.py'

model = dict(
bbox_head=dict(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_size_divisor=32,
boxlist2tensor=False),
boxtype2tensor=False),
backbone=dict(
type='mmdet.ResNet',
depth=50,
Expand Down
61 changes: 35 additions & 26 deletions mmrotate/core/bbox/coder/delta_xywht_hbbox_coder.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence
from typing import Optional, Sequence, Union

import numpy as np
import torch
from mmdet.models.task_modules.coders.base_bbox_coder import BaseBBoxCoder
from mmdet.structures.bbox import HorizontalBoxes
from mmdet.structures.bbox import HorizontalBoxes, get_box_tensor
from torch import Tensor

from mmrotate.core.bbox.structures import RotatedBoxes
Expand Down Expand Up @@ -35,6 +35,8 @@ class DeltaXYWHTHBBoxCoder(BaseBBoxCoder):
Defaults to False.
ctr_clamp (int): the maximum pixel shift to clamp. Only used by
YOLOF. Defaults to 32.
use_box_type (bool): Whether to warp decoded boxes with the
box type data structure. Defaults to True.
"""
encode_size = 5

Expand All @@ -46,8 +48,9 @@ def __init__(self,
edge_swap: bool = False,
clip_border: bool = True,
add_ctr_clamp: bool = False,
ctr_clamp: int = 32) -> None:
super().__init__()
ctr_clamp: int = 32,
use_box_type=True) -> None:
super().__init__(use_box_type=use_box_type)
self.means = target_means
self.stds = target_stds
self.angle_version = angle_version
Expand Down Expand Up @@ -81,16 +84,18 @@ def encode(self, bboxes: HorizontalBoxes,
else:
raise NotImplementedError

def decode(self,
bboxes: HorizontalBoxes,
pred_bboxes: Tensor,
max_shape: Optional[Sequence[int]] = None,
wh_ratio_clip: float = 16 / 1000) -> RotatedBoxes:
def decode(
self,
bboxes: Union[HorizontalBoxes, Tensor],
pred_bboxes: Tensor,
max_shape: Optional[Sequence[int]] = None,
wh_ratio_clip: float = 16 / 1000) -> Union[RotatedBoxes, Tensor]:
"""Apply transformation `pred_bboxes` to `boxes`.

Args:
bboxes (:obj:`HorizontalBoxes`): Basic boxes.
Shape (B, N, 4) or (N, 4)
bboxes (:obj:`HorizontalBoxes` or Tensor): Basic boxes.
Shape (B, N, 4) or (N, 4). In two stage detectors and refine
single stage detectors, the bboxes can be Tensor.
pred_bboxes (Tensor): Encoded offsets with respect to each
roi. Has shape (B, N, num_classes * 5) or (B, N, 5) or
(N, num_classes * 5) or (N, 5). Note N = num_anchors * W * H
Expand All @@ -104,20 +109,25 @@ def decode(self,
width and height.

Returns:
:obj:`RotatedBoxes`: Decoded boxes.
Union[:obj:`RotatedBoxes`, Tensor]: Decoded boxes.
"""
assert pred_bboxes.size(0) == bboxes.size(0)
if pred_bboxes.ndim == 3:
assert pred_bboxes.size(1) == bboxes.size(1)
assert bboxes.size(-1) == 4
assert pred_bboxes.size(-1) == 5
if self.angle_version in ['oc', 'le135', 'le90']:
return delta2bbox(bboxes, pred_bboxes, self.means, self.stds,
wh_ratio_clip, self.add_ctr_clamp,
self.ctr_clamp, self.angle_version,
self.norm_factor, self.edge_swap)
else:
raise NotImplementedError
assert self.angle_version in ['oc', 'le135', 'le90']
bboxes = get_box_tensor(bboxes)
decoded_bboxes = delta2bbox(bboxes, pred_bboxes, self.means, self.stds,
wh_ratio_clip, self.add_ctr_clamp,
self.ctr_clamp, self.angle_version,
self.norm_factor, self.edge_swap)
if self.use_box_type:
assert decoded_bboxes.size(-1) == 5, \
('Cannot warp decoded boxes with box type when decoded'
'boxes have shape of (N, num_classes * 5)')
decoded_bboxes = RotatedBoxes(decoded_bboxes)
return decoded_bboxes


def bbox2delta(proposals: HorizontalBoxes,
Expand Down Expand Up @@ -190,7 +200,7 @@ def bbox2delta(proposals: HorizontalBoxes,
return deltas


def delta2bbox(rois: HorizontalBoxes,
def delta2bbox(rois: Tensor,
deltas: Tensor,
means: Sequence[float] = (0., 0., 0., 0., 0.),
stds: Sequence[float] = (1., 1., 1., 1., 1.),
Expand All @@ -199,14 +209,14 @@ def delta2bbox(rois: HorizontalBoxes,
ctr_clamp: int = 32,
angle_version: str = 'oc',
norm_factor: Optional[float] = None,
edge_swap: bool = False) -> RotatedBoxes:
edge_swap: bool = False) -> Tensor:
"""Apply deltas to shift/scale base boxes. Typically the rois are anchor
or proposed bounding boxes and the deltas are network outputs used to
shift/scale those boxes. This is the inverse function of
:func:`bbox2delta`.

Args:
rois (:obj:`HorizontalBoxes`): Boxes to be transformed.
rois (Tensor): Boxes to be transformed.
Has shape (N, 4).
deltas (Tensor): Encoded offsets relative to each roi.
Has shape (N, num_classes * 5) or (N, 5). Note
Expand All @@ -230,14 +240,13 @@ def delta2bbox(rois: HorizontalBoxes,
Defaults to False.

Returns:
:obj:`RotatedBoxes`: Boxes with shape (N, num_classes * 5) or (N, 5),
Tensor: Boxes with shape (N, num_classes * 5) or (N, 5),
where 5 represent cx, cy, w, h, t.
"""
num_bboxes = deltas.size(0)
if num_bboxes == 0:
return RotatedBoxes(deltas)
return deltas

rois = rois.tensor
means = deltas.new_tensor(means).view(1, -1)
stds = deltas.new_tensor(stds).view(1, -1)
delta_shape = deltas.shape
Expand Down Expand Up @@ -291,4 +300,4 @@ def delta2bbox(rois: HorizontalBoxes,
decoded_bbox = torch.stack([gx, gy, gw, gh, gt],
dim=-1).view_as(deltas)

return RotatedBoxes(decoded_bbox)
return decoded_bbox
Loading