|
| 1 | +# Licensed to the Apache Software Foundation (ASF) under one |
| 2 | +# or more contributor license agreements. See the NOTICE file |
| 3 | +# distributed with this work for additional information |
| 4 | +# regarding copyright ownership. The ASF licenses this file |
| 5 | +# to you under the Apache License, Version 2.0 (the |
| 6 | +# "License"); you may not use this file except in compliance |
| 7 | +# with the License. You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, |
| 12 | +# software distributed under the License is distributed on an |
| 13 | +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | +# KIND, either express or implied. See the License for the |
| 15 | +# specific language governing permissions and limitations |
| 16 | +# under the License. |
| 17 | +# pylint: disable=invalid-name,line-too-long |
| 18 | +"""Intrinsics for RISCV tensorization""" |
| 19 | + |
| 20 | +import logging |
| 21 | +from tvm.ffi import register_func |
| 22 | +from tvm.runtime import DataType |
| 23 | +from tvm.script import tir as T |
| 24 | +from tvm.target.codegen import llvm_get_vector_width, target_has_features, Target |
| 25 | +from .. import TensorIntrin |
| 26 | + |
| 27 | +logger = logging.getLogger(__name__) |
| 28 | + |
| 29 | + |
| 30 | +def get_max_elems(vlen: int, lmul: int, sew: int) -> int: |
| 31 | + """Returns number of elements of a given data type (SEW) |
| 32 | + that fits multiple (LMUL) of the vector registers (VLEN). |
| 33 | +
|
| 34 | + Args: |
| 35 | + vlen (int): VLEN vector length in bits |
| 36 | + lmul (int): LMUL vector lenght multiplier |
| 37 | + sew (int): SEW standard (single) element width |
| 38 | +
|
| 39 | + Returns: |
| 40 | + int: Number of elements |
| 41 | + """ |
| 42 | + return (vlen // sew) * lmul |
| 43 | + |
| 44 | + |
| 45 | +def rvv_vec_dot_product_kernels( |
| 46 | + n_elems: int, |
| 47 | + n_lanes: int, |
| 48 | + data_dtype: str, |
| 49 | + weight_dtype: str, |
| 50 | + out_dtype: str, |
| 51 | + lmul: int, |
| 52 | +): |
| 53 | + """Dot product of vector and matrix rows using RISC-V vector instructions. |
| 54 | +
|
| 55 | + These kernels takes two arrays A[ELEMS] and B[ELEMS][MACS] and computes |
| 56 | + dot product of A[ELEMS] with each row of B[LANES], accumulating results |
| 57 | + with C[LANES]. |
| 58 | +
|
| 59 | + The pseudo code is as follows: |
| 60 | + .. code-block:: c |
| 61 | + void vec_dot_prod(A[ELEMS], B[LANES][ELEMS], C[LANES]){ |
| 62 | + for (j = 0; j < LANES; j++) { |
| 63 | + for (k = 0; k < ELEMS; k++) { |
| 64 | + C[j] += A[k] * B[j][k] |
| 65 | + } |
| 66 | + } |
| 67 | + } |
| 68 | + """ |
| 69 | + |
| 70 | + @T.prim_func |
| 71 | + def rvv_vec_dot_prod_desc( |
| 72 | + A: T.Buffer((n_elems,), data_dtype, offset_factor=1), |
| 73 | + B: T.Buffer((n_lanes, n_elems), weight_dtype, offset_factor=1), |
| 74 | + C: T.Buffer((n_lanes,), out_dtype, offset_factor=1), |
| 75 | + ) -> None: |
| 76 | + with T.block("root"): |
| 77 | + T.reads(C[0:n_lanes], A[0:n_elems], B[0:n_lanes, 0:n_elems]) |
| 78 | + T.writes(C[0:n_lanes]) |
| 79 | + for j in T.serial(0, n_lanes): |
| 80 | + for k in T.serial(0, n_elems): |
| 81 | + with T.block("update"): |
| 82 | + vj, vk = T.axis.remap("SR", [j, k]) |
| 83 | + C[vj] = C[vj] + T.cast(A[vk], out_dtype) * T.cast(B[vj, vk], out_dtype) |
| 84 | + |
| 85 | + # LLVM only supports ELEN=32 or ELEN=64 |
| 86 | + # https://llvm.org/docs//RISCV/RISCVVectorExtension.html |
| 87 | + d_dtype_lanes = (64 // DataType(data_dtype).bits) * lmul |
| 88 | + w_dtype_lanes = (64 // DataType(weight_dtype).bits) * lmul |
| 89 | + # reduction lanes narrows |
| 90 | + o_dtype_lanes = (64 // DataType(out_dtype).bits) * lmul // n_lanes |
| 91 | + # data type widening case |
| 92 | + o_dtype_lanes = max(o_dtype_lanes, 2) |
| 93 | + |
| 94 | + mask_args = () if data_dtype[0] in ("i", "u") else (T.uint64(7),) |
| 95 | + |
| 96 | + wide_dtype = out_dtype |
| 97 | + if DataType(out_dtype).bits > DataType(data_dtype).bits: |
| 98 | + wide_dtype = "".join(c for c in data_dtype if not c.isdigit()) |
| 99 | + wide_dtype += str(DataType(data_dtype).bits * 2) |
| 100 | + |
| 101 | + # fmt: off |
| 102 | + @T.prim_func |
| 103 | + def rvv_vec_dot_prod_impl( |
| 104 | + A: T.Buffer((n_elems,), data_dtype, offset_factor=1), |
| 105 | + B: T.Buffer((n_lanes, n_elems), weight_dtype, offset_factor=1), |
| 106 | + C: T.Buffer((n_lanes,), out_dtype, offset_factor=1), |
| 107 | + ) -> None: |
| 108 | + with T.block("root"): |
| 109 | + T.reads(C[0:n_lanes], A[0:n_elems], B[0:n_lanes, 0:n_elems]) |
| 110 | + T.writes(C[0:n_lanes]) |
| 111 | + |
| 112 | + vec_A = T.call_llvm_intrin( |
| 113 | + f"{data_dtype}xvscalex{d_dtype_lanes}", |
| 114 | + "llvm.riscv.vle", |
| 115 | + T.broadcast(T.Cast(data_dtype, 0), T.vscale() * d_dtype_lanes), |
| 116 | + T.tvm_access_ptr(T.type_annotation(data_dtype), A.data, 0, n_elems, 1), |
| 117 | + T.int64(n_elems)) |
| 118 | + |
| 119 | + for i in range(n_lanes): |
| 120 | + with T.block("update"): |
| 121 | + T.reads(B[i, 0:n_elems]) |
| 122 | + T.writes(C[i]) |
| 123 | + |
| 124 | + vec_B_row = T.call_llvm_intrin( |
| 125 | + f"{weight_dtype}xvscalex{w_dtype_lanes}", |
| 126 | + "llvm.riscv.vle", |
| 127 | + T.broadcast(T.Cast(data_dtype, 0), T.vscale() * w_dtype_lanes), |
| 128 | + T.tvm_access_ptr(T.type_annotation(weight_dtype), B.data, i * n_elems, n_elems, 1), |
| 129 | + T.int64(n_elems)) |
| 130 | + |
| 131 | + product = T.call_llvm_intrin( |
| 132 | + f"{wide_dtype}xvscalex{w_dtype_lanes}", |
| 133 | + "llvm.riscv.vfmul" if out_dtype[0] == "f" else \ |
| 134 | + "llvm.riscv.vwmulsu" if (data_dtype[0] != weight_dtype[0]) else \ |
| 135 | + "llvm.riscv.vwmul", |
| 136 | + T.broadcast(T.Cast(wide_dtype, 0), T.vscale() * w_dtype_lanes), |
| 137 | + vec_B_row, |
| 138 | + vec_A, |
| 139 | + *mask_args, |
| 140 | + T.uint64(n_elems)) |
| 141 | + |
| 142 | + ini_acc = T.call_llvm_intrin( |
| 143 | + f"{out_dtype}xvscalex{o_dtype_lanes}", |
| 144 | + "llvm.riscv.vle", |
| 145 | + T.broadcast(T.Cast(out_dtype, 0), T.vscale() * o_dtype_lanes), |
| 146 | + T.tvm_access_ptr(T.type_annotation(out_dtype), C.data, i, 1, 1), |
| 147 | + T.int64(1)) |
| 148 | + |
| 149 | + red_sum = T.call_llvm_intrin( |
| 150 | + f"{out_dtype}xvscalex{o_dtype_lanes}", |
| 151 | + "llvm.riscv.vfredusum" if out_dtype[0] == "f" else \ |
| 152 | + "llvm.riscv.vwredsum", |
| 153 | + T.broadcast(T.Cast(out_dtype, 0), T.vscale() * o_dtype_lanes), |
| 154 | + product, |
| 155 | + ini_acc, |
| 156 | + *mask_args, |
| 157 | + T.uint64(n_elems)) |
| 158 | + |
| 159 | + C[i] = T.call_llvm_intrin( |
| 160 | + out_dtype, |
| 161 | + "llvm.riscv.vfmv.f.s" if out_dtype[0] == "f" else \ |
| 162 | + "llvm.riscv.vmv.x.s", |
| 163 | + red_sum) |
| 164 | + # fmt: on |
| 165 | + return rvv_vec_dot_prod_desc, rvv_vec_dot_prod_impl |
| 166 | + |
| 167 | + |
| 168 | +@register_func("tir.tensor_intrin.register_rvv_isa_intrinsics") |
| 169 | +def register_rvv_isa_intrinsics(target: Target, inventory_only=False) -> dict(): |
| 170 | + """Register RISCV V (vector) intrinsics |
| 171 | + [x] Implementation follows version 1.0 vector specifications: |
| 172 | + https://github.com/riscvarchive/riscv-v-spec/releases/tag/v1.0 |
| 173 | +
|
| 174 | + Args: |
| 175 | + target (Target): TVM target |
| 176 | + inventory_only (bool): No registration inventory only |
| 177 | +
|
| 178 | + Returns: |
| 179 | + dict(): A catalog with registered kernel names and properties |
| 180 | + """ |
| 181 | + if not target_has_features("v", target): |
| 182 | + raise RuntimeError("Current target does not support `v` extension.") |
| 183 | + |
| 184 | + vlen = llvm_get_vector_width(target) |
| 185 | + # get maximum reduction lanes (without grouping) |
| 186 | + n_lanes = get_max_elems(vlen, lmul=1, sew=32) |
| 187 | + |
| 188 | + kernels_inventory = {} |
| 189 | + |
| 190 | + data_dtype = ["uint8", "int8", "float16", "float32"] |
| 191 | + weight_dtype = ["int8", "int8", "float16", "float32"] |
| 192 | + output_dtype = ["int32", "int32", "float16", "float32"] |
| 193 | + |
| 194 | + for d_dtype, w_dtype, o_dtype in zip(data_dtype, weight_dtype, output_dtype): |
| 195 | + # max elements to grouped registers |
| 196 | + max_elems = get_max_elems(vlen, lmul=8, sew=DataType(d_dtype).bits) |
| 197 | + # data widening halves available vector registers |
| 198 | + if DataType(o_dtype).bits > DataType(d_dtype).bits: |
| 199 | + max_elems //= 2 |
| 200 | + # compute optimal LMUL for full load |
| 201 | + lmul = max_elems // (vlen // DataType(d_dtype).bits) |
| 202 | + |
| 203 | + n_elems = max_elems |
| 204 | + while n_elems >= 4: |
| 205 | + |
| 206 | + dt = DataType(d_dtype) |
| 207 | + wt = DataType(w_dtype) |
| 208 | + ot = DataType(o_dtype) |
| 209 | + kernel_name = "rvv_dot" |
| 210 | + kernel_name += f"_{n_elems}{dt[0]}{dt.bits}" |
| 211 | + kernel_name += f"_{n_lanes}x{n_elems}{wt[0]}{wt.bits}" |
| 212 | + kernel_name += f"_{n_lanes}{ot[0]}{ot.bits}" |
| 213 | + kernels_inventory[kernel_name] = n_elems |
| 214 | + |
| 215 | + if not inventory_only: |
| 216 | + logger.debug(f"Registering kernel {kernel_name}") |
| 217 | + desc, impl = rvv_vec_dot_product_kernels( |
| 218 | + n_elems, n_lanes, d_dtype, w_dtype, o_dtype, lmul |
| 219 | + ) |
| 220 | + TensorIntrin.register(kernel_name, desc, impl, override=True) |
| 221 | + |
| 222 | + n_elems //= 2 |
| 223 | + |
| 224 | + return kernels_inventory |
| 225 | + |
| 226 | + |
| 227 | +def register_riscv_intrinsics(target: Target): |
| 228 | + """Register RISCV intrinsics |
| 229 | +
|
| 230 | + Args: |
| 231 | + target (Target): TVM target |
| 232 | + """ |
| 233 | + |
| 234 | + # RISCV `v` 1.0 extension templates |
| 235 | + _ = register_rvv_isa_intrinsics(target) |
| 236 | + logger.debug("Finished registering riscv intrinsics.") |
0 commit comments