Skip to content

Commit b25bd0b

Browse files
committed
[Dlight] Enhance Decode-GEMV Rules
This PR enhances Decode-GEMV rule with the following changes: - Normalize the GEMV iter domain to S-R-C via transform-block-layout. This would help with further analysis and scheduling, in cases for example, when there was no spatial loop in the original reduction block. - Get rid of the ad hoc iter type analysis, including the logic calling into a TVM packed func `tir.schedule.GetLoopIterType` using `tvm._ffi.get_global_func`. - Split out the logic for two separate cases of scheduling, where the innermost dimension is spatial or reduction. - Introduces `suggest_threads_per_block` to guess the threads to be allocated each threadblock. This helps avoid the previous case where dlight allocates 256 threads for a workload whose degree of parallelism is only 128. - Misc improvements. This rest of the changes are split out to separate PRs that are already merged to main. - [x] Pass the hints to arithmetic analyzer that shape variables should be positive ones (apache#15210) - [x] Eliminate unnecessary block predicate generation - should be provable via affine analysis (apache#15193) - [x] Shrink local memory allocation if only one element `X[threadIdx.x]` is used (apache#15207)
1 parent 780d6e6 commit b25bd0b

File tree

8 files changed

+275
-142
lines changed

8 files changed

+275
-142
lines changed

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17+
[tool.isort]
18+
profile = "black"
19+
src_paths = ["python", "tests/python"]
20+
1721

1822
[tool.black]
1923
line-length = 100

python/tvm/dlight/base/analysis.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,12 @@
1717
"""Analysis on TIR blocks, loops and functions."""
1818
from typing import List, Optional, Union
1919

20-
from typing_extensions import Literal
21-
2220
from tvm import tir
2321
from tvm._ffi import get_global_func
2422
from tvm.target.target import Target
2523
from tvm.tir import Schedule
2624
from tvm.tir.schedule import BlockRV
25+
from typing_extensions import Literal
2726

2827

2928
class IterInfo:

python/tvm/dlight/gpu/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
GPU-generic schedule rules.
1919
For CUDA/ROCm/Vulkan/Metal-specific rules, use `tvm.dlight.cuda/rocm/vulkan/metal` instead
2020
"""
21-
from .fallback import Fallback
2221
from .decode_gemv import DecodeGEMV
23-
from .reduction import Reduction
22+
from .fallback import Fallback
2423
from .matmul import Matmul
24+
from .reduction import Reduction

python/tvm/dlight/gpu/decode_gemv.py

Lines changed: 149 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,20 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17-
# pylint: disable=missing-docstring
18-
"""A fallback schedule rule for GPU operators."""
19-
# pylint: disable=invalid-name
17+
"""A rule for DecodeGEMV."""
18+
from typing import List, Optional, Set, Tuple, Union
2019

21-
from typing import List, Optional, Union
22-
23-
from tvm import tir
24-
from tvm._ffi import get_global_func
25-
from tvm.arith import normalize_to_iter_sum
20+
from tvm import arith, tir
2621
from tvm.ir import structural_equal
2722
from tvm.target import Target
2823

29-
from ..base import ScheduleRule, normalize_prim_func, try_inline_contiguous_spatial
24+
from ..base import (
25+
BlockInfo,
26+
ScheduleRule,
27+
normalize_prim_func,
28+
try_inline_contiguous_spatial,
29+
)
30+
from . import utils
3031

3132

3233
def _get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]:
@@ -47,13 +48,13 @@ def _get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]:
4748

4849
def _detect_dominant_read(block: tir.Block) -> tir.PrimExpr:
4950
dominant_read, read_iters = None, None
50-
tir_vars = set()
51+
tir_vars: Set[tir.Var] = set()
5152
for buffer_region in block.reads:
5253
tir_vars.clear()
5354

54-
def _collect_tir_var(e):
55-
if isinstance(e, tir.Var):
56-
tir_vars.add(e)
55+
def _collect_tir_var(expr):
56+
if isinstance(expr, tir.Var):
57+
tir_vars.add(expr)
5758

5859
for expr in buffer_region.region:
5960
assert expr.extent == 1
@@ -68,27 +69,18 @@ def _collect_tir_var(e):
6869

6970

7071
class DecodeGEMV(ScheduleRule):
71-
def __init__(self) -> None:
72-
super().__init__()
73-
self.get_loop_iter_type = get_global_func("tir.schedule.GetLoopIterType")
72+
"""A rule for DecodeGEMV."""
7473

75-
def apply( # pylint: disable=too-many-locals
74+
def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return-statements
7675
self,
7776
func: tir.PrimFunc,
7877
target: Target,
7978
_: bool,
8079
) -> Union[None, tir.Schedule, List[tir.Schedule]]:
8180
if not isinstance(func, tir.PrimFunc):
8281
return None
83-
84-
if target.kind.name == "cuda":
85-
len_tx, len_ty = 16, 16
86-
else:
87-
len_tx, len_ty = 8, 8
88-
8982
sch = tir.Schedule(func)
9083
block_infos = try_inline_contiguous_spatial(sch, normalize_prim_func(sch))
91-
9284
if block_infos is None or len(block_infos) > 2:
9385
return None
9486

@@ -97,96 +89,145 @@ def apply( # pylint: disable=too-many-locals
9789
block_stmt = sch.get(block)
9890

9991
# Step 1. Check reduction block
100-
if not block_info.is_reduction():
92+
if (
93+
(not block_info.is_reduction())
94+
or len(block_stmt.writes) != 1
95+
or _get_reduction_expr(block_stmt) is None
96+
):
10197
return None
102-
if len(block_stmt.writes) != 1:
103-
return None
104-
if _get_reduction_expr(block_stmt) is None:
105-
return None
106-
107-
# Step 2. Sort out the spatial and reduction loops
108-
sorted_iter_access = normalize_to_iter_sum(
109-
_detect_dominant_read(block_stmt),
110-
input_iters={i.var: i.dom for i in block_stmt.iter_vars},
98+
# Step 2. Normalize the block, merge spatial and reduction iters
99+
is_inner_reduction, c_factor = self._normalize(
100+
sch,
101+
block_info,
102+
arith.normalize_to_iter_sum(
103+
_detect_dominant_read(block_stmt),
104+
input_iters={i.var: i.dom for i in block_stmt.iter_vars},
105+
),
111106
)
112-
if sorted_iter_access.base != 0:
113-
return None
114-
iter_to_info = {i.var: i for i in block_info.iters}
115-
s_loops, r_loops, c_loops = [], [], []
116-
for split in sorted_iter_access.args:
117-
block_var = split.source.source
118-
block_var_info = iter_to_info[block_var]
119-
loop_rv = block_var_info.loop_rv
120-
is_inner_reduction = block_var_info.kind == "R"
121-
if split.lower_factor > 1:
122-
c_loop_factor = split.lower_factor
123-
loop_rv, c_loop = sch.split(loop_rv, factors=[None, c_loop_factor])
124-
c_loops.append(c_loop)
125-
is_loop_c_reduction = is_inner_reduction
126-
if is_inner_reduction:
127-
r_loops.append(loop_rv)
128-
else:
129-
s_loops.append(loop_rv)
130-
131-
if len(c_loops) > 1:
132-
return None
133-
if len(s_loops) != len([_ for i in block_info.iters if i.kind == "S"]):
107+
if is_inner_reduction is None and c_factor is None:
134108
return None
135-
if len(s_loops) == 0 or len(r_loops) == 0:
136-
return None
137-
138-
sch.reorder(*s_loops, *r_loops, *c_loops)
139-
s = sch.fuse(*s_loops)
140-
r = sch.fuse(*r_loops)
141-
142-
if is_inner_reduction:
143-
_, tx = sch.split(r, factors=[None, len_tx * len_ty])
144-
rf = sch.rfactor(tx, 0)
145-
s, r, tx = sch.get_loops(rf)[:3]
146-
sch.reorder(s, tx, r)
147-
sch.reverse_compute_at(block, s, preserve_unit_loops=True)
148-
sch.bind(tx, "threadIdx.x")
149-
sch.bind(s, "blockIdx.x")
150-
else:
151-
sch.split(s, factors=[None, len_tx])
152-
_, ty = sch.split(r, factors=[None, len_ty])
153-
rf = sch.rfactor(ty, 0)
154-
bx, tx, r, ty = sch.get_loops(rf)[:4]
155-
sch.reorder(bx, tx, ty, r)
156-
sch.reverse_compute_at(block, bx, preserve_unit_loops=True)
157-
sch.bind(tx, "threadIdx.x")
158-
sch.bind(ty, "threadIdx.y")
159-
sch.bind(bx, "blockIdx.x")
160-
161-
s_loops, r_loops = [], []
162-
for loop_rv in sch.get_loops(block)[1:]:
163-
iter_type = self.get_loop_iter_type(sch, loop_rv)
164-
if iter_type == "S":
165-
s_loops.append(loop_rv)
166-
elif iter_type == "R":
167-
r_loops.append(loop_rv)
168-
else:
169-
raise RuntimeError("Unknown loop type " + str(iter_type))
170-
sch.reorder(*s_loops, *r_loops)
171-
s_ctr = sch.fuse(*s_loops)
172-
r_ctr = sch.fuse(*r_loops)
173-
174-
if c_loops and not is_loop_c_reduction:
175-
s_ctr, inner = sch.split(s_ctr, factors=[None, c_loop_factor])
176-
sch.reorder(s_ctr, r_ctr, inner)
177-
109+
# Step 3. Do the scheduling
178110
if is_inner_reduction:
179-
sch.bind(r_ctr, "threadIdx.x")
180-
sch.set_scope(rf, 0, "local")
181-
sch.decompose_reduction(rf, sch.get_loops(rf)[2])
111+
self._sch_inner_reduction(sch, target, block, c_factor)
182112
else:
183-
sch.bind(s_ctr, "threadIdx.x")
184-
sch.bind(r_ctr, "threadIdx.y")
185-
sch.set_scope(rf, 0, "local")
186-
sch.decompose_reduction(rf, sch.get_loops(rf)[3])
187-
113+
self._sch_inner_spatial(sch, target, block, c_factor)
114+
# Step 4. Schedule epilogue
188115
if len(block_infos) == 2:
189116
sch.set_scope(block, 0, "local")
190117
sch.reverse_compute_at(block_infos[1].block_rv, sch.get_loops(block)[0])
191-
192118
return sch
119+
120+
def _normalize(
121+
self,
122+
sch: tir.Schedule,
123+
block_info: BlockInfo,
124+
iter_sum: arith.IterSumExpr,
125+
) -> Tuple[Optional[bool], Optional[int]]:
126+
if iter_sum.base != 0:
127+
return None, None
128+
iter_to_info = {i.var: i for i in block_info.iters}
129+
s_dom, r_dom, c_dom, c_factor = None, None, None, None
130+
for split in iter_sum.args:
131+
var = split.source.source
132+
info = iter_to_info[var]
133+
dom = info.dom
134+
is_inner_reduction = info.kind == "R"
135+
if split.lower_factor > 1:
136+
if c_dom is not None:
137+
return None, None
138+
c_dom = tir.floormod(var, split.lower_factor)
139+
var = tir.floordiv(var, split.lower_factor)
140+
dom = tir.floordiv(dom, split.lower_factor)
141+
if not is_inner_reduction:
142+
c_factor = split.lower_factor
143+
if is_inner_reduction:
144+
if r_dom is None:
145+
r_dom = var
146+
else:
147+
r_dom = r_dom * dom + var
148+
else:
149+
if s_dom is None:
150+
s_dom = var
151+
else:
152+
s_dom = s_dom * dom + var
153+
154+
assert r_dom is not None
155+
if s_dom is None:
156+
s_dom = tir.const(1, r_dom.dtype)
157+
if c_dom is None:
158+
c_dom = tir.const(1, r_dom.dtype)
159+
sch.transform_block_layout(
160+
block_info.block_rv,
161+
tir.IndexMap(
162+
[i.var for i in block_info.iters],
163+
[s_dom, r_dom, c_dom],
164+
None,
165+
),
166+
)
167+
return is_inner_reduction, c_factor
168+
169+
def _sch_inner_reduction(
170+
self,
171+
sch: tir.Schedule,
172+
target: Target,
173+
block: tir.schedule.BlockRV,
174+
unroll_spatial_factor: Optional[int],
175+
):
176+
# pylint: disable=invalid-name
177+
_, r, _ = sch.get_loops(block)
178+
(len_tx,) = utils.suggest_threads_per_block( # pylint: disable=unbalanced-tuple-unpacking
179+
target, [sch.get(r)]
180+
)
181+
182+
_, tx = sch.split(r, factors=[None, len_tx])
183+
# Schedule the RF block
184+
rf = sch.rfactor(tx, 0)
185+
bx, r, tx, _ = sch.get_loops(rf)
186+
sch.reorder(bx, tx, r)
187+
sch.bind(bx, "blockIdx.x")
188+
sch.bind(tx, "threadIdx.x")
189+
sch.set_scope(rf, 0, "local")
190+
sch.decompose_reduction(rf, r)
191+
# Schedule the write back block
192+
sch.reverse_compute_at(block, bx, preserve_unit_loops=True)
193+
_, tx, *s = sch.get_loops(block)
194+
s = sch.fuse(*s)
195+
sch.reorder(s, tx)
196+
if unroll_spatial_factor:
197+
s, inner = sch.split(s, factors=[None, unroll_spatial_factor])
198+
sch.reorder(s, tx, inner)
199+
sch.bind(tx, "threadIdx.x")
200+
# pylint: enable=invalid-name
201+
202+
def _sch_inner_spatial(
203+
self,
204+
sch: tir.Schedule,
205+
_: Target,
206+
block: tir.schedule.BlockRV,
207+
unroll_spatial_factor: Optional[int],
208+
):
209+
# pylint: disable=invalid-name
210+
s, r, _ = sch.get_loops(block)
211+
len_tx, len_ty = 16, 16
212+
_, _ = sch.split(s, factors=[None, len_tx])
213+
_, ty = sch.split(r, factors=[None, len_ty])
214+
# Schedule the RF block
215+
rf = sch.rfactor(ty, 0)
216+
bx, tx, r, ty, _ = sch.get_loops(rf)
217+
sch.reorder(bx, tx, ty, r)
218+
sch.bind(tx, "threadIdx.x")
219+
sch.bind(ty, "threadIdx.y")
220+
sch.bind(bx, "blockIdx.x")
221+
sch.set_scope(rf, 0, "local")
222+
sch.decompose_reduction(rf, r)
223+
# Schedule the write back block
224+
sch.reverse_compute_at(block, bx, preserve_unit_loops=True)
225+
_, r, *s = sch.get_loops(block)
226+
s = sch.fuse(*s)
227+
sch.reorder(s, r)
228+
if unroll_spatial_factor:
229+
s, inner = sch.split(s, factors=[None, unroll_spatial_factor])
230+
sch.reorder(s, r, inner)
231+
sch.bind(s, "threadIdx.x")
232+
sch.bind(r, "threadIdx.y")
233+
# pylint: enable=invalid-name

python/tvm/dlight/gpu/fallback.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
from tvm import tir
2222
from tvm.target import Target
2323

24-
from ..base import ScheduleRule, analysis, normalize_prim_func, try_inline
24+
from ..base import ScheduleRule, normalize_prim_func, try_inline
25+
from . import utils
2526

2627

2728
class Fallback(ScheduleRule):
@@ -36,7 +37,7 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
3637
target: Target,
3738
_: bool,
3839
) -> tir.Schedule:
39-
max_threads_per_block = analysis.get_max_threads_per_block(target)
40+
max_threads_per_block = utils.max_threads_per_block(target)
4041

4142
sch = tir.Schedule(func)
4243
block_infos = try_inline(sch, normalize_prim_func(sch))

python/tvm/dlight/gpu/matmul.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,14 @@
1616
# under the License.
1717
# pylint: disable=missing-docstring, invalid-name
1818
"""A GEMM schedule rule for GPU operators."""
19-
from enum import Enum
2019
from dataclasses import dataclass
20+
from enum import Enum
2121
from typing import Dict, List, Optional, Set, Tuple
2222

2323
from tvm import tir
2424
from tvm.ir import Range
2525
from tvm.target import Target
26-
from tvm.tir import PrimExpr, Var, IterVar
26+
from tvm.tir import IterVar, PrimExpr, Var
2727
from tvm.tir.analysis import undefined_vars
2828
from tvm.tir.schedule.schedule import BlockRV
2929

0 commit comments

Comments
 (0)