Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions python/tvm/relay/op/strategy/arm_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from ....auto_scheduler import is_auto_scheduler_enabled
from ....meta_schedule import is_meta_schedule_enabled
from ....topi.generic import conv2d as conv2d_generic
from ....topi.arm_cpu.mprofile import dsp
from .. import op as _op
from .generic import *

Expand Down Expand Up @@ -64,11 +63,19 @@ def concatenate_strategy_arm_cpu(attrs, inputs, out_type, target):
def schedule_pool_arm_cpu(attrs, outs, target):
"""schedule pooling ops arm cpu"""
layout = attrs.layout
avg_pool = isinstance(attrs, relay.op.op_attrs.AvgPool2DAttrs)
with target:
if target.features.has_dsp:
is_avg_pool = isinstance(attrs, relay.op.op_attrs.AvgPool2DAttrs)
return dsp.pool.schedule_pool(outs, layout, is_avg_pool)
return topi.arm_cpu.schedule_pool(outs, layout)
if (
avg_pool
and target.features.has_dsp
and layout in ("NCW", "NCHW")
or not avg_pool
and target.features.has_dsp
and layout in ("NWC", "NHWC")
):
return topi.arm_cpu.schedule_pool(outs, layout)
logger.warning("pool is not optimized for arm cpu.")
return topi.generic.schedule_pool(outs, layout)


def _get_padding_width(padding):
Expand Down
2 changes: 0 additions & 2 deletions python/tvm/topi/arm_cpu/mprofile/dsp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,3 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Schedule for arm_cpu targets supporting DSP"""
from .pool import schedule_pool
30 changes: 10 additions & 20 deletions python/tvm/topi/arm_cpu/mprofile/dsp/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,18 @@

import tvm

from tvm import te, topi
from tvm import te
from tvm.topi.utils import traverse_inline

from .micro_kernel.max_pool import intrin_max, max_impl
from .micro_kernel.max_pool import (
intrin_max,
max_impl,
)

from .micro_kernel.avg_pool import intrin_sum, sum_impl
from .micro_kernel.avg_pool import (
intrin_sum,
sum_impl,
)

logger = logging.getLogger("topi")

Expand Down Expand Up @@ -94,24 +100,8 @@ def schedule_avgpool_2d_nchw(s, op):
s[output].pragma(n, "import_c", sum_impl(pool_w, uniq_id))


def schedule_pool(outs, layout, is_avg_pool):
def pool_dsp_schedule(outs, layout):
"""Schedule function for v7e-m DSP instructions of pooling."""

if is_avg_pool and layout not in ["NCW", "NCHW"]:
logger.warning(
"avg pool not support for NCW or NCHW layouts on DSP"
"enabled targets, falling back on generic pool"
"implementation"
)
return topi.generic.schedule_pool(outs, layout)
elif not is_avg_pool and layout not in ["NWC", "NHWC"]:
logger.warning(
"max pool not support for NWC or NHWC layouts on DSP"
"enabled targets, falling back on generic pool"
"implementation"
)
return topi.generic.schedule_pool(outs, layout)

s = te.create_schedule([x.op for x in outs])

def _callback(op):
Expand Down
91 changes: 3 additions & 88 deletions python/tvm/topi/arm_cpu/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,94 +17,9 @@
# pylint: disable=invalid-name, unused-variable
"""Schedule for pooling operators"""

import logging
from tvm import topi, te
from tvm.target import Target
from .. import tag
from .mprofile.dsp.pool import pool_dsp_schedule


def schedule_pool(outs, layout):
"""Create schedule for avgpool/maxpool"""

if layout != "NHWC":
logger = logging.getLogger("topi")
logger.warning(
"""We currently only support NHWC target specific pools on arm_cpu,
falling back on generic pool scheduling"""
)
return topi.generic.schedule_pool(outs, layout)

return schedule_pool_2d(outs)


def schedule_pool_2d(outs):
"""Create arm_cpu specific 2D schedule for avgpool/maxpool"""

outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
schedule_ops = [x.op for x in outs]
schedule = te.create_schedule(schedule_ops)
scheduled_ops = []

def traverse(op):
# Recursively inline any injective operation that isn't the pooling
# operation or hasn't already been scheduled.
if tag.is_injective(op.tag):
if op not in schedule.outputs:
schedule[op].compute_inline()
for tensor in op.input_tensors:
if isinstance(tensor.op, te.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op)
# schedule the actual pooling operation
elif op.tag.startswith("pool"):
n, height, width, channel = schedule[op].op.axis
# Average pool consists of two parts; a sum then a division.
# We can schedule the division loop to parallelize across height and
# vectorize across width.
enable_explicit_vectorization = not Target.current(allow_none=False).features.has_sve
if op != outs[0].op:
output = outs[0]
output_fused = schedule[output].fuse(output.op.axis[1], output.op.axis[2])
schedule[output].parallel(output_fused)
vectorization_factor = (
8 if enable_explicit_vectorization else output.op.axis[3].dom.extent
)
_, inner = schedule[output].split(output.op.axis[3], vectorization_factor)
schedule[output].vectorize(inner)

padded_input = op.input_tensors[0]
if isinstance(padded_input.op, te.tensor.ComputeOp):
schedule[padded_input].compute_inline()

# For targets without SVE try explicitly vectorizing the channel
# loop, For SVE targets leave the loop in place for LLVM to convert
# into a scalable vector loop.
vectorization_factor = 8 if enable_explicit_vectorization else channel.dom.extent
channel_outer, channel_inner = schedule[op].split(channel, vectorization_factor)
schedule[op].vectorize(channel_inner)
schedule[op].parallel(height)
if len(schedule[op].op.reduce_axis) > 0:
filter_height, filter_width = schedule[op].op.reduce_axis
# We consider any filter of area < 10 to be small enough to
# unroll; 3x3 filters have shown better performance when
# unrolled.
if filter_height.dom.extent * filter_width.dom.extent <= 9:
# For small filters, unrolling the filter loops allows us to
# vectorize over channels without reordering anything.
schedule[op].unroll(filter_width)
schedule[op].unroll(filter_height)
else:
# Reordering so that channels is the fastest moving axis allows
# LLVM to vectorize across contiguous memory in the NHWC
# ordering.
schedule[op].reorder(
n, height, width, filter_height, filter_width, channel_outer, channel_inner
)
else:
schedule[op].reorder(n, height, width, channel_outer, channel_inner)
else:
raise RuntimeError("Unsupported operator: %s" % op.tag)

scheduled_ops.append(op)

traverse(outs[0].op)
return schedule
"""Create schedule for avgpool/maxpool with dsp"""
return pool_dsp_schedule(outs, layout)
1 change: 0 additions & 1 deletion tests/python/topi/python/test_topi_pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@

_pool_schedule = {
"generic": topi.generic.schedule_pool,
"arm_cpu": topi.arm_cpu.schedule_pool,
"cpu": topi.x86.schedule_pool,
"gpu": topi.cuda.schedule_pool,
"hls": topi.hls.schedule_pool,
Expand Down