2
2
import copy
3
3
import logging
4
4
import numpy as np
5
+ from typing import List , Optional , Union
5
6
import torch
6
7
8
+ from detectron2 .config import configurable
9
+
7
10
from . import detection_utils as utils
8
11
from . import transforms as T
9
12
@@ -31,38 +34,81 @@ class DatasetMapper:
31
34
3. Prepare data and annotations to Tensor and :class:`Instances`
32
35
"""
33
36
34
- def __init__ (self , cfg , is_train = True ):
35
- self .augmentation = utils .build_augmentation (cfg , is_train )
36
- if cfg .INPUT .CROP .ENABLED and is_train :
37
- self .augmentation .insert (0 , T .RandomCrop (cfg .INPUT .CROP .TYPE , cfg .INPUT .CROP .SIZE ))
38
- logging .getLogger (__name__ ).info (
39
- "Cropping used in training: " + str (self .augmentation [0 ])
40
- )
41
- self .compute_tight_boxes = True
42
- else :
43
- self .compute_tight_boxes = False
37
+ @configurable
38
+ def __init__ (
39
+ self ,
40
+ is_train : bool ,
41
+ * ,
42
+ augmentations : List [Union [T .Augmentation , T .Transform ]],
43
+ image_format : str ,
44
+ use_instance_mask : bool = False ,
45
+ use_keypoint : bool = False ,
46
+ instance_mask_format : str = "polygon" ,
47
+ keypoint_hflip_indices : Optional [np .ndarray ] = None ,
48
+ precomputed_proposal_topk : Optional [int ] = None ,
49
+ recompute_boxes : bool = False
50
+ ):
51
+ """
52
+ NOTE: this interface is experimental.
44
53
54
+ Args:
55
+ is_train: whether it's used in training or inference
56
+ augmentations: a list of augmentations or deterministic transforms to apply
57
+ image_format: an image format supported by :func:`detection_utils.read_image`.
58
+ use_instance_mask: whether to process instance segmentation annotations, if available
59
+ use_keypoint: whether to process keypoint annotations if available
60
+ instance_mask_format: one of "polygon" or "bitmask". Process instance segmentation
61
+ masks into this format.
62
+ keypoint_hflip_indices: see :func:`detection_utils.create_keypoint_hflip_indices`
63
+ precomputed_proposal_topk: if given, will load pre-computed
64
+ proposals from dataset_dict and keep the top k proposals for each image.
65
+ recompute_boxes: whether to overwrite bounding box annotations
66
+ by computing tight bounding boxes from instance mask annotations.
67
+ """
68
+ if recompute_boxes :
69
+ assert use_instance_mask , "recompute_boxes requires instance masks"
45
70
# fmt: off
46
- self .img_format = cfg .INPUT .FORMAT
47
- self .mask_on = cfg .MODEL .MASK_ON
48
- self .mask_format = cfg .INPUT .MASK_FORMAT
49
- self .keypoint_on = cfg .MODEL .KEYPOINT_ON
50
- self .load_proposals = cfg .MODEL .LOAD_PROPOSALS
71
+ self .is_train = is_train
72
+ self .augmentations = augmentations
73
+ self .image_format = image_format
74
+ self .use_instance_mask = use_instance_mask
75
+ self .instance_mask_format = instance_mask_format
76
+ self .use_keypoint = use_keypoint
77
+ self .keypoint_hflip_indices = keypoint_hflip_indices
78
+ self .proposal_topk = precomputed_proposal_topk
79
+ self .recompute_boxes = recompute_boxes
51
80
# fmt: on
52
- if self .keypoint_on and is_train :
53
- # Flip only makes sense in training
54
- self .keypoint_hflip_indices = utils .create_keypoint_hflip_indices (cfg .DATASETS .TRAIN )
55
- else :
56
- self .keypoint_hflip_indices = None
81
+ logger = logging .getLogger (__name__ )
82
+ logger .info ("Augmentations used in training: " + str (augmentations ))
57
83
58
- if self .load_proposals :
59
- self .proposal_min_box_size = cfg .MODEL .PROPOSAL_GENERATOR .MIN_SIZE
60
- self .proposal_topk = (
84
+ @classmethod
85
+ def from_config (cls , cfg , is_train : bool = True ):
86
+ augs = utils .build_augmentation (cfg , is_train )
87
+ if cfg .INPUT .CROP .ENABLED and is_train :
88
+ augs .insert (0 , T .RandomCrop (cfg .INPUT .CROP .TYPE , cfg .INPUT .CROP .SIZE ))
89
+ recompute_boxes = cfg .MODEL .MASK_ON
90
+ else :
91
+ recompute_boxes = False
92
+
93
+ ret = {
94
+ "is_train" : is_train ,
95
+ "augmentations" : augs ,
96
+ "image_format" : cfg .INPUT .FORMAT ,
97
+ "use_instance_mask" : cfg .MODEL .MASK_ON ,
98
+ "instance_mask_format" : cfg .INPUT .MASK_FORMAT ,
99
+ "use_keypoint" : cfg .MODEL .KEYPOINT_ON ,
100
+ "recompute_boxes" : recompute_boxes ,
101
+ }
102
+ if cfg .MODEL .KEYPOINT_ON :
103
+ ret ["keypoint_hflip_indices" ] = utils .create_keypoint_hflip_indices (cfg .DATASETS .TRAIN )
104
+
105
+ if cfg .MODEL .LOAD_PROPOSALS :
106
+ ret ["precomputed_proposal_topk" ] = (
61
107
cfg .DATASETS .PRECOMPUTED_PROPOSAL_TOPK_TRAIN
62
108
if is_train
63
109
else cfg .DATASETS .PRECOMPUTED_PROPOSAL_TOPK_TEST
64
110
)
65
- self . is_train = is_train
111
+ return ret
66
112
67
113
def __call__ (self , dataset_dict ):
68
114
"""
@@ -74,7 +120,7 @@ def __call__(self, dataset_dict):
74
120
"""
75
121
dataset_dict = copy .deepcopy (dataset_dict ) # it will be modified by code below
76
122
# USER: Write your own image loading if it's not from a file
77
- image = utils .read_image (dataset_dict ["file_name" ], format = self .img_format )
123
+ image = utils .read_image (dataset_dict ["file_name" ], format = self .image_format )
78
124
utils .check_image_size (dataset_dict , image )
79
125
80
126
# USER: Remove if you don't do semantic/panoptic segmentation.
@@ -84,7 +130,7 @@ def __call__(self, dataset_dict):
84
130
sem_seg_gt = None
85
131
86
132
aug_input = T .StandardAugInput (image , sem_seg = sem_seg_gt )
87
- transforms = aug_input .apply_augmentations (self .augmentation )
133
+ transforms = aug_input .apply_augmentations (self .augmentations )
88
134
image , sem_seg_gt = aug_input .image , aug_input .sem_seg
89
135
90
136
image_shape = image .shape [:2 ] # h, w
@@ -97,13 +143,9 @@ def __call__(self, dataset_dict):
97
143
98
144
# USER: Remove if you don't use pre-computed proposals.
99
145
# Most users would not need this feature.
100
- if self .load_proposals :
146
+ if self .proposal_topk is not None :
101
147
utils .transform_proposals (
102
- dataset_dict ,
103
- image_shape ,
104
- transforms ,
105
- proposal_topk = self .proposal_topk ,
106
- min_box_size = self .proposal_min_box_size ,
148
+ dataset_dict , image_shape , transforms , proposal_topk = self .proposal_topk
107
149
)
108
150
109
151
if not self .is_train :
@@ -115,9 +157,9 @@ def __call__(self, dataset_dict):
115
157
if "annotations" in dataset_dict :
116
158
# USER: Modify this if you want to keep them for some reason.
117
159
for anno in dataset_dict ["annotations" ]:
118
- if not self .mask_on :
160
+ if not self .use_instance_mask :
119
161
anno .pop ("segmentation" , None )
120
- if not self .keypoint_on :
162
+ if not self .use_keypoint :
121
163
anno .pop ("keypoints" , None )
122
164
123
165
# USER: Implement additional transformations if you have other types of data
@@ -129,15 +171,15 @@ def __call__(self, dataset_dict):
129
171
if obj .get ("iscrowd" , 0 ) == 0
130
172
]
131
173
instances = utils .annotations_to_instances (
132
- annos , image_shape , mask_format = self .mask_format
174
+ annos , image_shape , mask_format = self .instance_mask_format
133
175
)
134
176
135
177
# After transforms such as cropping are applied, the bounding box may no longer
136
178
# tightly bound the object. As an example, imagine a triangle object
137
179
# [(0,0), (2,0), (0,2)] cropped by a box [(1,0),(2,2)] (XYXY format). The tight
138
180
# bounding box of the cropped triangle should be [(1,0),(2,1)], which is not equal to
139
181
# the intersection of original bounding box and the cropping box.
140
- if self .compute_tight_boxes and instances . has ( "gt_masks" ) :
182
+ if self .recompute_boxes :
141
183
instances .gt_boxes = instances .gt_masks .get_bounding_boxes ()
142
184
dataset_dict ["instances" ] = utils .filter_empty_instances (instances )
143
185
return dataset_dict
0 commit comments