1414# KIND, either express or implied. See the License for the
1515# specific language governing permissions and limitations
1616# under the License.
17- # pylint: disable=missing-docstring
18- """A fallback schedule rule for GPU operators."""
19- # pylint: disable=invalid-name
17+ """A rule for DecodeGEMV."""
18+ from typing import List , Optional , Set , Tuple , Union
2019
21- from typing import List , Optional , Union
22-
23- from tvm import tir
24- from tvm ._ffi import get_global_func
25- from tvm .arith import normalize_to_iter_sum
20+ from tvm import arith , tir
2621from tvm .ir import structural_equal
2722from tvm .target import Target
2823
29- from ..base import ScheduleRule , normalize_prim_func , try_inline_contiguous_spatial
24+ from ..base import (
25+ BlockInfo ,
26+ ScheduleRule ,
27+ normalize_prim_func ,
28+ try_inline_contiguous_spatial ,
29+ )
30+ from . import utils
3031
3132
3233def _get_reduction_expr (block : tir .Block ) -> Optional [tir .PrimExpr ]:
@@ -47,13 +48,13 @@ def _get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]:
4748
4849def _detect_dominant_read (block : tir .Block ) -> tir .PrimExpr :
4950 dominant_read , read_iters = None , None
50- tir_vars = set ()
51+ tir_vars : Set [ tir . Var ] = set ()
5152 for buffer_region in block .reads :
5253 tir_vars .clear ()
5354
54- def _collect_tir_var (e ):
55- if isinstance (e , tir .Var ):
56- tir_vars .add (e )
55+ def _collect_tir_var (expr ):
56+ if isinstance (expr , tir .Var ):
57+ tir_vars .add (expr )
5758
5859 for expr in buffer_region .region :
5960 assert expr .extent == 1
@@ -68,27 +69,18 @@ def _collect_tir_var(e):
6869
6970
7071class DecodeGEMV (ScheduleRule ):
71- def __init__ (self ) -> None :
72- super ().__init__ ()
73- self .get_loop_iter_type = get_global_func ("tir.schedule.GetLoopIterType" )
72+ """A rule for DecodeGEMV."""
7473
75- def apply ( # pylint: disable=too-many-locals
74+ def apply ( # pylint: disable=too-many-locals,too-many-branches,too-many-return-statements
7675 self ,
7776 func : tir .PrimFunc ,
7877 target : Target ,
7978 _ : bool ,
8079 ) -> Union [None , tir .Schedule , List [tir .Schedule ]]:
8180 if not isinstance (func , tir .PrimFunc ):
8281 return None
83-
84- if target .kind .name == "cuda" :
85- len_tx , len_ty = 16 , 16
86- else :
87- len_tx , len_ty = 8 , 8
88-
8982 sch = tir .Schedule (func )
9083 block_infos = try_inline_contiguous_spatial (sch , normalize_prim_func (sch ))
91-
9284 if block_infos is None or len (block_infos ) > 2 :
9385 return None
9486
@@ -97,96 +89,145 @@ def apply( # pylint: disable=too-many-locals
9789 block_stmt = sch .get (block )
9890
9991 # Step 1. Check reduction block
100- if not block_info .is_reduction ():
92+ if (
93+ (not block_info .is_reduction ())
94+ or len (block_stmt .writes ) != 1
95+ or _get_reduction_expr (block_stmt ) is None
96+ ):
10197 return None
102- if len (block_stmt .writes ) != 1 :
103- return None
104- if _get_reduction_expr (block_stmt ) is None :
105- return None
106-
107- # Step 2. Sort out the spatial and reduction loops
108- sorted_iter_access = normalize_to_iter_sum (
109- _detect_dominant_read (block_stmt ),
110- input_iters = {i .var : i .dom for i in block_stmt .iter_vars },
98+ # Step 2. Normalize the block, merge spatial and reduction iters
99+ is_inner_reduction , c_factor = self ._normalize (
100+ sch ,
101+ block_info ,
102+ arith .normalize_to_iter_sum (
103+ _detect_dominant_read (block_stmt ),
104+ input_iters = {i .var : i .dom for i in block_stmt .iter_vars },
105+ ),
111106 )
112- if sorted_iter_access .base != 0 :
113- return None
114- iter_to_info = {i .var : i for i in block_info .iters }
115- s_loops , r_loops , c_loops = [], [], []
116- for split in sorted_iter_access .args :
117- block_var = split .source .source
118- block_var_info = iter_to_info [block_var ]
119- loop_rv = block_var_info .loop_rv
120- is_inner_reduction = block_var_info .kind == "R"
121- if split .lower_factor > 1 :
122- c_loop_factor = split .lower_factor
123- loop_rv , c_loop = sch .split (loop_rv , factors = [None , c_loop_factor ])
124- c_loops .append (c_loop )
125- is_loop_c_reduction = is_inner_reduction
126- if is_inner_reduction :
127- r_loops .append (loop_rv )
128- else :
129- s_loops .append (loop_rv )
130-
131- if len (c_loops ) > 1 :
132- return None
133- if len (s_loops ) != len ([_ for i in block_info .iters if i .kind == "S" ]):
107+ if is_inner_reduction is None and c_factor is None :
134108 return None
135- if len (s_loops ) == 0 or len (r_loops ) == 0 :
136- return None
137-
138- sch .reorder (* s_loops , * r_loops , * c_loops )
139- s = sch .fuse (* s_loops )
140- r = sch .fuse (* r_loops )
141-
142- if is_inner_reduction :
143- _ , tx = sch .split (r , factors = [None , len_tx * len_ty ])
144- rf = sch .rfactor (tx , 0 )
145- s , r , tx = sch .get_loops (rf )[:3 ]
146- sch .reorder (s , tx , r )
147- sch .reverse_compute_at (block , s , preserve_unit_loops = True )
148- sch .bind (tx , "threadIdx.x" )
149- sch .bind (s , "blockIdx.x" )
150- else :
151- sch .split (s , factors = [None , len_tx ])
152- _ , ty = sch .split (r , factors = [None , len_ty ])
153- rf = sch .rfactor (ty , 0 )
154- bx , tx , r , ty = sch .get_loops (rf )[:4 ]
155- sch .reorder (bx , tx , ty , r )
156- sch .reverse_compute_at (block , bx , preserve_unit_loops = True )
157- sch .bind (tx , "threadIdx.x" )
158- sch .bind (ty , "threadIdx.y" )
159- sch .bind (bx , "blockIdx.x" )
160-
161- s_loops , r_loops = [], []
162- for loop_rv in sch .get_loops (block )[1 :]:
163- iter_type = self .get_loop_iter_type (sch , loop_rv )
164- if iter_type == "S" :
165- s_loops .append (loop_rv )
166- elif iter_type == "R" :
167- r_loops .append (loop_rv )
168- else :
169- raise RuntimeError ("Unknown loop type " + str (iter_type ))
170- sch .reorder (* s_loops , * r_loops )
171- s_ctr = sch .fuse (* s_loops )
172- r_ctr = sch .fuse (* r_loops )
173-
174- if c_loops and not is_loop_c_reduction :
175- s_ctr , inner = sch .split (s_ctr , factors = [None , c_loop_factor ])
176- sch .reorder (s_ctr , r_ctr , inner )
177-
109+ # Step 3. Do the scheduling
178110 if is_inner_reduction :
179- sch .bind (r_ctr , "threadIdx.x" )
180- sch .set_scope (rf , 0 , "local" )
181- sch .decompose_reduction (rf , sch .get_loops (rf )[2 ])
111+ self ._sch_inner_reduction (sch , target , block , c_factor )
182112 else :
183- sch .bind (s_ctr , "threadIdx.x" )
184- sch .bind (r_ctr , "threadIdx.y" )
185- sch .set_scope (rf , 0 , "local" )
186- sch .decompose_reduction (rf , sch .get_loops (rf )[3 ])
187-
113+ self ._sch_inner_spatial (sch , target , block , c_factor )
114+ # Step 4. Schedule epilogue
188115 if len (block_infos ) == 2 :
189116 sch .set_scope (block , 0 , "local" )
190117 sch .reverse_compute_at (block_infos [1 ].block_rv , sch .get_loops (block )[0 ])
191-
192118 return sch
119+
120+ def _normalize (
121+ self ,
122+ sch : tir .Schedule ,
123+ block_info : BlockInfo ,
124+ iter_sum : arith .IterSumExpr ,
125+ ) -> Tuple [Optional [bool ], Optional [int ]]:
126+ if iter_sum .base != 0 :
127+ return None , None
128+ iter_to_info = {i .var : i for i in block_info .iters }
129+ s_dom , r_dom , c_dom , c_factor = None , None , None , None
130+ for split in iter_sum .args :
131+ var = split .source .source
132+ info = iter_to_info [var ]
133+ dom = info .dom
134+ is_inner_reduction = info .kind == "R"
135+ if split .lower_factor > 1 :
136+ if c_dom is not None :
137+ return None , None
138+ c_dom = tir .floormod (var , split .lower_factor )
139+ var = tir .floordiv (var , split .lower_factor )
140+ dom = tir .floordiv (dom , split .lower_factor )
141+ if not is_inner_reduction :
142+ c_factor = split .lower_factor
143+ if is_inner_reduction :
144+ if r_dom is None :
145+ r_dom = var
146+ else :
147+ r_dom = r_dom * dom + var
148+ else :
149+ if s_dom is None :
150+ s_dom = var
151+ else :
152+ s_dom = s_dom * dom + var
153+
154+ assert r_dom is not None
155+ if s_dom is None :
156+ s_dom = tir .const (1 , r_dom .dtype )
157+ if c_dom is None :
158+ c_dom = tir .const (1 , r_dom .dtype )
159+ sch .transform_block_layout (
160+ block_info .block_rv ,
161+ tir .IndexMap (
162+ [i .var for i in block_info .iters ],
163+ [s_dom , r_dom , c_dom ],
164+ None ,
165+ ),
166+ )
167+ return is_inner_reduction , c_factor
168+
169+ def _sch_inner_reduction (
170+ self ,
171+ sch : tir .Schedule ,
172+ target : Target ,
173+ block : tir .schedule .BlockRV ,
174+ unroll_spatial_factor : Optional [int ],
175+ ):
176+ # pylint: disable=invalid-name
177+ _ , r , _ = sch .get_loops (block )
178+ (len_tx ,) = utils .suggest_threads_per_block ( # pylint: disable=unbalanced-tuple-unpacking
179+ target , [sch .get (r )]
180+ )
181+
182+ _ , tx = sch .split (r , factors = [None , len_tx ])
183+ # Schedule the RF block
184+ rf = sch .rfactor (tx , 0 )
185+ bx , r , tx , _ = sch .get_loops (rf )
186+ sch .reorder (bx , tx , r )
187+ sch .bind (bx , "blockIdx.x" )
188+ sch .bind (tx , "threadIdx.x" )
189+ sch .set_scope (rf , 0 , "local" )
190+ sch .decompose_reduction (rf , r )
191+ # Schedule the write back block
192+ sch .reverse_compute_at (block , bx , preserve_unit_loops = True )
193+ _ , tx , * s = sch .get_loops (block )
194+ s = sch .fuse (* s )
195+ sch .reorder (s , tx )
196+ if unroll_spatial_factor :
197+ s , inner = sch .split (s , factors = [None , unroll_spatial_factor ])
198+ sch .reorder (s , tx , inner )
199+ sch .bind (tx , "threadIdx.x" )
200+ # pylint: enable=invalid-name
201+
202+ def _sch_inner_spatial (
203+ self ,
204+ sch : tir .Schedule ,
205+ _ : Target ,
206+ block : tir .schedule .BlockRV ,
207+ unroll_spatial_factor : Optional [int ],
208+ ):
209+ # pylint: disable=invalid-name
210+ s , r , _ = sch .get_loops (block )
211+ len_tx , len_ty = 16 , 16
212+ _ , _ = sch .split (s , factors = [None , len_tx ])
213+ _ , ty = sch .split (r , factors = [None , len_ty ])
214+ # Schedule the RF block
215+ rf = sch .rfactor (ty , 0 )
216+ bx , tx , r , ty , _ = sch .get_loops (rf )
217+ sch .reorder (bx , tx , ty , r )
218+ sch .bind (tx , "threadIdx.x" )
219+ sch .bind (ty , "threadIdx.y" )
220+ sch .bind (bx , "blockIdx.x" )
221+ sch .set_scope (rf , 0 , "local" )
222+ sch .decompose_reduction (rf , r )
223+ # Schedule the write back block
224+ sch .reverse_compute_at (block , bx , preserve_unit_loops = True )
225+ _ , r , * s = sch .get_loops (block )
226+ s = sch .fuse (* s )
227+ sch .reorder (s , r )
228+ if unroll_spatial_factor :
229+ s , inner = sch .split (s , factors = [None , unroll_spatial_factor ])
230+ sch .reorder (s , r , inner )
231+ sch .bind (s , "threadIdx.x" )
232+ sch .bind (r , "threadIdx.y" )
233+ # pylint: enable=invalid-name
0 commit comments