Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 30 additions & 3 deletions mmdet/datasets/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import copy
import inspect
import math
import warnings
from typing import List, Optional, Sequence, Tuple, Union

import cv2
Expand Down Expand Up @@ -2930,6 +2931,9 @@ class CopyPaste(BaseTransform):
all objects of the source image will be pasted to the
destination image.
Defaults to True.
paste_by_box (bool): Whether use boxes as masks when masks are not
available.
Defaults to False.
"""

def __init__(
Expand All @@ -2938,11 +2942,13 @@ def __init__(
bbox_occluded_thr: int = 10,
mask_occluded_thr: int = 300,
selected: bool = True,
paste_by_box: bool = False,
) -> None:
self.max_num_pasted = max_num_pasted
self.bbox_occluded_thr = bbox_occluded_thr
self.mask_occluded_thr = mask_occluded_thr
self.selected = selected
self.paste_by_box = paste_by_box

@cache_randomness
def get_indexes(self, dataset: BaseDataset) -> int:
Expand Down Expand Up @@ -2981,11 +2987,31 @@ def _get_selected_inds(self, num_bboxes: int) -> np.ndarray:
num_pasted = np.random.randint(0, max_num_pasted)
return np.random.choice(num_bboxes, size=num_pasted, replace=False)

def get_gt_masks(self, results: dict) -> BitmapMasks:
"""Get gt_masks originally or generated based on bboxes.

If gt_masks is not contained in results,
it will be generated based on gt_bboxes.
Args:
results (dict): Result dict.
Returns:
BitmapMasks: gt_masks, originally or generated based on bboxes.
"""
if results.get('gt_masks', None) is not None:
if self.paste_by_box:
warnings.warn('gt_masks is already contained in results, '
'so paste_by_box is disabled.')
return results['gt_masks']
else:
if not self.paste_by_box:
raise RuntimeError('results does not contain masks.')
return results['gt_bboxes'].create_masks(results['img'].shape[:2])

def _select_object(self, results: dict) -> dict:
"""Select some objects from the source results."""
bboxes = results['gt_bboxes']
labels = results['gt_bboxes_labels']
masks = results['gt_masks']
masks = self.get_gt_masks(results)
ignore_flags = results['gt_ignore_flags']

selected_inds = self._get_selected_inds(bboxes.shape[0])
Expand Down Expand Up @@ -3013,7 +3039,7 @@ def _copy_paste(self, dst_results: dict, src_results: dict) -> dict:
dst_img = dst_results['img']
dst_bboxes = dst_results['gt_bboxes']
dst_labels = dst_results['gt_bboxes_labels']
dst_masks = dst_results['gt_masks']
dst_masks = self.get_gt_masks(dst_results)
dst_ignore_flags = dst_results['gt_ignore_flags']

src_img = src_results['img']
Expand Down Expand Up @@ -3071,7 +3097,8 @@ def __repr__(self):
repr_str += f'(max_num_pasted={self.max_num_pasted}, '
repr_str += f'bbox_occluded_thr={self.bbox_occluded_thr}, '
repr_str += f'mask_occluded_thr={self.mask_occluded_thr}, '
repr_str += f'selected={self.selected})'
repr_str += f'selected={self.selected}), '
repr_str += f'paste_by_box={self.paste_by_box})'
return repr_str


Expand Down
20 changes: 20 additions & 0 deletions mmdet/structures/bbox/horizontal_boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,26 @@ def find_inside_points(self,
return (points[..., 0] >= x_min) & (points[..., 0] <= x_max) & \
(points[..., 1] >= y_min) & (points[..., 1] <= y_max)

def create_masks(self, img_shape: Tuple[int, int]) -> BitmapMasks:
"""
Args:
img_shape (Tuple[int, int]): A tuple of image height and width.

Returns:
:obj:`BitmapMasks`: Converted masks
"""
img_h, img_w = img_shape
boxes = self.tensor

xmin, ymin = boxes[:, 0:1], boxes[:, 1:2]
xmax, ymax = boxes[:, 2:3], boxes[:, 3:4]
gt_masks = np.zeros((len(boxes), img_h, img_w), dtype=np.uint8)
for i in range(len(boxes)):
gt_masks[i,
int(ymin[i]):int(ymax[i]),
int(xmin[i]):int(xmax[i])] = 1
return BitmapMasks(gt_masks, img_h, img_w)

@staticmethod
def overlaps(boxes1: BaseBoxes,
boxes2: BaseBoxes,
Expand Down
23 changes: 22 additions & 1 deletion tests/test_datasets/test_transforms/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1453,6 +1453,26 @@ def test_transform(self):
}]
results = transform(results)

# test copypaste with an empty mask results
transform = CopyPaste()
results = copy.deepcopy(self.dst_results)
results = {k: v for k, v in results.items() if 'mask' not in k}
results['mix_results'] = [copy.deepcopy(self.src_results)]
with self.assertRaises(RuntimeError):
results = transform(results)

# test copypaste with boxes as masks
transform = CopyPaste(paste_by_box=True)
results = copy.deepcopy(self.dst_results)
results = {k: v for k, v in results.items() if 'mask' not in k}
src_results = copy.deepcopy(self.src_results)
src_results = {k: v for k, v in src_results.items() if 'mask' not in k}
results['mix_results'] = [src_results]
results = transform(results)

self.assertEqual(results['img'].shape[:2],
self.dst_results['img'].shape[:2])

def test_transform_use_box_type(self):
src_results = copy.deepcopy(self.src_results)
src_results['gt_bboxes'] = HorizontalBoxes(src_results['gt_bboxes'])
Expand Down Expand Up @@ -1524,7 +1544,8 @@ def test_repr(self):
repr(transform), ('CopyPaste(max_num_pasted=100, '
'bbox_occluded_thr=10, '
'mask_occluded_thr=300, '
'selected=True)'))
'selected=True), '
'paste_by_box=False)'))


class TestAlbu(unittest.TestCase):
Expand Down