Skip to content

Commit f766535

Browse files
FranklandJacklhutton1
authored andcommitted
[topi] Add arm_cpu specific pooling schedules
This commit: * Adds specialized `arm_cpu` pooling schedules for both fixed width and salable vectors. * Enables topi testing of new `arm_cpu` schedules. Co-authored-by: Jack Frankland <[email protected]> Change-Id: Ib07fb438ba9ee8ab92fc5bfc438479959411e7db
1 parent 7392432 commit f766535

File tree

5 files changed

+116
-25
lines changed

5 files changed

+116
-25
lines changed

python/tvm/relay/op/strategy/arm_cpu.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from ....auto_scheduler import is_auto_scheduler_enabled
2727
from ....meta_schedule import is_meta_schedule_enabled
2828
from ....topi.generic import conv2d as conv2d_generic
29+
from ....topi.arm_cpu.mprofile import dsp
2930
from .. import op as _op
3031
from .generic import *
3132

@@ -63,19 +64,11 @@ def concatenate_strategy_arm_cpu(attrs, inputs, out_type, target):
6364
def schedule_pool_arm_cpu(attrs, outs, target):
6465
"""schedule pooling ops arm cpu"""
6566
layout = attrs.layout
66-
avg_pool = isinstance(attrs, relay.op.op_attrs.AvgPool2DAttrs)
6767
with target:
68-
if (
69-
avg_pool
70-
and target.features.has_dsp
71-
and layout in ("NCW", "NCHW")
72-
or not avg_pool
73-
and target.features.has_dsp
74-
and layout in ("NWC", "NHWC")
75-
):
76-
return topi.arm_cpu.schedule_pool(outs, layout)
77-
logger.warning("pool is not optimized for arm cpu.")
78-
return topi.generic.schedule_pool(outs, layout)
68+
if target.features.has_dsp:
69+
is_avg_pool = isinstance(attrs, relay.op.op_attrs.AvgPool2DAttrs)
70+
return dsp.pool.schedule_pool(outs, layout, is_avg_pool)
71+
return topi.arm_cpu.schedule_pool(outs, layout)
7972

8073

8174
def _get_padding_width(padding):

python/tvm/topi/arm_cpu/mprofile/dsp/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,5 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17+
"""Schedule for arm_cpu targets supporting DSP"""
18+
from .pool import schedule_pool

python/tvm/topi/arm_cpu/mprofile/dsp/pool.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,12 @@
2020

2121
import tvm
2222

23-
from tvm import te
23+
from tvm import te, topi
2424
from tvm.topi.utils import traverse_inline
2525

26-
from .micro_kernel.max_pool import (
27-
intrin_max,
28-
max_impl,
29-
)
26+
from .micro_kernel.max_pool import intrin_max, max_impl
3027

31-
from .micro_kernel.avg_pool import (
32-
intrin_sum,
33-
sum_impl,
34-
)
28+
from .micro_kernel.avg_pool import intrin_sum, sum_impl
3529

3630
logger = logging.getLogger("topi")
3731

@@ -100,8 +94,24 @@ def schedule_avgpool_2d_nchw(s, op):
10094
s[output].pragma(n, "import_c", sum_impl(pool_w, uniq_id))
10195

10296

103-
def pool_dsp_schedule(outs, layout):
97+
def schedule_pool(outs, layout, is_avg_pool):
10498
"""Schedule function for v7e-m DSP instructions of pooling."""
99+
100+
if is_avg_pool and layout not in ["NCW", "NCHW"]:
101+
logger.warning(
102+
"avg pool not support for NCW or NCHW layouts on DSP"
103+
"enabled targets, falling back on generic pool"
104+
"implementation"
105+
)
106+
return topi.generic.schedule_pool(outs, layout)
107+
elif not is_avg_pool and layout not in ["NWC", "NHWC"]:
108+
logger.warning(
109+
"max pool not support for NWC or NHWC layouts on DSP"
110+
"enabled targets, falling back on generic pool"
111+
"implementation"
112+
)
113+
return topi.generic.schedule_pool(outs, layout)
114+
105115
s = te.create_schedule([x.op for x in outs])
106116

107117
def _callback(op):

python/tvm/topi/arm_cpu/pooling.py

Lines changed: 88 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,94 @@
1717
# pylint: disable=invalid-name, unused-variable
1818
"""Schedule for pooling operators"""
1919

20-
from .mprofile.dsp.pool import pool_dsp_schedule
20+
import logging
21+
from tvm import topi, te
22+
from tvm.target import Target
23+
from .. import tag
2124

2225

2326
def schedule_pool(outs, layout):
24-
"""Create schedule for avgpool/maxpool with dsp"""
25-
return pool_dsp_schedule(outs, layout)
27+
"""Create schedule for avgpool/maxpool"""
28+
29+
if layout != "NHWC":
30+
logger = logging.getLogger("topi")
31+
logger.warning(
32+
"""We currently only support NHWC target specific pools on arm_cpu,
33+
falling back on generic pool scheduling"""
34+
)
35+
return topi.generic.schedule_pool(outs, layout)
36+
37+
return schedule_pool_2d(outs)
38+
39+
40+
def schedule_pool_2d(outs):
41+
"""Create arm_cpu specific 2D schedule for avgpool/maxpool"""
42+
43+
outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
44+
schedule_ops = [x.op for x in outs]
45+
schedule = te.create_schedule(schedule_ops)
46+
scheduled_ops = []
47+
48+
def traverse(op):
49+
# Recursively inline any injective operation that isn't the pooling
50+
# operation or hasn't already been scheduled.
51+
if tag.is_injective(op.tag):
52+
if op not in schedule.outputs:
53+
schedule[op].compute_inline()
54+
for tensor in op.input_tensors:
55+
if isinstance(tensor.op, te.tensor.ComputeOp) and tensor.op not in scheduled_ops:
56+
traverse(tensor.op)
57+
# schedule the actual pooling operation
58+
elif op.tag.startswith("pool"):
59+
n, height, width, channel = schedule[op].op.axis
60+
# Average pool consists of two parts; a sum then a division.
61+
# We can schedule the division loop to parallelize across height and
62+
# vectorize across width.
63+
enable_explicit_vectorization = not Target.current(allow_none=False).features.has_sve
64+
if op != outs[0].op:
65+
output = outs[0]
66+
output_fused = schedule[output].fuse(output.op.axis[1], output.op.axis[2])
67+
schedule[output].parallel(output_fused)
68+
vectorization_factor = (
69+
8 if enable_explicit_vectorization else output.op.axis[3].dom.extent
70+
)
71+
_, inner = schedule[output].split(output.op.axis[3], vectorization_factor)
72+
schedule[output].vectorize(inner)
73+
74+
padded_input = op.input_tensors[0]
75+
if isinstance(padded_input.op, te.tensor.ComputeOp):
76+
schedule[padded_input].compute_inline()
77+
78+
# For targets without SVE try explicitly vectorizing the channel
79+
# loop, For SVE targets leave the loop in place for LLVM to convert
80+
# into a scalable vector loop.
81+
vectorization_factor = 8 if enable_explicit_vectorization else channel.dom.extent
82+
channel_outer, channel_inner = schedule[op].split(channel, vectorization_factor)
83+
schedule[op].vectorize(channel_inner)
84+
schedule[op].parallel(height)
85+
if len(schedule[op].op.reduce_axis) > 0:
86+
filter_height, filter_width = schedule[op].op.reduce_axis
87+
# We consider any filter of area < 10 to be small enough to
88+
# unroll; 3x3 filters have shown better performance when
89+
# unrolled.
90+
if filter_height.dom.extent * filter_width.dom.extent <= 9:
91+
# For small filters, unrolling the filter loops allows us to
92+
# vectorize over channels without reordering anything.
93+
schedule[op].unroll(filter_width)
94+
schedule[op].unroll(filter_height)
95+
else:
96+
# Reordering so that channels is the fastest moving axis allows
97+
# LLVM to vectorize across contiguous memory in the NHWC
98+
# ordering.
99+
schedule[op].reorder(
100+
n, height, width, filter_height, filter_width, channel_outer, channel_inner
101+
)
102+
else:
103+
schedule[op].reorder(n, height, width, channel_outer, channel_inner)
104+
else:
105+
raise RuntimeError("Unsupported operator: %s" % op.tag)
106+
107+
scheduled_ops.append(op)
108+
109+
traverse(outs[0].op)
110+
return schedule

tests/python/topi/python/test_topi_pooling.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
_pool_schedule = {
3030
"generic": topi.generic.schedule_pool,
31+
"arm_cpu": topi.arm_cpu.schedule_pool,
3132
"cpu": topi.x86.schedule_pool,
3233
"gpu": topi.cuda.schedule_pool,
3334
"hls": topi.hls.schedule_pool,

0 commit comments

Comments
 (0)