Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions mmdet/datasets/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2938,11 +2938,13 @@ def __init__(
bbox_occluded_thr: int = 10,
mask_occluded_thr: int = 300,
selected: bool = True,
paste_by_box: bool = False,
) -> None:
self.max_num_pasted = max_num_pasted
self.bbox_occluded_thr = bbox_occluded_thr
self.mask_occluded_thr = mask_occluded_thr
self.selected = selected
self.paste_by_box = paste_by_box

@cache_randomness
def get_indexes(self, dataset: BaseDataset) -> int:
Expand Down Expand Up @@ -2981,11 +2983,28 @@ def _get_selected_inds(self, num_bboxes: int) -> np.ndarray:
num_pasted = np.random.randint(0, max_num_pasted)
return np.random.choice(num_bboxes, size=num_pasted, replace=False)

def get_gt_masks(self, results):
"""Get gt_masks originally or generated based on bboxes.

If gt_masks is not contained in results,
it will be generated based on gt_bboxes.
Args:
results (dict): Result dict.
Returns:
BitmapMasks: gt_masks, originally or generated based on bboxes.
"""
if results.get('gt_masks', None) is not None:
return results['gt_masks']
else:
if not self.paste_by_box:
raise RuntimeError('results does not contain masks.')
return results['gt_bboxes'].create_masks(results['img'].shape[:2])

def _select_object(self, results: dict) -> dict:
"""Select some objects from the source results."""
bboxes = results['gt_bboxes']
labels = results['gt_bboxes_labels']
masks = results['gt_masks']
masks = self.get_gt_masks(results)
ignore_flags = results['gt_ignore_flags']

selected_inds = self._get_selected_inds(bboxes.shape[0])
Expand Down Expand Up @@ -3013,7 +3032,7 @@ def _copy_paste(self, dst_results: dict, src_results: dict) -> dict:
dst_img = dst_results['img']
dst_bboxes = dst_results['gt_bboxes']
dst_labels = dst_results['gt_bboxes_labels']
dst_masks = dst_results['gt_masks']
dst_masks = self.get_gt_masks(dst_results)
dst_ignore_flags = dst_results['gt_ignore_flags']

src_img = src_results['img']
Expand Down
11 changes: 11 additions & 0 deletions mmdet/structures/bbox/base_boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,17 @@ def find_inside_points(self,
"""
pass

@abstractmethod
def create_masks(self, img_shape: Tuple[int, int]) -> Tensor:
"""
Args:
img_shape (Tuple[int, int]): A tuple of image height and width.

Returns:
:obj:`BitmapMasks`: Converted masks
"""
pass

@abstractstaticmethod
def overlaps(boxes1: 'BaseBoxes',
boxes2: 'BaseBoxes',
Expand Down
20 changes: 20 additions & 0 deletions mmdet/structures/bbox/horizontal_boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,26 @@ def find_inside_points(self,
return (points[..., 0] >= x_min) & (points[..., 0] <= x_max) & \
(points[..., 1] >= y_min) & (points[..., 1] <= y_max)

def create_masks(self, img_shape: Tuple[int, int]) -> BitmapMasks:
"""
Args:
img_shape (Tuple[int, int]): A tuple of image height and width.

Returns:
:obj:`BitmapMasks`: Converted masks
"""
img_h, img_w = img_shape
boxes = self.tensor

xmin, ymin = boxes[:, 0:1], boxes[:, 1:2]
xmax, ymax = boxes[:, 2:3], boxes[:, 3:4]
gt_masks = np.zeros((len(boxes), img_h, img_w), dtype=np.uint8)
for i in range(len(boxes)):
gt_masks[i,
int(ymin[i]):int(ymax[i]),
int(xmin[i]):int(xmax[i])] = 1
return BitmapMasks(gt_masks, img_h, img_w)

@staticmethod
def overlaps(boxes1: BaseBoxes,
boxes2: BaseBoxes,
Expand Down
3 changes: 3 additions & 0 deletions tests/test_structures/test_bbox/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ def is_inside(self, img_shape):
def find_inside_points(self, points, is_aligned=False):
pass

def create_masks(self, img_shape):
pass

def overlaps(bboxes1, bboxes2, mode='iou', is_aligned=False, eps=1e-6):
pass

Expand Down