Skip to content

Commit 9eb4831

Browse files
ppwwyyxxfacebook-github-bot
authored andcommitted
configurable for dataset_mapper
Reviewed By: rbgirshick Differential Revision: D22251484 fbshipit-source-id: 2d2cfb99f40e10b7af4e87a99bea14b3ee98a48c
1 parent 45a4d97 commit 9eb4831

File tree

2 files changed

+78
-38
lines changed

2 files changed

+78
-38
lines changed

detectron2/data/dataset_mapper.py

Lines changed: 78 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22
import copy
33
import logging
44
import numpy as np
5+
from typing import List, Optional, Union
56
import torch
67

8+
from detectron2.config import configurable
9+
710
from . import detection_utils as utils
811
from . import transforms as T
912

@@ -31,38 +34,81 @@ class DatasetMapper:
3134
3. Prepare data and annotations to Tensor and :class:`Instances`
3235
"""
3336

34-
def __init__(self, cfg, is_train=True):
35-
self.augmentation = utils.build_augmentation(cfg, is_train)
36-
if cfg.INPUT.CROP.ENABLED and is_train:
37-
self.augmentation.insert(0, T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE))
38-
logging.getLogger(__name__).info(
39-
"Cropping used in training: " + str(self.augmentation[0])
40-
)
41-
self.compute_tight_boxes = True
42-
else:
43-
self.compute_tight_boxes = False
37+
@configurable
38+
def __init__(
39+
self,
40+
is_train: bool,
41+
*,
42+
augmentations: List[Union[T.Augmentation, T.Transform]],
43+
image_format: str,
44+
use_instance_mask: bool = False,
45+
use_keypoint: bool = False,
46+
instance_mask_format: str = "polygon",
47+
keypoint_hflip_indices: Optional[np.ndarray] = None,
48+
precomputed_proposal_topk: Optional[int] = None,
49+
recompute_boxes: bool = False
50+
):
51+
"""
52+
NOTE: this interface is experimental.
4453
54+
Args:
55+
is_train: whether it's used in training or inference
56+
augmentations: a list of augmentations or deterministic transforms to apply
57+
image_format: an image format supported by :func:`detection_utils.read_image`.
58+
use_instance_mask: whether to process instance segmentation annotations, if available
59+
use_keypoint: whether to process keypoint annotations if available
60+
instance_mask_format: one of "polygon" or "bitmask". Process instance segmentation
61+
masks into this format.
62+
keypoint_hflip_indices: see :func:`detection_utils.create_keypoint_hflip_indices`
63+
precomputed_proposal_topk: if given, will load pre-computed
64+
proposals from dataset_dict and keep the top k proposals for each image.
65+
recompute_boxes: whether to overwrite bounding box annotations
66+
by computing tight bounding boxes from instance mask annotations.
67+
"""
68+
if recompute_boxes:
69+
assert use_instance_mask, "recompute_boxes requires instance masks"
4570
# fmt: off
46-
self.img_format = cfg.INPUT.FORMAT
47-
self.mask_on = cfg.MODEL.MASK_ON
48-
self.mask_format = cfg.INPUT.MASK_FORMAT
49-
self.keypoint_on = cfg.MODEL.KEYPOINT_ON
50-
self.load_proposals = cfg.MODEL.LOAD_PROPOSALS
71+
self.is_train = is_train
72+
self.augmentations = augmentations
73+
self.image_format = image_format
74+
self.use_instance_mask = use_instance_mask
75+
self.instance_mask_format = instance_mask_format
76+
self.use_keypoint = use_keypoint
77+
self.keypoint_hflip_indices = keypoint_hflip_indices
78+
self.proposal_topk = precomputed_proposal_topk
79+
self.recompute_boxes = recompute_boxes
5180
# fmt: on
52-
if self.keypoint_on and is_train:
53-
# Flip only makes sense in training
54-
self.keypoint_hflip_indices = utils.create_keypoint_hflip_indices(cfg.DATASETS.TRAIN)
55-
else:
56-
self.keypoint_hflip_indices = None
81+
logger = logging.getLogger(__name__)
82+
logger.info("Augmentations used in training: " + str(augmentations))
5783

58-
if self.load_proposals:
59-
self.proposal_min_box_size = cfg.MODEL.PROPOSAL_GENERATOR.MIN_SIZE
60-
self.proposal_topk = (
84+
@classmethod
85+
def from_config(cls, cfg, is_train: bool = True):
86+
augs = utils.build_augmentation(cfg, is_train)
87+
if cfg.INPUT.CROP.ENABLED and is_train:
88+
augs.insert(0, T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE))
89+
recompute_boxes = cfg.MODEL.MASK_ON
90+
else:
91+
recompute_boxes = False
92+
93+
ret = {
94+
"is_train": is_train,
95+
"augmentations": augs,
96+
"image_format": cfg.INPUT.FORMAT,
97+
"use_instance_mask": cfg.MODEL.MASK_ON,
98+
"instance_mask_format": cfg.INPUT.MASK_FORMAT,
99+
"use_keypoint": cfg.MODEL.KEYPOINT_ON,
100+
"recompute_boxes": recompute_boxes,
101+
}
102+
if cfg.MODEL.KEYPOINT_ON:
103+
ret["keypoint_hflip_indices"] = utils.create_keypoint_hflip_indices(cfg.DATASETS.TRAIN)
104+
105+
if cfg.MODEL.LOAD_PROPOSALS:
106+
ret["precomputed_proposal_topk"] = (
61107
cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TRAIN
62108
if is_train
63109
else cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TEST
64110
)
65-
self.is_train = is_train
111+
return ret
66112

67113
def __call__(self, dataset_dict):
68114
"""
@@ -74,7 +120,7 @@ def __call__(self, dataset_dict):
74120
"""
75121
dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
76122
# USER: Write your own image loading if it's not from a file
77-
image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
123+
image = utils.read_image(dataset_dict["file_name"], format=self.image_format)
78124
utils.check_image_size(dataset_dict, image)
79125

80126
# USER: Remove if you don't do semantic/panoptic segmentation.
@@ -84,7 +130,7 @@ def __call__(self, dataset_dict):
84130
sem_seg_gt = None
85131

86132
aug_input = T.StandardAugInput(image, sem_seg=sem_seg_gt)
87-
transforms = aug_input.apply_augmentations(self.augmentation)
133+
transforms = aug_input.apply_augmentations(self.augmentations)
88134
image, sem_seg_gt = aug_input.image, aug_input.sem_seg
89135

90136
image_shape = image.shape[:2] # h, w
@@ -97,13 +143,9 @@ def __call__(self, dataset_dict):
97143

98144
# USER: Remove if you don't use pre-computed proposals.
99145
# Most users would not need this feature.
100-
if self.load_proposals:
146+
if self.proposal_topk is not None:
101147
utils.transform_proposals(
102-
dataset_dict,
103-
image_shape,
104-
transforms,
105-
proposal_topk=self.proposal_topk,
106-
min_box_size=self.proposal_min_box_size,
148+
dataset_dict, image_shape, transforms, proposal_topk=self.proposal_topk
107149
)
108150

109151
if not self.is_train:
@@ -115,9 +157,9 @@ def __call__(self, dataset_dict):
115157
if "annotations" in dataset_dict:
116158
# USER: Modify this if you want to keep them for some reason.
117159
for anno in dataset_dict["annotations"]:
118-
if not self.mask_on:
160+
if not self.use_instance_mask:
119161
anno.pop("segmentation", None)
120-
if not self.keypoint_on:
162+
if not self.use_keypoint:
121163
anno.pop("keypoints", None)
122164

123165
# USER: Implement additional transformations if you have other types of data
@@ -129,15 +171,15 @@ def __call__(self, dataset_dict):
129171
if obj.get("iscrowd", 0) == 0
130172
]
131173
instances = utils.annotations_to_instances(
132-
annos, image_shape, mask_format=self.mask_format
174+
annos, image_shape, mask_format=self.instance_mask_format
133175
)
134176

135177
# After transforms such as cropping are applied, the bounding box may no longer
136178
# tightly bound the object. As an example, imagine a triangle object
137179
# [(0,0), (2,0), (0,2)] cropped by a box [(1,0),(2,2)] (XYXY format). The tight
138180
# bounding box of the cropped triangle should be [(1,0),(2,1)], which is not equal to
139181
# the intersection of original bounding box and the cropping box.
140-
if self.compute_tight_boxes and instances.has("gt_masks"):
182+
if self.recompute_boxes:
141183
instances.gt_boxes = instances.gt_masks.get_bounding_boxes()
142184
dataset_dict["instances"] = utils.filter_empty_instances(instances)
143185
return dataset_dict

detectron2/data/detection_utils.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -579,12 +579,10 @@ def build_augmentation(cfg, is_train):
579579
len(min_size)
580580
)
581581

582-
logger = logging.getLogger(__name__)
583582
augmentation = []
584583
augmentation.append(T.ResizeShortestEdge(min_size, max_size, sample_style))
585584
if is_train:
586585
augmentation.append(T.RandomFlip())
587-
logger.info("Augmentations used in training: " + str(augmentation))
588586
return augmentation
589587

590588

0 commit comments

Comments
 (0)