Skip to content

Commit f5ce619

Browse files
authored
add meta backend for ROIAlign (#1585)
1 parent 8b02d62 commit f5ce619

File tree

4 files changed

+181
-50
lines changed

4 files changed

+181
-50
lines changed

csrc/cpu/aten/ROIAlign.cpp

Lines changed: 100 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,19 @@ namespace cpu {
1313
DEFINE_DISPATCH(roi_align_forward_kernel_stub);
1414
DEFINE_DISPATCH(roi_align_backward_kernel_stub);
1515

16-
at::Tensor IPEXROIAlignOp::_forward(
16+
at::Tensor ROIAlign_forward_impl(
1717
const at::Tensor& input,
1818
const at::Tensor& rois,
1919
double spatial_scale,
2020
int64_t pooled_height,
2121
int64_t pooled_width,
2222
int64_t sampling_ratio,
2323
bool aligned) {
24-
RECORD_FUNCTION("IPEXROIAlignOp::_forward", c10::ArrayRef<c10::IValue>({}));
24+
#if defined(IPEX_DISP_OP)
25+
printf("torch_ipex::ROIAlign_forward\n");
26+
#endif
27+
RECORD_FUNCTION(
28+
"torch_ipex::ROIAlign_forward", c10::ArrayRef<c10::IValue>({}));
2529

2630
return roi_align_forward_kernel_stub(
2731
kCPU,
@@ -34,6 +38,66 @@ at::Tensor IPEXROIAlignOp::_forward(
3438
aligned);
3539
}
3640

41+
at::Tensor ROIAlign_backward(
42+
const at::Tensor& grad,
43+
const at::Tensor& rois,
44+
double spatial_scale,
45+
int64_t pooled_height,
46+
int64_t pooled_width,
47+
int64_t batch_size,
48+
int64_t channels,
49+
int64_t height,
50+
int64_t width,
51+
int64_t sampling_ratio,
52+
bool aligned,
53+
bool is_channels_last) {
54+
#if defined(IPEX_DISP_OP)
55+
printf("torch_ipex::ROIAlign_backward\n");
56+
#endif
57+
RECORD_FUNCTION(
58+
"torch_ipex::ROIAlign_backward", c10::ArrayRef<c10::IValue>({}));
59+
60+
return roi_align_backward_kernel_stub(
61+
kCPU,
62+
grad,
63+
rois,
64+
spatial_scale,
65+
pooled_height,
66+
pooled_width,
67+
batch_size,
68+
channels,
69+
height,
70+
width,
71+
sampling_ratio,
72+
aligned,
73+
is_channels_last);
74+
}
75+
76+
at::Tensor IPEXROIAlignOp::_forward(
77+
const at::Tensor& input,
78+
const at::Tensor& rois,
79+
double spatial_scale,
80+
int64_t pooled_height,
81+
int64_t pooled_width,
82+
int64_t sampling_ratio,
83+
bool aligned) {
84+
at::AutoDispatchBelowADInplaceOrView g;
85+
RECORD_FUNCTION("IPEXROIAlignOp::_forward", c10::ArrayRef<c10::IValue>({}));
86+
87+
static auto op = torch::Dispatcher::singleton()
88+
.findSchemaOrThrow("torch_ipex::ROIAlign_forward", "")
89+
.typed<decltype(ROIAlign_forward)>();
90+
91+
return op.call(
92+
input,
93+
rois,
94+
spatial_scale,
95+
pooled_height,
96+
pooled_width,
97+
sampling_ratio,
98+
aligned);
99+
}
100+
37101
at::Tensor IPEXROIAlignOp::forward(
38102
torch::autograd::AutogradContext* ctx,
39103
const at::Tensor& input,
@@ -45,7 +109,7 @@ at::Tensor IPEXROIAlignOp::forward(
45109
bool aligned) {
46110
RECORD_FUNCTION("IPEXROIAlignOp::forward", c10::ArrayRef<c10::IValue>({}));
47111

48-
ctx->saved_data["input_shape"] = input.sizes();
112+
ctx->saved_data["input_shape"] = input.sym_sizes();
49113
ctx->saved_data["spatial_scale"] = spatial_scale;
50114
ctx->saved_data["pooled_height"] = pooled_height;
51115
ctx->saved_data["pooled_width"] = pooled_width;
@@ -55,8 +119,7 @@ at::Tensor IPEXROIAlignOp::forward(
55119
input.is_contiguous(at::MemoryFormat::ChannelsLast);
56120
ctx->save_for_backward({rois});
57121

58-
return roi_align_forward_kernel_stub(
59-
kCPU,
122+
return _forward(
60123
input,
61124
rois,
62125
spatial_scale,
@@ -81,8 +144,11 @@ torch::autograd::variable_list IPEXROIAlignOp::backward(
81144
auto saved = ctx->get_saved_variables();
82145
at::Tensor rois = saved[0];
83146

84-
at::Tensor grad_input = roi_align_backward_kernel_stub(
85-
kCPU,
147+
static auto op = torch::Dispatcher::singleton()
148+
.findSchemaOrThrow("torch_ipex::ROIAlign_backward", "")
149+
.typed<decltype(ROIAlign_backward)>();
150+
151+
auto grad_input = op.call(
86152
grad_outputs[0],
87153
rois,
88154
spatial_scale,
@@ -134,45 +200,26 @@ at::Tensor ROIAlign_forward(
134200
aligned);
135201
}
136202

137-
} // namespace cpu
138-
} // namespace torch_ipex
139-
140-
namespace torch_ipex {
141-
namespace autocast {
142-
143-
at::Tensor roi_align_autocast(
203+
at::Tensor ROIAlign_forward_meta(
144204
const at::Tensor& input,
145205
const at::Tensor& rois,
146206
double spatial_scale,
147207
int64_t pooled_height,
148208
int64_t pooled_width,
149209
int64_t sampling_ratio,
150210
bool aligned) {
151-
c10::impl::ExcludeDispatchKeyGuard no_autocastCPU(DispatchKey::AutocastCPU);
152-
static auto op = torch::Dispatcher::singleton()
153-
.findSchemaOrThrow("torchvision::roi_align", "")
154-
.typed<decltype(torch_ipex::cpu::ROIAlign_forward)>();
155-
if (input.scalar_type() == at::ScalarType::BFloat16) {
156-
return op.call(
157-
input,
158-
cpu_cached_cast(at::kFloat, rois),
159-
spatial_scale,
160-
pooled_height,
161-
pooled_width,
162-
sampling_ratio,
163-
aligned);
164-
} else {
165-
return op.call(
166-
input,
167-
cpu_cached_cast(input.scalar_type(), rois),
168-
spatial_scale,
169-
pooled_height,
170-
pooled_width,
171-
sampling_ratio,
172-
aligned);
173-
}
211+
auto num_rois = rois.sym_size(0);
212+
auto channels = input.sym_size(1);
213+
return at::empty_symint(
214+
{num_rois, channels, pooled_height, pooled_width}, input.options());
174215
}
175216

217+
} // namespace cpu
218+
} // namespace torch_ipex
219+
220+
namespace torch_ipex {
221+
namespace autocast {
222+
176223
at::Tensor ROIAlign_forward(
177224
const at::Tensor& input,
178225
const at::Tensor& rois,
@@ -222,6 +269,21 @@ IPEX_TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
222269
"ROIAlign_forward",
223270
c10::DispatchKey::AutocastCPU,
224271
torch_ipex::autocast::ROIAlign_forward);
272+
m.impl(
273+
"ROIAlign_forward",
274+
c10::DispatchKey::CPU,
275+
torch_ipex::cpu::ROIAlign_forward_impl);
276+
m.impl(
277+
"ROIAlign_forward",
278+
c10::DispatchKey::Meta,
279+
torch_ipex::cpu::ROIAlign_forward_meta);
280+
// bw
281+
m.def(
282+
"ROIAlign_backward(Tensor grad, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width, int sampling_ratio, bool aligned, bool is_channels_last) -> Tensor");
283+
m.impl(
284+
"ROIAlign_backward",
285+
c10::DispatchKey::CPU,
286+
torch_ipex::cpu::ROIAlign_backward);
225287
}
226288

227289
IPEX_TORCH_LIBRARY_FRAGMENT(torchvision, m) {
@@ -232,7 +294,7 @@ IPEX_TORCH_LIBRARY_FRAGMENT(torchvision, m) {
232294
m.impl(
233295
"roi_align",
234296
c10::DispatchKey::AutocastCPU,
235-
torch_ipex::autocast::roi_align_autocast);
297+
torch_ipex::autocast::ROIAlign_forward);
236298
}
237299

238300
} // namespace

csrc/cpu/aten/ROIAlign.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,29 @@
77
namespace torch_ipex {
88
namespace cpu {
99

10+
at::Tensor ROIAlign_forward_impl(
11+
const at::Tensor& input,
12+
const at::Tensor& rois,
13+
double spatial_scale,
14+
int64_t pooled_height,
15+
int64_t pooled_width,
16+
int64_t sampling_ratio,
17+
bool aligned);
18+
19+
at::Tensor ROIAlign_backward(
20+
const at::Tensor& grad,
21+
const at::Tensor& rois,
22+
double spatial_scale,
23+
int64_t pooled_height,
24+
int64_t pooled_width,
25+
int64_t batch_size,
26+
int64_t channels,
27+
int64_t height,
28+
int64_t width,
29+
int64_t sampling_ratio,
30+
bool aligned,
31+
bool is_channels_last);
32+
1033
class IPEXROIAlignOp : public torch::autograd::Function<IPEXROIAlignOp> {
1134
public:
1235
// forward function without autograd overhead, will go this way when only do
@@ -44,6 +67,15 @@ at::Tensor ROIAlign_forward(
4467
int64_t sampling_ratio,
4568
bool aligned);
4669

70+
at::Tensor ROIAlign_forward_meta(
71+
const at::Tensor& input,
72+
const at::Tensor& rois,
73+
double spatial_scale,
74+
int64_t pooled_height,
75+
int64_t pooled_width,
76+
int64_t sampling_ratio,
77+
bool aligned);
78+
4779
namespace {
4880

4981
template <typename T>

csrc/cpu/aten/kernels/ROIAlignKrnl.cpp

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -673,12 +673,6 @@ at::Tensor roi_align_forward_kernel_impl(
673673
int64_t pooled_width,
674674
int64_t sampling_ratio,
675675
bool aligned) {
676-
#if defined(IPEX_DISP_OP)
677-
printf("torch_ipex::ROIAlign_forward\n");
678-
#endif
679-
RECORD_FUNCTION(
680-
"torch_ipex::ROIAlign_forward", c10::ArrayRef<c10::IValue>({}));
681-
682676
TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor");
683677
TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor");
684678
TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]");
@@ -741,12 +735,6 @@ at::Tensor roi_align_backward_kernel_impl(
741735
int64_t sampling_ratio,
742736
bool aligned,
743737
bool is_channels_last) {
744-
#if defined(IPEX_DISP_OP)
745-
printf("torch_ipex::ROIAlign_backward\n");
746-
#endif
747-
RECORD_FUNCTION(
748-
"torch_ipex::ROIAlign_backward", c10::ArrayRef<c10::IValue>({}));
749-
750738
TORCH_CHECK(grad.device().is_cpu(), "grad must be a CPU tensor");
751739
TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor");
752740

tests/cpu/test_roialign.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import unittest, copy
2+
import itertools
23
import torch
34
import intel_extension_for_pytorch as ipex
45
from common_utils import TestCase
@@ -219,6 +220,54 @@ def test_torchvision_roialign(self):
219220
self.assertTrue(x4.grad.dtype == torch.bfloat16)
220221
self.assertTrue(torch.allclose(gt_x.grad.to(x4.dtype), x4.grad, rtol=1e-5, atol=1e-5))
221222

223+
@skipIfNoTorchVision
224+
def test_torchvision_roialign_torchcompile(self):
225+
pool_size = 5
226+
n_channels = 2 * (pool_size ** 2)
227+
x = torch.rand(2, n_channels, 10, 10)
228+
rois = torch.tensor([[0, 0, 0, 9, 9], # format is (xyxy)
229+
[0, 0, 5, 4, 9],
230+
[0, 5, 5, 9, 9],
231+
[1, 0, 0, 9, 9]])
232+
pool_h, pool_w = pool_size, pool_size
233+
234+
# TODO: add dynamic tests when 'ipex' backend supports it.
235+
for dtype, backend, dynamic in itertools.product([torch.float32, torch.bfloat16], ['ipex', 'inductor'], [False]):
236+
torch._dynamo.reset()
237+
torchcompile_torchvision_fn = torch.compile(torchvision_fn, backend=backend, dynamic=dynamic)
238+
x = x.to(dtype=dtype)
239+
rois = rois.to(dtype=dtype)
240+
# forward
241+
with torch.cpu.amp.autocast(enabled=(dtype==torch.bfloat16)), torch.no_grad():
242+
y0 = torchvision_fn(x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1)
243+
y1 = torchcompile_torchvision_fn(x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1)
244+
self.assertEqual(y0, y1)
245+
self.assertTrue(y1.dtype == dtype)
246+
247+
@skipIfNoTorchVision
248+
def test_roialign_torchcompile(self):
249+
pool_size = 5
250+
n_channels = 2 * (pool_size ** 2)
251+
x = torch.rand(2, n_channels, 10, 10)
252+
rois = torch.tensor([[0, 0, 0, 9, 9], # format is (xyxy)
253+
[0, 0, 5, 4, 9],
254+
[0, 5, 5, 9, 9],
255+
[1, 0, 0, 9, 9]])
256+
pool_h, pool_w = pool_size, pool_size
257+
torch._dynamo.allow_in_graph(ipex.nn.modules._roi_align.RoIAlign)
258+
259+
# TODO: add dynamic tests when 'ipex' backend supports it.
260+
for dtype, backend, dynamic in itertools.product([torch.float32, torch.bfloat16], ['ipex', 'inductor'], [False]):
261+
torch._dynamo.reset()
262+
torchcompile_fn = torch.compile(fn, backend=backend, dynamic=dynamic)
263+
x = x.to(dtype=dtype)
264+
rois = rois.to(dtype=dtype)
265+
# forward
266+
with torch.cpu.amp.autocast(enabled=(dtype==torch.bfloat16)), torch.no_grad():
267+
y0 = fn(x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1)
268+
y1 = torchcompile_fn(x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1)
269+
self.assertEqual(y0, y1)
270+
self.assertTrue(y1.dtype == dtype)
222271

223272
if __name__ == '__main__':
224273
test = unittest.main()

0 commit comments

Comments
 (0)