Skip to content

Commit 48fe2f3

Browse files
authored
Revert "[topi] Add arm_cpu specific pooling schedules" (#15371)
Revert "[topi] Add `arm_cpu` specific pooling schedules (#15311)" This reverts commit 0a3ad64, due to cyclic importing issue reported in #15311.
1 parent 0fadb98 commit 48fe2f3

File tree

5 files changed

+25
-118
lines changed

5 files changed

+25
-118
lines changed

python/tvm/relay/op/strategy/arm_cpu.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from ....auto_scheduler import is_auto_scheduler_enabled
2727
from ....meta_schedule import is_meta_schedule_enabled
2828
from ....topi.generic import conv2d as conv2d_generic
29-
from ....topi.arm_cpu.mprofile import dsp
3029
from .. import op as _op
3130
from .generic import *
3231

@@ -64,11 +63,19 @@ def concatenate_strategy_arm_cpu(attrs, inputs, out_type, target):
6463
def schedule_pool_arm_cpu(attrs, outs, target):
6564
"""schedule pooling ops arm cpu"""
6665
layout = attrs.layout
66+
avg_pool = isinstance(attrs, relay.op.op_attrs.AvgPool2DAttrs)
6767
with target:
68-
if target.features.has_dsp:
69-
is_avg_pool = isinstance(attrs, relay.op.op_attrs.AvgPool2DAttrs)
70-
return dsp.pool.schedule_pool(outs, layout, is_avg_pool)
71-
return topi.arm_cpu.schedule_pool(outs, layout)
68+
if (
69+
avg_pool
70+
and target.features.has_dsp
71+
and layout in ("NCW", "NCHW")
72+
or not avg_pool
73+
and target.features.has_dsp
74+
and layout in ("NWC", "NHWC")
75+
):
76+
return topi.arm_cpu.schedule_pool(outs, layout)
77+
logger.warning("pool is not optimized for arm cpu.")
78+
return topi.generic.schedule_pool(outs, layout)
7279

7380

7481
def _get_padding_width(padding):

python/tvm/topi/arm_cpu/mprofile/dsp/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,3 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17-
"""Schedule for arm_cpu targets supporting DSP"""
18-
from .pool import schedule_pool

python/tvm/topi/arm_cpu/mprofile/dsp/pool.py

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,15 @@
2323
from tvm import te
2424
from tvm.topi.utils import traverse_inline
2525

26-
from .micro_kernel.max_pool import intrin_max, max_impl
27-
from .micro_kernel.avg_pool import intrin_sum, sum_impl
28-
from .... import generic
26+
from .micro_kernel.max_pool import (
27+
intrin_max,
28+
max_impl,
29+
)
30+
31+
from .micro_kernel.avg_pool import (
32+
intrin_sum,
33+
sum_impl,
34+
)
2935

3036
logger = logging.getLogger("topi")
3137

@@ -94,24 +100,8 @@ def schedule_avgpool_2d_nchw(s, op):
94100
s[output].pragma(n, "import_c", sum_impl(pool_w, uniq_id))
95101

96102

97-
def schedule_pool(outs, layout, is_avg_pool):
103+
def pool_dsp_schedule(outs, layout):
98104
"""Schedule function for v7e-m DSP instructions of pooling."""
99-
100-
if is_avg_pool and layout not in ["NCW", "NCHW"]:
101-
logger.warning(
102-
"avg pool not support for NCW or NCHW layouts on DSP"
103-
"enabled targets, falling back on generic pool"
104-
"implementation"
105-
)
106-
return generic.schedule_pool(outs, layout)
107-
elif not is_avg_pool and layout not in ["NWC", "NHWC"]:
108-
logger.warning(
109-
"max pool not support for NWC or NHWC layouts on DSP"
110-
"enabled targets, falling back on generic pool"
111-
"implementation"
112-
)
113-
return generic.schedule_pool(outs, layout)
114-
115105
s = te.create_schedule([x.op for x in outs])
116106

117107
def _callback(op):

python/tvm/topi/arm_cpu/pooling.py

Lines changed: 3 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -17,96 +17,9 @@
1717
# pylint: disable=invalid-name, unused-variable
1818
"""Schedule for pooling operators"""
1919

20-
import logging
21-
from tvm import te
22-
from tvm.target import Target
23-
24-
from .. import tag
25-
from .. import generic
20+
from .mprofile.dsp.pool import pool_dsp_schedule
2621

2722

2823
def schedule_pool(outs, layout):
29-
"""Create schedule for avgpool/maxpool"""
30-
31-
if layout != "NHWC":
32-
logger = logging.getLogger("topi")
33-
logger.warning(
34-
"""We currently only support NHWC target specific pools on arm_cpu,
35-
falling back on generic pool scheduling"""
36-
)
37-
return generic.schedule_pool(outs, layout)
38-
39-
return schedule_pool_2d(outs)
40-
41-
42-
def schedule_pool_2d(outs):
43-
"""Create arm_cpu specific 2D schedule for avgpool/maxpool"""
44-
45-
outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
46-
schedule_ops = [x.op for x in outs]
47-
schedule = te.create_schedule(schedule_ops)
48-
scheduled_ops = []
49-
50-
def traverse(op):
51-
# Recursively inline any injective operation that isn't the pooling
52-
# operation or hasn't already been scheduled.
53-
if tag.is_injective(op.tag):
54-
if op not in schedule.outputs:
55-
schedule[op].compute_inline()
56-
for tensor in op.input_tensors:
57-
if isinstance(tensor.op, te.tensor.ComputeOp) and tensor.op not in scheduled_ops:
58-
traverse(tensor.op)
59-
# schedule the actual pooling operation
60-
elif op.tag.startswith("pool"):
61-
n, height, width, channel = schedule[op].op.axis
62-
# Average pool consists of two parts; a sum then a division.
63-
# We can schedule the division loop to parallelize across height and
64-
# vectorize across width.
65-
enable_explicit_vectorization = not Target.current(allow_none=False).features.has_sve
66-
if op != outs[0].op:
67-
output = outs[0]
68-
output_fused = schedule[output].fuse(output.op.axis[1], output.op.axis[2])
69-
schedule[output].parallel(output_fused)
70-
vectorization_factor = (
71-
8 if enable_explicit_vectorization else output.op.axis[3].dom.extent
72-
)
73-
_, inner = schedule[output].split(output.op.axis[3], vectorization_factor)
74-
schedule[output].vectorize(inner)
75-
76-
padded_input = op.input_tensors[0]
77-
if isinstance(padded_input.op, te.tensor.ComputeOp):
78-
schedule[padded_input].compute_inline()
79-
80-
# For targets without SVE try explicitly vectorizing the channel
81-
# loop, For SVE targets leave the loop in place for LLVM to convert
82-
# into a scalable vector loop.
83-
vectorization_factor = 8 if enable_explicit_vectorization else channel.dom.extent
84-
channel_outer, channel_inner = schedule[op].split(channel, vectorization_factor)
85-
schedule[op].vectorize(channel_inner)
86-
schedule[op].parallel(height)
87-
if len(schedule[op].op.reduce_axis) > 0:
88-
filter_height, filter_width = schedule[op].op.reduce_axis
89-
# We consider any filter of area < 10 to be small enough to
90-
# unroll; 3x3 filters have shown better performance when
91-
# unrolled.
92-
if filter_height.dom.extent * filter_width.dom.extent <= 9:
93-
# For small filters, unrolling the filter loops allows us to
94-
# vectorize over channels without reordering anything.
95-
schedule[op].unroll(filter_width)
96-
schedule[op].unroll(filter_height)
97-
else:
98-
# Reordering so that channels is the fastest moving axis allows
99-
# LLVM to vectorize across contiguous memory in the NHWC
100-
# ordering.
101-
schedule[op].reorder(
102-
n, height, width, filter_height, filter_width, channel_outer, channel_inner
103-
)
104-
else:
105-
schedule[op].reorder(n, height, width, channel_outer, channel_inner)
106-
else:
107-
raise RuntimeError("Unsupported operator: %s" % op.tag)
108-
109-
scheduled_ops.append(op)
110-
111-
traverse(outs[0].op)
112-
return schedule
24+
"""Create schedule for avgpool/maxpool with dsp"""
25+
return pool_dsp_schedule(outs, layout)

tests/python/topi/python/test_topi_pooling.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828

2929
_pool_schedule = {
3030
"generic": topi.generic.schedule_pool,
31-
"arm_cpu": topi.arm_cpu.schedule_pool,
3231
"cpu": topi.x86.schedule_pool,
3332
"gpu": topi.cuda.schedule_pool,
3433
"hls": topi.hls.schedule_pool,

0 commit comments

Comments
 (0)