Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions include/tvm/target/target_kind.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,10 @@ class TargetKind : public ObjectRef {
*/
TVM_DLL static Optional<TargetKind> Get(const String& target_kind_name);
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TargetKind, ObjectRef, TargetKindNode);

private:
/*! \brief Mutable access to the container class */
TargetKindNode* operator->() { return static_cast<TargetKindNode*>(data_.get()); }

private:
TVM_DLL static const AttrRegistryMapContainerMap<TargetKind>& GetAttrMapContainer(
const String& attr_name);
friend class TargetKindRegEntry;
Expand Down
5 changes: 2 additions & 3 deletions python/tvm/relay/qnn/op/legalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from tvm._ffi.base import TVMError
from tvm.relay.qnn.op.canonicalizations import create_integer_lookup_op

from ....target.x86 import target_has_sse42
from ....target.x86 import target_has_features
from ....topi.utils import is_target
from .. import op as reg

Expand Down Expand Up @@ -457,8 +457,7 @@ def _shift(data, zero_point, out_dtype):

def is_fast_int8_on_intel():
"""Checks whether the hardware has support for fast Int8 arithmetic operations."""
target = tvm.target.Target.current(allow_none=False)
return target_has_sse42(target.mcpu)
return target_has_features("sse4.2")


# Helper function to align up given value.
Expand Down
7 changes: 4 additions & 3 deletions python/tvm/relay/qnn/op/qnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from tvm.runtime import Object
from tvm.target import Target
from tvm.topi.nn.qnn import SQNN_DTYPE_TO_CODE
from tvm.target.x86 import target_has_sse41
from tvm.target.x86 import target_has_features

from . import _make, _requantize

Expand All @@ -54,8 +54,9 @@ def _get_node_default_rounding():
@staticmethod
def _get_node_default_compute_dtype():
target = Target.current(True)
if target and str(target.kind) == "llvm" and target_has_sse41(target.mcpu):
return "float32"
if target and str(target.kind) == "llvm":
if target_has_features("sse4.1", target):
return "float32"

return "int64"

Expand Down
32 changes: 32 additions & 0 deletions python/tvm/target/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,38 @@ def llvm_get_intrinsic_name(intrin_id: int) -> str:
return _ffi_api.llvm_get_intrinsic_name(intrin_id)


def llvm_x86_get_archlist(only64bit=False):
"""Get X86 CPU name list.

Parameters
----------
only64bit : bool
Filter 64bit architectures.

Returns
-------
features : list[str]
String list of X86 architectures.
"""
return _ffi_api.llvm_x86_get_archlist(only64bit)


def llvm_x86_get_features(cpu_name):
"""Get X86 CPU features.

Parameters
----------
cpu_name : string
X86 CPU name (e.g. "skylake").

Returns
-------
features : list[str]
String list of X86 CPU features.
"""
return _ffi_api.llvm_x86_get_features(cpu_name)


def llvm_version_major(allow_none=False):
"""Get the major LLVM version.

Expand Down
161 changes: 41 additions & 120 deletions python/tvm/target/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,127 +16,48 @@
# under the License.
"""Common x86 related utilities"""
from .._ffi import register_func
from .target import Target


@register_func("tvm.target.x86.target_has_sse41")
def target_has_sse41(target):
return (
target_has_sse42(target)
or target_has_avx(target)
or target_has_avx2(target)
or target_has_avx512(target)
or target_has_vnni(target)
or target
in {
"btver2",
"penryn",
}
)


@register_func("tvm.target.x86.target_has_sse42")
def target_has_sse42(target):
return (
target_has_avx(target)
or target_has_avx2(target)
or target_has_avx512(target)
or target_has_vnni(target)
or target
in {
"silvermont",
"slm",
"goldmont",
"goldmont-plus",
"tremont",
"nehalem",
"corei7",
"westmere",
"bdver1",
"bdver2",
"bdver3",
"x86-64-v2",
}
)


@register_func("tvm.target.x86.target_has_avx")
def target_has_avx(target):
return (
target_has_avx2(target)
or target_has_avx512(target)
or target_has_vnni(target)
or target in {"sandybridge", "corei7-avx", "ivybridge", "core-avx-i"}
)


@register_func("tvm.target.x86.target_has_avx2")
def target_has_avx2(target):
return (
target_has_avx512(target)
or target_has_vnni(target)
or target
in {
"haswell",
"core-avx2",
"broadwell",
"skylake",
"bdver4",
"znver1",
"znver2",
"znver3",
"x86-64-v3",
}
)


@register_func("tvm.target.x86.target_has_avx512")
def target_has_avx512(target):
return target in {
"skylake-avx512",
"skx",
"knl",
"knm",
"x86-64-v4",
"cannonlake",
# explicit enumeration of VNNI capable due to collision with alderlake
"cascadelake",
"icelake-client",
"icelake-server",
"rocketlake",
"tigerlake",
"cooperlake",
"sapphirerapids",
}


@register_func("tvm.target.x86.target_has_vnni")
def target_has_vnni(target):
return target in {
"cascadelake",
"icelake-client",
"icelake-server",
"rocketlake",
"tigerlake",
"cooperlake",
"sapphirerapids",
"alderlake",
}


@register_func("tvm.target.x86.target_has_amx")
def target_has_amx(target):
return target in {
"sapphirerapids",
}
from . import _ffi_api
from ..ir.container import Array


@register_func("tvm.target.x86.target_has_features")
def target_has_features(features, target=None):
"""Check X86 CPU features.
Parameters
----------
features : str or Array
Feature(s) to check.
target : Target
Optional TVM target, default `None` use the global context target.
Returns
-------
has_feats : bool
True if feature(s) are in the target arch.
"""
has_feats = True
assert isinstance(features, (Array, str))
features = [features] if isinstance(features, str) else features
for feat in features:
has_feats &= _ffi_api.llvm_x86_has_feature(feat, target)
return has_feats


@register_func("tvm.topi.x86.utils.get_simd_32bit_lanes")
def get_simd_32bit_lanes():
mcpu = Target.current().mcpu
fp32_vec_len = 4
if target_has_avx512(mcpu):
fp32_vec_len = 16
elif target_has_avx2(mcpu):
fp32_vec_len = 8
return fp32_vec_len
"""X86 SIMD optimal vector length lookup.
Parameters
----------
Returns
-------
vec_len : int
The optimal vector length of CPU from the global context target.
"""
vec_len = 4
# avx512f: llvm.x86.avx512.addpd.w.512 (LLVM auto, added)
# avx512bw: llvm.x86.avx512.pmaddubs.w.512" (TVM required)
# + llvm.x86.avx512.pmaddw.d.512"
if target_has_features(["avx512bw", "avx512f"]):
vec_len = 16
elif target_has_features("avx2"):
vec_len = 8
return vec_len
16 changes: 10 additions & 6 deletions python/tvm/topi/x86/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from tvm import autotvm, te
from tvm.autotvm.task.space import SplitEntity
from tvm.contrib import cblas, mkl
from tvm.target.x86 import target_has_amx, target_has_avx512
from tvm.target.x86 import target_has_features

from .. import generic, nn
from ..transform import layout_transform
Expand All @@ -38,8 +38,10 @@ def batch_matmul_int8_compute(cfg, x, y, *_):
packed_y = layout_transform(y, "BNK", packed_y_layout)
_, n_o, _, n_i, _ = packed_y.shape
ak = te.reduce_axis((0, k), name="k")
mcpu = tvm.target.Target.current().mcpu
if target_has_avx512(mcpu):
# avx512f: llvm.x86.avx512.addpd.w.512 (LLVM auto, added)
# avx512bw: llvm.x86.avx512.pmaddubs.w.512" (TVM required)
# + llvm.x86.avx512.pmaddw.d.512"
if target_has_features(["avx512bw", "avx512f"]):
attrs_info = {"schedule_rule": "batch_matmul_int8"}
else:
attrs_info = None
Expand Down Expand Up @@ -233,14 +235,16 @@ def _callback(op):
def schedule_batch_matmul_int8(cfg, outs):
"""Schedule for batch_matmul_int8"""
s = te.create_schedule([x.op for x in outs])
mcpu = tvm.target.Target.current().mcpu

def _callback(op):
if "batch_matmul_int8" in op.tag:
layout_trans = op.input_tensors[1]
if target_has_amx(mcpu):
if target_has_features("amx-int8"):
batch_matmul_amx_schedule(cfg, s, op.output(0), outs[0], layout_trans)
elif target_has_avx512(mcpu):
# avx512f: llvm.x86.avx512.addpd.w.512 (LLVM auto, added)
# avx512bw: llvm.x86.avx512.pmaddubs.w.512" (TVM required)
# + llvm.x86.avx512.pmaddw.d.512"
elif target_has_features(["avx512bw", "avx512f"]):
batch_matmul_int8_schedule(cfg, s, op.output(0), outs[0], layout_trans)

traverse_inline(s, outs[0].op, _callback)
Expand Down
10 changes: 6 additions & 4 deletions python/tvm/topi/x86/conv2d_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import tvm
from tvm import autotvm, te
from tvm.target.x86 import target_has_sse42
from tvm.target.x86 import target_has_features

from .. import nn, tag
from ..generic import conv2d as conv2d_generic
Expand Down Expand Up @@ -49,7 +49,10 @@ def _get_default_config_int8(
"""
if is_depthwise:
# Fallback to FP32 default config until a VNNI schedule is defined.
wkl = _get_depthwise_conv2d_workload(data, kernel, strides, padding, out_dtype)
wkl = _get_depthwise_conv2d_workload(
data, kernel, strides, padding, dilation, out_dtype, layout
)

from .depthwise_conv2d import _fallback_schedule

_fallback_schedule(cfg, wkl)
Expand Down Expand Up @@ -81,8 +84,7 @@ def is_int8_hw_support(data_dtype, kernel_dtype):
is_llvm_support = llvm_version >= 8

# 3) Check target
mcpu = tvm.target.Target.current().mcpu
is_target_support = target_has_sse42(mcpu)
is_target_support = target_has_features("sse4.2")

return is_dtype_support and is_llvm_support and is_target_support

Expand Down
16 changes: 10 additions & 6 deletions python/tvm/topi/x86/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from tvm import autotvm, te
from tvm.autotvm.task.space import SplitEntity
from tvm.contrib import cblas, dnnl, mkl
from tvm.target.x86 import get_simd_32bit_lanes, target_has_amx, target_has_avx512
from tvm.target.x86 import get_simd_32bit_lanes, target_has_features

from .. import generic, tag
from ..utils import get_const_tuple, traverse_inline
Expand Down Expand Up @@ -298,13 +298,15 @@ def dense_int8(cfg, data, weight, bias=None, out_dtype=None):
def schedule_dense_int8(cfg, outs):
"""Create a schedule for dense__int8"""
s = te.create_schedule([x.op for x in outs])
mcpu = tvm.target.Target.current().mcpu

def _callback(op):
if "dense_int8" in op.tag:
if target_has_amx(mcpu):
if target_has_features("amx-int8"):
dense_amx_int8_schedule(cfg, s, op.output(0), outs[0])
elif target_has_avx512(mcpu):
# avx512f: llvm.x86.avx512.addpd.w.512 (LLVM auto, added)
# avx512bw: llvm.x86.avx512.pmaddubs.w.512" (TVM required)
# + llvm.x86.avx512.pmaddw.d.512"
elif target_has_features(["avx512bw", "avx512f"]):
dense_int8_schedule(cfg, s, op.output(0), outs[0])

traverse_inline(s, outs[0].op, _callback)
Expand All @@ -316,8 +318,10 @@ def dense_int8_compute(cfg, X, packed_w, bias=None):
m, k = X.shape
n_o, _, n_i, _ = packed_w.shape
ak = te.reduce_axis((0, k), name="k")
mcpu = tvm.target.Target.current().mcpu
if target_has_avx512(mcpu):
# avx512f: llvm.x86.avx512.addpd.w.512 (LLVM auto, added)
# avx512bw: llvm.x86.avx512.pmaddubs.w.512" (TVM required)
# + llvm.x86.avx512.pmaddw.d.512"
if target_has_features(["avx512bw", "avx512f"]):
target_attr = {"schedule_rule": "meta_schedule.x86.dense_int8"}
else:
target_attr = None
Expand Down
11 changes: 7 additions & 4 deletions python/tvm/topi/x86/dense_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import tvm
from tvm import autotvm, relay, te
from tvm.target.x86 import target_has_amx, target_has_avx512
from tvm.target.x86 import target_has_features

from .. import nn
from ..nn import dense_alter_layout
Expand All @@ -28,9 +28,12 @@


def check_int8_applicable(x, y, allow_padding=False):
mcpu = tvm.target.Target.current().mcpu
# TODO(vvchernov): may be also target_has_avx2 or lower?
simd_avai = target_has_avx512(mcpu) or target_has_amx(mcpu)
# avx512f: llvm.x86.avx512.addpd.w.512 (LLVM auto, added)
# avx512bw: llvm.x86.avx512.pmaddubs.w.512" (TVM required)
# + llvm.x86.avx512.pmaddw.d.512"
simd_avai = target_has_features(["avx512bw", "avx512f"])
simd_avai |= target_has_features("amx-int8")
# TODO(vvchernov): may be also target_has_features("avx2") or lower?
return (
simd_avai
and "int8" in x.dtype
Expand Down
Loading