@@ -13,15 +13,19 @@ namespace cpu {
13
13
DEFINE_DISPATCH (roi_align_forward_kernel_stub);
14
14
DEFINE_DISPATCH (roi_align_backward_kernel_stub);
15
15
16
- at::Tensor IPEXROIAlignOp::_forward (
16
+ at::Tensor ROIAlign_forward_impl (
17
17
const at::Tensor& input,
18
18
const at::Tensor& rois,
19
19
double spatial_scale,
20
20
int64_t pooled_height,
21
21
int64_t pooled_width,
22
22
int64_t sampling_ratio,
23
23
bool aligned) {
24
- RECORD_FUNCTION (" IPEXROIAlignOp::_forward" , c10::ArrayRef<c10::IValue>({}));
24
+ #if defined(IPEX_DISP_OP)
25
+ printf (" torch_ipex::ROIAlign_forward\n " );
26
+ #endif
27
+ RECORD_FUNCTION (
28
+ " torch_ipex::ROIAlign_forward" , c10::ArrayRef<c10::IValue>({}));
25
29
26
30
return roi_align_forward_kernel_stub (
27
31
kCPU ,
@@ -34,6 +38,66 @@ at::Tensor IPEXROIAlignOp::_forward(
34
38
aligned);
35
39
}
36
40
41
+ at::Tensor ROIAlign_backward (
42
+ const at::Tensor& grad,
43
+ const at::Tensor& rois,
44
+ double spatial_scale,
45
+ int64_t pooled_height,
46
+ int64_t pooled_width,
47
+ int64_t batch_size,
48
+ int64_t channels,
49
+ int64_t height,
50
+ int64_t width,
51
+ int64_t sampling_ratio,
52
+ bool aligned,
53
+ bool is_channels_last) {
54
+ #if defined(IPEX_DISP_OP)
55
+ printf (" torch_ipex::ROIAlign_backward\n " );
56
+ #endif
57
+ RECORD_FUNCTION (
58
+ " torch_ipex::ROIAlign_backward" , c10::ArrayRef<c10::IValue>({}));
59
+
60
+ return roi_align_backward_kernel_stub (
61
+ kCPU ,
62
+ grad,
63
+ rois,
64
+ spatial_scale,
65
+ pooled_height,
66
+ pooled_width,
67
+ batch_size,
68
+ channels,
69
+ height,
70
+ width,
71
+ sampling_ratio,
72
+ aligned,
73
+ is_channels_last);
74
+ }
75
+
76
+ at::Tensor IPEXROIAlignOp::_forward (
77
+ const at::Tensor& input,
78
+ const at::Tensor& rois,
79
+ double spatial_scale,
80
+ int64_t pooled_height,
81
+ int64_t pooled_width,
82
+ int64_t sampling_ratio,
83
+ bool aligned) {
84
+ at::AutoDispatchBelowADInplaceOrView g;
85
+ RECORD_FUNCTION (" IPEXROIAlignOp::_forward" , c10::ArrayRef<c10::IValue>({}));
86
+
87
+ static auto op = torch::Dispatcher::singleton ()
88
+ .findSchemaOrThrow (" torch_ipex::ROIAlign_forward" , " " )
89
+ .typed <decltype (ROIAlign_forward)>();
90
+
91
+ return op.call (
92
+ input,
93
+ rois,
94
+ spatial_scale,
95
+ pooled_height,
96
+ pooled_width,
97
+ sampling_ratio,
98
+ aligned);
99
+ }
100
+
37
101
at::Tensor IPEXROIAlignOp::forward (
38
102
torch::autograd::AutogradContext* ctx,
39
103
const at::Tensor& input,
@@ -45,7 +109,7 @@ at::Tensor IPEXROIAlignOp::forward(
45
109
bool aligned) {
46
110
RECORD_FUNCTION (" IPEXROIAlignOp::forward" , c10::ArrayRef<c10::IValue>({}));
47
111
48
- ctx->saved_data [" input_shape" ] = input.sizes ();
112
+ ctx->saved_data [" input_shape" ] = input.sym_sizes ();
49
113
ctx->saved_data [" spatial_scale" ] = spatial_scale;
50
114
ctx->saved_data [" pooled_height" ] = pooled_height;
51
115
ctx->saved_data [" pooled_width" ] = pooled_width;
@@ -55,8 +119,7 @@ at::Tensor IPEXROIAlignOp::forward(
55
119
input.is_contiguous (at::MemoryFormat::ChannelsLast);
56
120
ctx->save_for_backward ({rois});
57
121
58
- return roi_align_forward_kernel_stub (
59
- kCPU ,
122
+ return _forward (
60
123
input,
61
124
rois,
62
125
spatial_scale,
@@ -81,8 +144,11 @@ torch::autograd::variable_list IPEXROIAlignOp::backward(
81
144
auto saved = ctx->get_saved_variables ();
82
145
at::Tensor rois = saved[0 ];
83
146
84
- at::Tensor grad_input = roi_align_backward_kernel_stub (
85
- kCPU ,
147
+ static auto op = torch::Dispatcher::singleton ()
148
+ .findSchemaOrThrow (" torch_ipex::ROIAlign_backward" , " " )
149
+ .typed <decltype (ROIAlign_backward)>();
150
+
151
+ auto grad_input = op.call (
86
152
grad_outputs[0 ],
87
153
rois,
88
154
spatial_scale,
@@ -134,45 +200,26 @@ at::Tensor ROIAlign_forward(
134
200
aligned);
135
201
}
136
202
137
- } // namespace cpu
138
- } // namespace torch_ipex
139
-
140
- namespace torch_ipex {
141
- namespace autocast {
142
-
143
- at::Tensor roi_align_autocast (
203
+ at::Tensor ROIAlign_forward_meta (
144
204
const at::Tensor& input,
145
205
const at::Tensor& rois,
146
206
double spatial_scale,
147
207
int64_t pooled_height,
148
208
int64_t pooled_width,
149
209
int64_t sampling_ratio,
150
210
bool aligned) {
151
- c10::impl::ExcludeDispatchKeyGuard no_autocastCPU (DispatchKey::AutocastCPU);
152
- static auto op = torch::Dispatcher::singleton ()
153
- .findSchemaOrThrow (" torchvision::roi_align" , " " )
154
- .typed <decltype (torch_ipex::cpu::ROIAlign_forward)>();
155
- if (input.scalar_type () == at::ScalarType::BFloat16) {
156
- return op.call (
157
- input,
158
- cpu_cached_cast (at::kFloat , rois),
159
- spatial_scale,
160
- pooled_height,
161
- pooled_width,
162
- sampling_ratio,
163
- aligned);
164
- } else {
165
- return op.call (
166
- input,
167
- cpu_cached_cast (input.scalar_type (), rois),
168
- spatial_scale,
169
- pooled_height,
170
- pooled_width,
171
- sampling_ratio,
172
- aligned);
173
- }
211
+ auto num_rois = rois.sym_size (0 );
212
+ auto channels = input.sym_size (1 );
213
+ return at::empty_symint (
214
+ {num_rois, channels, pooled_height, pooled_width}, input.options ());
174
215
}
175
216
217
+ } // namespace cpu
218
+ } // namespace torch_ipex
219
+
220
+ namespace torch_ipex {
221
+ namespace autocast {
222
+
176
223
at::Tensor ROIAlign_forward (
177
224
const at::Tensor& input,
178
225
const at::Tensor& rois,
@@ -222,6 +269,21 @@ IPEX_TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
222
269
" ROIAlign_forward" ,
223
270
c10::DispatchKey::AutocastCPU,
224
271
torch_ipex::autocast::ROIAlign_forward);
272
+ m.impl (
273
+ " ROIAlign_forward" ,
274
+ c10::DispatchKey::CPU,
275
+ torch_ipex::cpu::ROIAlign_forward_impl);
276
+ m.impl (
277
+ " ROIAlign_forward" ,
278
+ c10::DispatchKey::Meta,
279
+ torch_ipex::cpu::ROIAlign_forward_meta);
280
+ // bw
281
+ m.def (
282
+ " ROIAlign_backward(Tensor grad, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width, int sampling_ratio, bool aligned, bool is_channels_last) -> Tensor" );
283
+ m.impl (
284
+ " ROIAlign_backward" ,
285
+ c10::DispatchKey::CPU,
286
+ torch_ipex::cpu::ROIAlign_backward);
225
287
}
226
288
227
289
IPEX_TORCH_LIBRARY_FRAGMENT (torchvision, m) {
@@ -232,7 +294,7 @@ IPEX_TORCH_LIBRARY_FRAGMENT(torchvision, m) {
232
294
m.impl (
233
295
" roi_align" ,
234
296
c10::DispatchKey::AutocastCPU,
235
- torch_ipex::autocast::roi_align_autocast );
297
+ torch_ipex::autocast::ROIAlign_forward );
236
298
}
237
299
238
300
} // namespace
0 commit comments