|
17 | 17 | # pylint: disable=invalid-name, unused-variable |
18 | 18 | """Schedule for pooling operators""" |
19 | 19 |
|
20 | | -import logging |
21 | | -from tvm import te |
22 | | -from tvm.target import Target |
23 | | - |
24 | | -from .. import tag |
25 | | -from .. import generic |
| 20 | +from .mprofile.dsp.pool import pool_dsp_schedule |
26 | 21 |
|
27 | 22 |
|
28 | 23 | def schedule_pool(outs, layout): |
29 | | - """Create schedule for avgpool/maxpool""" |
30 | | - |
31 | | - if layout != "NHWC": |
32 | | - logger = logging.getLogger("topi") |
33 | | - logger.warning( |
34 | | - """We currently only support NHWC target specific pools on arm_cpu, |
35 | | - falling back on generic pool scheduling""" |
36 | | - ) |
37 | | - return generic.schedule_pool(outs, layout) |
38 | | - |
39 | | - return schedule_pool_2d(outs) |
40 | | - |
41 | | - |
42 | | -def schedule_pool_2d(outs): |
43 | | - """Create arm_cpu specific 2D schedule for avgpool/maxpool""" |
44 | | - |
45 | | - outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs |
46 | | - schedule_ops = [x.op for x in outs] |
47 | | - schedule = te.create_schedule(schedule_ops) |
48 | | - scheduled_ops = [] |
49 | | - |
50 | | - def traverse(op): |
51 | | - # Recursively inline any injective operation that isn't the pooling |
52 | | - # operation or hasn't already been scheduled. |
53 | | - if tag.is_injective(op.tag): |
54 | | - if op not in schedule.outputs: |
55 | | - schedule[op].compute_inline() |
56 | | - for tensor in op.input_tensors: |
57 | | - if isinstance(tensor.op, te.tensor.ComputeOp) and tensor.op not in scheduled_ops: |
58 | | - traverse(tensor.op) |
59 | | - # schedule the actual pooling operation |
60 | | - elif op.tag.startswith("pool"): |
61 | | - n, height, width, channel = schedule[op].op.axis |
62 | | - # Average pool consists of two parts; a sum then a division. |
63 | | - # We can schedule the division loop to parallelize across height and |
64 | | - # vectorize across width. |
65 | | - enable_explicit_vectorization = not Target.current(allow_none=False).features.has_sve |
66 | | - if op != outs[0].op: |
67 | | - output = outs[0] |
68 | | - output_fused = schedule[output].fuse(output.op.axis[1], output.op.axis[2]) |
69 | | - schedule[output].parallel(output_fused) |
70 | | - vectorization_factor = ( |
71 | | - 8 if enable_explicit_vectorization else output.op.axis[3].dom.extent |
72 | | - ) |
73 | | - _, inner = schedule[output].split(output.op.axis[3], vectorization_factor) |
74 | | - schedule[output].vectorize(inner) |
75 | | - |
76 | | - padded_input = op.input_tensors[0] |
77 | | - if isinstance(padded_input.op, te.tensor.ComputeOp): |
78 | | - schedule[padded_input].compute_inline() |
79 | | - |
80 | | - # For targets without SVE try explicitly vectorizing the channel |
81 | | - # loop, For SVE targets leave the loop in place for LLVM to convert |
82 | | - # into a scalable vector loop. |
83 | | - vectorization_factor = 8 if enable_explicit_vectorization else channel.dom.extent |
84 | | - channel_outer, channel_inner = schedule[op].split(channel, vectorization_factor) |
85 | | - schedule[op].vectorize(channel_inner) |
86 | | - schedule[op].parallel(height) |
87 | | - if len(schedule[op].op.reduce_axis) > 0: |
88 | | - filter_height, filter_width = schedule[op].op.reduce_axis |
89 | | - # We consider any filter of area < 10 to be small enough to |
90 | | - # unroll; 3x3 filters have shown better performance when |
91 | | - # unrolled. |
92 | | - if filter_height.dom.extent * filter_width.dom.extent <= 9: |
93 | | - # For small filters, unrolling the filter loops allows us to |
94 | | - # vectorize over channels without reordering anything. |
95 | | - schedule[op].unroll(filter_width) |
96 | | - schedule[op].unroll(filter_height) |
97 | | - else: |
98 | | - # Reordering so that channels is the fastest moving axis allows |
99 | | - # LLVM to vectorize across contiguous memory in the NHWC |
100 | | - # ordering. |
101 | | - schedule[op].reorder( |
102 | | - n, height, width, filter_height, filter_width, channel_outer, channel_inner |
103 | | - ) |
104 | | - else: |
105 | | - schedule[op].reorder(n, height, width, channel_outer, channel_inner) |
106 | | - else: |
107 | | - raise RuntimeError("Unsupported operator: %s" % op.tag) |
108 | | - |
109 | | - scheduled_ops.append(op) |
110 | | - |
111 | | - traverse(outs[0].op) |
112 | | - return schedule |
| 24 | + """Create schedule for avgpool/maxpool with dsp""" |
| 25 | + return pool_dsp_schedule(outs, layout) |
0 commit comments