Skip to content

Commit 0a3ad64

Browse files
[topi] Add arm_cpu specific pooling schedules (#15311)
This commit: * Adds specialized `arm_cpu` pooling schedules for both fixed width and salable vectors. * Enables topi testing of new `arm_cpu` schedules. * Remove self-import and use relative import Co-authored-by: Jack Frankland <[email protected]> Change-Id: Ib07fb438ba9ee8ab92fc5bfc438479959411e7db
1 parent 7e830e5 commit 0a3ad64

File tree

5 files changed

+118
-25
lines changed

5 files changed

+118
-25
lines changed

python/tvm/relay/op/strategy/arm_cpu.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from ....auto_scheduler import is_auto_scheduler_enabled
2727
from ....meta_schedule import is_meta_schedule_enabled
2828
from ....topi.generic import conv2d as conv2d_generic
29+
from ....topi.arm_cpu.mprofile import dsp
2930
from .. import op as _op
3031
from .generic import *
3132

@@ -63,19 +64,11 @@ def concatenate_strategy_arm_cpu(attrs, inputs, out_type, target):
6364
def schedule_pool_arm_cpu(attrs, outs, target):
6465
"""schedule pooling ops arm cpu"""
6566
layout = attrs.layout
66-
avg_pool = isinstance(attrs, relay.op.op_attrs.AvgPool2DAttrs)
6767
with target:
68-
if (
69-
avg_pool
70-
and target.features.has_dsp
71-
and layout in ("NCW", "NCHW")
72-
or not avg_pool
73-
and target.features.has_dsp
74-
and layout in ("NWC", "NHWC")
75-
):
76-
return topi.arm_cpu.schedule_pool(outs, layout)
77-
logger.warning("pool is not optimized for arm cpu.")
78-
return topi.generic.schedule_pool(outs, layout)
68+
if target.features.has_dsp:
69+
is_avg_pool = isinstance(attrs, relay.op.op_attrs.AvgPool2DAttrs)
70+
return dsp.pool.schedule_pool(outs, layout, is_avg_pool)
71+
return topi.arm_cpu.schedule_pool(outs, layout)
7972

8073

8174
def _get_padding_width(padding):

python/tvm/topi/arm_cpu/mprofile/dsp/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,5 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17+
"""Schedule for arm_cpu targets supporting DSP"""
18+
from .pool import schedule_pool

python/tvm/topi/arm_cpu/mprofile/dsp/pool.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,9 @@
2323
from tvm import te
2424
from tvm.topi.utils import traverse_inline
2525

26-
from .micro_kernel.max_pool import (
27-
intrin_max,
28-
max_impl,
29-
)
30-
31-
from .micro_kernel.avg_pool import (
32-
intrin_sum,
33-
sum_impl,
34-
)
26+
from .micro_kernel.max_pool import intrin_max, max_impl
27+
from .micro_kernel.avg_pool import intrin_sum, sum_impl
28+
from .... import generic
3529

3630
logger = logging.getLogger("topi")
3731

@@ -100,8 +94,24 @@ def schedule_avgpool_2d_nchw(s, op):
10094
s[output].pragma(n, "import_c", sum_impl(pool_w, uniq_id))
10195

10296

103-
def pool_dsp_schedule(outs, layout):
97+
def schedule_pool(outs, layout, is_avg_pool):
10498
"""Schedule function for v7e-m DSP instructions of pooling."""
99+
100+
if is_avg_pool and layout not in ["NCW", "NCHW"]:
101+
logger.warning(
102+
"avg pool not support for NCW or NCHW layouts on DSP"
103+
"enabled targets, falling back on generic pool"
104+
"implementation"
105+
)
106+
return generic.schedule_pool(outs, layout)
107+
elif not is_avg_pool and layout not in ["NWC", "NHWC"]:
108+
logger.warning(
109+
"max pool not support for NWC or NHWC layouts on DSP"
110+
"enabled targets, falling back on generic pool"
111+
"implementation"
112+
)
113+
return generic.schedule_pool(outs, layout)
114+
105115
s = te.create_schedule([x.op for x in outs])
106116

107117
def _callback(op):

python/tvm/topi/arm_cpu/pooling.py

Lines changed: 90 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,96 @@
1717
# pylint: disable=invalid-name, unused-variable
1818
"""Schedule for pooling operators"""
1919

20-
from .mprofile.dsp.pool import pool_dsp_schedule
20+
import logging
21+
from tvm import te
22+
from tvm.target import Target
23+
24+
from .. import tag
25+
from .. import generic
2126

2227

2328
def schedule_pool(outs, layout):
24-
"""Create schedule for avgpool/maxpool with dsp"""
25-
return pool_dsp_schedule(outs, layout)
29+
"""Create schedule for avgpool/maxpool"""
30+
31+
if layout != "NHWC":
32+
logger = logging.getLogger("topi")
33+
logger.warning(
34+
"""We currently only support NHWC target specific pools on arm_cpu,
35+
falling back on generic pool scheduling"""
36+
)
37+
return generic.schedule_pool(outs, layout)
38+
39+
return schedule_pool_2d(outs)
40+
41+
42+
def schedule_pool_2d(outs):
43+
"""Create arm_cpu specific 2D schedule for avgpool/maxpool"""
44+
45+
outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
46+
schedule_ops = [x.op for x in outs]
47+
schedule = te.create_schedule(schedule_ops)
48+
scheduled_ops = []
49+
50+
def traverse(op):
51+
# Recursively inline any injective operation that isn't the pooling
52+
# operation or hasn't already been scheduled.
53+
if tag.is_injective(op.tag):
54+
if op not in schedule.outputs:
55+
schedule[op].compute_inline()
56+
for tensor in op.input_tensors:
57+
if isinstance(tensor.op, te.tensor.ComputeOp) and tensor.op not in scheduled_ops:
58+
traverse(tensor.op)
59+
# schedule the actual pooling operation
60+
elif op.tag.startswith("pool"):
61+
n, height, width, channel = schedule[op].op.axis
62+
# Average pool consists of two parts; a sum then a division.
63+
# We can schedule the division loop to parallelize across height and
64+
# vectorize across width.
65+
enable_explicit_vectorization = not Target.current(allow_none=False).features.has_sve
66+
if op != outs[0].op:
67+
output = outs[0]
68+
output_fused = schedule[output].fuse(output.op.axis[1], output.op.axis[2])
69+
schedule[output].parallel(output_fused)
70+
vectorization_factor = (
71+
8 if enable_explicit_vectorization else output.op.axis[3].dom.extent
72+
)
73+
_, inner = schedule[output].split(output.op.axis[3], vectorization_factor)
74+
schedule[output].vectorize(inner)
75+
76+
padded_input = op.input_tensors[0]
77+
if isinstance(padded_input.op, te.tensor.ComputeOp):
78+
schedule[padded_input].compute_inline()
79+
80+
# For targets without SVE try explicitly vectorizing the channel
81+
# loop, For SVE targets leave the loop in place for LLVM to convert
82+
# into a scalable vector loop.
83+
vectorization_factor = 8 if enable_explicit_vectorization else channel.dom.extent
84+
channel_outer, channel_inner = schedule[op].split(channel, vectorization_factor)
85+
schedule[op].vectorize(channel_inner)
86+
schedule[op].parallel(height)
87+
if len(schedule[op].op.reduce_axis) > 0:
88+
filter_height, filter_width = schedule[op].op.reduce_axis
89+
# We consider any filter of area < 10 to be small enough to
90+
# unroll; 3x3 filters have shown better performance when
91+
# unrolled.
92+
if filter_height.dom.extent * filter_width.dom.extent <= 9:
93+
# For small filters, unrolling the filter loops allows us to
94+
# vectorize over channels without reordering anything.
95+
schedule[op].unroll(filter_width)
96+
schedule[op].unroll(filter_height)
97+
else:
98+
# Reordering so that channels is the fastest moving axis allows
99+
# LLVM to vectorize across contiguous memory in the NHWC
100+
# ordering.
101+
schedule[op].reorder(
102+
n, height, width, filter_height, filter_width, channel_outer, channel_inner
103+
)
104+
else:
105+
schedule[op].reorder(n, height, width, channel_outer, channel_inner)
106+
else:
107+
raise RuntimeError("Unsupported operator: %s" % op.tag)
108+
109+
scheduled_ops.append(op)
110+
111+
traverse(outs[0].op)
112+
return schedule

tests/python/topi/python/test_topi_pooling.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
_pool_schedule = {
3030
"generic": topi.generic.schedule_pool,
31+
"arm_cpu": topi.arm_cpu.schedule_pool,
3132
"cpu": topi.x86.schedule_pool,
3233
"gpu": topi.cuda.schedule_pool,
3334
"hls": topi.hls.schedule_pool,

0 commit comments

Comments
 (0)