Skip to content

Commit 06fb02e

Browse files
authored
[LLVM][METASCHEDULE] Add RISCV V-extension v1.0 kernels to metaschedule (#18243)
- Enables high performance kernels covering majority of usual ML datatype inputs - It is currently compliant with RVV specs version v1.0 (does not work with older v0.7.1) - TIR kernels implemented here are using recently added VLA extension support
1 parent a819115 commit 06fb02e

File tree

8 files changed

+325
-1
lines changed

8 files changed

+325
-1
lines changed

include/tvm/meta_schedule/postproc.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,8 @@ class Postproc : public runtime::ObjectRef {
166166
TVM_DLL static Array<Postproc, void> DefaultLLVM();
167167
/*! \brief Create default postprocessors for x86 (AVX512 and VNNI) */
168168
TVM_DLL static Array<Postproc, void> DefaultCPUTensorization();
169+
/*! \brief Create default postprocessors for RISCV */
170+
TVM_DLL static Array<Postproc, void> DefaultRISCV();
169171
/*! \brief Create default postprocessors for CUDA */
170172
TVM_DLL static Array<Postproc, void> DefaultCUDA();
171173
/*! \brief Create default postprocessors for CUDA with TensorCore */

include/tvm/meta_schedule/schedule_rule.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,8 @@ class ScheduleRule : public runtime::ObjectRef {
301301
TVM_DLL static Array<ScheduleRule, void> DefaultHexagon();
302302
/*! \brief Create default schedule rules for ARM CPU (NEON and DOTPROD) */
303303
TVM_DLL static Array<ScheduleRule, void> DefaultARM(const String& type);
304+
/*! \brief Create default schedule rules for RISCV CPU (RVV) */
305+
TVM_DLL static Array<ScheduleRule, void> DefaultRISCV(int vlen);
304306

305307
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ScheduleRule, ObjectRef, ScheduleRuleNode);
306308
};

python/tvm/target/target.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -637,6 +637,14 @@ def riscv_cpu(model="sifive-u54", options=None):
637637
"-mabi=lp64d",
638638
# cc: riscv64-unknown-linux-gnu-g++ -march=rv64gc -mabi=lp64d -mcpu=sifive-u74
639639
],
640+
"licheepi3a": [
641+
"-num-cores=8",
642+
"-mtriple=riscv64-unknown-linux-gnu",
643+
"-mcpu=spacemit-x60",
644+
"-mfloat-abi=hard",
645+
"-mabi=lp64d",
646+
# cc: riscv64-unknown-linux-gnu-g++ -march=rv64gcv -mabi=lp64d -mcpu=spacemit-x60
647+
],
640648
}
641649
pre_defined_opt = trans_table.get(model, ["-model=%s" % model])
642650

python/tvm/tir/tensor_intrin/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,4 @@
2020
from . import cuda
2121

2222
if enabled("llvm"):
23-
from . import arm_cpu, x86, rocm, hexagon
23+
from . import arm_cpu, x86, rocm, hexagon, riscv_cpu
Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=invalid-name,line-too-long
18+
"""Intrinsics for RISCV tensorization"""
19+
20+
import logging
21+
from tvm.ffi import register_func
22+
from tvm.runtime import DataType
23+
from tvm.script import tir as T
24+
from tvm.target.codegen import llvm_get_vector_width, target_has_features, Target
25+
from .. import TensorIntrin
26+
27+
logger = logging.getLogger(__name__)
28+
29+
30+
def get_max_elems(vlen: int, lmul: int, sew: int) -> int:
31+
"""Returns number of elements of a given data type (SEW)
32+
that fits multiple (LMUL) of the vector registers (VLEN).
33+
34+
Args:
35+
vlen (int): VLEN vector length in bits
36+
lmul (int): LMUL vector lenght multiplier
37+
sew (int): SEW standard (single) element width
38+
39+
Returns:
40+
int: Number of elements
41+
"""
42+
return (vlen // sew) * lmul
43+
44+
45+
def rvv_vec_dot_product_kernels(
46+
n_elems: int,
47+
n_lanes: int,
48+
data_dtype: str,
49+
weight_dtype: str,
50+
out_dtype: str,
51+
lmul: int,
52+
):
53+
"""Dot product of vector and matrix rows using RISC-V vector instructions.
54+
55+
These kernels takes two arrays A[ELEMS] and B[ELEMS][MACS] and computes
56+
dot product of A[ELEMS] with each row of B[LANES], accumulating results
57+
with C[LANES].
58+
59+
The pseudo code is as follows:
60+
.. code-block:: c
61+
void vec_dot_prod(A[ELEMS], B[LANES][ELEMS], C[LANES]){
62+
for (j = 0; j < LANES; j++) {
63+
for (k = 0; k < ELEMS; k++) {
64+
C[j] += A[k] * B[j][k]
65+
}
66+
}
67+
}
68+
"""
69+
70+
@T.prim_func
71+
def rvv_vec_dot_prod_desc(
72+
A: T.Buffer((n_elems,), data_dtype, offset_factor=1),
73+
B: T.Buffer((n_lanes, n_elems), weight_dtype, offset_factor=1),
74+
C: T.Buffer((n_lanes,), out_dtype, offset_factor=1),
75+
) -> None:
76+
with T.block("root"):
77+
T.reads(C[0:n_lanes], A[0:n_elems], B[0:n_lanes, 0:n_elems])
78+
T.writes(C[0:n_lanes])
79+
for j in T.serial(0, n_lanes):
80+
for k in T.serial(0, n_elems):
81+
with T.block("update"):
82+
vj, vk = T.axis.remap("SR", [j, k])
83+
C[vj] = C[vj] + T.cast(A[vk], out_dtype) * T.cast(B[vj, vk], out_dtype)
84+
85+
# LLVM only supports ELEN=32 or ELEN=64
86+
# https://llvm.org/docs//RISCV/RISCVVectorExtension.html
87+
d_dtype_lanes = (64 // DataType(data_dtype).bits) * lmul
88+
w_dtype_lanes = (64 // DataType(weight_dtype).bits) * lmul
89+
# reduction lanes narrows
90+
o_dtype_lanes = (64 // DataType(out_dtype).bits) * lmul // n_lanes
91+
# data type widening case
92+
o_dtype_lanes = max(o_dtype_lanes, 2)
93+
94+
mask_args = () if data_dtype[0] in ("i", "u") else (T.uint64(7),)
95+
96+
wide_dtype = out_dtype
97+
if DataType(out_dtype).bits > DataType(data_dtype).bits:
98+
wide_dtype = "".join(c for c in data_dtype if not c.isdigit())
99+
wide_dtype += str(DataType(data_dtype).bits * 2)
100+
101+
# fmt: off
102+
@T.prim_func
103+
def rvv_vec_dot_prod_impl(
104+
A: T.Buffer((n_elems,), data_dtype, offset_factor=1),
105+
B: T.Buffer((n_lanes, n_elems), weight_dtype, offset_factor=1),
106+
C: T.Buffer((n_lanes,), out_dtype, offset_factor=1),
107+
) -> None:
108+
with T.block("root"):
109+
T.reads(C[0:n_lanes], A[0:n_elems], B[0:n_lanes, 0:n_elems])
110+
T.writes(C[0:n_lanes])
111+
112+
vec_A = T.call_llvm_intrin(
113+
f"{data_dtype}xvscalex{d_dtype_lanes}",
114+
"llvm.riscv.vle",
115+
T.broadcast(T.Cast(data_dtype, 0), T.vscale() * d_dtype_lanes),
116+
T.tvm_access_ptr(T.type_annotation(data_dtype), A.data, 0, n_elems, 1),
117+
T.int64(n_elems))
118+
119+
for i in range(n_lanes):
120+
with T.block("update"):
121+
T.reads(B[i, 0:n_elems])
122+
T.writes(C[i])
123+
124+
vec_B_row = T.call_llvm_intrin(
125+
f"{weight_dtype}xvscalex{w_dtype_lanes}",
126+
"llvm.riscv.vle",
127+
T.broadcast(T.Cast(data_dtype, 0), T.vscale() * w_dtype_lanes),
128+
T.tvm_access_ptr(T.type_annotation(weight_dtype), B.data, i * n_elems, n_elems, 1),
129+
T.int64(n_elems))
130+
131+
product = T.call_llvm_intrin(
132+
f"{wide_dtype}xvscalex{w_dtype_lanes}",
133+
"llvm.riscv.vfmul" if out_dtype[0] == "f" else \
134+
"llvm.riscv.vwmulsu" if (data_dtype[0] != weight_dtype[0]) else \
135+
"llvm.riscv.vwmul",
136+
T.broadcast(T.Cast(wide_dtype, 0), T.vscale() * w_dtype_lanes),
137+
vec_B_row,
138+
vec_A,
139+
*mask_args,
140+
T.uint64(n_elems))
141+
142+
ini_acc = T.call_llvm_intrin(
143+
f"{out_dtype}xvscalex{o_dtype_lanes}",
144+
"llvm.riscv.vle",
145+
T.broadcast(T.Cast(out_dtype, 0), T.vscale() * o_dtype_lanes),
146+
T.tvm_access_ptr(T.type_annotation(out_dtype), C.data, i, 1, 1),
147+
T.int64(1))
148+
149+
red_sum = T.call_llvm_intrin(
150+
f"{out_dtype}xvscalex{o_dtype_lanes}",
151+
"llvm.riscv.vfredusum" if out_dtype[0] == "f" else \
152+
"llvm.riscv.vwredsum",
153+
T.broadcast(T.Cast(out_dtype, 0), T.vscale() * o_dtype_lanes),
154+
product,
155+
ini_acc,
156+
*mask_args,
157+
T.uint64(n_elems))
158+
159+
C[i] = T.call_llvm_intrin(
160+
out_dtype,
161+
"llvm.riscv.vfmv.f.s" if out_dtype[0] == "f" else \
162+
"llvm.riscv.vmv.x.s",
163+
red_sum)
164+
# fmt: on
165+
return rvv_vec_dot_prod_desc, rvv_vec_dot_prod_impl
166+
167+
168+
@register_func("tir.tensor_intrin.register_rvv_isa_intrinsics")
169+
def register_rvv_isa_intrinsics(target: Target, inventory_only=False) -> dict():
170+
"""Register RISCV V (vector) intrinsics
171+
[x] Implementation follows version 1.0 vector specifications:
172+
https://github.com/riscvarchive/riscv-v-spec/releases/tag/v1.0
173+
174+
Args:
175+
target (Target): TVM target
176+
inventory_only (bool): No registration inventory only
177+
178+
Returns:
179+
dict(): A catalog with registered kernel names and properties
180+
"""
181+
if not target_has_features("v", target):
182+
raise RuntimeError("Current target does not support `v` extension.")
183+
184+
vlen = llvm_get_vector_width(target)
185+
# get maximum reduction lanes (without grouping)
186+
n_lanes = get_max_elems(vlen, lmul=1, sew=32)
187+
188+
kernels_inventory = {}
189+
190+
data_dtype = ["uint8", "int8", "float16", "float32"]
191+
weight_dtype = ["int8", "int8", "float16", "float32"]
192+
output_dtype = ["int32", "int32", "float16", "float32"]
193+
194+
for d_dtype, w_dtype, o_dtype in zip(data_dtype, weight_dtype, output_dtype):
195+
# max elements to grouped registers
196+
max_elems = get_max_elems(vlen, lmul=8, sew=DataType(d_dtype).bits)
197+
# data widening halves available vector registers
198+
if DataType(o_dtype).bits > DataType(d_dtype).bits:
199+
max_elems //= 2
200+
# compute optimal LMUL for full load
201+
lmul = max_elems // (vlen // DataType(d_dtype).bits)
202+
203+
n_elems = max_elems
204+
while n_elems >= 4:
205+
206+
dt = DataType(d_dtype)
207+
wt = DataType(w_dtype)
208+
ot = DataType(o_dtype)
209+
kernel_name = "rvv_dot"
210+
kernel_name += f"_{n_elems}{dt[0]}{dt.bits}"
211+
kernel_name += f"_{n_lanes}x{n_elems}{wt[0]}{wt.bits}"
212+
kernel_name += f"_{n_lanes}{ot[0]}{ot.bits}"
213+
kernels_inventory[kernel_name] = n_elems
214+
215+
if not inventory_only:
216+
logger.debug(f"Registering kernel {kernel_name}")
217+
desc, impl = rvv_vec_dot_product_kernels(
218+
n_elems, n_lanes, d_dtype, w_dtype, o_dtype, lmul
219+
)
220+
TensorIntrin.register(kernel_name, desc, impl, override=True)
221+
222+
n_elems //= 2
223+
224+
return kernels_inventory
225+
226+
227+
def register_riscv_intrinsics(target: Target):
228+
"""Register RISCV intrinsics
229+
230+
Args:
231+
target (Target): TVM target
232+
"""
233+
234+
# RISCV `v` 1.0 extension templates
235+
_ = register_rvv_isa_intrinsics(target)
236+
logger.debug("Finished registering riscv intrinsics.")

src/meta_schedule/postproc/postproc.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,14 @@ Array<Postproc> Postproc::DefaultCPUTensorization() {
6969
};
7070
}
7171

72+
Array<Postproc> Postproc::DefaultRISCV() {
73+
return Array<Postproc>{
74+
Postproc::DisallowDynamicLoop(), Postproc::RewriteParallelVectorizeUnroll(),
75+
Postproc::RewriteReductionBlock(), Postproc::RewriteTensorize(/*vectorize_init_loop=*/false),
76+
Postproc::RewriteLayout(),
77+
};
78+
}
79+
7280
Array<Postproc> Postproc::DefaultCUDA() {
7381
return Array<Postproc>{
7482
Postproc::DisallowDynamicLoop(),

src/meta_schedule/schedule_rule/schedule_rule.cc

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
* under the License.
1818
*/
1919
#include <tvm/ffi/reflection/registry.h>
20+
#include <tvm/runtime/data_type.h>
2021

2122
#include "../utils.h"
2223

@@ -304,6 +305,62 @@ Array<ScheduleRule> ScheduleRule::DefaultHexagon() {
304305
};
305306
}
306307

308+
Array<ScheduleRule> ScheduleRule::DefaultRISCV(const int vlen) {
309+
Array<ScheduleRule> rules;
310+
rules.push_back(ScheduleRule::ApplyCustomRule());
311+
rules.push_back(ScheduleRule::InlineConstantScalars());
312+
rules.push_back(ScheduleRule::AutoInline(
313+
/*into_producer=*/false,
314+
/*into_consumer=*/true,
315+
/*inline_const_tensor=*/true,
316+
/*disallow_if_then_else=*/true,
317+
/*require_injective=*/true,
318+
/*require_ordered=*/true,
319+
/*disallow_op=*/Array<String>{"tir.exp"}));
320+
rules.push_back(ScheduleRule::AddRFactor(
321+
/*max_jobs_per_core=*/16,
322+
/*max_innermost_factor=*/Integer(64)));
323+
auto current_target = tvm::Target::Current();
324+
const auto reg_rvv_intrinsics =
325+
tvm::ffi::Function::GetGlobalRequired("tir.tensor_intrin.register_rvv_isa_intrinsics");
326+
const auto rvv_kernels_inventory =
327+
reg_rvv_intrinsics(current_target, /* inventory_only */ true).cast<Map<String, int>>();
328+
for (const auto& intrin : rvv_kernels_inventory) {
329+
if (!tir::TensorIntrin::Get(intrin.first, /*allow_missing*/ true)) {
330+
// on demand intrinsic register
331+
reg_rvv_intrinsics(current_target, /* inventory_only */ false);
332+
}
333+
rules.push_back(ScheduleRule::MultiLevelTilingWithIntrin(
334+
/*intrin_name=*/intrin.first,
335+
/*structure=*/"SSRSRS",
336+
/*tile_binds=*/std::nullopt,
337+
/*max_innermost_factor=*/Integer(intrin.second),
338+
/*vector_load_lens=*/std::nullopt,
339+
/*reuse_read=*/std::nullopt,
340+
/*reuse_write=*/
341+
Map<String, ffi::Any>{{"req", String("may")},
342+
{"levels", Array<Integer>{1, 2}},
343+
{"scope", String("global")}}));
344+
}
345+
rules.push_back(ScheduleRule::MultiLevelTiling(
346+
/*structure=*/"SSRSRS",
347+
/*tile_binds=*/std::nullopt,
348+
/*max_innermost_factor=*/Integer(64),
349+
/*vector_load_lens=*/std::nullopt,
350+
/*reuse_read=*/std::nullopt,
351+
/*reuse_write=*/
352+
Map<String, ffi::Any>{
353+
{"req", String("may")}, {"levels", Array<Integer>{1, 2}}, {"scope", String("global")}}));
354+
rules.push_back(ScheduleRule::ParallelizeVectorizeUnroll(
355+
/*max_jobs_per_core=*/16,
356+
/*max_vectorize_extent=*/64,
357+
/*unroll_max_steps=*/Array<Integer>{0, 16, 64, 512},
358+
/*unroll_explicit=*/true));
359+
rules.push_back(ScheduleRule::RandomComputeLocation());
360+
361+
return rules;
362+
}
363+
307364
Array<ScheduleRule> GetARMNeonSpecificRules() {
308365
return {
309366
ScheduleRule::MultiLevelTilingWithIntrin(

src/meta_schedule/space_generator/space_generator.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ String GetRuleKindFromTarget(const Target& target) {
3939
return "avx512";
4040
}
4141
}
42+
bool have_rvv = target_has_feature_fn_ptr("v", target).cast<bool>();
43+
if (have_rvv) {
44+
return "rvv";
45+
}
4246

4347
TargetJSON target_json = target::parsers::aprofile::ParseTarget(target->Export());
4448
TargetFeatures afeatures = Downcast<TargetFeatures>(target_json.at("features"));
@@ -117,6 +121,13 @@ void SpaceGeneratorNode::InitializeWithTuneContext(const TuneContext& context) {
117121
default_sch_rules = ScheduleRule::DefaultX86("avx512");
118122
default_postprocs = Postproc::DefaultCPUTensorization();
119123
default_mutator_probs = Mutator::DefaultLLVM();
124+
} else if (kind == "rvv") {
125+
static auto llvm_get_vector_width =
126+
tvm::ffi::Function::GetGlobalRequired("target.llvm_get_vector_width");
127+
const int vlen = llvm_get_vector_width(context->target.value()).cast<int>();
128+
default_sch_rules = ScheduleRule::DefaultRISCV(vlen);
129+
default_postprocs = Postproc::DefaultRISCV();
130+
default_mutator_probs = Mutator::DefaultLLVM();
120131
} else if (kind == "asimd") {
121132
default_sch_rules = ScheduleRule::DefaultARM("neon");
122133
default_postprocs = Postproc::DefaultCPUTensorization();

0 commit comments

Comments
 (0)