|
17 | 17 | # pylint: disable=invalid-name, unused-variable |
18 | 18 | """Schedule for pooling operators""" |
19 | 19 |
|
20 | | -from .mprofile.dsp.pool import pool_dsp_schedule |
| 20 | +import logging |
| 21 | +from tvm import te |
| 22 | +from tvm.target import Target |
| 23 | + |
| 24 | +from .. import tag |
| 25 | +from .. import generic |
21 | 26 |
|
22 | 27 |
|
23 | 28 | def schedule_pool(outs, layout): |
24 | | - """Create schedule for avgpool/maxpool with dsp""" |
25 | | - return pool_dsp_schedule(outs, layout) |
| 29 | + """Create schedule for avgpool/maxpool""" |
| 30 | + |
| 31 | + if layout != "NHWC": |
| 32 | + logger = logging.getLogger("topi") |
| 33 | + logger.warning( |
| 34 | + """We currently only support NHWC target specific pools on arm_cpu, |
| 35 | + falling back on generic pool scheduling""" |
| 36 | + ) |
| 37 | + return generic.schedule_pool(outs, layout) |
| 38 | + |
| 39 | + return schedule_pool_2d(outs) |
| 40 | + |
| 41 | + |
| 42 | +def schedule_pool_2d(outs): |
| 43 | + """Create arm_cpu specific 2D schedule for avgpool/maxpool""" |
| 44 | + |
| 45 | + outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs |
| 46 | + schedule_ops = [x.op for x in outs] |
| 47 | + schedule = te.create_schedule(schedule_ops) |
| 48 | + scheduled_ops = [] |
| 49 | + |
| 50 | + def traverse(op): |
| 51 | + # Recursively inline any injective operation that isn't the pooling |
| 52 | + # operation or hasn't already been scheduled. |
| 53 | + if tag.is_injective(op.tag): |
| 54 | + if op not in schedule.outputs: |
| 55 | + schedule[op].compute_inline() |
| 56 | + for tensor in op.input_tensors: |
| 57 | + if isinstance(tensor.op, te.tensor.ComputeOp) and tensor.op not in scheduled_ops: |
| 58 | + traverse(tensor.op) |
| 59 | + # schedule the actual pooling operation |
| 60 | + elif op.tag.startswith("pool"): |
| 61 | + n, height, width, channel = schedule[op].op.axis |
| 62 | + # Average pool consists of two parts; a sum then a division. |
| 63 | + # We can schedule the division loop to parallelize across height and |
| 64 | + # vectorize across width. |
| 65 | + enable_explicit_vectorization = not Target.current(allow_none=False).features.has_sve |
| 66 | + if op != outs[0].op: |
| 67 | + output = outs[0] |
| 68 | + output_fused = schedule[output].fuse(output.op.axis[1], output.op.axis[2]) |
| 69 | + schedule[output].parallel(output_fused) |
| 70 | + vectorization_factor = ( |
| 71 | + 8 if enable_explicit_vectorization else output.op.axis[3].dom.extent |
| 72 | + ) |
| 73 | + _, inner = schedule[output].split(output.op.axis[3], vectorization_factor) |
| 74 | + schedule[output].vectorize(inner) |
| 75 | + |
| 76 | + padded_input = op.input_tensors[0] |
| 77 | + if isinstance(padded_input.op, te.tensor.ComputeOp): |
| 78 | + schedule[padded_input].compute_inline() |
| 79 | + |
| 80 | + # For targets without SVE try explicitly vectorizing the channel |
| 81 | + # loop, For SVE targets leave the loop in place for LLVM to convert |
| 82 | + # into a scalable vector loop. |
| 83 | + vectorization_factor = 8 if enable_explicit_vectorization else channel.dom.extent |
| 84 | + channel_outer, channel_inner = schedule[op].split(channel, vectorization_factor) |
| 85 | + schedule[op].vectorize(channel_inner) |
| 86 | + schedule[op].parallel(height) |
| 87 | + if len(schedule[op].op.reduce_axis) > 0: |
| 88 | + filter_height, filter_width = schedule[op].op.reduce_axis |
| 89 | + # We consider any filter of area < 10 to be small enough to |
| 90 | + # unroll; 3x3 filters have shown better performance when |
| 91 | + # unrolled. |
| 92 | + if filter_height.dom.extent * filter_width.dom.extent <= 9: |
| 93 | + # For small filters, unrolling the filter loops allows us to |
| 94 | + # vectorize over channels without reordering anything. |
| 95 | + schedule[op].unroll(filter_width) |
| 96 | + schedule[op].unroll(filter_height) |
| 97 | + else: |
| 98 | + # Reordering so that channels is the fastest moving axis allows |
| 99 | + # LLVM to vectorize across contiguous memory in the NHWC |
| 100 | + # ordering. |
| 101 | + schedule[op].reorder( |
| 102 | + n, height, width, filter_height, filter_width, channel_outer, channel_inner |
| 103 | + ) |
| 104 | + else: |
| 105 | + schedule[op].reorder(n, height, width, channel_outer, channel_inner) |
| 106 | + else: |
| 107 | + raise RuntimeError("Unsupported operator: %s" % op.tag) |
| 108 | + |
| 109 | + scheduled_ops.append(op) |
| 110 | + |
| 111 | + traverse(outs[0].op) |
| 112 | + return schedule |
0 commit comments