From d2908157275f983b082c7c7c500b12ddcf5a06bd Mon Sep 17 00:00:00 2001 From: Siva Date: Mon, 30 Dec 2024 22:19:16 +0530 Subject: [PATCH 01/31] Annotate Custom Scope layout pass for Adreno GPU Texture scope annotation is handled by - Layout conversion from 4D to 5D with convert_layout pass - Legalization of ops with Adreno specific legalization - FuseOps & FuseTIR - Now, over the fused TIR annotate the scopes by hint_on_device - RealizeVDevice will take care of injecting to_device as needed. - Also, introduced SpecializeTIRParams to update the fused TIR the prim function buffer var map with new scope information. Changes in FuseOps and FuseTIR are to forward op attr and op pattern info. This info is used for Texture specific scoping decisions. --- include/tvm/ir/global_info.h | 9 + include/tvm/relax/attrs/op.h | 2 + include/tvm/relax/transform.h | 18 + python/tvm/dlight/__init__.py | 1 + python/tvm/dlight/adreno/__init__.py | 20 + python/tvm/dlight/adreno/base.py | 41 + python/tvm/dlight/adreno/convolution.py | 230 ++++ python/tvm/relax/transform/__init__.py | 2 + .../relax/transform/legalize_ops/__init__.py | 3 + .../transform/legalize_ops/adreno/__init__.py | 18 + .../legalize_ops/adreno/convolution.py | 35 + .../tvm/relax/transform/optimize_batchnorm.py | 108 ++ python/tvm/relax/transform/transform.py | 25 + python/tvm/tir/analysis/analysis.py | 6 +- python/tvm/topi/nn/conv2d.py | 129 +++ src/relax/op/nn/convolution.cc | 8 +- src/relax/op/op.cc | 3 +- src/relax/op/tensor/binary.cc | 21 +- src/relax/op/tensor/manipulate.cc | 42 + .../transform/annotate_custom_storage.cc | 455 ++++++++ src/relax/transform/fuse_tir.cc | 16 + src/relax/transform/legalize_ops.cc | 28 + src/relax/transform/realize_vdevice.cc | 4 +- src/relax/transform/specialize_tir_params.cc | 170 +++ src/relax/transform/utils.h | 2 +- src/script/printer/relax/call.cc | 3 +- src/script/printer/relax/struct_info.cc | 5 +- src/script/printer/relax/utils.h | 2 +- src/tir/schedule/analysis/analysis.cc | 1 + .../test_transform_annotate_custom_scope1.py | 162 +++ .../test_transform_annotate_custom_scope.py | 1032 +++++++++++++++++ .../relax/test_transform_convert_layout.py | 415 ++++++- 32 files changed, 2994 insertions(+), 22 deletions(-) create mode 100644 python/tvm/dlight/adreno/__init__.py create mode 100644 python/tvm/dlight/adreno/base.py create mode 100644 python/tvm/dlight/adreno/convolution.py create mode 100644 python/tvm/relax/transform/legalize_ops/adreno/__init__.py create mode 100644 python/tvm/relax/transform/legalize_ops/adreno/convolution.py create mode 100644 python/tvm/relax/transform/optimize_batchnorm.py create mode 100644 src/relax/transform/annotate_custom_storage.cc create mode 100644 src/relax/transform/specialize_tir_params.cc create mode 100644 tests/python/relax/adreno/test_transform_annotate_custom_scope1.py create mode 100644 tests/python/relax/test_transform_annotate_custom_scope.py diff --git a/include/tvm/ir/global_info.h b/include/tvm/ir/global_info.h index 892bba4da694..b646057a009f 100644 --- a/include/tvm/ir/global_info.h +++ b/include/tvm/ir/global_info.h @@ -87,6 +87,15 @@ class VDeviceNode : public GlobalInfoNode { class VDevice : public GlobalInfo { public: TVM_DLL explicit VDevice(Target tgt, int dev_id, MemoryScope mem_scope); + /*! + * \brief Equal comparator. + * \param other The data type to compare against. + * \return The comparison result. + */ + bool operator==(const VDevice& other) const { + return this->get()->target == other->target && this->get()->vdevice_id == other->vdevice_id && + this->get()->memory_scope == other->memory_scope; + } TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(VDevice, GlobalInfo, VDeviceNode); }; diff --git a/include/tvm/relax/attrs/op.h b/include/tvm/relax/attrs/op.h index 36356ba83e48..e504a36bf3c9 100644 --- a/include/tvm/relax/attrs/op.h +++ b/include/tvm/relax/attrs/op.h @@ -104,6 +104,7 @@ struct ToVDeviceAttrs : public AttrsNodeReflAdapter { struct HintOnDeviceAttrs : public AttrsNodeReflAdapter { int32_t device_type; int32_t index; + MemoryScope memory_scope; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -111,6 +112,7 @@ struct HintOnDeviceAttrs : public AttrsNodeReflAdapter { .def_ro("device_type", &HintOnDeviceAttrs::device_type, "The device type where the data is supposed to be executed.") .def_ro("index", &HintOnDeviceAttrs::index, "The device id."); + .def_ro("memory_scope", &HintOnDeviceAttrs::memory_scope, "The device memory scope."); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.HintOnDeviceAttrs", HintOnDeviceAttrs, BaseAttrsNode); diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index a8ccc4076bb3..cf7989c1a14d 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -679,6 +679,24 @@ TVM_DLL Pass RewriteCUDAGraph(); */ TVM_DLL Pass FewShotTuning(int valid_count, bool benchmark); +/*! + * \brief This pass is designed to annotate the memory scope information via VDevice attribute. + * This pass need operator attrbutes which in general vanish aftre legalization. + * FuseOps and FuseTIR are modified to pass on the operator specific attributes and also + * op_pattern details as part of the PrimFunc. This pass is Adreno specific and annotates each + * BindingVar with appropriate HintInDevice. RealizeVDevice pass followed by handles these hints. + * Followed by this pass we also invoke SpecializeTIRParams which updates the var_buffer_map + * based on this new VDevice information. + */ +TVM_DLL Pass AnnotateCustomMemoryScope(Target target); + +/*! + * \brief This pass updates the var_buffer mapping of PrimFunctions from the call_tir info. + * Primarily used to update the VDevice information is any changes occured from the caller. + * This pass recreated the buffers and updates the map. + */ +TVM_DLL Pass SpecializeTIRParams(); + } // namespace transform } // namespace relax } // namespace tvm diff --git a/python/tvm/dlight/__init__.py b/python/tvm/dlight/__init__.py index bd70acf00f90..3d42d1972dcc 100644 --- a/python/tvm/dlight/__init__.py +++ b/python/tvm/dlight/__init__.py @@ -16,6 +16,7 @@ # under the License. """DLight package provides efficient schedules out-of-box for deep learning workloads.""" from . import gpu +from . import adreno from . import cpu from .analysis import ( BlockInfo, diff --git a/python/tvm/dlight/adreno/__init__.py b/python/tvm/dlight/adreno/__init__.py new file mode 100644 index 000000000000..ea2781455989 --- /dev/null +++ b/python/tvm/dlight/adreno/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Adreno schedule rules. +""" +from .convolution import Conv2d diff --git a/python/tvm/dlight/adreno/base.py b/python/tvm/dlight/adreno/base.py new file mode 100644 index 000000000000..d043706c2fc5 --- /dev/null +++ b/python/tvm/dlight/adreno/base.py @@ -0,0 +1,41 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Base schedule rule for Adreno operators.""" + +from tvm.target import Target + +from ..base import ScheduleRule + + +class AdrenoScheduleRule(ScheduleRule): # pylint: disable=too-few-public-methods + """The Schedule Rule specific to Adreno targets, + will return None if the target is not Adreno.""" + + def is_target_available(self, target: Target) -> bool: + """Check whether the target is available for Adreno rule. + + Parameters + ---------- + target : Target + The compilation target to check. + + Returns + ------- + available : bool + Whether the target is available for this rule. + """ + return super().is_target_available(target) and "adreno" in target.keys diff --git a/python/tvm/dlight/adreno/convolution.py b/python/tvm/dlight/adreno/convolution.py new file mode 100644 index 000000000000..16075f814df2 --- /dev/null +++ b/python/tvm/dlight/adreno/convolution.py @@ -0,0 +1,230 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring, invalid-name +"""A Conv2d schedule rule for Adreno GPU operators.""" +from dataclasses import dataclass +from typing import List, Optional + +from tvm import tir +from tvm.target import Target +from tvm.tir import IterVar +from tvm.tir.schedule.schedule import BlockRV + +from ..base import analysis, BlockInfo, IterInfo +from .base import AdrenoScheduleRule + + +def is_spatial_block(sch: tir.Schedule, block: BlockRV) -> bool: + block_stmt = sch.get(block) + iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars} + return iter_types == {IterVar.DataPar} + + +def is_reduction_block(sch: tir.Schedule, block: BlockRV) -> bool: + block_stmt = sch.get(block) + iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars} + return iter_types == {IterVar.CommReduce, IterVar.DataPar} + + +def _collect_producers(sch: tir.Schedule, block: tir.schedule.BlockRV): + result = [] + for producer in sch.get_producers(block): + result.append(producer) + result.extend(_collect_producers(sch, producer)) + return result + + +def _collect_consumers(sch: tir.Schedule, block: tir.schedule.BlockRV): + result = [] + for consumer in sch.get_consumers(block): + result.append(consumer) + result.extend(_collect_consumers(sch, consumer)) + return result + + +def get_block_info(sch: tir.Schedule, block: tir.schedule.BlockRV) -> BlockInfo: + def _iter_kind(loop: tir.IterVar) -> str: + return {tir.IterVar.DataPar: "S", tir.IterVar.CommReduce: "R"}.get(loop.iter_type, "O") + + def _is_reduction_block(block: tir.schedule.BlockRV): + for iter_var in sch.get(block).iter_vars: + if _iter_kind(iter_var) == "R": + return True + return False + + return BlockInfo( + name=sch.get(block).name_hint, + iters=[ + IterInfo( + kind=_iter_kind(iter_var), + var=iter_var.var, + dom=iter_var.dom.extent, + loop_rv=loop_rv, + ) + for loop_rv, iter_var in zip(sch.get_loops(block), sch.get(block).iter_vars) + ], + block_rv=block, + reduction_block=_is_reduction_block(block), + ) + + +def get_reduction_blocks(sch: tir.Schedule, blocks: List[tir.schedule.BlockRV]) -> bool: + # NOTE: We assume there is only one reduction block in the function + # all blocks are required to be spatial or reduction + if not all( + [is_reduction_block(sch, block) or is_spatial_block(sch, block) for block in blocks] + ): + return None + + # There is only one reduction block + reduction_blocks = [block for block in blocks if is_reduction_block(sch, block)] + if len(reduction_blocks) != 1: + return None + + return reduction_blocks[0] + + +def is_convolution(sch: tir.Schedule, block: tir.schedule.BlockRV): + # TODO: Use buffer access patterns to discover convolution type kernels instead of using name. + return ( + sch.get(block).name_hint.count("conv2d_NCHWc_OIHWo") + and "".join([iter_type.kind for iter_type in get_block_info(sch, block).iters]) + == "SSSSSRRR" + ) + + +class Conv2d(AdrenoScheduleRule): + """The schedule rule for convolution computation""" + + @dataclass + class Config: + block_size_x: int = 8 + block_size_y: int = 8 + vector_size: int = 1 + unroll: int = 256 # 0 means no unroll + use_shared: bool = True + storage_align: bool = False + inner_x: bool = False + + def get_configs(self, target: Target) -> Config: + """Get the schedule config for the target""" + if target.kind.name == "cuda" or target.kind.name == "rocm": + return Conv2d.Config( + block_size_x=8, + block_size_y=16, + vector_size=2, + unroll=256, + use_shared=True, + storage_align=True, + inner_x=False, + ) + elif target.kind.name == "opencl" and ( + ("android" in str(target.host)) or ("adreno" in str(target.attrs)) + ): + return Conv2d.Config( + block_size_x=32, + block_size_y=4, + vector_size=8, + unroll=16, + use_shared=False, + storage_align=False, + inner_x=True, + ) + else: + return Conv2d.Config() + + def apply( # pylint: disable=too-many-locals,missing-docstring + self, + func: tir.PrimFunc, + target: Target, + _: bool, + ) -> Optional[tir.Schedule]: + if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): + return None + + if isinstance(func, tir.PrimFunc): + sch = tir.Schedule(func) + + # config = self.get_configs(target) + root_block = analysis.get_root_block(sch) + blocks = sch.get_child_blocks(root_block) + reduction_block = get_reduction_blocks(sch, blocks) + + if reduction_block is None: + return None + if not is_convolution(sch, reduction_block): + return None + + def schedule_data_pad(blk): + axes = sch.get_loops(blk) + axes, vec = axes[:-1], axes[-1] + axis = sch.fuse(*axes) + bx, ty, tx = sch.split(axis, [None, 16, 16]) + sch.bind(bx, "blockIdx.x") + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + sch.vectorize(vec) + + def schedule_conv2d(blk): + # TODO: Loop Pattern mayn't be reliable, need to perform better analysis. + n, oc, oh, ow, ob, ic, kh, kw = sch.get_loops(blk) + sch.reorder(n, oc, oh, ow, ic, kh, kw, ob) + main_lp = sch.fuse(n, oc, oh, ow) + bx, ty, tx = sch.split(main_lp, [None, 16, 16]) + sch.bind(tx, "threadIdx.x") + sch.bind(ty, "threadIdx.y") + sch.bind(bx, "blockIdx.x") + + ico, icv = sch.split(ic, [None, 4]) + sch.reorder(ico, kh, kw, icv, ob) + rblk = sch.cache_read(blk, 0, "local") + sch.compute_at(rblk, kw) + sch.vectorize(sch.get_loops(rblk)[-1]) + wblk = sch.cache_write(blk, 0, "local") + sch.reverse_compute_at(wblk, tx) + sch.vectorize(sch.get_loops(wblk)[-1]) + sch.vectorize(ob) + init_blk = sch.decompose_reduction(blk, ico) + sch.vectorize(sch.get_loops(init_blk)[-1]) + + def is_data_pad(block: tir.stmt.Block): + return is_spatial_block(sch, block) and tir.analysis.has_if_then_else(sch.get(block)) + + def schedule_conv2d_blocks(): + + # Do analysis to find block type + blocks = sch.get_child_blocks(root_block) + passed_reduction = False + for blk in blocks: + if is_reduction_block(sch, blk): + schedule_conv2d(blk) + passed_reduction = True + elif is_data_pad(blk): + schedule_data_pad(blk) + elif is_spatial_block(sch, blk): + try: + if not passed_reduction: + sch.compute_inline(blk) + else: + sch.reverse_compute_inline(blk) + except: # pylint: disable=W0702 + pass + else: + raise TypeError("Can't Schedule this Block", sch.get(blk)) + + schedule_conv2d_blocks() + return sch diff --git a/python/tvm/relax/transform/__init__.py b/python/tvm/relax/transform/__init__.py index 724921e5fee7..5767fb235099 100644 --- a/python/tvm/relax/transform/__init__.py +++ b/python/tvm/relax/transform/__init__.py @@ -83,6 +83,8 @@ UpdateVDevice, VMBuiltinLower, VMShapeLower, + AnnotateCustomMemoryScope, + SpecializeTIRParams, dataflowblock_pass, function_pass, ) diff --git a/python/tvm/relax/transform/legalize_ops/__init__.py b/python/tvm/relax/transform/legalize_ops/__init__.py index 5614d0229646..d4a681997b7a 100644 --- a/python/tvm/relax/transform/legalize_ops/__init__.py +++ b/python/tvm/relax/transform/legalize_ops/__init__.py @@ -32,3 +32,6 @@ from . import statistical from . import unary from . import vision + +# Device specific legalizations +from . import adreno diff --git a/python/tvm/relax/transform/legalize_ops/adreno/__init__.py b/python/tvm/relax/transform/legalize_ops/adreno/__init__.py new file mode 100644 index 000000000000..f2b3f4a781d2 --- /dev/null +++ b/python/tvm/relax/transform/legalize_ops/adreno/__init__.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Legalize high-level operator calls in Relax functions to call_tir.""" +from .convolution import conv2d_NCHWc_OIHWo diff --git a/python/tvm/relax/transform/legalize_ops/adreno/convolution.py b/python/tvm/relax/transform/legalize_ops/adreno/convolution.py new file mode 100644 index 000000000000..8d6a871e4d1c --- /dev/null +++ b/python/tvm/relax/transform/legalize_ops/adreno/convolution.py @@ -0,0 +1,35 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring, invalid-name +"""A Convolution impl for Adreno GPU.""" + +from tvm import relax +from tvm import topi + +def conv2d_NCHWc_OIHWo(bb: relax.BlockBuilder, call: relax.Call) -> relax.Expr: + return bb.call_te( + topi.nn.conv2d_NCHWc_OIHWo, + data=call.args[0], + kernel=call.args[1], + stride=call.attrs.strides, + padding=call.attrs.padding, + dilation=call.attrs.dilation, + layout=call.attrs.data_layout, + out_layout=call.attrs.out_layout, + #out_dtype=call.attrs.out_dtype, + primfunc_name_hint="conv2d_NCHWc_OIHWo", + ) diff --git a/python/tvm/relax/transform/optimize_batchnorm.py b/python/tvm/relax/transform/optimize_batchnorm.py new file mode 100644 index 000000000000..51b4b1e73319 --- /dev/null +++ b/python/tvm/relax/transform/optimize_batchnorm.py @@ -0,0 +1,108 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument, redefined-argument-from-local +"""Relax Optimize Batchnorm to fold it into previous Conv pass.""" +from tvm.ir.module import IRModule +from tvm.ir.transform import PassContext +from tvm.relax import Expr +from tvm.relax.dpl import is_op, rewrite_call, wildcard, is_const, TupleGetItemPattern +from tvm import relax, tir + +from . import function_pass + + +@function_pass(opt_level=0) +class OptimizeBatchnorm: + """ + Fuse Batchnorm to its previous Conv2D + """ + + def __init__(self): + self.input = wildcard() + self.weight = is_const() + self.pattern_conv2d = is_op("relax.nn.conv2d")(self.input, self.weight) + self.bn_weight = is_const() + self.bias = is_const() + self.mean = is_const() + self.variance = is_const() + self.pattern_bn = is_op("relax.nn.batch_norm")( + self.pattern_conv2d, + self.bn_weight, + self.bias, + self.mean, + self.variance + ) + + self.pattern = TupleGetItemPattern(self.pattern_bn, 0) + + def transform_function(self, func: Expr, mod: IRModule, ctx: PassContext) -> IRModule: + """ + Tranformation function to pattern Conv2D+BatchNorm+TupleGetItem pattern + + Parameters + ---------- + func: Expr + The relax function to be optimized + + mod: IRModule + The ir module + + ctx: PassContext + Relax pass context + """ + + self.mod = mod + updated_call = func + + # Skip primitive functions + if "Primitive" in func.attrs.keys() and func.attrs["Primitive"] != 0: + return updated_call + + def rewriter(expr, matches): + conv_input = matches[self.input] + conv_weight = matches[self.weight] + bn_weight = matches[self.bn_weight] + bn_bias = matches[self.bias] + bn_mean = matches[self.mean] + bn_variance = matches[self.variance] + conv_op = matches[self.pattern_conv2d] + bn_op = matches[self.pattern_bn] + conv_attrs = conv_op.attrs + bn_attrs = bn_op.attrs + + bn_variance = relax.op.add( + bn_variance, + relax.PrimValue(tir.FloatImm("float32", bn_attrs['epsilon'])) + ) + dino = relax.op.sqrt(bn_variance) + wt = relax.op.divide(bn_weight, dino) + bs = relax.op.subtract(bn_bias, relax.op.multiply(bn_mean, wt)) + if conv_attrs["kernel_layout"] == "OIHW": + wt = relax.op.reshape(wt, shape=(bn_weight.struct_info.shape[0], 1, 1, 1)) + elif conv_attrs["kernel_layout"] == "IOHW": + wt = wt.reshape(1, bn_weight.struct_info.shape[0], 1, 1) + else: + return expr + wt_conv = relax.op.multiply(conv_weight, wt) + bs_args = relax.op.reshape(bs, shape=(1, bn_bias.struct_info.shape[0] , 1, 1)) + + conv_out = relax.Call(conv_op.op, (conv_input, wt_conv), conv_attrs) + return relax.op.add(conv_out, bs_args) + + updated_call = rewrite_call(self.pattern, rewriter, func) + + return updated_call diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index b3c4e7110157..cc170bd38c9b 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -30,6 +30,7 @@ from tvm.relax.dpl import DFPattern from tvm.runtime import Tensor, Object from tvm.tir import IndexMap, PrimFunc +from tvm.target import Target from . import _ffi_api from .legalize_ops.common import LegalizeFunc @@ -1605,6 +1606,30 @@ def AllocateWorkspace() -> tvm.ir.transform.Pass: return _ffi_api.AllocateWorkspace() # type: ignore +def AnnotateCustomMemoryScope(target: Optional[Target] = None) -> tvm.ir.transform.Pass: + """Allocate the memory scope information. This is Adreno specific pass to annotate + The memory scope information and realize the same with RealizeVDevice pass followed by + updating the Prim Function var_buffer mapping using SpecializeTIRParams. + + Returns + ------- + ret: tvm.ir.transform.Pass + The registered pass for allocating workspace. + """ + return _ffi_api.AnnotateCustomMemoryScope(target) # type: ignore + + +def SpecializeTIRParams() -> tvm.ir.transform.Pass: + """Map modified tir_call params to prim_func buffers. + + Returns + ------- + ret: tvm.ir.transform.Pass + The registered pass for allocating workspace. + """ + return _ffi_api.SpecializeTIRParams() # type: ignore + + def _wrap_class_function_pass(pass_cls, pass_info): """Wrap a python class as function pass.""" diff --git a/python/tvm/tir/analysis/analysis.py b/python/tvm/tir/analysis/analysis.py index 915b7f765c10..ef2da05c347c 100644 --- a/python/tvm/tir/analysis/analysis.py +++ b/python/tvm/tir/analysis/analysis.py @@ -18,7 +18,7 @@ # pylint: disable=invalid-name from typing import Dict, List, Optional, Union -import tvm +from tvm import Object, _ffi from tvm.ir import IRModule from tvm.tir.expr import Var from tvm.tir.stmt import Block, BufferRegion, PrimExpr @@ -301,6 +301,10 @@ def find_anchor_block(mod: IRModule) -> Block: return _ffi_api.find_anchor_block(mod) # type: ignore # pylint: disable=no-member +def has_if_then_else(stmt: Stmt) -> bool: + return _ffi.get_global_func("tir.schedule.HasIfThenElse")(stmt) + + def get_vtcm_compaction_passes() -> List[tvm.transform.Pass]: """Utility function to get the list of lowering passes to be applied to calculate the compacted VTCM allocation size diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py index 531c0a6c6663..ce14df8beddf 100644 --- a/python/tvm/topi/nn/conv2d.py +++ b/python/tvm/topi/nn/conv2d.py @@ -394,6 +394,135 @@ def conv2d_NCHWc(data, kernel, stride, padding, dilation, layout, out_layout, ou ) +def conv2d_NCHWc_OIHWo( + data: te.Tensor, kernel, stride, padding, dilation, layout, out_layout, out_dtype="float32" +): + """Conv2D operator for nChw[x]c layout. + + Parameters + ---------- + data : tvm.te.Tensor + 5-D with shape [batch, in_channel_chunk, in_height, in_width, in_channel_block] + + kernel : tvm.te.Tensor + 6-D with shape + [num_filter_chunk, in_channel_chunk, filter_height, filter_width, + num_filter_block] + + stride : int or a list/tuple of two ints + stride size, or [stride_height, stride_width] + + padding : int or a list/tuple of 2 or 4 ints + padding size, or + [pad_height, pad_width] for 2 ints, or + [pad_top, pad_left, pad_bottom, pad_right] for 4 ints + + dilation: int or a list/tuple of two ints + dilation size, or [dilation_height, dilation_width] + + layout : str + Input data layout + + out_layout : str + Output data layout + + out_dtype : str + output data type + + Returns + ------- + output : tvm.te.Tensor + 5-D with shape [batch, out_channel_chunk, out_height, out_width, out_channel_block] + """ + + # layout and out_layout are not used here, + # we keep them for debug convenience when dumping autotvm workload + HSTR, WSTR = stride if isinstance(stride, (tuple, list)) else (stride, stride) + dilation_h, dilation_w = ( + dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation) + ) + + n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape) + in_channel = ic_chunk * ic_bn + kernel_shape = get_const_tuple(kernel.shape) + if len(kernel_shape) == 6: # OIHW4i4o + oc_chunk, ic_chunk_group, kernel_height, kernel_width, kernel_ic_bn, oc_bn = kernel_shape + groups = in_channel // (ic_chunk_group * kernel_ic_bn) + else: # OIHW4o + oc_chunk, ic, kernel_height, kernel_width, oc_bn = kernel_shape + groups = in_channel // ic + + num_filter = oc_chunk * oc_bn + dilated_kernel_h = (kernel_height - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_width - 1) * dilation_w + 1 + + pad_top, pad_left, pad_down, pad_right = get_pad_tuple( + padding, (dilated_kernel_h, dilated_kernel_w) + ) + HPAD = pad_top + pad_down + WPAD = pad_left + pad_right + + # output shape + out_height = (ih + HPAD - dilated_kernel_h) // HSTR + 1 + out_width = (iw + WPAD - dilated_kernel_w) // WSTR + 1 + oshape = (n, oc_chunk, out_height, out_width, oc_bn) + pad_before = (0, 0, pad_top, pad_left, 0) + pad_after = (0, 0, pad_down, pad_right, 0) + + # DOPAD + DOPAD = HPAD != 0 or WPAD != 0 + if DOPAD: + data_pad = pad(data, pad_before, pad_after, name="conv2d_data_pad") + else: + data_pad = data + + kh = te.reduce_axis((0, kernel_height), name="kh") + kw = te.reduce_axis((0, kernel_width), name="kw") + + idxdiv = tvm.tir.indexdiv + idxmod = tvm.tir.indexmod + + def compute_conv2d(*args): + n, occ, oh, ow, ocb = args + ic = te.reduce_axis((0, in_channel // groups), name="ic") + if groups == 1: + data_pad_ = data_pad[ + n, + idxdiv(ic, ic_bn), + oh * HSTR + kh * dilation_h, + ow * WSTR + kw * dilation_w, + idxmod(ic, ic_bn), + ] + else: + data_pad_ = data_pad[ + n, + (occ // (oc_chunk // groups)) * (ic_chunk // groups) + idxdiv(ic, ic_bn), + oh * HSTR + kh * dilation_h, + ow * WSTR + kw * dilation_w, + idxmod(ic, ic_bn), + ] + if len(kernel_shape) == 5: + kernel_ = kernel[occ, ic, kh, kw, ocb] + else: + kernel_ = kernel[occ, idxdiv(ic, oc_bn), kh, kw, idxmod(ic, oc_bn), ocb] + + if out_dtype is not None: + data_pad_ = data_pad_.astype(out_dtype) + kernel_ = kernel_.astype(out_dtype) + + return te.sum( + data_pad_ * kernel_, + axis=[ic, kh, kw], + ) + + return te.compute( + oshape, + lambda *indices: compute_conv2d(*indices), # pylint: disable=W0108 + name="conv2d_NCHWc_OIHWo", + tag="conv2d_NCHWc_OIHWo", + ) + + def conv2d_NCHWc_int8( data, kernel, stride, padding, dilation, layout, out_layout, out_dtype="int32", n_elems=4 ): diff --git a/src/relax/op/nn/convolution.cc b/src/relax/op/nn/convolution.cc index 4f3c3382536c..49e92719ba15 100644 --- a/src/relax/op/nn/convolution.cc +++ b/src/relax/op/nn/convolution.cc @@ -319,6 +319,8 @@ InferLayoutOutput InferLayoutConv2d( ICHECK(attrs) << "Invalid Call"; LayoutDecision data_layout, weight_layout, output_layout; + data_layout = GetLayoutDecision(var_layout_map, call->args[0]); + weight_layout = GetLayoutDecision(var_layout_map, call->args[1]); ObjectPtr new_attrs = ffi::make_object(*attrs); if (it != desired_layouts.end()) { @@ -366,14 +368,16 @@ InferLayoutOutput InferLayoutConv2d( new_attrs->kernel_layout = (*it).second[1]; new_attrs->out_layout = (*it).second.size() == 3 ? (*it).second[2] : (*it).second[0]; return InferLayoutOutput({data_layout, weight_layout}, {output_layout}, Attrs(new_attrs)); + } else { + data_layout = LayoutDecision(InitialLayout(4)); + weight_layout = LayoutDecision(InitialLayout(4)); } } } // We don't have a desired layout for conv2d or desired layouts not compatible. // We can just propagate the layout from the input. - data_layout = GetLayoutDecision(var_layout_map, call->args[0]); - weight_layout = GetLayoutDecision(var_layout_map, call->args[1]); + output_layout = data_layout; new_attrs->data_layout = TransposeLike(attrs->data_layout, InitialLayout(4), data_layout->layout).name(); diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index d91c19b63fd2..d732250cf9ec 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -1506,11 +1506,12 @@ TVM_REGISTER_OP("relax.hint_on_device") .set_attr("FInferStructInfo", InferHintOnDeviceStructInfo) .set_attr("FPurity", Bool(true)); -Expr MakeHintOnDevice(Expr data, Device device) { +Expr MakeHintOnDevice(Expr data, VDevice vdevice) { static const Op& op = Op::Get("relax.hint_on_device"); ObjectPtr attrs = ffi::make_object(); attrs->device_type = static_cast(device.device_type); attrs->index = device.device_id; + attrs->memory_scope = vdevice->memory_scope; return Call(op, {data}, Attrs(attrs), {}); } diff --git a/src/relax/op/tensor/binary.cc b/src/relax/op/tensor/binary.cc index eeb4d552e787..7051d2b1b975 100644 --- a/src/relax/op/tensor/binary.cc +++ b/src/relax/op/tensor/binary.cc @@ -158,14 +158,19 @@ InferLayoutOutput InferLayoutBinaryEwise( ffi::Optional shape1 = ffi::GetRef(x1_sinfo->shape.as()); ffi::Optional shape2 = ffi::GetRef(x2_sinfo->shape.as()); // Lets handle sub indexing as long as primal dims are matching - if (layout1->layout.ndim_primal() == layout2->layout.ndim_primal()) { - if ((layout1->layout.ndim() >= layout2->layout.ndim()) && shape2.defined()) { - if (CanProveLayoutTransform(layout2->layout, layout1->layout, shape2.value()->values)) { - return InferLayoutOutput({layout1, layout1}, {layout1}, Attrs(call->attrs)); - } - } else if (shape1.defined()) { - if (CanProveLayoutTransform(layout1->layout, layout2->layout, shape1.value()->values)) { - return InferLayoutOutput({layout2, layout2}, {layout2}, Attrs(call->attrs)); + if ((layout1->layout.ndim() != layout1->layout.ndim_primal()) || + (layout2->layout.ndim() != layout2->layout.ndim_primal())) { + if (layout1->layout.ndim_primal() == layout2->layout.ndim_primal()) { + if ((layout1->layout.ndim() >= layout2->layout.ndim()) && shape2.defined()) { + if (CanProveLayoutTransform(InitialLayout(shape2.value()->values.size()), layout1->layout, + shape2.value()->values)) { + return InferLayoutOutput({layout1, layout1}, {layout1}, Attrs(call->attrs)); + } + } else if (shape1.defined()) { + if (CanProveLayoutTransform(InitialLayout(shape1.value()->values.size()), layout2->layout, + shape1.value()->values)) { + return InferLayoutOutput({layout2, layout2}, {layout2}, Attrs(call->attrs)); + } } } } diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 79c0687cada5..7c2c220dcf44 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -334,12 +334,54 @@ InferLayoutOutput InferLayoutConcat( const auto* attrs = call->attrs.as(); ICHECK(attrs != nullptr) << "Invalid Call"; + NLayout nlayout = GetNLayout(var_layout_map, call->args[0]); ICHECK(nlayout.IsNested()); ICHECK(nlayout.NestedArray()[0].IsLeaf()); int n_tensor = nlayout.NestedArray().size(); LayoutDecision layout = nlayout.NestedArray()[0].LeafValue(); + + // We may expect mix of sub indexed and regular layouts here + // Pick the first sub indexed layout and try to prove it for all tensors + // On any failre select first occuring regular layout for all + auto nlayout_array = nlayout.NestedArray(); + for (auto n_layout : nlayout_array) { + ICHECK(n_layout.IsLeaf()); + LayoutDecision in_layout = n_layout.LeafValue(); + if (in_layout->layout.ndim() != in_layout->layout.ndim_primal()) { + const auto* tuple_sinfo = GetStructInfoAs(call->args[0]); + ICHECK(tuple_sinfo != nullptr) + << " expects the input to be a Tuple of Tensors. However, the given input is " + << call->args[0]->struct_info_->GetTypeKey(); + for (size_t i = 0; i < tuple_sinfo->fields.size(); ++i) { + StructInfo field_sinfo = tuple_sinfo->fields[i]; + const auto* field_tensor_sinfo = field_sinfo.as(); + ICHECK(field_tensor_sinfo != nullptr) + << call->op + << " expects the input to be a Tuple of Tensors. However, the given input is " + << call->args[0]->struct_info_; + auto t_sinfo = GetRef(field_tensor_sinfo); + Optional t_shape = GetRef(t_sinfo->shape.as()); + LayoutDecision curr_layout = nlayout_array[i].LeafValue(); + if (!CanProveLayoutTransform(curr_layout->layout, in_layout->layout, + t_shape.value()->values)) { + // Some tensor unhappy with sub indexed layout, lets pick first regular layout + for (auto pick_layout : nlayout_array) { + if (pick_layout.LeafValue()->layout.ndim() == + pick_layout.LeafValue()->layout.ndim_primal()) { + in_layout = pick_layout.LeafValue(); + break; + } + } + break; + } + } + layout = in_layout; + break; + } + } + ffi::Array input_layouts, output_layouts; for (int i = 0; i < n_tensor; ++i) { input_layouts.push_back(layout); diff --git a/src/relax/transform/annotate_custom_storage.cc b/src/relax/transform/annotate_custom_storage.cc new file mode 100644 index 000000000000..0abea4636dcc --- /dev/null +++ b/src/relax/transform/annotate_custom_storage.cc @@ -0,0 +1,455 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/transform/annotate_texture_storage.cc + * \brief Texture Storage Annotation Pass. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "../op/tensor/manipulate.h" +#include "infer_layout_utils.h" +#include "utils.h" + +namespace tvm { +namespace relax { + +using tvm::tir::Buffer; + +static Array GetShapeFromTensorStructInfo(const TensorStructInfo& tensor_sinfo) { + auto shape = tensor_sinfo->GetShape(); + ICHECK(shape.defined()); + return shape.value(); +} + +class CollectProduserScopeInfo : public ExprVisitor { + public: + using ExprVisitor::VisitExpr_; + + Map Collect(const IRModule& mod, Function func, + const Map>>& scope_info, + const Target& target) { + mod_ = mod; + scope_info_ = scope_info; + target_ = target; + VisitExpr(func->body); + + return producer_sinfo; + } + + void VisitBinding_(const VarBindingNode* binding, const CallNode* call) final { + ExprVisitor::VisitBinding_(binding, call); + + static const Op& call_tir_op = Op::Get("relax.call_tir"); + StructInfo out_sinfo; + + if (call->op == call_tir_op) { + out_sinfo = call->sinfo_args[0]; + } else { + return; + } + + std::unordered_map scope_count; + + auto arg_var = binding->var.as(); + if (scope_info_.find(GetRef(arg_var)) != scope_info_.end()) { + for (const auto& val : scope_info_[GetRef(arg_var)]) { + auto call_node = Downcast(val.first); + if (scope_count.find(val.second[0]) == scope_count.end()) { + scope_count.insert({val.second[0], 1}); + } else { + auto curr_count = scope_count[val.second[0]]; + scope_count.emplace(val.second[0], curr_count + 1); + } + } + } + String final_scope = "global"; + int count = 0; + for (const auto& sval : scope_count) { + if (sval.second > count) { + final_scope = sval.first; + count = sval.second; + } + } + StructInfo updated_ret_sinfo = UpdateStructInfo(out_sinfo, {final_scope}); + producer_sinfo.Set(GetRef(call), updated_ret_sinfo); + } + + private: + StructInfo UpdateStructInfo(const StructInfo& out_sinfo, Array scope) { + if (out_sinfo->IsInstance()) { + auto tensor_sinfo = Downcast(out_sinfo); + auto shape_arr = GetShapeFromTensorStructInfo(tensor_sinfo); + return TensorStructInfo(ShapeExpr(shape_arr), tensor_sinfo->dtype, + VDevice(target_, 0, scope[0])); + } + + ICHECK(out_sinfo->IsInstance()) + << "Expect output struct info of call_tir to be either TupleStructInfo or " + "TensorStructInfo, but got " + << out_sinfo; + + const auto& tuple_sinfo = Downcast(out_sinfo); + Array sinfo_fields; + for (const auto& si : tuple_sinfo->fields) { + ICHECK(si->IsInstance()) + << "Fields of TupleStructInfo must be TensorStructInfo for call_tir " + "output structinfo, but got " + << si; + auto sinfo = Downcast(si); + auto shape_arr = GetShapeFromTensorStructInfo(sinfo); + sinfo_fields.push_back( + TensorStructInfo(ShapeExpr(shape_arr), sinfo->dtype, VDevice(target_, 0, scope[0]))); + } + return TupleStructInfo(sinfo_fields); + } + + Map>> scope_info_; + Map producer_sinfo; + IRModule mod_; + Target target_; +}; + +class CollectConsumerScopeInfo : public ExprVisitor { + public: + using ExprVisitor::VisitExpr_; + + std::pair>, Map>>> Collect( + const IRModule& mod, Function func, const Target& target) { + mod_ = mod; + target_ = target; + VisitExpr(func->body); + for (const auto& val : arg_to_binding) { + if (scope_info.find(val.first) != scope_info.end()) { + if (scope_info.find(val.second) == scope_info.end()) { + scope_info.Set(val.second, scope_info[val.first]); + } else { + auto ent = scope_info[val.second]; + for (auto ent_val : scope_info[val.first]) { + ent.Set(ent_val.first, ent_val.second); + } + scope_info.Set(val.second, ent); + } + } + } + + return std::make_pair(call_scope_info, scope_info); + } + + void VisitBinding_(const VarBindingNode* binding, + const TupleGetItemNode* tuple_get_item_node) final { + if (arg_to_binding.find(GetRef(binding->var.get())) == arg_to_binding.end()) { + arg_to_binding.Set(GetRef(binding->var.get()), + GetRef(tuple_get_item_node->tuple.get())); + } + } + + void VisitExpr_(const CallNode* call) final { + static const Op& call_tir_op = Op::Get("relax.call_tir"); + GlobalVar gv; + Array op_attrs; + Optional op_pattern = Integer(static_cast(relay::kOpaque)); + Tuple func_args; + + StructInfo out_sinfo; + + if (call->op == call_tir_op) { + gv = Downcast(call->args[0]); + tir::PrimFunc pfunc = Downcast(mod_->Lookup(gv)); + op_attrs = ExtractAttrs(pfunc); + op_pattern = ExtractPattern(pfunc); + out_sinfo = call->sinfo_args[0]; + func_args = Downcast(call->args[1]); + } else { + return; + } + + bool is_texture_supported = SupportsTexture(op_attrs, op_pattern.value()); + + Array arg_scope; + for (auto arg : func_args->fields) { + auto sinfo = GetStructInfo(arg); + if (auto tensor_sinfo = sinfo.as()) { + auto scope = is_texture_supported + ? Scope(GetShapeFromTensorStructInfo(tensor_sinfo.value())) + : "global"; + Map> ent_call; + const VarNode* arg_var = arg.as(); + if (scope_info.find(GetRef(arg_var)) != scope_info.end()) { + ent_call = scope_info[GetRef(arg_var)]; + } + ent_call.Set(GetRef(call), {scope}); + scope_info.Set(GetRef(arg_var), ent_call); + arg_scope.push_back(scope); + } + } + call_scope_info.Set(GetRef(call), arg_scope); + } + + private: + template + Array ExtractAttrs(const T& func) { + Array op_attrs; + Optional attrs = func->template GetAttr("op_attrs"); + if (attrs) { + if (auto val = attrs.value().as()) { + op_attrs.push_back(val.value()); + } else if (auto val = attrs.value().as>()) { + op_attrs = val.value(); + } + } + return std::move(op_attrs); + } + + template + Optional ExtractPattern(const T& func) { + Optional op_pat = func->template GetAttr("op_pattern"); + return std::move(op_pat); + } + + bool SupportsTexture(const Array& op_attrs, Integer op_pattern) { + if (op_pattern.IntValue() < relay::kCommReduce) return true; + + for (auto attr : op_attrs) { + if (auto conv_attr = attr.as()) { + if (conv_attr->data_layout == "NCHW4c" && conv_attr->kernel_layout == "OIHW4o") { + return true; + } + } else if (auto pool_attrs = attr.as()) { + if (pool_attrs->layout == "NCHW4c") { + return true; + } + } else if (auto avg_attrs = attr.as()) { + if (avg_attrs->layout == "NCHW4c") { + return true; + } + } else if (attr.as()) { + return true; + } + } + + return false; + } + + std::string Scope(Array shape) { + // currently we support only textures been made from 5d tensors + // 5d requirement is not limitation of textures in general, it is limitation how + // we are representing memory scopes/layout and flattening of textures in tir + if (shape.size() == 5 && shape[4].as()->value == 4) { + std::map diffs; + int spatial_limit = + target_->GetAttr("texture_spatial_limit").value_or(Integer(16384))->value; + int depth_limit = + target_->GetAttr("texture_depth_limit").value_or(Integer(2048))->value; + int a0 = shape[0].as()->value; + int a1 = shape[1].as()->value; + int a2 = shape[2].as()->value; + int a3 = shape[3].as()->value; + + int d1r = a0 * a1; + int d2r = a2 * a3; + int d3r = a1 * a2 * a3; + std::string scope = "global"; + if (a0 < spatial_limit && d3r < spatial_limit) + scope += ".texture-weight"; + else if (a0 < depth_limit && a1 < spatial_limit && d2r < spatial_limit) + scope += ".texture-nhwc"; + else if (d1r < depth_limit && a2 < spatial_limit && a3 < spatial_limit) + scope += ".texture"; + return scope; + } + return "global"; + } + + Map>> scope_info; + Map> call_scope_info; + Map arg_to_binding; + IRModule mod_; + Target target_; +}; + +class DefineVDevice : ExprMutator { + public: + explicit DefineVDevice(const Target& target) : target_(target) {} + + IRModule Run(IRModule& mod) { + mod_ = mod; + for (const auto& [gv, func] : mod_->functions) { + if (func->IsInstance()) { + const auto& base_func = mod_->Lookup(gv); + // Only non primitive relax functions + if (base_func->HasNonzeroAttr(attr::kPrimitive)) { + continue; + } + auto info = CollectConsumerScopeInfo().Collect(mod_, Downcast(func), target_); + call_scope_info_ = info.first; + scope_info_ = info.second; + producer_sinfo_ = CollectProduserScopeInfo().Collect(mod_, Downcast(func), + scope_info_, target_); + relax::Function update_func = Downcast(VisitExpr(func)); + updates_->Add(gv, update_func); + } + } + mod_.CopyOnWrite()->Update(updates_); + + Array global_vdevices_; + for (auto vdev : vdevices_) { + global_vdevices_.push_back(vdev.as().value()); + } + mod_.CopyOnWrite()->global_infos.Set("vdevice", global_vdevices_); + + mod_ = relax::transform::SpecializeTIRParams()(mod_); + mod_ = relax::transform::DeadCodeElimination()(mod_); + mod_ = relax::transform::RealizeVDevice()(mod_); + mod_ = relax::transform::SpecializeTIRParams()(mod_); + + return mod_; + } + + using ExprMutator::VisitExpr_; + + Expr VisitExpr_(const CallNode* call_node) override { + auto call = Downcast(ExprMutator::VisitExpr_(call_node)); + static const Op& call_tir_op = Op::Get("relax.call_tir"); + + GlobalVar gv; + Tuple func_args; + + StructInfo out_sinfo; + + if (call->op == call_tir_op) { + gv = Downcast(call->args[0]); + tir::PrimFunc pfunc = Downcast(mod_->Lookup(gv)); + out_sinfo = call->sinfo_args[0]; + func_args = Downcast(call->args[1]); + } else { + return call; + } + + Array new_args; + StructInfo updated_ret_sinfo = producer_sinfo_[GetRef(call_node)]; + + int arg_idx = 0; + for (auto arg : func_args->fields) { + auto sinfo = GetStructInfo(arg); + if (auto tensor_sinfo = sinfo.as()) { + String scope = "global"; + if (call_scope_info_.find(GetRef(call_node)) != call_scope_info_.end()) { + scope = call_scope_info_[GetRef(call_node)][arg_idx]; + } + new_args.push_back(HintArg(arg, scope)); + arg_idx++; + } else { + new_args.push_back(arg); + } + } + + if (call->op == call_tir_op) { + auto updated_call = + Call(call_tir_op, {gv, Tuple(new_args)}, call->attrs, {updated_ret_sinfo}); + return builder_->Normalize(updated_call); + } else { + auto updated_call = Call(call->op, new_args, call->attrs, {updated_ret_sinfo}); + return builder_->Normalize(updated_call); + } + } + + private: + void AppendToVDevices(VDevice vdev) { + int device_type = vdev->target->GetTargetDeviceType(); + for (auto vdevice : vdevices_) { + int dev_type = vdevice->target->GetTargetDeviceType(); + if (dev_type == device_type && vdevice->vdevice_id == vdev->vdevice_id && + vdevice->memory_scope == vdev->memory_scope) { + return; + } + } + vdevices_.push_back(vdev); + return; + } + + Expr HintArg(const Expr& arg, const String& scope) { + if (arg->IsInstance()) { + if (auto tsinfo = arg->struct_info_.as()) { + if (!tsinfo->vdevice.defined()) { + VDevice vdev = VDevice(target_, 0, scope); + CHECK(tsinfo->shape.defined()) << "Shape not defined for a constant tensor ..!"; + arg->struct_info_ = + TensorStructInfo(tsinfo->shape.value(), tsinfo->dtype, vdev, tsinfo->span); + return arg; + } + } + } + ObjectPtr attrs = make_object(); + attrs->dev_type = target_->GetTargetDeviceType(); + attrs->dev_id = 0; + attrs->memory_scope = scope; + + Expr new_arg = Call(hint_on_device_op_, {arg}, Attrs{std::move(attrs)}, {}); + + AppendToVDevices(VDevice(target_, 0, scope)); + return std::move(new_arg); + } + + Optional GetTarget(const StructInfo& sinfo) { + auto tinfo = sinfo.as(); + if (tinfo->vdevice.defined()) { + auto vdevice = tinfo->vdevice.value(); + if (vdevice->target.defined()) { + return vdevice->target; + } + } + return NullOpt; + } + + const Op& hint_on_device_op_ = Op::Get("relax.hint_on_device"); + IRModule mod_; + IRModule updates_; + Target target_; + Array vdevices_; + Map>> scope_info_; + Map producer_sinfo_; + Map> call_scope_info_; +}; + +namespace transform { + +Pass AnnotateCustomMemoryScope(Target target) { + runtime::TypedPackedFunc pass_func = + [=](IRModule mod, PassContext pc) { return relax::DefineVDevice(target).Run(mod); }; + return CreateModulePass(/*pass_function=*/pass_func, + /*opt_level=*/0, + /*pass_name=*/"AnnotateCustomMemoryScope", + /*required=*/{}); +} + +TVM_REGISTER_GLOBAL("relax.transform.AnnotateCustomMemoryScope") + .set_body_typed(AnnotateCustomMemoryScope); + +} // namespace transform +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index ba4515faf390..3ed14fa85307 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -642,6 +642,14 @@ class FusedTIRConstructor : public ExprVisitor { // Step 1. Get Global var and PrimFunc GlobalVar gv = Downcast(call->args[0]); tir::PrimFunc prim_func_ = Downcast(mod_->Lookup(gv)); + if (prim_func_->GetAttr("op_attrs")) { + func_info_.op_attrs.push_back(prim_func_->GetAttr("op_attrs").value()); + } + + if (prim_func_->GetAttr("op_pattern")) { + auto op_pattern = prim_func_->GetAttr("op_pattern").value(); + func_info_.op_pattern.push_back(static_cast(op_pattern.IntValue())); + } // Step 2. Renew all vars/buffer definitions and blocks to avoid duplication tir::PrimFunc prim_func = tir::RenewDefs(prim_func_); @@ -959,6 +967,12 @@ class FusedTIRConstructor : public ExprVisitor { tir::PrimFunc ConstructFunc() { ffi::Map attr_map; attr_map.Set(tir::attr::kNoAlias, true); + attr_map.Set("op_attrs", func_info_.op_attrs); + int op_pattern = relay::kOpaque; + if (!func_info_.op_pattern.empty()) { + op_pattern = *max_element(func_info_.op_pattern.begin(), func_info_.op_pattern.end()); + } + attr_map.Set("op_pattern", Integer(static_cast(op_pattern))); tir::FuseTIRBufferSubstitutor subst(func_info_.buffer_subst_map, func_info_.symbolic_var_remap); ICHECK(func_info_.global_name != "fused"); // Remove output buffers from func_info_.alloc_buffers @@ -1010,6 +1024,8 @@ class FusedTIRConstructor : public ExprVisitor { ffi::Array alloc_buffers; /*! \brief The bodies of the original funcs, which is also the body of the fused func. */ ffi::Array bodies; + ffi::Array op_attrs; + std::vector op_pattern; /*! \brief The params of the fused function*/ ffi::Array params; /*! diff --git a/src/relax/transform/legalize_ops.cc b/src/relax/transform/legalize_ops.cc index 64ac5e86fb48..e8358798dddc 100644 --- a/src/relax/transform/legalize_ops.cc +++ b/src/relax/transform/legalize_ops.cc @@ -154,6 +154,32 @@ class LegalizeMutator : public ExprMutator { return std::nullopt; } + Expr AttributeOpAttrs(Expr expr, Attrs attrs) { + if (!expr->IsInstance()) { + return expr; + } + + auto call = Downcast(expr); + if (call->args.empty()) { + return expr; + } + + auto gvar = call->args[0].as(); + if (!gvar.defined()) { + return expr; + } + + auto base_func = builder_->GetContextIRModule()->Lookup(gvar.value()); + auto opt_prim_func = base_func.as(); + if (!opt_prim_func) { + return expr; + } + auto prim_func = opt_prim_func.value(); + auto new_prim_func = WithAttr(prim_func, "op_attrs", attrs); + builder_->UpdateFunction(gvar.value(), new_prim_func); + return call; + } + Expr BindTarget(Expr expr) { if (!expr->IsInstance()) { // FLegalize returned something other than a relax::Call. This @@ -344,6 +370,8 @@ class LegalizeMutator : public ExprMutator { } Expr legalized = legalization_func(builder_, visited_call); + legalized = AttributeOpAttrs(legalized, call->attrs); + // Append the target attribute to any PrimFunc generated in // legalization. legalized = BindTarget(legalized); diff --git a/src/relax/transform/realize_vdevice.cc b/src/relax/transform/realize_vdevice.cc index 79c1bf36b549..b9456052264a 100644 --- a/src/relax/transform/realize_vdevice.cc +++ b/src/relax/transform/realize_vdevice.cc @@ -56,6 +56,7 @@ class VDeviceLookup { ICHECK(attrs); int32_t device_type = attrs->device_type; int32_t device_id = attrs->index; + String memory_scope = attrs->memory_scope; CHECK(opt_vdevices_.defined()) << "ValueError: The target VDevice in the GlobalInfos was not found."; @@ -66,7 +67,8 @@ class VDeviceLookup { for (auto vdevice : vdevices) { int dev_type = vdevice->target->GetTargetDeviceType(); - if (dev_type == device_type && vdevice->vdevice_id == device_id) { + if (dev_type == device_type && vdevice->vdevice_id == device_id && + memory_scope == vdevice->memory_scope) { return vdevice; } } diff --git a/src/relax/transform/specialize_tir_params.cc b/src/relax/transform/specialize_tir_params.cc new file mode 100644 index 000000000000..7df1c1579576 --- /dev/null +++ b/src/relax/transform/specialize_tir_params.cc @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/transform/specialize_tir_params.cc + * \brief Update PrimFunc buffers based on updated scope (or structure) info. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "../op/tensor/manipulate.h" +#include "infer_layout_utils.h" +#include "utils.h" + +namespace tvm { +namespace relax { + +using tvm::tir::Buffer; + +static Array GetShapeFromTensorStructInfo(const TensorStructInfo& tensor_sinfo) { + auto shape = tensor_sinfo->GetShape(); + ICHECK(shape.defined()); + return shape.value(); +} + +class SpecializeTIRCallArgs : ExprMutator { + public: + IRModule Run(IRModule mod) { + mod_ = mod; + for (const auto& [gv, func] : mod->functions) { + if (func->IsInstance()) { + const auto& base_func = mod->Lookup(gv); + // Only non primitive relax functions + if (base_func->HasNonzeroAttr(attr::kPrimitive)) { + continue; + } + relax::Function update_func = Downcast(VisitExpr(func)); + updates_->Add(gv, update_func); + } + } + mod_.CopyOnWrite()->Update(updates_); + return mod_; + } + + using ExprMutator::VisitExpr_; + + Expr VisitExpr_(const CallNode* call_node) override { + auto call = Downcast(ExprMutator::VisitExpr_(call_node)); + static const Op& call_tir_op = Op::Get("relax.call_tir"); + if (call->op == call_tir_op) { + return SpecializeTirPrimFunc(call); + } + return call; + } + + private: + Expr SpecializeTirPrimFunc(Call call) { + auto gv = Downcast(call->args[0]); + auto pfunc = Downcast(mod_->Lookup(gv)); + auto args = Downcast(call->args[1])->fields; + Map> param_map; + + for (size_t i = 0; i < args.size(); ++i) { + auto sinfo = GetStructInfo(args[i]); + CHECK(sinfo->IsInstance()) + << "Expected Tensor struct Info for call :" << call->op; + auto tensor_sinfo = Downcast(sinfo); + CHECK(tensor_sinfo->shape.defined()) << "Shape undefined for call:" << call->args[0]; + String scope = "global"; + if (tensor_sinfo->vdevice.defined()) { + scope = tensor_sinfo->vdevice.value()->memory_scope; + } + String name; + if (args[i]->IsInstance()) { + name = Downcast(args[i])->name_hint(); + } else { + name = std::string({static_cast('A' + i)}); + } + + const Buffer& buffer = tir::decl_buffer(GetShapeFromTensorStructInfo(tensor_sinfo), + tensor_sinfo->dtype, name, scope); + param_map.Set(pfunc->params[i], buffer); + } + String scope = "global"; + auto out_sinfo = call->sinfo_args[0]; + if (out_sinfo->IsInstance()) { + auto sinfo = Downcast(out_sinfo); + if (sinfo->vdevice.defined()) { + scope = sinfo->vdevice.value()->memory_scope; + } + const Buffer& buffer = + tir::decl_buffer(GetShapeFromTensorStructInfo(sinfo), sinfo->dtype, "ret_val", scope); + param_map.Set(pfunc->params[pfunc->params.size() - 1], buffer); + } else { + ICHECK(out_sinfo->IsInstance()) + << "Expect output struct info of call_tir to be either TupleStructInfo or " + "TensorStructInfo, but got " + << out_sinfo; + + const auto& tuple_sinfo = Downcast(out_sinfo); + Array sinfo_fields; + int index = 0; + for (const auto& si : tuple_sinfo->fields) { + ICHECK(si->IsInstance()) + << "Fields of TupleStructInfo must be TensorStructInfo for call_tir " + "output structinfo, but got " + << si; + auto sinfo = Downcast(si); + if (sinfo->vdevice.defined()) { + scope = sinfo->vdevice.value()->memory_scope; + } + const Buffer& buffer = + tir::decl_buffer(GetShapeFromTensorStructInfo(sinfo), sinfo->dtype, "ret_val", scope); + param_map.Set(pfunc->params[args.size() + index], buffer); + index++; + } + } + + auto new_pfunc = Specialize(pfunc, param_map); + for (const auto& [var, buffer] : new_pfunc->buffer_map) { + auto* ptr = buffer->data->type_annotation.as(); + ICHECK(ptr) << "Buffer Var's type annotation must be of PointerType"; + } + auto new_prim_func = WithAttr(new_pfunc, "scoped", Integer(1)); + updates_->Add(gv, new_prim_func); + return call; + } + IRModule mod_; + IRModule updates_; +}; + +namespace transform { + +Pass SpecializeTIRParams() { + runtime::TypedPackedFunc pass_func = + [=](IRModule mod, PassContext pc) { return relax::SpecializeTIRCallArgs().Run(mod); }; + return CreateModulePass(/*pass_function=*/pass_func, + /*opt_level=*/0, + /*pass_name=*/"SpecializeTIRParams", + /*required=*/{}); +} + +TVM_REGISTER_GLOBAL("relax.transform.SpecializeTIRParams").set_body_typed(SpecializeTIRParams); + +} // namespace transform +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/utils.h b/src/relax/transform/utils.h index ff8596cd79e3..7ecbdcc5af5b 100644 --- a/src/relax/transform/utils.h +++ b/src/relax/transform/utils.h @@ -386,7 +386,7 @@ inline ffi::String GetCodegenName(const std::string& composite_name) { inline int GetDeviceIndex(const IRModule& mod, const VDevice& vdevice) { ffi::Array vdevices = mod->global_infos["vdevice"]; for (int i = 0; i < static_cast(vdevices.size()); ++i) { - if (vdevices[i] == vdevice) { + if (vdevices[i].as() == vdevice) { return i; } } diff --git a/src/script/printer/relax/call.cc b/src/script/printer/relax/call.cc index 666b3839ea0e..62a6db1f03ed 100644 --- a/src/script/printer/relax/call.cc +++ b/src/script/printer/relax/call.cc @@ -217,7 +217,8 @@ ffi::Optional PrintToVDevice(const relax::Call& n, const AccessPath& n_ int dev_index = FindVDeviceIndexByTargetKind(vdev, d); kwargs_keys.push_back("dst_vdevice"); kwargs_values.push_back( - LiteralDoc::Str(dev_kind + ":" + std::to_string(dev_index), n_p->Attr("dst_vdevice"))); + LiteralDoc::Str(dev_kind + ":" + std::to_string(dev_index) + ":" + vdev->memory_scope, + n_p->Attr("dst_vdevice"))); } return Relax(d, "to_vdevice")->Call(args, kwargs_keys, kwargs_values); } diff --git a/src/script/printer/relax/struct_info.cc b/src/script/printer/relax/struct_info.cc index d6e2ac0f13f5..e597df64501d 100644 --- a/src/script/printer/relax/struct_info.cc +++ b/src/script/printer/relax/struct_info.cc @@ -126,8 +126,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) kwargs_keys.push_back("vdevice"); std::string dev_kind = n->vdevice.value()->target->kind->name; int dev_index = FindVDeviceIndexByTargetKind(n->vdevice.value(), d); - kwargs_values.push_back( - LiteralDoc::Str(dev_kind + ":" + std::to_string(dev_index), n_p->Attr("vdevice"))); + kwargs_values.push_back(LiteralDoc::Str( + dev_kind + ":" + std::to_string(dev_index) + ":" + n->vdevice.value()->memory_scope, + n_p->Attr("vdevice"))); } if (args.empty() && kwargs_keys.empty()) { return Relax(d, "Tensor"); diff --git a/src/script/printer/relax/utils.h b/src/script/printer/relax/utils.h index 7dddfaecbbe7..df8f6b2d470d 100644 --- a/src/script/printer/relax/utils.h +++ b/src/script/printer/relax/utils.h @@ -142,7 +142,7 @@ inline int FindVDeviceIndexByTargetKind(const VDevice& vdevice, const IRDocsifie int kind_index = 0; for (size_t i = 0; i < vdevices.size(); ++i) { auto vdev = Downcast(vdevices[i]); - if (vdev.same_as(vdevice)) { + if (vdev == vdevice) { return kind_index; } if (vdev->target->kind->name == vdevice->target->kind->name) { diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index b0d712b5acc7..7cc15a407ead 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -2164,6 +2164,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { } else { return "O"; } + return HasIfThenElse(stmt); }); } diff --git a/tests/python/relax/adreno/test_transform_annotate_custom_scope1.py b/tests/python/relax/adreno/test_transform_annotate_custom_scope1.py new file mode 100644 index 000000000000..8c629eecc8fc --- /dev/null +++ b/tests/python/relax/adreno/test_transform_annotate_custom_scope1.py @@ -0,0 +1,162 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm import relax +import tvm.testing +from tvm.script.parser import ir as I, relax as R, tir as T +from tvm.relax.transform.legalize_ops import adreno as legalize_adreno +from tvm.ir.module import IRModule +from tvm.relax.expr_functor import PyExprMutator, PyExprVisitor, mutator, visitor + + +@visitor +class ValidateScope(PyExprVisitor): # pylint: disable=abstract-method + def __init__(self, scope_info: dict) -> None: + self.scope_info = scope_info + self.matched = True + + def visit(self, mod: IRModule) -> None: + """Entry point""" + for _, func in mod.functions_items(): + if isinstance(func, relax.Function): + self.visit_expr(func) + return self.matched + + def visit_call_(self, call: relax.Call) -> None: # pylint: disable=arguments-renamed + if call.op.name == "relax.call_tir": + # if call.args[0].name_hint in self.scope_info: + for idx, arg in enumerate(call.args[1]): + arg_sinfo = arg.struct_info + assert isinstance( + arg_sinfo, relax.TensorStructInfo + ), f"Expected TensorStructInfo but git {type(arg_sinfo)}" + assert ( + arg_sinfo.vdevice.memory_scope + == self.scope_info[call.args[0].name_hint][0][idx] + ), f"Scope mismatched for argument {idx} in {call.args[0].name_hint}" + if isinstance(call.sinfo_args[0], relax.TensorStructInfo): + assert ( + call.sinfo_args[0].vdevice.memory_scope + == self.scope_info[call.args[0].name_hint][1][0] + ), f"Scope mismatched for return scope: {call.args[0].name_hint}" + else: + assert isinstance( + call.sinfo_args[0], relax.TupleStructInfo + ), f"Expected TupleStructInfo but git {type(call.sinfo_args[0])}" + for idx, sinfo in enumerate(call.sinfo_args[0].fields): + assert ( + sinfo.vdevice.memory_scope + == self.scope_info[call.args[0].name_hint][1][idx] + ), f"Scope mismatched for return scope for {idx} in {call.args[0].name_hint}" + + +def verify(mod, expected): + tgt = tvm.target.Target("opencl --device=adreno", host="llvm") + skip_ops = [ + "relax.nn.conv2d", + "relax.nn.max_pool2d", + "relax.nn.adaptive_avg_pool2d", + # "relax.nn.layer_norm", + ] + with tgt: + mod = tvm.tir.transform.BindTarget(tvm.target.Target.current(allow_none=False))(mod) + mod = tvm.relax.transform.DecomposeOpsForInference()(mod) + mod = tvm.relax.transform.FoldConstant()(mod) + desired_layouts = {"relax.nn.conv2d": ["NCHW4c", "OIHW4o", "NCHW4c"]} + mod = tvm.relax.transform.ConvertLayout(desired_layouts)(mod) + mod = tvm.relax.transform.Normalize()(mod) + mod = tvm.relax.transform.FoldConstant()(mod) + mod = tvm.relax.transform.LegalizeOps(skip_ops=skip_ops)(mod) + mod = tvm.relax.transform.AnnotateTIROpPattern()(mod) + mod = tvm.relax.backend.adreno.transform.AnnotateCustomMemoryScope(tgt)(mod) + # There is a possibility of some skipped ops above might not use 5D layouts. + mod = tvm.relax.transform.LegalizeOps()(mod) + mod = tvm.relax.transform.LegalizeOps( + {"relax.nn.conv2d": legalize_adreno.conv2d_NCHWc_OIHWo}, + )(mod) + # Lets get pattern info for newly legalized ops + mod = tvm.relax.transform.AnnotateTIROpPattern()(mod) + mod = tvm.relax.transform.FoldConstant()(mod) + mod = tvm.relax.transform.FuseOps()(mod) + mod = tvm.relax.transform.FuseTIR()(mod) + mod = tvm.relax.transform.DeadCodeElimination()(mod) + mod = tvm.relax.backend.adreno.transform.FoldVDeviceScopeChange()(mod) + mod = tvm.relax.transform.DeadCodeElimination()(mod) + mod = tvm.relax.transform.SpecializePrimFuncBasedOnCallSite()(mod) + mod = tvm.relax.transform.Normalize()(mod) + print(mod) + ValidateScope(expected).visit(mod) + +def test_conv2d_conv2d_fallback_to_buffer_conv2d(): + """ + layout_transform (NCHW->NCHW4c) + | <- texture + conv2d (1) <- textures as output + / \ + conv2d (2) conv2d (3) <- conv2d (2) emits texture, conv2d (3) emits buffer + \ / <- concat shouldn't support textures here + concatenation + | <- buffer + layout_transform (NCHW4c->NCHW) + """ + + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 32, 40, 40), "float32"), + w1: R.Tensor((96, 32, 2, 2), "float32"), + w2: R.Tensor((32, 96, 2, 2), "float32"), + w3: R.Tensor((5, 96, 2, 2), "float32"), + bias1: R.Tensor((1, 96, 1, 1), "float32"), + bias2: R.Tensor((1, 32, 1, 1), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv = R.nn.conv2d(x, w1, strides=[2, 2], out_dtype="float32") + gv1 = R.add(gv, bias1) + gv2 = R.nn.relu(gv1) + gv3 = R.nn.conv2d(gv2, w2, strides=[2, 2], out_dtype="float32") + gv4 = R.add(gv3, bias2) + gv5 = R.nn.relu(gv4) + gv6 = R.nn.conv2d(gv2, w3, strides=[2, 2], out_dtype="float32") + gv7 = R.concat((gv3, gv6), axis=1) + R.output(gv7) + return gv7 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "te_layout_transform2": (["global"], ["global.texture-weight"]), + "fused_conv2d_NCHWc_OIHWo_opencl_add_relu": ( + ["global.texture-weight", "global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform3": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo1_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform4": (["global"], ["global"]), + "conv2d_opencl": (["global", "global"], ["global"]), + "te_layout_transform5": (["global"], ["global"]), + "concatenate": (["global", "global"], ["global"]), + } + verify(Input, Expected) + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform_annotate_custom_scope.py b/tests/python/relax/test_transform_annotate_custom_scope.py new file mode 100644 index 000000000000..af923b49c1b6 --- /dev/null +++ b/tests/python/relax/test_transform_annotate_custom_scope.py @@ -0,0 +1,1032 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm import relax +import tvm.testing +from tvm.relax.transform import ConvertLayout, Normalize +from tvm.script.parser import ir as I, relax as R, tir as T +from tvm.relax.transform.legalize_ops import adreno as legalize_adreno +from tvm.ir.module import IRModule +from tvm.relax.expr_functor import PyExprMutator, PyExprVisitor, mutator, visitor + +@visitor +class ValidateScope(PyExprVisitor): # pylint: disable=abstract-method + def __init__(self, scope_info: dict) -> None: + self.scope_info = scope_info + self.matched = True + + def visit(self, mod: IRModule) -> None: + """Entry point""" + print("Mod:", mod["main"]) + for _, func in mod.functions_items(): + if isinstance(func, relax.Function): + self.visit_expr(func) + return self.matched + + + def visit_call_(self, call: relax.Call) -> None: # pylint: disable=arguments-renamed + if (call.op.name == "relax.call_tir"): + #if call.args[0].name_hint in self.scope_info: + for idx, arg in enumerate(call.args[1]): + arg_sinfo = arg.struct_info + assert isinstance(arg_sinfo, relax.TensorStructInfo + ), f"Expected TensorStructInfo but git {type(arg_sinfo)}" + assert (arg_sinfo.vdevice.memory_scope == self.scope_info[call.args[0].name_hint][0][idx] + ), f"Scope mispatched for argument {idx} in {call.args[0].name_hint}" + if isinstance(call.sinfo_args[0], relax.TensorStructInfo): + assert(call.sinfo_args[0].vdevice.memory_scope == self.scope_info[call.args[0].name_hint][1][0] + ), f"Scope mispatched for return scope: {call.args[0].name_hint}" + else: + assert isinstance(call.sinfo_args[0], relax.TupleStructInfo + ), f"Expected TupleStructInfo but git {type(call.sinfo_args[0])}" + for idx, sinfo in enumerate(call.sinfo_args[0].fields): + assert (sinfo.vdevice.memory_scope == self.scope_info[call.args[0].name_hint][1][idx] + ), f"Scope mispatched for return scope for {idx} in {call.args[0].name_hint}" + + +def verify(mod, expected): + tgt = tvm.target.Target("opencl --device=adreno", host="llvm") + with tgt: + mod = tvm.tir.transform.BindTarget(tvm.target.Target.current(allow_none=False))(mod) + mod = tvm.relax.transform.OptimizeBatchnorm()(mod) + mod = tvm.relax.transform.FoldConstant()(mod) + mod = tvm.relax.transform.DecomposeOpsForInference()(mod) + mod = tvm.relax.transform.FoldConstant()(mod) + desired_layouts = {"relax.nn.conv2d": ["NCHW4c", "OIHW4o", "NCHW4c"]} + mod = tvm.relax.transform.ConvertLayout(desired_layouts)(mod) + mod = tvm.relax.transform.Normalize()(mod) + mod = tvm.relax.transform.FoldConstant()(mod) + mod = tvm.relax.transform.LegalizeOps()(mod) + mod = tvm.relax.transform.LegalizeOps({"relax.nn.conv2d": legalize_adreno.conv2d_NCHWc_OIHWo})(mod) + mod = tvm.relax.transform.AnnotateTIROpPattern()(mod) + mod = tvm.relax.transform.FoldConstant()(mod) + mod = tvm.relax.transform.FuseOps()(mod) + mod = tvm.relax.transform.FuseTIR()(mod) + mod = tvm.relax.transform.DeadCodeElimination()(mod) + mod = tvm.relax.transform.AnnotateCustomMemoryScope(tgt)(mod) + mod = Normalize()(mod) + + ValidateScope(expected).visit(mod) + +def test_conv2d(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 64, 56, 56), "float32"), w: R.Tensor((32, 64, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 32, 54, 54), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + R.output(gv) + return gv + + + Expected = { + 'te_layout_transform': (["global"], ["global.texture-nhwc"]), + 'te_layout_transform1': (["global"], ["global.texture-weight"]), + 'conv2d_NCHWc_OIHWo': (["global.texture-nhwc", "global.texture-weight"], ["global"]), + 'te_layout_transform2': (["global"], ["global"]), + } + + verify(Input, Expected) + +def test_conv2d_NCHW_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d( + x, + w, + data_layout="NCHW", + kernel_layout="OIHW", + out_dtype="float32", + ) + R.output(gv) + return gv + + Expected = { + 'te_layout_transform': (["global"], ["global.texture-weight"]), + 'te_layout_transform1': (["global"], ["global.texture-weight"]), + 'conv2d_NCHWc_OIHWo': (["global.texture-weight", "global.texture-weight"], ["global"]), + 'te_layout_transform2': (["global"], ["global"]), + } + + verify(Input, Expected) + + +def test_conv2d_NHWC_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 28, 28, 16), "float32"), w: R.Tensor((4, 3, 3, 16), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 26, 26, 4), "float32") = R.nn.conv2d( + x, + w, + data_layout="NHWC", + kernel_layout="OHWI", + out_dtype="float32", + ) + R.output(gv) + return gv + + Expected = { + 'te_layout_transform': (["global"], ["global.texture-weight"]), + 'te_layout_transform1': (["global"], ["global.texture-weight"]), + 'conv2d_NCHWc_OIHWo': (["global.texture-weight", "global.texture-weight"], ["global"]), + 'te_layout_transform2': (["global"], ["global"]), + } + + verify(Input, Expected) + +def _test_conv2d_symbolic_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor("float32", ndim=4), w: R.Tensor("float32", ndim=4) + ) -> R.Tensor("float32", ndim=4): + with R.dataflow(): + N, C, H, W = T.int64(), T.int64(16), T.int64(), T.int64() + Nw, Cw, Hw, Ww = T.int64(4), T.int64(16), T.int64(), T.int64() + lv0 = R.match_cast(x, R.Tensor((N, C, H, W), "float32")) + lv1 = R.match_cast(w, R.Tensor((Nw, Cw, Hw, Ww), "float32")) + gv: R.Tensor( + (N, T.int64(4), H + T.int64(1) - Hw, W + T.int64(1) - Ww), "float32" + ) = R.nn.conv2d(lv0, lv1, out_dtype="float32") + R.output(gv) + return gv + + Expected = { + 'te_layout_transform': (["global"], ["global.texture-weight"]), + 'te_layout_transform1': (["global"], ["global.texture-weight"]), + 'conv2d_NCHWc_OIHWo': (["global.texture-weight", "global.texture-weight"], ["global"]), + 'te_layout_transform2': (["global"], ["global"]), + } + + verify(Input, Expected) + + +def test_conv2d_relu_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv) + R.output(gv2) + return gv2 + + Expected = { + 'te_layout_transform': (["global"], ["global.texture-weight"]), + 'te_layout_transform1': (["global"], ["global.texture-weight"]), + 'fused_conv2d_NCHWc_OIHWo_relu': (["global.texture-weight", "global.texture-weight"], ["global"]), + 'te_layout_transform2': (["global"], ["global"]), + } + + verify(Input, Expected) + + +def test_relu_conv2d_relu_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + x0: R.Tensor((2, 16, 28, 28), "float32") = R.nn.relu(x) + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x0, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv) + R.output(gv2) + return gv2 + + Expected = { + 'relu': (["global"], ["global"]), + 'te_layout_transform': (["global"], ["global.texture-weight"]), + 'te_layout_transform1': (["global"], ["global.texture-weight"]), + 'fused_conv2d_NCHWc_OIHWo_relu1': (["global.texture-weight", "global.texture-weight"], ["global"]), + 'te_layout_transform2': (["global"], ["global"]), + } + + verify(Input, Expected) + + +def test_conv2d_relu_tanh_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv) + gv3: R.Tensor((2, 4, 26, 26), "float32") = R.tanh(gv2) + R.output(gv3) + return gv3 + + Expected = { + 'te_layout_transform': (["global"], ["global.texture-weight"]), + 'te_layout_transform1': (["global"], ["global.texture-weight"]), + 'fused_conv2d_NCHWc_OIHWo_relu_tir_tanh': (["global.texture-weight", "global.texture-weight"], ["global"]), + 'te_layout_transform2': (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_add_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), + w: R.Tensor((4, 16, 3, 3), "float32"), + bias: R.Tensor((2, 4, 26, 26), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.add(gv, bias) + R.output(gv2) + return gv2 + + Expected = { + 'te_layout_transform': (["global"], ["global.texture-weight"]), + 'te_layout_transform1': (["global"], ["global.texture-weight"]), + 'te_layout_transform2': (["global"], ["global.texture-weight"]), + 'fused_conv2d_NCHWc_OIHWo_add': (["global.texture-weight", "global.texture-weight", "global.texture-weight"], ["global"]), + 'te_layout_transform3': (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_fma_relu_conv2d_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 4, 28, 28), "float32"), + w: R.Tensor((4, 4, 3, 3), "float32"), + scale: R.Tensor((2, 4, 26, 26), dtype="float32"), + bias: R.Tensor((2, 4, 26, 26), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.ewise_fma(gv, scale, bias) + gv3: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv2) + gv4: R.Tensor((2, 4, 24, 24), "float32") = R.nn.conv2d(gv3, w, out_dtype="float32") + R.output(gv4) + return gv4 + + Expected = { + 'te_layout_transform': (["global"], ["global.texture-weight"]), + 'te_layout_transform1': (["global"], ["global.texture-weight"]), + 'conv2d_NCHWc_OIHWo': (["global.texture-weight", "global.texture-weight"], ["global"]), + 'relu': (["global"], ["global"]), + 'te_layout_transform2': (["global"], ["global"]), + 'te_layout_transform3': (["global"], ["global.texture-weight"]), + 'te_layout_transform4': (["global"], ["global.texture-weight"]), + 'conv2d_NCHWc_OIHWo1': (["global.texture-weight", "global.texture-weight"], ["global"]), + 'te_layout_transform5': (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_sum_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=2): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4), "float32") = R.sum(gv, axis=[2, 3]) + R.output(gv2) + return gv2 + + Expected = { + 'te_layout_transform': (["global"], ["global.texture-weight"]), + 'te_layout_transform1': (["global"], ["global.texture-weight"]), + 'conv2d_NCHWc_OIHWo': (["global.texture-weight", "global.texture-weight"], ["global"]), + 'sum': (["global"], ["global"]), + 'te_layout_transform2': (["global"], ["global"]), + } + verify(Input, Expected) + +def test_conv2d_sum_keepdims_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=2): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 1, 1), "float32") = R.sum(gv, axis=[2, 3], keepdims=True) + R.output(gv2) + return gv2 + + Expected = { + 'te_layout_transform': (["global"], ["global.texture-weight"]), + 'te_layout_transform1': (["global"], ["global.texture-weight"]), + 'conv2d_NCHWc_OIHWo': (["global.texture-weight", "global.texture-weight"], ["global"]), + 'sum': (["global"], ["global"]), + 'te_layout_transform2': (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_sum_reduce_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=2): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 26), "float32") = R.sum(gv, axis=[1, 2]) + R.output(gv2) + return gv2 + + Expected = { + 'te_layout_transform': (["global"], ["global.texture-weight"]), + 'te_layout_transform1': (["global"], ["global.texture-weight"]), + 'conv2d_NCHWc_OIHWo': (["global.texture-weight", "global.texture-weight"], ["global"]), + 'sum': (["global"], ["global"]), + 'te_layout_transform2': (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_transpose_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((26, 26, 4, 2), "float32") = R.permute_dims(gv, axes=[3, 2, 1, 0]) + R.output(gv2) + return gv2 + + Expected = { + 'te_layout_transform': (["global"], ["global.texture-weight"]), + 'te_layout_transform1': (["global"], ["global.texture-weight"]), + 'conv2d_NCHWc_OIHWo': (["global.texture-weight", "global.texture-weight"], ["global"]), + 'te_layout_transform2': (["global"], ["global"]), + 'transpose': (["global"], ["global"]), + } + verify(Input, Expected) + +def test_conv2d_expand_dims_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=6): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 1, 4, 1, 26, 26), "float32") = R.expand_dims(gv, axis=(-3, 1)) + R.output(gv2) + return gv2 + + Expected = { + 'te_layout_transform': (["global"], ["global.texture-weight"]), + 'te_layout_transform1': (["global"], ["global.texture-weight"]), + 'conv2d_NCHWc_OIHWo': (["global.texture-weight", "global.texture-weight"], ["global"]), + 'te_layout_transform2': (["global"], ["global"]), + 'expand_dims': (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_squeeze_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((1, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=3): + with R.dataflow(): + gv: R.Tensor((1, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((4, 26, 26), "float32") = R.squeeze(gv, axis=[0]) + R.output(gv2) + return gv2 + + Expected = { + 'te_layout_transform': (["global"], ["global.texture-weight"]), + 'te_layout_transform1': (["global"], ["global.texture-weight"]), + 'conv2d_NCHWc_OIHWo': (["global.texture-weight", "global.texture-weight"], ["global"]), + 'te_layout_transform2': (["global"], ["global"]), + 'squeeze': (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_strided_slice_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 2, 9, 7), dtype="float32") = R.strided_slice( + gv, begin=[0, 0, 0], end=[4, 26, 26], strides=[2, 3, 4], axes=[1, 2, 3] + ) + R.output(gv2) + return gv2 + + Expected = { + 'te_layout_transform': (["global"], ["global.texture-weight"]), + 'te_layout_transform1': (["global"], ["global.texture-weight"]), + 'conv2d_NCHWc_OIHWo': (["global.texture-weight", "global.texture-weight"], ["global"]), + 'te_layout_transform2': (["global"], ["global"]), + 'strided_slice': (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_relu_concat_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv) + gv3: R.Tensor((2, 8, 26, 26), "float32") = R.concat((gv, gv2), axis=1) + R.output(gv3) + return gv3 + + Expected = { + 'te_layout_transform': (["global"], ["global.texture-weight"]), + 'te_layout_transform1': (["global"], ["global.texture-weight"]), + 'conv2d_NCHWc_OIHWo': (["global.texture-weight", "global.texture-weight"], ["global.texture-weight"]), + 'fused_relu_concatenate': (["global.texture-weight"], ["global"]), + 'te_layout_transform2': (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_relu_concat_split_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main(x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32")): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv) + gv3: R.Tensor((2, 8, 26, 26), "float32") = R.concat((gv, gv2), axis=1) + gv4 = R.split(gv3, indices_or_sections=2, axis=1) + R.output(gv4) + return gv4 + + Expected = { + 'te_layout_transform': (["global"], ["global.texture-weight"]), + 'te_layout_transform1': (["global"], ["global.texture-weight"]), + 'conv2d_NCHWc_OIHWo': (["global.texture-weight", "global.texture-weight"], ["global.texture-weight"]), + 'fused_relu_concatenate_split': (["global.texture-weight"], ["global", "global"]), + 'te_layout_transform2': (["global"], ["global"]), + 'te_layout_transform3': (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_relu_concat_split_transpose_concat_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main(x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32")): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv) + gv3: R.Tensor((2, 8, 26, 26), "float32") = R.concat((gv, gv2), axis=1) + gv4 = R.split(gv3, indices_or_sections=2, axis=1) + gv5: R.Tensor((26, 26, 4, 2), "float32") = R.permute_dims(gv4[0], axes=[3, 2, 1, 0]) + gv6: R.Tensor((26, 26, 4, 2), "float32") = R.permute_dims(gv4[1], axes=[3, 2, 1, 0]) + gv7: R.Tensor((26, 26, 8, 2), "float32") = R.concat((gv5, gv6), axis=2) + R.output(gv7) + return gv7 + + Expected = { + 'te_layout_transform': (["global"], ["global.texture-weight"]), + 'te_layout_transform1': (["global"], ["global.texture-weight"]), + 'conv2d_NCHWc_OIHWo': (["global.texture-weight", "global.texture-weight"], ["global.texture-weight"]), + 'fused_relu_concatenate_split': (["global.texture-weight"], ["global", "global"]), + 'te_layout_transform2': (["global"], ["global"]), + 'te_layout_transform3': (["global"], ["global"]), + 'fused_transpose_transpose1_concatenate1': (["global", "global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_maxpool2d_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2 = R.nn.max_pool2d( + gv, + pool_size=[2, 2], + strides=[2, 2], + padding=[0, 0], + layout="NCHW", + out_layout="NCHW", + ) + R.output(gv2) + return gv2 + + Expected = { + 'te_layout_transform': (["global"], ["global.texture-weight"]), + 'te_layout_transform1': (["global"], ["global.texture-weight"]), + 'conv2d_NCHWc_OIHWo': (["global.texture-weight", "global.texture-weight"], ["global.texture-weight"]), + 'max_pool2d': (["global.texture-weight"], ["global"]), + 'te_layout_transform2': (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_avgpool2d_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2 = R.nn.adaptive_avg_pool2d(gv, output_size=[13, 13], layout="NCHW") + R.output(gv2) + return gv2 + + Expected = { + 'te_layout_transform': (["global"], ["global.texture-weight"]), + 'te_layout_transform1': (["global"], ["global.texture-weight"]), + 'conv2d_NCHWc_OIHWo': (["global.texture-weight", "global.texture-weight"], ["global.texture-weight"]), + 'adaptive_avg_pool2d': (["global.texture-weight"], ["global"]), + 'te_layout_transform2': (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_softmax_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2 = R.nn.softmax(gv, axis=1) + R.output(gv2) + return gv2 + + Expected = { + 'te_layout_transform': (["global"], ["global.texture-weight"]), + 'te_layout_transform1': (["global"], ["global.texture-weight"]), + 'conv2d_NCHWc_OIHWo': (["global.texture-weight", "global.texture-weight"], ["global"]), + 'te_layout_transform2': (["global"], ["global"]), + 'softmax': (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_layernorm_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), + w: R.Tensor((4, 16, 3, 3), "float32"), + gamma: R.Tensor((26, 26), dtype="float32"), + beta: R.Tensor((26, 26), dtype="float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.nn.layer_norm( + gv, gamma, beta, axes=[-2, -1] + ) + R.output(gv2) + return gv2 + + Expected = { + 'te_layout_transform': (["global"], ["global.texture-weight"]), + 'te_layout_transform1': (["global"], ["global.texture-weight"]), + 'conv2d_NCHWc_OIHWo': (["global.texture-weight", "global.texture-weight"], ["global.texture-weight"]), + 'layer_norm': (["global.texture-weight", "global", "global"], ["global"]), + 'te_layout_transform2': (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_binary_broadcast_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), + w: R.Tensor((4, 16, 3, 3), "float32"), + bias: R.Tensor((26, 26), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.add(gv, bias) + R.output(gv2) + return gv2 + + Expected = { + 'te_layout_transform': (["global"], ["global.texture-weight"]), + 'te_layout_transform1': (["global"], ["global.texture-weight"]), + 'conv2d_NCHWc_OIHWo': (["global.texture-weight", "global.texture-weight"], ["global"]), + 'te_layout_transform2': (["global"], ["global"]), + 'add': (["global", "global"], ["global"]), + } + verify(Input, Expected) + + +def test_binary_ewise_scalar_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.add(gv, R.const(1, "float32")) + R.output(gv2) + return gv2 + + Expected = { + 'te_layout_transform': (["global"], ["global.texture-weight"]), + 'te_layout_transform1': (["global"], ["global.texture-weight"]), + 'fused_conv2d_NCHWc_OIHWo_add': (["global.texture-weight", "global.texture-weight"], ["global"]), + 'te_layout_transform2': (["global"], ["global"]), + } + verify(Input, Expected) + +def test_residual_block(): + """ + - some kind of residual block followed by convolution to have texture after residual block + - scalar data type verification which should be mapped to global memory scope + layout_transform (NCHW->NCHW4c) + | <- buffer + conv2d (1) <- to get textures as output + / \ + conv2d (2) | + \ / + add <- add should be fused into conv2d (2) + multiply to scalar <- buffer to the input of multiply scalar value + relu + | <- texture in intermediate tensor + conv2d (3) + relu + | <- buffer + layout_transform (NCHW4c->NCHW) + """ + + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 32, 40, 40), "float32"), + w1: R.Tensor((32, 32, 2, 2), "float32"), + w2: R.Tensor((32, 32, 1, 1), "float32"), + w3: R.Tensor((32, 32, 2, 2), "float32"), + bias: R.Tensor((1, 32, 1, 1), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv = R.nn.conv2d(x, w1, strides=[2, 2], out_dtype="float32") + gv1 = R.add(gv, bias) + gv2 = R.nn.relu(gv1) + gv3 = R.nn.conv2d(gv2, w2, strides=[1, 1], out_dtype="float32") + bias_1 = R.multiply(bias, R.const(0.15, "float32")) + gv4 = R.add(gv3, bias_1) + gv5 = R.nn.relu(gv4) + gv6 = R.nn.conv2d(gv5, w3, strides=[2, 2], out_dtype="float32") + gv7 = R.nn.relu(gv6) + R.output(gv7) + return gv7 + + Expected = { + 'te_layout_transform': (["global"], ["global.texture-weight"]), + 'te_layout_transform1': (["global"], ["global.texture-weight"]), + 'te_layout_transform2': (["global"], ["global.texture-weight"]), + 'fused_conv2d_NCHWc_OIHWo_add_relu': (["global.texture-weight", "global.texture-weight", "global.texture-weight"], ["global.texture-weight"]), + 'te_layout_transform3': (["global"], ["global.texture-weight"]), + 'multiply': (["global"], ["global"]), + 'te_layout_transform4': (["global"], ["global.texture-weight"]), + 'fused_conv2d_NCHWc_OIHWo1_add1_relu1': (["global.texture-weight", "global.texture-weight", "global.texture-weight"], ["global.texture-weight"]), + 'te_layout_transform5': (["global"], ["global.texture-weight"]), + 'fused_conv2d_NCHWc_OIHWo2_relu2': (["global.texture-weight", "global.texture-weight"], ["global"]), + 'te_layout_transform6': (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_conv2d_fallback_to_buffer_conv2d(): + """ + layout_transform (NCHW->NCHW4c) + | <- texture + conv2d (1) <- textures as output + / \ + conv2d (2) conv2d (3) <- conv2d (2) emits texture, conv2d (3) emits buffer + \ / <- concat shouldn't support textures here + concatenation + | <- buffer + layout_transform (NCHW4c->NCHW) + """ + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 32, 40, 40), "float32"), + w1: R.Tensor((96, 32, 2, 2), "float32"), + w2: R.Tensor((32, 96, 2, 2), "float32"), + w3: R.Tensor((5, 96, 2, 2), "float32"), + bias1: R.Tensor((1, 96, 1, 1), "float32"), + bias2: R.Tensor((1, 32, 1, 1), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv = R.nn.conv2d(x, w1, strides=[2, 2], out_dtype="float32") + gv1 = R.add(gv, bias1) + gv2 = R.nn.relu(gv1) + gv3 = R.nn.conv2d(gv2, w2, strides=[2, 2], out_dtype="float32") + gv4 = R.add(gv3, bias2) + gv5 = R.nn.relu(gv4) + gv6 = R.nn.conv2d(gv2, w3, strides=[2, 2], out_dtype="float32") + gv7 = R.concat((gv3, gv6), axis=1) + R.output(gv7) + return gv7 + + Expected = { + 'te_layout_transform': (["global"], ["global.texture-weight"]), + 'te_layout_transform1': (["global"], ["global.texture-weight"]), + 'te_layout_transform2': (["global"], ["global.texture-weight"]), + 'fused_conv2d_NCHWc_OIHWo_add_relu': (["global.texture-weight", "global.texture-weight", "global.texture-weight"], ["global"]), + 'te_layout_transform3': (["global"], ["global.texture-weight"]), + 'conv2d_NCHWc_OIHWo1': (["global.texture-weight", "global.texture-weight"], ["global"]), + 'te_layout_transform4': (["global"], ["global"]), + 'conv2d': (["global", "global"], ["global"]), + 'te_layout_transform5': (["global"], ["global"]), + 'concatenate': (["global", "global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_conv2d_conv2d_concat(): + """ + layout_transform (NCHW->NCHW4c) + | <- texture + conv2d (1) <- textures as output + / \ + conv2d (2) conv2d (3) + \ / <- concat does support textures here + concatenation + | <- buffer + layout_transform (NCHW4c->NCHW) + """ + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 32, 40, 40), "float32"), + w1: R.Tensor((96, 32, 2, 2), "float32"), + w2: R.Tensor((32, 96, 2, 2), "float32"), + w3: R.Tensor((8, 96, 2, 2), "float32"), + bias1: R.Tensor((1, 96, 1, 1), "float32"), + bias2: R.Tensor((1, 32, 1, 1), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv = R.nn.conv2d(x, w1, strides=[2, 2], out_dtype="float32") + gv1 = R.add(gv, bias1) + gv2 = R.nn.relu(gv1) + gv3 = R.nn.conv2d(gv2, w2, strides=[2, 2], out_dtype="float32") + gv4 = R.add(gv3, bias2) + gv5 = R.nn.relu(gv4) + gv6 = R.nn.conv2d(gv2, w3, strides=[2, 2], out_dtype="float32") + gv7 = R.concat((gv3, gv6), axis=1) + R.output(gv7) + return gv7 + + Expected = { + 'te_layout_transform': (["global"], ["global.texture-weight"]), + 'te_layout_transform1': (["global"], ["global.texture-weight"]), + 'te_layout_transform2': (["global"], ["global.texture-weight"]), + 'fused_conv2d_NCHWc_OIHWo_add_relu': (["global.texture-weight", "global.texture-weight", "global.texture-weight"], ["global.texture-weight"]), + 'te_layout_transform3': (["global"], ["global.texture-weight"]), + 'conv2d_NCHWc_OIHWo1': (["global.texture-weight", "global.texture-weight"], ["global.texture-weight"]), + 'te_layout_transform4': (["global"], ["global.texture-weight"]), + 'conv2d_NCHWc_OIHWo2': (["global.texture-weight", "global.texture-weight"], ["global.texture-weight"]), + 'concatenate': (["global.texture-weight", "global.texture-weight"], ["global"]), + 'te_layout_transform5': (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_pooling_branching_texture_params(): + """ + Verification of the pooling and many branches having textures + layout_transform (NCHW->NCHW4c) + | <- texture + conv2d (0) <- to get textures + | <- textures + pooling + / \ \ <- textures + conv2d (1) conv2d (2) conv2d (3) + \ / | + add | <- to have the only one output, will be fused + \ / + add <- to have the only one output, will be fused + | <- buffer + layout_transform (NCHW4c->NCHW) + """ + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 32, 40, 40), "float32"), + w1: R.Tensor((32, 32, 1, 1), "float32"), + w2: R.Tensor((32, 32, 2, 2), "float32"), + w3: R.Tensor((32, 32, 1, 1), "float32"), + w4: R.Tensor((32, 32, 2, 2), "float32"), + bias1: R.Tensor((1, 32, 1, 1), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv = R.nn.conv2d(x, w1, strides=[1, 1], out_dtype="float32") + gv1 = R.nn.max_pool2d(gv, pool_size=[2, 2], strides=[2, 2]) + gv2 = R.nn.conv2d(gv1, w2, padding=[0, 0, 1, 1], strides=[1, 1], out_dtype="float32") + gv3 = R.add(gv2, bias1) + gv4 = R.nn.relu(gv3) + gv5 = R.nn.conv2d(gv1, w3, padding=[0, 0, 0, 0], strides=[1, 1], out_dtype="float32") + gv6 = R.nn.conv2d(gv1, w4, padding=[0, 1, 1, 0], strides=[1, 1], out_dtype="float32") + gv7 = R.nn.relu(gv6) + gv8 = R.add(gv2, gv5) + gv9 = R.add(gv8, gv6) + R.output(gv9) + return gv9 + + Expected = { + 'te_layout_transform': (["global"], ["global.texture-weight"]), + 'te_layout_transform1': (["global"], ["global.texture-weight"]), + 'conv2d_NCHWc_OIHWo': (["global.texture-weight", "global.texture-weight"], ["global.texture-weight"]), + 'max_pool2d': (["global.texture-weight"], ["global.texture-weight"]), + 'te_layout_transform2': (["global"], ["global.texture-weight"]), + 'te_layout_transform3': (["global"], ["global.texture-weight"]), + 'conv2d_NCHWc_OIHWo2': (["global.texture-weight", "global.texture-weight"], ["global.texture-weight"]), + 'fused_conv2d_NCHWc_OIHWo1_add': (["global.texture-weight", "global.texture-weight", "global.texture-weight"], ["global.texture-weight"]), + 'te_layout_transform4': (["global"], ["global.texture-weight"]), + 'fused_conv2d_NCHWc_OIHWo3_add1': (["global.texture-weight", "global.texture-weight", "global.texture-weight"], ["global"]), + 'te_layout_transform5': (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_injective_inputs1(): + """ + Input + / \ + / | + | / + conv2d (1) / + | / + conv2d (2) mean / + / \ / + | | \ / + | | (3) add + | | | + | \ / + \ mul + \ / + add + + """ + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((1, 4, 40, 40), "float32"), + w1: R.Tensor((4, 4, 3, 3), "float32"), + w2: R.Tensor((4, 4, 3, 3), "float32"), + w3: R.Tensor((4, 4, 3, 3), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + mean = R.mean(x, axis=1, keepdims=True) + conv1 = R.nn.conv2d(x, w1, padding=[1, 1, 1, 1], strides=[1, 1], out_dtype="float32") + conv2 = R.nn.conv2d(conv1, w2, padding=[1, 1, 1, 1], strides=[1, 1], out_dtype="float32") + ad3 = R.add(conv1, conv2) + ad1 = R.add(mean, conv1) + ad2 = R.multiply(ad1, conv1) + gv = R.add(ad3, ad2) + R.output(gv) + return gv + + Expected = { + 'te_layout_transform': (["global"], ["global.texture-weight"]), + 'te_layout_transform1': (["global"], ["global.texture-weight"]), + 'conv2d_NCHWc_OIHWo': (["global.texture-weight", "global.texture-weight"], ["global.texture-weight"]), + 'te_layout_transform3': (["global"], ["global"]), + 'fused_mean_add1': (["global", "global"], ["global"]), + 'te_layout_transform2': (["global"], ["global.texture-weight"]), + 'te_layout_transform4': (["global"], ["global.texture-weight"]), + 'fused_conv2d_NCHWc_OIHWo1_add_multiply_add2': (["global.texture-weight", "global.texture-weight", "global.texture-weight"], ["global"]), + 'te_layout_transform5': (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_injective_nwo_inputs2(): + """ + Input + / \ + | \ + conv2d \ + | / + conv2d mean / + / \ / + add | \ | + | | \ | + | | \ / + | | (3) add + | | | + | \ / + | \ / + \ mul + \ / + add + + """ + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((1, 4, 40, 40), "float32"), + w1: R.Tensor((4, 4, 3, 3), "float32"), + w2: R.Tensor((4, 4, 3, 3), "float32"), + w3: R.Tensor((4, 4, 3, 3), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + mean = R.mean(x, axis=1, keepdims=True) + conv1 = R.nn.conv2d(x, w1, padding=[1, 1, 1, 1], strides=[1, 1], out_dtype="float32") + conv2 = R.nn.conv2d(conv1, w2, padding=[1, 1, 1, 1], strides=[1, 1], out_dtype="float32") + ad3 = R.add(conv1, conv2) + ad1 = R.add(mean, conv1) + ad2 = R.multiply(ad1, conv2) + gv = R.add(ad2, ad3) + R.output(gv) + return gv + + Expected = { + 'te_layout_transform': (["global"], ["global.texture-weight"]), + 'te_layout_transform1': (["global"], ["global.texture-weight"]), + 'conv2d_NCHWc_OIHWo': (["global.texture-weight", "global.texture-weight"], ["global.texture-weight"]), + 'te_layout_transform3': (["global"], ["global"]), + 'fused_mean_add1': (["global", "global"], ["global"]), + 'te_layout_transform2': (["global"], ["global.texture-weight"]), + 'te_layout_transform4': (["global"], ["global.texture-weight"]), + 'fused_conv2d_NCHWc_OIHWo1_add_multiply_add2': (["global.texture-weight", "global.texture-weight", "global.texture-weight"], ["global"]), + 'te_layout_transform5': (["global"], ["global"]), + } + verify(Input, Expected) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform_convert_layout.py b/tests/python/relax/test_transform_convert_layout.py index 262e37b91b1b..83b81a6898a7 100644 --- a/tests/python/relax/test_transform_convert_layout.py +++ b/tests/python/relax/test_transform_convert_layout.py @@ -206,10 +206,9 @@ def main( lv2: R.Tensor((N, H, W, C), dtype="float32") = R.match_cast( lv0, R.Tensor((N, H, W, C), dtype="float32") ) - lv3: R.Tensor((N, C, H, W), dtype="float32") = R.permute_dims( - lv2, axes=[0, 3, 1, 2] - ) - gv: R.Tensor(dtype="float32", ndim=4) = R.add(lv3, w) + lv3: R.Tensor(dtype="float32", ndim=4) = R.permute_dims(w, axes=[0, 2, 3, 1]) + lv4: R.Tensor(dtype="float32", ndim=4) = R.add(lv2, lv3) + gv: R.Tensor(dtype="float32", ndim=4) = R.permute_dims(lv4, axes=[0, 3, 1, 2]) R.output(gv) return gv @@ -4585,5 +4584,413 @@ def main( verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) +def test_conv2d_conv2d_conv2d_concat(): + """ + layout_transform (NCHW->NCHW4c) + | <- texture + conv2d (1) <- textures as output + / \ + conv2d (2) conv2d (3) + \ / <- concat does support textures here + concatenation + | <- buffer + layout_transform (NCHW4c->NCHW) + """ + + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 32, 40, 40), "float32"), + w1: R.Tensor((96, 32, 2, 2), "float32"), + w2: R.Tensor((32, 96, 2, 2), "float32"), + w3: R.Tensor((8, 96, 2, 2), "float32"), + bias1: R.Tensor((1, 96, 1, 1), "float32"), + bias2: R.Tensor((1, 32, 1, 1), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv = R.nn.conv2d(x, w1, strides=[2, 2], out_dtype="float32") + gv1 = R.add(gv, bias1) + gv2 = R.nn.relu(gv1) + gv3 = R.nn.conv2d(gv2, w2, strides=[2, 2], out_dtype="float32") + gv4 = R.add(gv3, bias2) + gv5 = R.nn.relu(gv4) + gv6 = R.nn.conv2d(gv2, w3, strides=[2, 2], out_dtype="float32") + gv7 = R.concat((gv3, gv6), axis=1) + R.output(gv7) + return gv7 + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 32, 40, 40), dtype="float32"), + w1: R.Tensor((96, 32, 2, 2), dtype="float32"), + w2: R.Tensor((32, 96, 2, 2), dtype="float32"), + w3: R.Tensor((8, 96, 2, 2), dtype="float32"), + bias1: R.Tensor((1, 96, 1, 1), dtype="float32"), + bias2: R.Tensor((1, 32, 1, 1), dtype="float32"), + ) -> R.Tensor((2, 40, 10, 10), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 8, 40, 40, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((24, 32, 2, 2, 4), dtype="float32") = R.layout_transform( + w1, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 24, 20, 20, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[2, 2], + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv2: R.Tensor((1, 24, 1, 1, 4), dtype="float32") = R.layout_transform( + bias1, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + gv1: R.Tensor((2, 24, 20, 20, 4), dtype="float32") = R.add(gv, lv2) + gv2: R.Tensor((2, 24, 20, 20, 4), dtype="float32") = R.nn.relu(gv1) + lv3: R.Tensor((8, 96, 2, 2, 4), dtype="float32") = R.layout_transform( + w2, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv3: R.Tensor((2, 8, 10, 10, 4), dtype="float32") = R.nn.conv2d( + gv2, + lv3, + strides=[2, 2], + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv4: R.Tensor((1, 8, 1, 1, 4), dtype="float32") = R.layout_transform( + bias2, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + gv4: R.Tensor((2, 8, 10, 10, 4), dtype="float32") = R.add(gv3, lv4) + gv5: R.Tensor((2, 8, 10, 10, 4), dtype="float32") = R.nn.relu(gv4) + lv5: R.Tensor((2, 96, 2, 2, 4), dtype="float32") = R.layout_transform( + w3, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv6: R.Tensor((2, 2, 10, 10, 4), dtype="float32") = R.nn.conv2d( + gv2, + lv5, + strides=[2, 2], + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv6: R.Tensor((2, 10, 10, 10, 4), dtype="float32") = R.concat((gv3, gv6), axis=1) + gv7: R.Tensor((2, 40, 10, 10), dtype="float32") = R.layout_transform( + lv6, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + R.output(gv7) + return gv7 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + + +def test_conv2d_conv2d_callback_to_buffer_conv2d_concat(): + """ + layout_transform (NCHW->NCHW4c) + | <- texture + conv2d (1) <- textures as output + / \ + conv2d (2) conv2d (3) <- conv2d (2) emits texture, conv2d (3) emits buffer + \ / <- concat shouldn't support textures here + concatenation + | <- buffer + layout_transform (NCHW4c->NCHW) + """ + + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 32, 40, 40), "float32"), + w1: R.Tensor((96, 32, 2, 2), "float32"), + w2: R.Tensor((32, 96, 2, 2), "float32"), + w3: R.Tensor((5, 96, 2, 2), "float32"), + bias1: R.Tensor((1, 96, 1, 1), "float32"), + bias2: R.Tensor((1, 32, 1, 1), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv = R.nn.conv2d(x, w1, strides=[2, 2], out_dtype="float32") + gv1 = R.add(gv, bias1) + gv2 = R.nn.relu(gv1) + gv3 = R.nn.conv2d(gv2, w2, strides=[2, 2], out_dtype="float32") + gv4 = R.add(gv3, bias2) + gv5 = R.nn.relu(gv4) + gv6 = R.nn.conv2d(gv2, w3, strides=[2, 2], out_dtype="float32") + gv7 = R.concat((gv3, gv6), axis=1) + R.output(gv7) + return gv7 + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 32, 40, 40), dtype="float32"), + w1: R.Tensor((96, 32, 2, 2), dtype="float32"), + w2: R.Tensor((32, 96, 2, 2), dtype="float32"), + w3: R.Tensor((5, 96, 2, 2), dtype="float32"), + bias1: R.Tensor((1, 96, 1, 1), dtype="float32"), + bias2: R.Tensor((1, 32, 1, 1), dtype="float32"), + ) -> R.Tensor((2, 37, 10, 10), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 8, 40, 40, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((24, 32, 2, 2, 4), dtype="float32") = R.layout_transform( + w1, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 24, 20, 20, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[2, 2], + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv2: R.Tensor((1, 24, 1, 1, 4), dtype="float32") = R.layout_transform( + bias1, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + gv1: R.Tensor((2, 24, 20, 20, 4), dtype="float32") = R.add(gv, lv2) + gv2: R.Tensor((2, 24, 20, 20, 4), dtype="float32") = R.nn.relu(gv1) + lv3: R.Tensor((8, 96, 2, 2, 4), dtype="float32") = R.layout_transform( + w2, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv3: R.Tensor((2, 8, 10, 10, 4), dtype="float32") = R.nn.conv2d( + gv2, + lv3, + strides=[2, 2], + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv4: R.Tensor((1, 8, 1, 1, 4), dtype="float32") = R.layout_transform( + bias2, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + gv4: R.Tensor((2, 8, 10, 10, 4), dtype="float32") = R.add(gv3, lv4) + gv5: R.Tensor((2, 8, 10, 10, 4), dtype="float32") = R.nn.relu(gv4) + lv5: R.Tensor((2, 96, 20, 20), dtype="float32") = R.layout_transform( + gv2, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + gv6: R.Tensor((2, 5, 10, 10), dtype="float32") = R.nn.conv2d( + lv5, + w3, + strides=[2, 2], + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="float32", + ) + lv6: R.Tensor((2, 32, 10, 10), dtype="float32") = R.layout_transform( + gv3, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + gv7: R.Tensor((2, 37, 10, 10), dtype="float32") = R.concat((lv6, gv6), axis=1) + R.output(gv7) + return gv7 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + + +def test_pooling_branching_texture_params(): + """ + Verification of the pooling and many branches having textures + layout_transform (NCHW->NCHW4c) + | <- texture + conv2d (0) <- to get textures + | <- textures + pooling + / \ \ <- textures + conv2d (1) conv2d (2) conv2d (3) + \ / | + add | <- to have the only one output, will be fused + \ / + add <- to have the only one output, will be fused + | <- buffer + layout_transform (NCHW4c->NCHW) + """ + + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 32, 40, 40), "float32"), + w1: R.Tensor((32, 32, 1, 1), "float32"), + w2: R.Tensor((32, 32, 2, 2), "float32"), + w3: R.Tensor((32, 32, 1, 1), "float32"), + w4: R.Tensor((32, 32, 2, 2), "float32"), + bias1: R.Tensor((1, 32, 1, 1), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv = R.nn.conv2d(x, w1, strides=[1, 1], out_dtype="float32") + gv1 = R.nn.max_pool2d(gv, pool_size=[2, 2], strides=[2, 2]) + gv2 = R.nn.conv2d( + gv1, w2, padding=[0, 0, 1, 1], strides=[1, 1], out_dtype="float32" + ) + gv3 = R.add(gv2, bias1) + gv4 = R.nn.relu(gv3) + gv5 = R.nn.conv2d( + gv1, w3, padding=[0, 0, 0, 0], strides=[1, 1], out_dtype="float32" + ) + gv6 = R.nn.conv2d( + gv1, w4, padding=[0, 1, 1, 0], strides=[1, 1], out_dtype="float32" + ) + gv7 = R.nn.relu(gv6) + gv8 = R.add(gv2, gv5) + gv9 = R.add(gv8, gv6) + R.output(gv9) + return gv9 + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 32, 40, 40), dtype="float32"), + w1: R.Tensor((32, 32, 1, 1), dtype="float32"), + w2: R.Tensor((32, 32, 2, 2), dtype="float32"), + w3: R.Tensor((32, 32, 1, 1), dtype="float32"), + w4: R.Tensor((32, 32, 2, 2), dtype="float32"), + bias1: R.Tensor((1, 32, 1, 1), dtype="float32"), + ) -> R.Tensor((2, 32, 20, 20), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 8, 40, 40, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((8, 32, 1, 1, 4), dtype="float32") = R.layout_transform( + w1, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 8, 40, 40, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + gv1: R.Tensor((2, 8, 20, 20, 4), dtype="float32") = R.nn.max_pool2d( + gv, pool_size=[2, 2], strides=[2, 2], layout="NCHW4c", out_layout="NCHW4c" + ) + lv2: R.Tensor((8, 32, 2, 2, 4), dtype="float32") = R.layout_transform( + w2, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv2: R.Tensor((2, 8, 20, 20, 4), dtype="float32") = R.nn.conv2d( + gv1, + lv2, + padding=[0, 0, 1, 1], + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv3: R.Tensor((1, 8, 1, 1, 4), dtype="float32") = R.layout_transform( + bias1, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + gv3: R.Tensor((2, 8, 20, 20, 4), dtype="float32") = R.add(gv2, lv3) + gv4: R.Tensor((2, 8, 20, 20, 4), dtype="float32") = R.nn.relu(gv3) + lv4: R.Tensor((8, 32, 1, 1, 4), dtype="float32") = R.layout_transform( + w3, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv5: R.Tensor((2, 8, 20, 20, 4), dtype="float32") = R.nn.conv2d( + gv1, + lv4, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv5: R.Tensor((8, 32, 2, 2, 4), dtype="float32") = R.layout_transform( + w4, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv6: R.Tensor((2, 8, 20, 20, 4), dtype="float32") = R.nn.conv2d( + gv1, + lv5, + strides=[1, 1], + padding=[0, 1, 1, 0], + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + gv7: R.Tensor((2, 8, 20, 20, 4), dtype="float32") = R.nn.relu(gv6) + gv8: R.Tensor((2, 8, 20, 20, 4), dtype="float32") = R.add(gv2, gv5) + lv6: R.Tensor((2, 8, 20, 20, 4), dtype="float32") = R.add(gv8, gv6) + gv9: R.Tensor((2, 32, 20, 20), dtype="float32") = R.layout_transform( + lv6, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + R.output(gv9) + return gv9 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + + if __name__ == "__main__": tvm.testing.main() From f8e0aaec684bd491dbad3f8a39a822e2ba9b4453 Mon Sep 17 00:00:00 2001 From: Siva Date: Tue, 21 Jan 2025 22:08:16 +0530 Subject: [PATCH 02/31] lint --- tests/python/relax/test_transform_annotate_custom_scope.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python/relax/test_transform_annotate_custom_scope.py b/tests/python/relax/test_transform_annotate_custom_scope.py index af923b49c1b6..5041a89bfc3f 100644 --- a/tests/python/relax/test_transform_annotate_custom_scope.py +++ b/tests/python/relax/test_transform_annotate_custom_scope.py @@ -929,7 +929,7 @@ def test_injective_inputs1(): conv2d (1) / | / conv2d (2) mean / - / \ / + / \ / | | \ / | | (3) add | | | @@ -976,12 +976,12 @@ def main( def test_injective_nwo_inputs2(): """ Input - / \ + / \ | \ conv2d \ | / conv2d mean / - / \ / + / \ / add | \ | | | \ | | | \ / From 2605777b2dd967cb2b450d90d8c3797626accc9a Mon Sep 17 00:00:00 2001 From: Siva Date: Wed, 22 Jan 2025 21:44:33 +0530 Subject: [PATCH 03/31] Optional attr addition in legalization --- include/tvm/relax/transform.h | 5 +- python/tvm/dlight/adreno/convolution.py | 2 +- python/tvm/relax/op/base.py | 9 +- .../legalize_ops/adreno/convolution.py | 3 +- .../tvm/relax/transform/optimize_batchnorm.py | 11 +- python/tvm/relax/transform/transform.py | 12 +- src/relax/op/op.cc | 14 +- src/relax/transform/fuse_tir.cc | 8 +- src/relax/transform/legalize_ops.cc | 19 +- src/script/printer/relax/call.cc | 4 + tests/python/relax/test_transform.py | 1 + .../test_transform_annotate_custom_scope.py | 539 +++++++++++------- tests/python/relax/test_transform_fuse_tir.py | 4 +- .../test_tvmscript_parser_op_manipulate.py | 15 + 14 files changed, 404 insertions(+), 242 deletions(-) diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index cf7989c1a14d..0f6a4f25caa5 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -246,10 +246,11 @@ TVM_DLL Pass FoldConstant(); * will override the default one. * \param enable_warning A boolean value indicating if to print warnings for TIR functions not * showing up in the database. + * \param add_attributes A boolean value indicating adding of call attributes to TIR functions * \return The Pass. */ -TVM_DLL Pass LegalizeOps(ffi::Optional> cmap, - bool enable_warning = false); +TVM_DLL Pass LegalizeOps(ffi::Optional> cmap, bool enable_warning = false + bool add_attributes = false); /*! * \brief Propagate virtual device information. diff --git a/python/tvm/dlight/adreno/convolution.py b/python/tvm/dlight/adreno/convolution.py index 16075f814df2..f084885dad73 100644 --- a/python/tvm/dlight/adreno/convolution.py +++ b/python/tvm/dlight/adreno/convolution.py @@ -221,7 +221,7 @@ def schedule_conv2d_blocks(): sch.compute_inline(blk) else: sch.reverse_compute_inline(blk) - except: # pylint: disable=W0702 + except: # pylint: disable=W0702 pass else: raise TypeError("Can't Schedule this Block", sch.get(blk)) diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py index e205abde30b4..ffa19fbaa060 100644 --- a/python/tvm/relax/op/base.py +++ b/python/tvm/relax/op/base.py @@ -849,7 +849,7 @@ def to_vdevice(data, dst_vdevice) -> Expr: return _ffi_api.to_vdevice(data, dst_vdevice) # type: ignore -def hint_on_device(data, dst_vdevice) -> Expr: +def hint_on_device(data, dst_vdevice, memory_scope="global") -> Expr: """It provides a hint specifying the device on which the input data should be executed. This hint is utilized by RealizeVDevice to propagate the virtual device." @@ -858,12 +858,15 @@ def hint_on_device(data, dst_vdevice) -> Expr: data : Expr The tensor to be copied. - dst_device : VDevice + dst_device : Device The destination device where the data is supposed to be executed. + memory_scope: String + Memory scope of buffer on target device. + Returns ------- result : Expr The result. """ - return _ffi_api.hint_on_device(data, dst_vdevice) # type: ignore + return _ffi_api.hint_on_device(data, dst_vdevice, memory_scope) # type: ignore diff --git a/python/tvm/relax/transform/legalize_ops/adreno/convolution.py b/python/tvm/relax/transform/legalize_ops/adreno/convolution.py index 8d6a871e4d1c..eb0bf30cfbf2 100644 --- a/python/tvm/relax/transform/legalize_ops/adreno/convolution.py +++ b/python/tvm/relax/transform/legalize_ops/adreno/convolution.py @@ -20,6 +20,7 @@ from tvm import relax from tvm import topi + def conv2d_NCHWc_OIHWo(bb: relax.BlockBuilder, call: relax.Call) -> relax.Expr: return bb.call_te( topi.nn.conv2d_NCHWc_OIHWo, @@ -30,6 +31,6 @@ def conv2d_NCHWc_OIHWo(bb: relax.BlockBuilder, call: relax.Call) -> relax.Expr: dilation=call.attrs.dilation, layout=call.attrs.data_layout, out_layout=call.attrs.out_layout, - #out_dtype=call.attrs.out_dtype, + # out_dtype=call.attrs.out_dtype, primfunc_name_hint="conv2d_NCHWc_OIHWo", ) diff --git a/python/tvm/relax/transform/optimize_batchnorm.py b/python/tvm/relax/transform/optimize_batchnorm.py index 51b4b1e73319..9eee18c0d032 100644 --- a/python/tvm/relax/transform/optimize_batchnorm.py +++ b/python/tvm/relax/transform/optimize_batchnorm.py @@ -40,11 +40,7 @@ def __init__(self): self.mean = is_const() self.variance = is_const() self.pattern_bn = is_op("relax.nn.batch_norm")( - self.pattern_conv2d, - self.bn_weight, - self.bias, - self.mean, - self.variance + self.pattern_conv2d, self.bn_weight, self.bias, self.mean, self.variance ) self.pattern = TupleGetItemPattern(self.pattern_bn, 0) @@ -85,8 +81,7 @@ def rewriter(expr, matches): bn_attrs = bn_op.attrs bn_variance = relax.op.add( - bn_variance, - relax.PrimValue(tir.FloatImm("float32", bn_attrs['epsilon'])) + bn_variance, relax.PrimValue(tir.FloatImm("float32", bn_attrs["epsilon"])) ) dino = relax.op.sqrt(bn_variance) wt = relax.op.divide(bn_weight, dino) @@ -98,7 +93,7 @@ def rewriter(expr, matches): else: return expr wt_conv = relax.op.multiply(conv_weight, wt) - bs_args = relax.op.reshape(bs, shape=(1, bn_bias.struct_info.shape[0] , 1, 1)) + bs_args = relax.op.reshape(bs, shape=(1, bn_bias.struct_info.shape[0], 1, 1)) conv_out = relax.Call(conv_op.op, (conv_input, wt_conv), conv_attrs) return relax.op.add(conv_out, bs_args) diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index cc170bd38c9b..fcb010a7db07 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -1063,7 +1063,9 @@ def BundleModelParams(param_tuple_name: Optional[str] = None) -> tvm.ir.transfor def LegalizeOps( - customize_legalize_map: Optional[Dict[str, LegalizeFunc]] = None, enable_warning: bool = False + customize_legalize_map: Optional[Dict[str, LegalizeFunc]] = None, + enable_warning: bool = False, + add_attributes: bool = False, ): """Legalize high-level operator calls in Relax functions to call_tir with corresponding low-level TIR PrimFuncs. @@ -1094,6 +1096,10 @@ def LegalizeOps( legalization function is not registered. By default we don't print warnings. + add_attributes : bool + A boolean value indicating if we want legalize ops to add operator attributes to legalized + prim function attributes. By default it's false. + Returns ------- ret : tvm.transform.Pass @@ -1168,7 +1174,9 @@ def multiply( T_multiply[v_ax0, v_ax1] = A[v_ax0, v_ax1] * B[v_ax0, v_ax1] """ - return _ffi_api.LegalizeOps(customize_legalize_map, enable_warning) # type: ignore + return _ffi_api.LegalizeOps( + customize_legalize_map, enable_warning, add_attributes # type: ignore + ) def RealizeVDevice() -> tvm.ir.transform.Pass: diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index d732250cf9ec..345832da8b61 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -1506,18 +1506,28 @@ TVM_REGISTER_OP("relax.hint_on_device") .set_attr("FInferStructInfo", InferHintOnDeviceStructInfo) .set_attr("FPurity", Bool(true)); -Expr MakeHintOnDevice(Expr data, VDevice vdevice) { +Expr MakeHintOnDevice(Expr data, Device device, String memory_scope = "global") { static const Op& op = Op::Get("relax.hint_on_device"); ObjectPtr attrs = ffi::make_object(); attrs->device_type = static_cast(device.device_type); attrs->index = device.device_id; attrs->memory_scope = vdevice->memory_scope; + attrs->memory_scope = memory_scope; return Call(op, {data}, Attrs(attrs), {}); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.op.hint_on_device", MakeHintOnDevice); + refl::GlobalDef().def_packed("relax.op.hint_on_device", [](ffi::PackedArgs args, ffi::Any* rv), { + Expr data = args[0].cast(); + Device device = args[1].cast(); + if (args.size() == 3) { + String scope = args[2].cast(); + *rv = MakeHintOnDevice(data, device, scope); + } else { + *rv = MakeHintOnDevice(data, device); + } + }); } } // namespace relax diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index 3ed14fa85307..63671d077d60 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -967,12 +967,14 @@ class FusedTIRConstructor : public ExprVisitor { tir::PrimFunc ConstructFunc() { ffi::Map attr_map; attr_map.Set(tir::attr::kNoAlias, true); - attr_map.Set("op_attrs", func_info_.op_attrs); - int op_pattern = relay::kOpaque; + if (!func_info_.op_attrs.empty()) { + attr_map.Set("op_attrs", func_info_.op_attrs); + } if (!func_info_.op_pattern.empty()) { + int op_pattern = relay::kOpaque; op_pattern = *max_element(func_info_.op_pattern.begin(), func_info_.op_pattern.end()); + attr_map.Set("op_pattern", Integer(static_cast(op_pattern))); } - attr_map.Set("op_pattern", Integer(static_cast(op_pattern))); tir::FuseTIRBufferSubstitutor subst(func_info_.buffer_subst_map, func_info_.symbolic_var_remap); ICHECK(func_info_.global_name != "fused"); // Remove output buffers from func_info_.alloc_buffers diff --git a/src/relax/transform/legalize_ops.cc b/src/relax/transform/legalize_ops.cc index e8358798dddc..5986228ad06b 100644 --- a/src/relax/transform/legalize_ops.cc +++ b/src/relax/transform/legalize_ops.cc @@ -62,8 +62,11 @@ class LegalizeMutator : public ExprMutator { public: explicit LegalizeMutator(const IRModule& mod, const ffi::Optional>& cmap, - bool enable_warning) - : ExprMutator(mod), mod_(std::move(mod)), enable_warning_(enable_warning) { + bool enable_warning, bool add_attributes) + : ExprMutator(mod), + mod_(std::move(mod)), + enable_warning_(enable_warning), + add_attributes_(add_attributes) { if (cmap) { cmap_ = cmap.value(); } @@ -370,7 +373,9 @@ class LegalizeMutator : public ExprMutator { } Expr legalized = legalization_func(builder_, visited_call); - legalized = AttributeOpAttrs(legalized, call->attrs); + if (call->attrs.as() && add_attributes_) { + legalized = AttributeOpAttrs(legalized, call->attrs); + } // Append the target attribute to any PrimFunc generated in // legalization. @@ -415,16 +420,20 @@ class LegalizeMutator : public ExprMutator { * legalization function is not registered. */ bool enable_warning_; + /*! + * \brief Boolean indicating this pass to add operator attributes to prim function attr + */ + bool add_attributes_; }; namespace transform { -Pass LegalizeOps(ffi::Optional> cmap, bool enable_warning) { +Pass LegalizeOps(ffi::Optional> cmap, bool enable_warning, bool add_attributes) { auto pass_func = [=](IRModule mod, PassContext pc) { bool apply_legalize_ops = pc->GetConfig("relax.transform.apply_legalize_ops").value_or(Bool(true))->value; if (apply_legalize_ops) { - mod = LegalizeMutator(mod, cmap, enable_warning).Transform(); + mod = LegalizeMutator(mod, cmap, enable_warning, add_attributes).Transform(); } return mod; }; diff --git a/src/script/printer/relax/call.cc b/src/script/printer/relax/call.cc index 62a6db1f03ed..6d96327e2db4 100644 --- a/src/script/printer/relax/call.cc +++ b/src/script/printer/relax/call.cc @@ -194,7 +194,11 @@ ffi::Optional PrintHintOnDevice(const relax::Call& n, const AccessPath& ICHECK(n->attrs.defined()); if (n->attrs.as()) { AttrPrinter(n_p->Attr("attrs"), d, &kwargs_keys, &kwargs_values)(n->attrs); + ExprDoc scope_val = kwargs_values.back(); + kwargs_keys.pop_back(); + kwargs_values.pop_back(); args.push_back(Relax(d, "device")->Call({}, kwargs_keys, kwargs_values)); + args.push_back(scope_val); } return Relax(d, "hint_on_device")->Call(args); } diff --git a/tests/python/relax/test_transform.py b/tests/python/relax/test_transform.py index e3274aea886a..b0bec5e858af 100644 --- a/tests/python/relax/test_transform.py +++ b/tests/python/relax/test_transform.py @@ -17,6 +17,7 @@ import pytest import tvm +import tvm.testing from tvm import relax import tvm.script diff --git a/tests/python/relax/test_transform_annotate_custom_scope.py b/tests/python/relax/test_transform_annotate_custom_scope.py index 5041a89bfc3f..1f67b1eaab49 100644 --- a/tests/python/relax/test_transform_annotate_custom_scope.py +++ b/tests/python/relax/test_transform_annotate_custom_scope.py @@ -24,6 +24,7 @@ from tvm.ir.module import IRModule from tvm.relax.expr_functor import PyExprMutator, PyExprVisitor, mutator, visitor + @visitor class ValidateScope(PyExprVisitor): # pylint: disable=abstract-method def __init__(self, scope_info: dict) -> None: @@ -38,25 +39,32 @@ def visit(self, mod: IRModule) -> None: self.visit_expr(func) return self.matched - def visit_call_(self, call: relax.Call) -> None: # pylint: disable=arguments-renamed - if (call.op.name == "relax.call_tir"): - #if call.args[0].name_hint in self.scope_info: - for idx, arg in enumerate(call.args[1]): - arg_sinfo = arg.struct_info - assert isinstance(arg_sinfo, relax.TensorStructInfo - ), f"Expected TensorStructInfo but git {type(arg_sinfo)}" - assert (arg_sinfo.vdevice.memory_scope == self.scope_info[call.args[0].name_hint][0][idx] - ), f"Scope mispatched for argument {idx} in {call.args[0].name_hint}" - if isinstance(call.sinfo_args[0], relax.TensorStructInfo): - assert(call.sinfo_args[0].vdevice.memory_scope == self.scope_info[call.args[0].name_hint][1][0] - ), f"Scope mispatched for return scope: {call.args[0].name_hint}" - else: - assert isinstance(call.sinfo_args[0], relax.TupleStructInfo - ), f"Expected TupleStructInfo but git {type(call.sinfo_args[0])}" - for idx, sinfo in enumerate(call.sinfo_args[0].fields): - assert (sinfo.vdevice.memory_scope == self.scope_info[call.args[0].name_hint][1][idx] - ), f"Scope mispatched for return scope for {idx} in {call.args[0].name_hint}" + if call.op.name == "relax.call_tir": + # if call.args[0].name_hint in self.scope_info: + for idx, arg in enumerate(call.args[1]): + arg_sinfo = arg.struct_info + assert isinstance( + arg_sinfo, relax.TensorStructInfo + ), f"Expected TensorStructInfo but git {type(arg_sinfo)}" + assert ( + arg_sinfo.vdevice.memory_scope + == self.scope_info[call.args[0].name_hint][0][idx] + ), f"Scope mispatched for argument {idx} in {call.args[0].name_hint}" + if isinstance(call.sinfo_args[0], relax.TensorStructInfo): + assert ( + call.sinfo_args[0].vdevice.memory_scope + == self.scope_info[call.args[0].name_hint][1][0] + ), f"Scope mispatched for return scope: {call.args[0].name_hint}" + else: + assert isinstance( + call.sinfo_args[0], relax.TupleStructInfo + ), f"Expected TupleStructInfo but git {type(call.sinfo_args[0])}" + for idx, sinfo in enumerate(call.sinfo_args[0].fields): + assert ( + sinfo.vdevice.memory_scope + == self.scope_info[call.args[0].name_hint][1][idx] + ), f"Scope mispatched for return scope for {idx} in {call.args[0].name_hint}" def verify(mod, expected): @@ -71,8 +79,11 @@ def verify(mod, expected): mod = tvm.relax.transform.ConvertLayout(desired_layouts)(mod) mod = tvm.relax.transform.Normalize()(mod) mod = tvm.relax.transform.FoldConstant()(mod) - mod = tvm.relax.transform.LegalizeOps()(mod) - mod = tvm.relax.transform.LegalizeOps({"relax.nn.conv2d": legalize_adreno.conv2d_NCHWc_OIHWo})(mod) + mod = tvm.relax.transform.LegalizeOps(add_attributes=True)(mod) + mod = tvm.relax.transform.LegalizeOps( + {"relax.nn.conv2d": legalize_adreno.conv2d_NCHWc_OIHWo}, + add_attributes=True, + )(mod) mod = tvm.relax.transform.AnnotateTIROpPattern()(mod) mod = tvm.relax.transform.FoldConstant()(mod) mod = tvm.relax.transform.FuseOps()(mod) @@ -83,6 +94,7 @@ def verify(mod, expected): ValidateScope(expected).visit(mod) + def test_conv2d(): @I.ir_module class Input: @@ -95,16 +107,16 @@ def main( R.output(gv) return gv - Expected = { - 'te_layout_transform': (["global"], ["global.texture-nhwc"]), - 'te_layout_transform1': (["global"], ["global.texture-weight"]), - 'conv2d_NCHWc_OIHWo': (["global.texture-nhwc", "global.texture-weight"], ["global"]), - 'te_layout_transform2': (["global"], ["global"]), + "te_layout_transform": (["global"], ["global.texture-nhwc"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo": (["global.texture-nhwc", "global.texture-weight"], ["global"]), + "te_layout_transform2": (["global"], ["global"]), } verify(Input, Expected) + def test_conv2d_NCHW_sub_indexed(): @I.ir_module class Input: @@ -124,10 +136,10 @@ def main( return gv Expected = { - 'te_layout_transform': (["global"], ["global.texture-weight"]), - 'te_layout_transform1': (["global"], ["global.texture-weight"]), - 'conv2d_NCHWc_OIHWo': (["global.texture-weight", "global.texture-weight"], ["global"]), - 'te_layout_transform2': (["global"], ["global"]), + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo": (["global.texture-weight", "global.texture-weight"], ["global"]), + "te_layout_transform2": (["global"], ["global"]), } verify(Input, Expected) @@ -152,14 +164,15 @@ def main( return gv Expected = { - 'te_layout_transform': (["global"], ["global.texture-weight"]), - 'te_layout_transform1': (["global"], ["global.texture-weight"]), - 'conv2d_NCHWc_OIHWo': (["global.texture-weight", "global.texture-weight"], ["global"]), - 'te_layout_transform2': (["global"], ["global"]), + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo": (["global.texture-weight", "global.texture-weight"], ["global"]), + "te_layout_transform2": (["global"], ["global"]), } verify(Input, Expected) + def _test_conv2d_symbolic_sub_indexed(): @I.ir_module class Input: @@ -179,10 +192,10 @@ def main( return gv Expected = { - 'te_layout_transform': (["global"], ["global.texture-weight"]), - 'te_layout_transform1': (["global"], ["global.texture-weight"]), - 'conv2d_NCHWc_OIHWo': (["global.texture-weight", "global.texture-weight"], ["global"]), - 'te_layout_transform2': (["global"], ["global"]), + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo": (["global.texture-weight", "global.texture-weight"], ["global"]), + "te_layout_transform2": (["global"], ["global"]), } verify(Input, Expected) @@ -202,10 +215,13 @@ def main( return gv2 Expected = { - 'te_layout_transform': (["global"], ["global.texture-weight"]), - 'te_layout_transform1': (["global"], ["global.texture-weight"]), - 'fused_conv2d_NCHWc_OIHWo_relu': (["global.texture-weight", "global.texture-weight"], ["global"]), - 'te_layout_transform2': (["global"], ["global"]), + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "fused_conv2d_NCHWc_OIHWo_relu": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform2": (["global"], ["global"]), } verify(Input, Expected) @@ -226,11 +242,14 @@ def main( return gv2 Expected = { - 'relu': (["global"], ["global"]), - 'te_layout_transform': (["global"], ["global.texture-weight"]), - 'te_layout_transform1': (["global"], ["global.texture-weight"]), - 'fused_conv2d_NCHWc_OIHWo_relu1': (["global.texture-weight", "global.texture-weight"], ["global"]), - 'te_layout_transform2': (["global"], ["global"]), + "relu": (["global"], ["global"]), + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "fused_conv2d_NCHWc_OIHWo_relu1": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform2": (["global"], ["global"]), } verify(Input, Expected) @@ -251,10 +270,13 @@ def main( return gv3 Expected = { - 'te_layout_transform': (["global"], ["global.texture-weight"]), - 'te_layout_transform1': (["global"], ["global.texture-weight"]), - 'fused_conv2d_NCHWc_OIHWo_relu_tir_tanh': (["global.texture-weight", "global.texture-weight"], ["global"]), - 'te_layout_transform2': (["global"], ["global"]), + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "fused_conv2d_NCHWc_OIHWo_relu_tir_tanh": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform2": (["global"], ["global"]), } verify(Input, Expected) @@ -275,11 +297,14 @@ def main( return gv2 Expected = { - 'te_layout_transform': (["global"], ["global.texture-weight"]), - 'te_layout_transform1': (["global"], ["global.texture-weight"]), - 'te_layout_transform2': (["global"], ["global.texture-weight"]), - 'fused_conv2d_NCHWc_OIHWo_add': (["global.texture-weight", "global.texture-weight", "global.texture-weight"], ["global"]), - 'te_layout_transform3': (["global"], ["global"]), + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "te_layout_transform2": (["global"], ["global.texture-weight"]), + "fused_conv2d_NCHWc_OIHWo_add": ( + ["global.texture-weight", "global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform3": (["global"], ["global"]), } verify(Input, Expected) @@ -303,15 +328,15 @@ def main( return gv4 Expected = { - 'te_layout_transform': (["global"], ["global.texture-weight"]), - 'te_layout_transform1': (["global"], ["global.texture-weight"]), - 'conv2d_NCHWc_OIHWo': (["global.texture-weight", "global.texture-weight"], ["global"]), - 'relu': (["global"], ["global"]), - 'te_layout_transform2': (["global"], ["global"]), - 'te_layout_transform3': (["global"], ["global.texture-weight"]), - 'te_layout_transform4': (["global"], ["global.texture-weight"]), - 'conv2d_NCHWc_OIHWo1': (["global.texture-weight", "global.texture-weight"], ["global"]), - 'te_layout_transform5': (["global"], ["global"]), + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo": (["global.texture-weight", "global.texture-weight"], ["global"]), + "relu": (["global"], ["global"]), + "te_layout_transform2": (["global"], ["global"]), + "te_layout_transform3": (["global"], ["global.texture-weight"]), + "te_layout_transform4": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo1": (["global.texture-weight", "global.texture-weight"], ["global"]), + "te_layout_transform5": (["global"], ["global"]), } verify(Input, Expected) @@ -330,14 +355,15 @@ def main( return gv2 Expected = { - 'te_layout_transform': (["global"], ["global.texture-weight"]), - 'te_layout_transform1': (["global"], ["global.texture-weight"]), - 'conv2d_NCHWc_OIHWo': (["global.texture-weight", "global.texture-weight"], ["global"]), - 'sum': (["global"], ["global"]), - 'te_layout_transform2': (["global"], ["global"]), + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo": (["global.texture-weight", "global.texture-weight"], ["global"]), + "sum": (["global"], ["global"]), + "te_layout_transform2": (["global"], ["global"]), } verify(Input, Expected) + def test_conv2d_sum_keepdims_sub_indexed(): @I.ir_module class Input: @@ -352,11 +378,11 @@ def main( return gv2 Expected = { - 'te_layout_transform': (["global"], ["global.texture-weight"]), - 'te_layout_transform1': (["global"], ["global.texture-weight"]), - 'conv2d_NCHWc_OIHWo': (["global.texture-weight", "global.texture-weight"], ["global"]), - 'sum': (["global"], ["global"]), - 'te_layout_transform2': (["global"], ["global"]), + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo": (["global.texture-weight", "global.texture-weight"], ["global"]), + "sum": (["global"], ["global"]), + "te_layout_transform2": (["global"], ["global"]), } verify(Input, Expected) @@ -375,11 +401,11 @@ def main( return gv2 Expected = { - 'te_layout_transform': (["global"], ["global.texture-weight"]), - 'te_layout_transform1': (["global"], ["global.texture-weight"]), - 'conv2d_NCHWc_OIHWo': (["global.texture-weight", "global.texture-weight"], ["global"]), - 'sum': (["global"], ["global"]), - 'te_layout_transform2': (["global"], ["global"]), + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo": (["global.texture-weight", "global.texture-weight"], ["global"]), + "sum": (["global"], ["global"]), + "te_layout_transform2": (["global"], ["global"]), } verify(Input, Expected) @@ -398,14 +424,15 @@ def main( return gv2 Expected = { - 'te_layout_transform': (["global"], ["global.texture-weight"]), - 'te_layout_transform1': (["global"], ["global.texture-weight"]), - 'conv2d_NCHWc_OIHWo': (["global.texture-weight", "global.texture-weight"], ["global"]), - 'te_layout_transform2': (["global"], ["global"]), - 'transpose': (["global"], ["global"]), + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo": (["global.texture-weight", "global.texture-weight"], ["global"]), + "te_layout_transform2": (["global"], ["global"]), + "transpose": (["global"], ["global"]), } verify(Input, Expected) + def test_conv2d_expand_dims_sub_indexed(): @I.ir_module class Input: @@ -420,11 +447,11 @@ def main( return gv2 Expected = { - 'te_layout_transform': (["global"], ["global.texture-weight"]), - 'te_layout_transform1': (["global"], ["global.texture-weight"]), - 'conv2d_NCHWc_OIHWo': (["global.texture-weight", "global.texture-weight"], ["global"]), - 'te_layout_transform2': (["global"], ["global"]), - 'expand_dims': (["global"], ["global"]), + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo": (["global.texture-weight", "global.texture-weight"], ["global"]), + "te_layout_transform2": (["global"], ["global"]), + "expand_dims": (["global"], ["global"]), } verify(Input, Expected) @@ -443,11 +470,11 @@ def main( return gv2 Expected = { - 'te_layout_transform': (["global"], ["global.texture-weight"]), - 'te_layout_transform1': (["global"], ["global.texture-weight"]), - 'conv2d_NCHWc_OIHWo': (["global.texture-weight", "global.texture-weight"], ["global"]), - 'te_layout_transform2': (["global"], ["global"]), - 'squeeze': (["global"], ["global"]), + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo": (["global.texture-weight", "global.texture-weight"], ["global"]), + "te_layout_transform2": (["global"], ["global"]), + "squeeze": (["global"], ["global"]), } verify(Input, Expected) @@ -468,11 +495,11 @@ def main( return gv2 Expected = { - 'te_layout_transform': (["global"], ["global.texture-weight"]), - 'te_layout_transform1': (["global"], ["global.texture-weight"]), - 'conv2d_NCHWc_OIHWo': (["global.texture-weight", "global.texture-weight"], ["global"]), - 'te_layout_transform2': (["global"], ["global"]), - 'strided_slice': (["global"], ["global"]), + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo": (["global.texture-weight", "global.texture-weight"], ["global"]), + "te_layout_transform2": (["global"], ["global"]), + "strided_slice": (["global"], ["global"]), } verify(Input, Expected) @@ -492,11 +519,14 @@ def main( return gv3 Expected = { - 'te_layout_transform': (["global"], ["global.texture-weight"]), - 'te_layout_transform1': (["global"], ["global.texture-weight"]), - 'conv2d_NCHWc_OIHWo': (["global.texture-weight", "global.texture-weight"], ["global.texture-weight"]), - 'fused_relu_concatenate': (["global.texture-weight"], ["global"]), - 'te_layout_transform2': (["global"], ["global"]), + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo": ( + ["global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "fused_relu_concatenate": (["global.texture-weight"], ["global"]), + "te_layout_transform2": (["global"], ["global"]), } verify(Input, Expected) @@ -515,12 +545,15 @@ def main(x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "fl return gv4 Expected = { - 'te_layout_transform': (["global"], ["global.texture-weight"]), - 'te_layout_transform1': (["global"], ["global.texture-weight"]), - 'conv2d_NCHWc_OIHWo': (["global.texture-weight", "global.texture-weight"], ["global.texture-weight"]), - 'fused_relu_concatenate_split': (["global.texture-weight"], ["global", "global"]), - 'te_layout_transform2': (["global"], ["global"]), - 'te_layout_transform3': (["global"], ["global"]), + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo": ( + ["global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "fused_relu_concatenate_split": (["global.texture-weight"], ["global", "global"]), + "te_layout_transform2": (["global"], ["global"]), + "te_layout_transform3": (["global"], ["global"]), } verify(Input, Expected) @@ -542,13 +575,16 @@ def main(x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "fl return gv7 Expected = { - 'te_layout_transform': (["global"], ["global.texture-weight"]), - 'te_layout_transform1': (["global"], ["global.texture-weight"]), - 'conv2d_NCHWc_OIHWo': (["global.texture-weight", "global.texture-weight"], ["global.texture-weight"]), - 'fused_relu_concatenate_split': (["global.texture-weight"], ["global", "global"]), - 'te_layout_transform2': (["global"], ["global"]), - 'te_layout_transform3': (["global"], ["global"]), - 'fused_transpose_transpose1_concatenate1': (["global", "global"], ["global"]), + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo": ( + ["global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "fused_relu_concatenate_split": (["global.texture-weight"], ["global", "global"]), + "te_layout_transform2": (["global"], ["global"]), + "te_layout_transform3": (["global"], ["global"]), + "fused_transpose_transpose1_concatenate1": (["global", "global"], ["global"]), } verify(Input, Expected) @@ -574,11 +610,14 @@ def main( return gv2 Expected = { - 'te_layout_transform': (["global"], ["global.texture-weight"]), - 'te_layout_transform1': (["global"], ["global.texture-weight"]), - 'conv2d_NCHWc_OIHWo': (["global.texture-weight", "global.texture-weight"], ["global.texture-weight"]), - 'max_pool2d': (["global.texture-weight"], ["global"]), - 'te_layout_transform2': (["global"], ["global"]), + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo": ( + ["global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "max_pool2d": (["global.texture-weight"], ["global"]), + "te_layout_transform2": (["global"], ["global"]), } verify(Input, Expected) @@ -597,11 +636,14 @@ def main( return gv2 Expected = { - 'te_layout_transform': (["global"], ["global.texture-weight"]), - 'te_layout_transform1': (["global"], ["global.texture-weight"]), - 'conv2d_NCHWc_OIHWo': (["global.texture-weight", "global.texture-weight"], ["global.texture-weight"]), - 'adaptive_avg_pool2d': (["global.texture-weight"], ["global"]), - 'te_layout_transform2': (["global"], ["global"]), + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo": ( + ["global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "adaptive_avg_pool2d": (["global.texture-weight"], ["global"]), + "te_layout_transform2": (["global"], ["global"]), } verify(Input, Expected) @@ -620,11 +662,11 @@ def main( return gv2 Expected = { - 'te_layout_transform': (["global"], ["global.texture-weight"]), - 'te_layout_transform1': (["global"], ["global.texture-weight"]), - 'conv2d_NCHWc_OIHWo': (["global.texture-weight", "global.texture-weight"], ["global"]), - 'te_layout_transform2': (["global"], ["global"]), - 'softmax': (["global"], ["global"]), + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo": (["global.texture-weight", "global.texture-weight"], ["global"]), + "te_layout_transform2": (["global"], ["global"]), + "softmax": (["global"], ["global"]), } verify(Input, Expected) @@ -648,11 +690,14 @@ def main( return gv2 Expected = { - 'te_layout_transform': (["global"], ["global.texture-weight"]), - 'te_layout_transform1': (["global"], ["global.texture-weight"]), - 'conv2d_NCHWc_OIHWo': (["global.texture-weight", "global.texture-weight"], ["global.texture-weight"]), - 'layer_norm': (["global.texture-weight", "global", "global"], ["global"]), - 'te_layout_transform2': (["global"], ["global"]), + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo": ( + ["global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "layer_norm": (["global.texture-weight", "global", "global"], ["global"]), + "te_layout_transform2": (["global"], ["global"]), } verify(Input, Expected) @@ -673,11 +718,11 @@ def main( return gv2 Expected = { - 'te_layout_transform': (["global"], ["global.texture-weight"]), - 'te_layout_transform1': (["global"], ["global.texture-weight"]), - 'conv2d_NCHWc_OIHWo': (["global.texture-weight", "global.texture-weight"], ["global"]), - 'te_layout_transform2': (["global"], ["global"]), - 'add': (["global", "global"], ["global"]), + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo": (["global.texture-weight", "global.texture-weight"], ["global"]), + "te_layout_transform2": (["global"], ["global"]), + "add": (["global", "global"], ["global"]), } verify(Input, Expected) @@ -696,13 +741,17 @@ def main( return gv2 Expected = { - 'te_layout_transform': (["global"], ["global.texture-weight"]), - 'te_layout_transform1': (["global"], ["global.texture-weight"]), - 'fused_conv2d_NCHWc_OIHWo_add': (["global.texture-weight", "global.texture-weight"], ["global"]), - 'te_layout_transform2': (["global"], ["global"]), + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "fused_conv2d_NCHWc_OIHWo_add": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform2": (["global"], ["global"]), } verify(Input, Expected) + def test_residual_block(): """ - some kind of residual block followed by convolution to have texture after residual block @@ -747,17 +796,26 @@ def main( return gv7 Expected = { - 'te_layout_transform': (["global"], ["global.texture-weight"]), - 'te_layout_transform1': (["global"], ["global.texture-weight"]), - 'te_layout_transform2': (["global"], ["global.texture-weight"]), - 'fused_conv2d_NCHWc_OIHWo_add_relu': (["global.texture-weight", "global.texture-weight", "global.texture-weight"], ["global.texture-weight"]), - 'te_layout_transform3': (["global"], ["global.texture-weight"]), - 'multiply': (["global"], ["global"]), - 'te_layout_transform4': (["global"], ["global.texture-weight"]), - 'fused_conv2d_NCHWc_OIHWo1_add1_relu1': (["global.texture-weight", "global.texture-weight", "global.texture-weight"], ["global.texture-weight"]), - 'te_layout_transform5': (["global"], ["global.texture-weight"]), - 'fused_conv2d_NCHWc_OIHWo2_relu2': (["global.texture-weight", "global.texture-weight"], ["global"]), - 'te_layout_transform6': (["global"], ["global"]), + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "te_layout_transform2": (["global"], ["global.texture-weight"]), + "fused_conv2d_NCHWc_OIHWo_add_relu": ( + ["global.texture-weight", "global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "te_layout_transform3": (["global"], ["global.texture-weight"]), + "multiply": (["global"], ["global"]), + "te_layout_transform4": (["global"], ["global.texture-weight"]), + "fused_conv2d_NCHWc_OIHWo1_add_relu": ( + ["global.texture-weight", "global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "te_layout_transform5": (["global"], ["global.texture-weight"]), + "fused_conv2d_NCHWc_OIHWo2_relu1": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform6": (["global"], ["global"]), } verify(Input, Expected) @@ -774,6 +832,7 @@ def test_conv2d_conv2d_fallback_to_buffer_conv2d(): | <- buffer layout_transform (NCHW4c->NCHW) """ + @I.ir_module class Input: @R.function @@ -798,16 +857,19 @@ def main( return gv7 Expected = { - 'te_layout_transform': (["global"], ["global.texture-weight"]), - 'te_layout_transform1': (["global"], ["global.texture-weight"]), - 'te_layout_transform2': (["global"], ["global.texture-weight"]), - 'fused_conv2d_NCHWc_OIHWo_add_relu': (["global.texture-weight", "global.texture-weight", "global.texture-weight"], ["global"]), - 'te_layout_transform3': (["global"], ["global.texture-weight"]), - 'conv2d_NCHWc_OIHWo1': (["global.texture-weight", "global.texture-weight"], ["global"]), - 'te_layout_transform4': (["global"], ["global"]), - 'conv2d': (["global", "global"], ["global"]), - 'te_layout_transform5': (["global"], ["global"]), - 'concatenate': (["global", "global"], ["global"]), + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "te_layout_transform2": (["global"], ["global.texture-weight"]), + "fused_conv2d_NCHWc_OIHWo_add_relu": ( + ["global.texture-weight", "global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform3": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo1": (["global.texture-weight", "global.texture-weight"], ["global"]), + "te_layout_transform4": (["global"], ["global"]), + "conv2d": (["global", "global"], ["global"]), + "te_layout_transform5": (["global"], ["global"]), + "concatenate": (["global", "global"], ["global"]), } verify(Input, Expected) @@ -824,6 +886,7 @@ def test_conv2d_conv2d_conv2d_concat(): | <- buffer layout_transform (NCHW4c->NCHW) """ + @I.ir_module class Input: @R.function @@ -848,16 +911,25 @@ def main( return gv7 Expected = { - 'te_layout_transform': (["global"], ["global.texture-weight"]), - 'te_layout_transform1': (["global"], ["global.texture-weight"]), - 'te_layout_transform2': (["global"], ["global.texture-weight"]), - 'fused_conv2d_NCHWc_OIHWo_add_relu': (["global.texture-weight", "global.texture-weight", "global.texture-weight"], ["global.texture-weight"]), - 'te_layout_transform3': (["global"], ["global.texture-weight"]), - 'conv2d_NCHWc_OIHWo1': (["global.texture-weight", "global.texture-weight"], ["global.texture-weight"]), - 'te_layout_transform4': (["global"], ["global.texture-weight"]), - 'conv2d_NCHWc_OIHWo2': (["global.texture-weight", "global.texture-weight"], ["global.texture-weight"]), - 'concatenate': (["global.texture-weight", "global.texture-weight"], ["global"]), - 'te_layout_transform5': (["global"], ["global"]), + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "te_layout_transform2": (["global"], ["global.texture-weight"]), + "fused_conv2d_NCHWc_OIHWo_add_relu": ( + ["global.texture-weight", "global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "te_layout_transform3": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo1": ( + ["global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "te_layout_transform4": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo2": ( + ["global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "concatenate": (["global.texture-weight", "global.texture-weight"], ["global"]), + "te_layout_transform5": (["global"], ["global"]), } verify(Input, Expected) @@ -879,6 +951,7 @@ def test_pooling_branching_texture_params(): | <- buffer layout_transform (NCHW4c->NCHW) """ + @I.ir_module class Input: @R.function @@ -893,11 +966,17 @@ def main( with R.dataflow(): gv = R.nn.conv2d(x, w1, strides=[1, 1], out_dtype="float32") gv1 = R.nn.max_pool2d(gv, pool_size=[2, 2], strides=[2, 2]) - gv2 = R.nn.conv2d(gv1, w2, padding=[0, 0, 1, 1], strides=[1, 1], out_dtype="float32") + gv2 = R.nn.conv2d( + gv1, w2, padding=[0, 0, 1, 1], strides=[1, 1], out_dtype="float32" + ) gv3 = R.add(gv2, bias1) gv4 = R.nn.relu(gv3) - gv5 = R.nn.conv2d(gv1, w3, padding=[0, 0, 0, 0], strides=[1, 1], out_dtype="float32") - gv6 = R.nn.conv2d(gv1, w4, padding=[0, 1, 1, 0], strides=[1, 1], out_dtype="float32") + gv5 = R.nn.conv2d( + gv1, w3, padding=[0, 0, 0, 0], strides=[1, 1], out_dtype="float32" + ) + gv6 = R.nn.conv2d( + gv1, w4, padding=[0, 1, 1, 0], strides=[1, 1], out_dtype="float32" + ) gv7 = R.nn.relu(gv6) gv8 = R.add(gv2, gv5) gv9 = R.add(gv8, gv6) @@ -905,17 +984,29 @@ def main( return gv9 Expected = { - 'te_layout_transform': (["global"], ["global.texture-weight"]), - 'te_layout_transform1': (["global"], ["global.texture-weight"]), - 'conv2d_NCHWc_OIHWo': (["global.texture-weight", "global.texture-weight"], ["global.texture-weight"]), - 'max_pool2d': (["global.texture-weight"], ["global.texture-weight"]), - 'te_layout_transform2': (["global"], ["global.texture-weight"]), - 'te_layout_transform3': (["global"], ["global.texture-weight"]), - 'conv2d_NCHWc_OIHWo2': (["global.texture-weight", "global.texture-weight"], ["global.texture-weight"]), - 'fused_conv2d_NCHWc_OIHWo1_add': (["global.texture-weight", "global.texture-weight", "global.texture-weight"], ["global.texture-weight"]), - 'te_layout_transform4': (["global"], ["global.texture-weight"]), - 'fused_conv2d_NCHWc_OIHWo3_add1': (["global.texture-weight", "global.texture-weight", "global.texture-weight"], ["global"]), - 'te_layout_transform5': (["global"], ["global"]), + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo": ( + ["global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "max_pool2d": (["global.texture-weight"], ["global.texture-weight"]), + "te_layout_transform2": (["global"], ["global.texture-weight"]), + "te_layout_transform3": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo2": ( + ["global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "fused_conv2d_NCHWc_OIHWo1_add": ( + ["global.texture-weight", "global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "te_layout_transform4": (["global"], ["global.texture-weight"]), + "fused_conv2d_NCHWc_OIHWo3_add": ( + ["global.texture-weight", "global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform5": (["global"], ["global"]), } verify(Input, Expected) @@ -939,6 +1030,7 @@ def test_injective_inputs1(): add """ + @I.ir_module class Input: @R.function @@ -950,8 +1042,12 @@ def main( ) -> R.Tensor(None, "float32", ndim=4): with R.dataflow(): mean = R.mean(x, axis=1, keepdims=True) - conv1 = R.nn.conv2d(x, w1, padding=[1, 1, 1, 1], strides=[1, 1], out_dtype="float32") - conv2 = R.nn.conv2d(conv1, w2, padding=[1, 1, 1, 1], strides=[1, 1], out_dtype="float32") + conv1 = R.nn.conv2d( + x, w1, padding=[1, 1, 1, 1], strides=[1, 1], out_dtype="float32" + ) + conv2 = R.nn.conv2d( + conv1, w2, padding=[1, 1, 1, 1], strides=[1, 1], out_dtype="float32" + ) ad3 = R.add(conv1, conv2) ad1 = R.add(mean, conv1) ad2 = R.multiply(ad1, conv1) @@ -960,15 +1056,21 @@ def main( return gv Expected = { - 'te_layout_transform': (["global"], ["global.texture-weight"]), - 'te_layout_transform1': (["global"], ["global.texture-weight"]), - 'conv2d_NCHWc_OIHWo': (["global.texture-weight", "global.texture-weight"], ["global.texture-weight"]), - 'te_layout_transform3': (["global"], ["global"]), - 'fused_mean_add1': (["global", "global"], ["global"]), - 'te_layout_transform2': (["global"], ["global.texture-weight"]), - 'te_layout_transform4': (["global"], ["global.texture-weight"]), - 'fused_conv2d_NCHWc_OIHWo1_add_multiply_add2': (["global.texture-weight", "global.texture-weight", "global.texture-weight"], ["global"]), - 'te_layout_transform5': (["global"], ["global"]), + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo": ( + ["global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "te_layout_transform3": (["global"], ["global"]), + "fused_mean_add1": (["global", "global"], ["global"]), + "te_layout_transform2": (["global"], ["global.texture-weight"]), + "te_layout_transform4": (["global"], ["global.texture-weight"]), + "fused_conv2d_NCHWc_OIHWo1_add_multiply_add": ( + ["global.texture-weight", "global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform5": (["global"], ["global"]), } verify(Input, Expected) @@ -994,6 +1096,7 @@ def test_injective_nwo_inputs2(): add """ + @I.ir_module class Input: @R.function @@ -1005,8 +1108,12 @@ def main( ) -> R.Tensor(None, "float32", ndim=4): with R.dataflow(): mean = R.mean(x, axis=1, keepdims=True) - conv1 = R.nn.conv2d(x, w1, padding=[1, 1, 1, 1], strides=[1, 1], out_dtype="float32") - conv2 = R.nn.conv2d(conv1, w2, padding=[1, 1, 1, 1], strides=[1, 1], out_dtype="float32") + conv1 = R.nn.conv2d( + x, w1, padding=[1, 1, 1, 1], strides=[1, 1], out_dtype="float32" + ) + conv2 = R.nn.conv2d( + conv1, w2, padding=[1, 1, 1, 1], strides=[1, 1], out_dtype="float32" + ) ad3 = R.add(conv1, conv2) ad1 = R.add(mean, conv1) ad2 = R.multiply(ad1, conv2) @@ -1015,15 +1122,21 @@ def main( return gv Expected = { - 'te_layout_transform': (["global"], ["global.texture-weight"]), - 'te_layout_transform1': (["global"], ["global.texture-weight"]), - 'conv2d_NCHWc_OIHWo': (["global.texture-weight", "global.texture-weight"], ["global.texture-weight"]), - 'te_layout_transform3': (["global"], ["global"]), - 'fused_mean_add1': (["global", "global"], ["global"]), - 'te_layout_transform2': (["global"], ["global.texture-weight"]), - 'te_layout_transform4': (["global"], ["global.texture-weight"]), - 'fused_conv2d_NCHWc_OIHWo1_add_multiply_add2': (["global.texture-weight", "global.texture-weight", "global.texture-weight"], ["global"]), - 'te_layout_transform5': (["global"], ["global"]), + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo": ( + ["global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "te_layout_transform3": (["global"], ["global"]), + "fused_mean_add1": (["global", "global"], ["global"]), + "te_layout_transform2": (["global"], ["global.texture-weight"]), + "te_layout_transform4": (["global"], ["global.texture-weight"]), + "fused_conv2d_NCHWc_OIHWo1_add_multiply_add": ( + ["global.texture-weight", "global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform5": (["global"], ["global"]), } verify(Input, Expected) diff --git a/tests/python/relax/test_transform_fuse_tir.py b/tests/python/relax/test_transform_fuse_tir.py index 8e583b3dd4cc..8b93e22e0752 100644 --- a/tests/python/relax/test_transform_fuse_tir.py +++ b/tests/python/relax/test_transform_fuse_tir.py @@ -1119,7 +1119,7 @@ def fused_concatenate_transpose2( (T.int64(2), T.int64(64), T.int64(64), T.int64(4)), "float32" ), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"op_pattern": 2, "tir.noalias": True}) T_concat_handle_intermediate = T.alloc_buffer( (T.int64(2), T.int64(4), T.int64(64), T.int64(64)) ) @@ -1309,7 +1309,7 @@ def fused_reshape( (T.int64(4), T.int64(8), T.int64(32), T.int64(64)), "float32" ), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"op_pattern": 2, "tir.noalias": True}) # with T.block("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(8), T.int64(32), T.int64(64)): with T.block("T_reshape"): diff --git a/tests/python/relax/test_tvmscript_parser_op_manipulate.py b/tests/python/relax/test_tvmscript_parser_op_manipulate.py index 694e7a688cf7..c0ff78ca4c6b 100644 --- a/tests/python/relax/test_tvmscript_parser_op_manipulate.py +++ b/tests/python/relax/test_tvmscript_parser_op_manipulate.py @@ -439,5 +439,20 @@ def foo(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): _check(foo, bb.get()["foo"]) +def test_hint_on_device_scoped(): + @R.function + def foo(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + r = R.hint_on_device(x, R.device(4, 2), "global.texture") + return r + + x = relax.Var("x", R.Tensor((), "int32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x,)): + tensor = bb.emit(relax.op.hint_on_device(x, R.opencl(2), "global.texture")) + bb.emit_func_output(tensor) + + _check(foo, bb.get()["foo"]) + + if __name__ == "__main__": tvm.testing.main() From 1d2c9af119646d18eec740ff43316dd64f93f5f5 Mon Sep 17 00:00:00 2001 From: Siva Date: Mon, 27 Jan 2025 12:27:31 +0530 Subject: [PATCH 04/31] VDevice ptr equality. --- .../transform/annotate_custom_storage.cc | 75 +++++++++++++------ src/script/printer/relax/utils.h | 2 +- .../test_transform_annotate_custom_scope.py | 4 +- 3 files changed, 56 insertions(+), 25 deletions(-) diff --git a/src/relax/transform/annotate_custom_storage.cc b/src/relax/transform/annotate_custom_storage.cc index 0abea4636dcc..70e7db5d91df 100644 --- a/src/relax/transform/annotate_custom_storage.cc +++ b/src/relax/transform/annotate_custom_storage.cc @@ -353,6 +353,45 @@ class DefineVDevice : ExprMutator { Array new_args; StructInfo updated_ret_sinfo = producer_sinfo_[GetRef(call_node)]; + if (updated_ret_sinfo->IsInstance()) { + auto tensor_sinfo = Downcast(updated_ret_sinfo); + auto shape = tensor_sinfo->shape.value(); + auto dtype = tensor_sinfo->dtype; + if (tensor_sinfo->vdevice.defined()) { + auto vdev = tensor_sinfo->vdevice.value(); + const VDevice& vdev_global = MakeGlobalVDevice(vdev); + updated_ret_sinfo = TensorStructInfo(shape, dtype, vdev_global); + } + } else { + ICHECK(updated_ret_sinfo->IsInstance()) + << "Expect output struct info of call_tir to be either TupleStructInfo or " + "TensorStructInfo, but got " + << updated_ret_sinfo; + + const auto& tuple_sinfo = Downcast(updated_ret_sinfo); + Array sinfo_fields; + for (const auto& si : tuple_sinfo->fields) { + ICHECK(si->IsInstance()) + << "Fields of TupleStructInfo must be TensorStructInfo for call_tir " + "output structinfo, but got " + << si; + auto sinfo = Downcast(si); + + auto shape_arr = GetShapeFromTensorStructInfo(sinfo); + + auto shape = sinfo->shape.value(); + auto dtype = sinfo->dtype; + if (sinfo->vdevice.defined()) { + auto vdev = sinfo->vdevice.value(); + const VDevice& vdev_global = MakeGlobalVDevice(vdev); + sinfo_fields.push_back(TensorStructInfo(shape, dtype, vdev_global)); + } else { + sinfo_fields.push_back(sinfo); + } + } + updated_ret_sinfo = TupleStructInfo(sinfo_fields); + } + int arg_idx = 0; for (auto arg : func_args->fields) { auto sinfo = GetStructInfo(arg); @@ -368,35 +407,29 @@ class DefineVDevice : ExprMutator { } } - if (call->op == call_tir_op) { - auto updated_call = - Call(call_tir_op, {gv, Tuple(new_args)}, call->attrs, {updated_ret_sinfo}); - return builder_->Normalize(updated_call); - } else { - auto updated_call = Call(call->op, new_args, call->attrs, {updated_ret_sinfo}); - return builder_->Normalize(updated_call); - } + auto updated_call = Call(call_tir_op, {gv, Tuple(new_args)}, call->attrs, {updated_ret_sinfo}); + return builder_->Normalize(updated_call); } private: - void AppendToVDevices(VDevice vdev) { + VDevice MakeGlobalVDevice(VDevice vdev) { int device_type = vdev->target->GetTargetDeviceType(); - for (auto vdevice : vdevices_) { - int dev_type = vdevice->target->GetTargetDeviceType(); - if (dev_type == device_type && vdevice->vdevice_id == vdev->vdevice_id && - vdevice->memory_scope == vdev->memory_scope) { - return; + for (size_t i = 0; i < vdevices_.size(); ++i) { + int dev_type = vdevices_[i]->target->GetTargetDeviceType(); + if (dev_type == device_type && vdevices_[i]->vdevice_id == vdev->vdevice_id && + vdevices_[i]->memory_scope == vdev->memory_scope) { + return vdevices_[i]; } } vdevices_.push_back(vdev); - return; + return (vdevices_.back()); } - Expr HintArg(const Expr& arg, const String& scope) { + Expr HintArg(const Expr& arg, String scope) { if (arg->IsInstance()) { if (auto tsinfo = arg->struct_info_.as()) { if (!tsinfo->vdevice.defined()) { - VDevice vdev = VDevice(target_, 0, scope); + const VDevice& vdev = MakeGlobalVDevice(VDevice(target_, 0, scope)); CHECK(tsinfo->shape.defined()) << "Shape not defined for a constant tensor ..!"; arg->struct_info_ = TensorStructInfo(tsinfo->shape.value(), tsinfo->dtype, vdev, tsinfo->span); @@ -405,13 +438,13 @@ class DefineVDevice : ExprMutator { } } ObjectPtr attrs = make_object(); - attrs->dev_type = target_->GetTargetDeviceType(); - attrs->dev_id = 0; - attrs->memory_scope = scope; + const VDevice& vdev = MakeGlobalVDevice(VDevice(target_, 0, scope)); + attrs->dev_type = vdev->target->GetTargetDeviceType(); + attrs->dev_id = vdev->vdevice_id; + attrs->memory_scope = vdev->memory_scope; Expr new_arg = Call(hint_on_device_op_, {arg}, Attrs{std::move(attrs)}, {}); - AppendToVDevices(VDevice(target_, 0, scope)); return std::move(new_arg); } diff --git a/src/script/printer/relax/utils.h b/src/script/printer/relax/utils.h index df8f6b2d470d..7dddfaecbbe7 100644 --- a/src/script/printer/relax/utils.h +++ b/src/script/printer/relax/utils.h @@ -142,7 +142,7 @@ inline int FindVDeviceIndexByTargetKind(const VDevice& vdevice, const IRDocsifie int kind_index = 0; for (size_t i = 0; i < vdevices.size(); ++i) { auto vdev = Downcast(vdevices[i]); - if (vdev == vdevice) { + if (vdev.same_as(vdevice)) { return kind_index; } if (vdev->target->kind->name == vdevice->target->kind->name) { diff --git a/tests/python/relax/test_transform_annotate_custom_scope.py b/tests/python/relax/test_transform_annotate_custom_scope.py index 1f67b1eaab49..9f9c5c0aa78f 100644 --- a/tests/python/relax/test_transform_annotate_custom_scope.py +++ b/tests/python/relax/test_transform_annotate_custom_scope.py @@ -18,7 +18,6 @@ import tvm from tvm import relax import tvm.testing -from tvm.relax.transform import ConvertLayout, Normalize from tvm.script.parser import ir as I, relax as R, tir as T from tvm.relax.transform.legalize_ops import adreno as legalize_adreno from tvm.ir.module import IRModule @@ -33,7 +32,6 @@ def __init__(self, scope_info: dict) -> None: def visit(self, mod: IRModule) -> None: """Entry point""" - print("Mod:", mod["main"]) for _, func in mod.functions_items(): if isinstance(func, relax.Function): self.visit_expr(func) @@ -90,7 +88,7 @@ def verify(mod, expected): mod = tvm.relax.transform.FuseTIR()(mod) mod = tvm.relax.transform.DeadCodeElimination()(mod) mod = tvm.relax.transform.AnnotateCustomMemoryScope(tgt)(mod) - mod = Normalize()(mod) + mod = tvm.relax.transform.Normalize()(mod) ValidateScope(expected).visit(mod) From 58285c9c1b8c6dc7b9ceafe6afc0fd39778732cf Mon Sep 17 00:00:00 2001 From: Siva Date: Mon, 27 Jan 2025 23:16:17 +0530 Subject: [PATCH 05/31] Rename pass - SpecializeTIRParams --- include/tvm/relax/transform.h | 10 +++++----- python/tvm/relax/transform/__init__.py | 2 +- python/tvm/relax/transform/transform.py | 10 ++++++---- src/relax/transform/annotate_custom_storage.cc | 4 ++-- ...ams.cc => specialize_primfunc_based_on_callsite.cc} | 7 ++++--- 5 files changed, 18 insertions(+), 15 deletions(-) rename src/relax/transform/{specialize_tir_params.cc => specialize_primfunc_based_on_callsite.cc} (96%) diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index 0f6a4f25caa5..ad4fb38b3a79 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -686,17 +686,17 @@ TVM_DLL Pass FewShotTuning(int valid_count, bool benchmark); * FuseOps and FuseTIR are modified to pass on the operator specific attributes and also * op_pattern details as part of the PrimFunc. This pass is Adreno specific and annotates each * BindingVar with appropriate HintInDevice. RealizeVDevice pass followed by handles these hints. - * Followed by this pass we also invoke SpecializeTIRParams which updates the var_buffer_map - * based on this new VDevice information. + * Followed by this pass we also invoke SpecializePrimFuncBasedOnCallSite which updates the + * var_buffer_map based on this new VDevice information. */ TVM_DLL Pass AnnotateCustomMemoryScope(Target target); /*! * \brief This pass updates the var_buffer mapping of PrimFunctions from the call_tir info. - * Primarily used to update the VDevice information is any changes occured from the caller. - * This pass recreated the buffers and updates the map. + * Primarily used to update the VDevice information if any changes occured from the caller. + * This pass recreates the buffers and updates the map. */ -TVM_DLL Pass SpecializeTIRParams(); +TVM_DLL Pass SpecializePrimFuncBasedOnCallSite(); } // namespace transform } // namespace relax diff --git a/python/tvm/relax/transform/__init__.py b/python/tvm/relax/transform/__init__.py index 5767fb235099..e3c652f553e6 100644 --- a/python/tvm/relax/transform/__init__.py +++ b/python/tvm/relax/transform/__init__.py @@ -84,7 +84,7 @@ VMBuiltinLower, VMShapeLower, AnnotateCustomMemoryScope, - SpecializeTIRParams, + SpecializePrimFuncBasedOnCallSite, dataflowblock_pass, function_pass, ) diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index fcb010a7db07..fa20cc196837 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -1617,7 +1617,7 @@ def AllocateWorkspace() -> tvm.ir.transform.Pass: def AnnotateCustomMemoryScope(target: Optional[Target] = None) -> tvm.ir.transform.Pass: """Allocate the memory scope information. This is Adreno specific pass to annotate The memory scope information and realize the same with RealizeVDevice pass followed by - updating the Prim Function var_buffer mapping using SpecializeTIRParams. + updating the Prim Function var_buffer mapping using SpecializePrimFuncBasedOnCallSite. Returns ------- @@ -1627,15 +1627,17 @@ def AnnotateCustomMemoryScope(target: Optional[Target] = None) -> tvm.ir.transfo return _ffi_api.AnnotateCustomMemoryScope(target) # type: ignore -def SpecializeTIRParams() -> tvm.ir.transform.Pass: - """Map modified tir_call params to prim_func buffers. +def SpecializePrimFuncBasedOnCallSite() -> tvm.ir.transform.Pass: + """This pass updates the var_buffer mapping of PrimFunctions from the call_tir info. + Primarily used to update the VDevice information if any changes occured from the caller. + This pass recreates the buffers and updates the map. Returns ------- ret: tvm.ir.transform.Pass The registered pass for allocating workspace. """ - return _ffi_api.SpecializeTIRParams() # type: ignore + return _ffi_api.SpecializePrimFuncBasedOnCallSite() # type: ignore def _wrap_class_function_pass(pass_cls, pass_info): diff --git a/src/relax/transform/annotate_custom_storage.cc b/src/relax/transform/annotate_custom_storage.cc index 70e7db5d91df..9c00edb0a9ab 100644 --- a/src/relax/transform/annotate_custom_storage.cc +++ b/src/relax/transform/annotate_custom_storage.cc @@ -322,10 +322,10 @@ class DefineVDevice : ExprMutator { } mod_.CopyOnWrite()->global_infos.Set("vdevice", global_vdevices_); - mod_ = relax::transform::SpecializeTIRParams()(mod_); + mod_ = relax::transform::SpecializePrimFuncBasedOnCallSite()(mod_); mod_ = relax::transform::DeadCodeElimination()(mod_); mod_ = relax::transform::RealizeVDevice()(mod_); - mod_ = relax::transform::SpecializeTIRParams()(mod_); + mod_ = relax::transform::SpecializePrimFuncBasedOnCallSite()(mod_); return mod_; } diff --git a/src/relax/transform/specialize_tir_params.cc b/src/relax/transform/specialize_primfunc_based_on_callsite.cc similarity index 96% rename from src/relax/transform/specialize_tir_params.cc rename to src/relax/transform/specialize_primfunc_based_on_callsite.cc index 7df1c1579576..fe2cc9329860 100644 --- a/src/relax/transform/specialize_tir_params.cc +++ b/src/relax/transform/specialize_primfunc_based_on_callsite.cc @@ -154,16 +154,17 @@ class SpecializeTIRCallArgs : ExprMutator { namespace transform { -Pass SpecializeTIRParams() { +Pass SpecializePrimFuncBasedOnCallSite() { runtime::TypedPackedFunc pass_func = [=](IRModule mod, PassContext pc) { return relax::SpecializeTIRCallArgs().Run(mod); }; return CreateModulePass(/*pass_function=*/pass_func, /*opt_level=*/0, - /*pass_name=*/"SpecializeTIRParams", + /*pass_name=*/"SpecializePrimFuncBasedOnCallSite", /*required=*/{}); } -TVM_REGISTER_GLOBAL("relax.transform.SpecializeTIRParams").set_body_typed(SpecializeTIRParams); +TVM_REGISTER_GLOBAL("relax.transform.SpecializePrimFuncBasedOnCallSite") + .set_body_typed(SpecializePrimFuncBasedOnCallSite); } // namespace transform } // namespace relax From 35adccab8dd83b280b90089b178676cf9b08cafe Mon Sep 17 00:00:00 2001 From: Siva Date: Tue, 28 Jan 2025 10:08:41 +0530 Subject: [PATCH 06/31] Remote OptimizeBatchnorm pass. Redundant with DecomposeOpsForInference. --- .../tvm/relax/transform/optimize_batchnorm.py | 103 ------------------ .../test_transform_annotate_custom_scope.py | 2 - 2 files changed, 105 deletions(-) delete mode 100644 python/tvm/relax/transform/optimize_batchnorm.py diff --git a/python/tvm/relax/transform/optimize_batchnorm.py b/python/tvm/relax/transform/optimize_batchnorm.py deleted file mode 100644 index 9eee18c0d032..000000000000 --- a/python/tvm/relax/transform/optimize_batchnorm.py +++ /dev/null @@ -1,103 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name, unused-argument, redefined-argument-from-local -"""Relax Optimize Batchnorm to fold it into previous Conv pass.""" -from tvm.ir.module import IRModule -from tvm.ir.transform import PassContext -from tvm.relax import Expr -from tvm.relax.dpl import is_op, rewrite_call, wildcard, is_const, TupleGetItemPattern -from tvm import relax, tir - -from . import function_pass - - -@function_pass(opt_level=0) -class OptimizeBatchnorm: - """ - Fuse Batchnorm to its previous Conv2D - """ - - def __init__(self): - self.input = wildcard() - self.weight = is_const() - self.pattern_conv2d = is_op("relax.nn.conv2d")(self.input, self.weight) - self.bn_weight = is_const() - self.bias = is_const() - self.mean = is_const() - self.variance = is_const() - self.pattern_bn = is_op("relax.nn.batch_norm")( - self.pattern_conv2d, self.bn_weight, self.bias, self.mean, self.variance - ) - - self.pattern = TupleGetItemPattern(self.pattern_bn, 0) - - def transform_function(self, func: Expr, mod: IRModule, ctx: PassContext) -> IRModule: - """ - Tranformation function to pattern Conv2D+BatchNorm+TupleGetItem pattern - - Parameters - ---------- - func: Expr - The relax function to be optimized - - mod: IRModule - The ir module - - ctx: PassContext - Relax pass context - """ - - self.mod = mod - updated_call = func - - # Skip primitive functions - if "Primitive" in func.attrs.keys() and func.attrs["Primitive"] != 0: - return updated_call - - def rewriter(expr, matches): - conv_input = matches[self.input] - conv_weight = matches[self.weight] - bn_weight = matches[self.bn_weight] - bn_bias = matches[self.bias] - bn_mean = matches[self.mean] - bn_variance = matches[self.variance] - conv_op = matches[self.pattern_conv2d] - bn_op = matches[self.pattern_bn] - conv_attrs = conv_op.attrs - bn_attrs = bn_op.attrs - - bn_variance = relax.op.add( - bn_variance, relax.PrimValue(tir.FloatImm("float32", bn_attrs["epsilon"])) - ) - dino = relax.op.sqrt(bn_variance) - wt = relax.op.divide(bn_weight, dino) - bs = relax.op.subtract(bn_bias, relax.op.multiply(bn_mean, wt)) - if conv_attrs["kernel_layout"] == "OIHW": - wt = relax.op.reshape(wt, shape=(bn_weight.struct_info.shape[0], 1, 1, 1)) - elif conv_attrs["kernel_layout"] == "IOHW": - wt = wt.reshape(1, bn_weight.struct_info.shape[0], 1, 1) - else: - return expr - wt_conv = relax.op.multiply(conv_weight, wt) - bs_args = relax.op.reshape(bs, shape=(1, bn_bias.struct_info.shape[0], 1, 1)) - - conv_out = relax.Call(conv_op.op, (conv_input, wt_conv), conv_attrs) - return relax.op.add(conv_out, bs_args) - - updated_call = rewrite_call(self.pattern, rewriter, func) - - return updated_call diff --git a/tests/python/relax/test_transform_annotate_custom_scope.py b/tests/python/relax/test_transform_annotate_custom_scope.py index 9f9c5c0aa78f..303966c88f7b 100644 --- a/tests/python/relax/test_transform_annotate_custom_scope.py +++ b/tests/python/relax/test_transform_annotate_custom_scope.py @@ -69,8 +69,6 @@ def verify(mod, expected): tgt = tvm.target.Target("opencl --device=adreno", host="llvm") with tgt: mod = tvm.tir.transform.BindTarget(tvm.target.Target.current(allow_none=False))(mod) - mod = tvm.relax.transform.OptimizeBatchnorm()(mod) - mod = tvm.relax.transform.FoldConstant()(mod) mod = tvm.relax.transform.DecomposeOpsForInference()(mod) mod = tvm.relax.transform.FoldConstant()(mod) desired_layouts = {"relax.nn.conv2d": ["NCHW4c", "OIHW4o", "NCHW4c"]} From e1bd1638e8b5372d65c6c178dae4fd33903c9eab Mon Sep 17 00:00:00 2001 From: Siva Date: Fri, 7 Feb 2025 11:21:49 +0530 Subject: [PATCH 07/31] review --- include/tvm/ir/global_info.h | 9 --------- src/relax/transform/utils.h | 2 +- 2 files changed, 1 insertion(+), 10 deletions(-) diff --git a/include/tvm/ir/global_info.h b/include/tvm/ir/global_info.h index b646057a009f..892bba4da694 100644 --- a/include/tvm/ir/global_info.h +++ b/include/tvm/ir/global_info.h @@ -87,15 +87,6 @@ class VDeviceNode : public GlobalInfoNode { class VDevice : public GlobalInfo { public: TVM_DLL explicit VDevice(Target tgt, int dev_id, MemoryScope mem_scope); - /*! - * \brief Equal comparator. - * \param other The data type to compare against. - * \return The comparison result. - */ - bool operator==(const VDevice& other) const { - return this->get()->target == other->target && this->get()->vdevice_id == other->vdevice_id && - this->get()->memory_scope == other->memory_scope; - } TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(VDevice, GlobalInfo, VDeviceNode); }; diff --git a/src/relax/transform/utils.h b/src/relax/transform/utils.h index 7ecbdcc5af5b..07f485ee4813 100644 --- a/src/relax/transform/utils.h +++ b/src/relax/transform/utils.h @@ -386,7 +386,7 @@ inline ffi::String GetCodegenName(const std::string& composite_name) { inline int GetDeviceIndex(const IRModule& mod, const VDevice& vdevice) { ffi::Array vdevices = mod->global_infos["vdevice"]; for (int i = 0; i < static_cast(vdevices.size()); ++i) { - if (vdevices[i].as() == vdevice) { + if (vdevices[i].same_as(vdevice)) { return i; } } From 1d0aec48297e977323f765336560eeb81aeb7ff5 Mon Sep 17 00:00:00 2001 From: Siva Date: Tue, 25 Feb 2025 19:46:41 +0530 Subject: [PATCH 08/31] About to go for InferStructInfo query to set the scope --- include/tvm/relax/transform.h | 9 +- python/tvm/relax/transform/__init__.py | 2 + python/tvm/relax/transform/transform.py | 22 +- src/relax/op/op_common.h | 5 + .../transform/annotate_custom_storage.cc | 219 ++++++++++++++++-- src/relax/transform/fuse_tir.cc | 6 +- src/relax/transform/legalize_ops.cc | 56 +++-- .../test_transform_annotate_custom_scope.py | 151 ++++++------ 8 files changed, 346 insertions(+), 124 deletions(-) diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index ad4fb38b3a79..9c412245e223 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -249,8 +249,8 @@ TVM_DLL Pass FoldConstant(); * \param add_attributes A boolean value indicating adding of call attributes to TIR functions * \return The Pass. */ -TVM_DLL Pass LegalizeOps(ffi::Optional> cmap, bool enable_warning = false - bool add_attributes = false); +TVM_DLL Pass LegalizeOps(ffi::Optional> cmap, ffi::Optional> skip_ops, + bool enable_warning = false, bool add_attributes = false); /*! * \brief Propagate virtual device information. @@ -698,6 +698,11 @@ TVM_DLL Pass AnnotateCustomMemoryScope(Target target); */ TVM_DLL Pass SpecializePrimFuncBasedOnCallSite(); + +TVM_DLL Pass RemoveRedundantAssignments(); + +TVM_DLL Pass RemoveToDeviceForScopeChange(); + } // namespace transform } // namespace relax } // namespace tvm diff --git a/python/tvm/relax/transform/__init__.py b/python/tvm/relax/transform/__init__.py index e3c652f553e6..559d4a9380e1 100644 --- a/python/tvm/relax/transform/__init__.py +++ b/python/tvm/relax/transform/__init__.py @@ -85,6 +85,7 @@ VMShapeLower, AnnotateCustomMemoryScope, SpecializePrimFuncBasedOnCallSite, + RemoveToDeviceForScopeChange, dataflowblock_pass, function_pass, ) @@ -98,6 +99,7 @@ from .optimize_layout_transform import OptimizeLayoutTransform from .fold_batch_norm_to_conv2d_for_inference import FoldBatchnormToConv2D from .remove_redundant_reshape import RemoveRedundantReshape +#from .remove_to_device_for_scope_change import RemoveToDeviceForScopeChange # Import to register the legalization functions. from . import legalize_ops diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index fa20cc196837..aad6372b1658 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -1064,8 +1064,8 @@ def BundleModelParams(param_tuple_name: Optional[str] = None) -> tvm.ir.transfor def LegalizeOps( customize_legalize_map: Optional[Dict[str, LegalizeFunc]] = None, + skip_ops: Optional[List[str]] = None, enable_warning: bool = False, - add_attributes: bool = False, ): """Legalize high-level operator calls in Relax functions to call_tir with corresponding low-level TIR PrimFuncs. @@ -1091,15 +1091,14 @@ def LegalizeOps( The customized operator legalization function map. The customized function will override the default one. + skip_ops : Optional,List[str]] + List of ops that need to be skipped from legalization + enable_warning : bool A boolean value indicating if to print warnings for CallNode whose op's legalization function is not registered. By default we don't print warnings. - add_attributes : bool - A boolean value indicating if we want legalize ops to add operator attributes to legalized - prim function attributes. By default it's false. - Returns ------- ret : tvm.transform.Pass @@ -1175,7 +1174,7 @@ def multiply( """ return _ffi_api.LegalizeOps( - customize_legalize_map, enable_warning, add_attributes # type: ignore + customize_legalize_map, skip_ops, enable_warning # type: ignore ) @@ -1640,6 +1639,17 @@ def SpecializePrimFuncBasedOnCallSite() -> tvm.ir.transform.Pass: return _ffi_api.SpecializePrimFuncBasedOnCallSite() # type: ignore +def RemoveToDeviceForScopeChange() -> tvm.ir.transform.Pass: + """This pass + + Returns + ------- + ret: tvm.ir.transform.Pass + The registered pass for allocating workspace. + """ + return _ffi_api.RemoveToDeviceForScopeChange() # type: ignore + + def _wrap_class_function_pass(pass_cls, pass_info): """Wrap a python class as function pass.""" diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h index 0d4d594222e2..0e0c4254e34b 100644 --- a/src/relax/op/op_common.h +++ b/src/relax/op/op_common.h @@ -351,6 +351,10 @@ inline ffi::Optional InferBinaryArithOpOutVDevice(const Call& call, } }; + if (call->sinfo_args.size() > 0) { + return get_vdevice(call->sinfo_args[0]); + } + auto lhs_vdevice = get_vdevice(lhs_sinfo); auto rhs_vdevice = get_vdevice(rhs_sinfo); @@ -360,6 +364,7 @@ inline ffi::Optional InferBinaryArithOpOutVDevice(const Call& call, if (!rhs_vdevice.defined() || !rhs_vdevice.value()->target.defined()) { return lhs_vdevice; } + if (lhs_vdevice.value() != rhs_vdevice.value()) { ctx->ReportFatal(Diagnostic::Error(call) << "TypeErorr: " diff --git a/src/relax/transform/annotate_custom_storage.cc b/src/relax/transform/annotate_custom_storage.cc index 9c00edb0a9ab..80a8d36ae6c3 100644 --- a/src/relax/transform/annotate_custom_storage.cc +++ b/src/relax/transform/annotate_custom_storage.cc @@ -27,6 +27,7 @@ #include #include #include +#include #include #include @@ -46,16 +47,143 @@ static Array GetShapeFromTensorStructInfo(const TensorStructInfo& tens return shape.value(); } +namespace { +std::tuple)>> CreatePatterns(Map>> scope_info) { + auto pat_gv = WildcardPattern(); + + auto pat_inp = WildcardPattern(); + auto pat_call_tir = IsOp("relax.call_tir")(pat_gv, pat_inp); + auto pattern_out = IsOp("relax.to_vdevice")(pat_call_tir); + + auto rewriter = [=](Expr expr, Map matches) -> Expr { + const auto* call_tir = matches[pat_call_tir].as(); + ICHECK(call_tir) << "InternalError: " + << "Match of relax.call_tir operator should produce Call, " + << "but instead produces " << matches[pat_call_tir] << " with type " + << matches[pat_call_tir]->GetTypeKey(); + + const auto* out = matches[pattern_out].as(); + ICHECK(out) << "InternalError: " + << "Match of relax.to_vdevice operator should produce Call, " + << "but instead produces " << matches[pattern_out] << " with type " + << matches[pattern_out]->GetTypeKey(); + + const auto* vdev_attrs = out->attrs.as(); + ICHECK(vdev_attrs) << "InternalError: " + << "Attributes for relax.to_vdevice operator should be ToVDeviceAttrs, " + << "but were instead " << out->attrs << " with type " + << out->GetTypeKey(); + + const auto* tir_out_sinfo = call_tir->sinfo_args[0].as(); + if (!tir_out_sinfo) return expr; + + if (!tir_out_sinfo->vdevice.defined()) return expr; + + const VarNode* arg_var = out->args[0].as(); + if (scope_info.find(GetRef(arg_var)) != scope_info.end()) { + if (scope_info[GetRef(arg_var)].size() > 1) { + /* Don't do to_device optimization as we are not the only consumer */ + return expr; + } + } + + if ((std::string(tir_out_sinfo->vdevice.value()->memory_scope).find("texture") != std::string::npos) && + (vdev_attrs->dst_vdevice->memory_scope == "global")) { + LOG(WARNING) << "Can be optimized"; + auto shape_arr = tir_out_sinfo->GetShape().value(); + auto new_sinfo = TensorStructInfo(ShapeExpr(shape_arr), tir_out_sinfo->dtype, + vdev_attrs->dst_vdevice); + + return Call(call_tir->op, call_tir->args, call_tir->attrs, {new_sinfo}); + } + return expr; + }; + + return {pattern_out, rewriter}; +} + +} // namespace + +class RemoveRedundantAssignments : public ExprMutator { + public: + using ExprMutator::VisitExpr_; + + IRModule Run(IRModule& mod) { + mod_ = mod; + for (const auto& [gv, func] : mod_->functions) { + if (func->IsInstance()) { + const auto& base_func = mod_->Lookup(gv); + // Only non primitive relax functions + if (base_func->HasNonzeroAttr(attr::kPrimitive)) { + continue; + } + relax::Function update_func = Downcast(VisitExpr(func)); + updates_->Add(gv, update_func); + } + } + mod_.CopyOnWrite()->Update(updates_); + return mod_; + } + + void VisitBinding_(const VarBindingNode* binding, const VarNode* var) final { + LOG(WARNING) << "VisitBinding_:" << binding->var.get()->name_hint() << " : " << var->name_hint(); + redundant_map.Set(GetRef(binding->var.get()), GetRef(var)); + } + + void VisitBinding_(const VarBindingNode* binding, const DataflowVarNode* val) override { + LOG(WARNING) << "VisitBinding_ - DataFlow:" << binding->var.get()->name_hint() << " : " << static_cast(val)->name_hint(); + redundant_map.Set(GetRef(binding->var.get()), GetRef(val)); + } + + Expr VisitExpr_(const CallNode* call_node) final { + auto call = Downcast(ExprMutator::VisitExpr_(call_node)); + static const Op& call_tir_op = Op::Get("relax.call_tir"); + LOG(WARNING) << "VisitExpr_:" << call->op << " Args:" << call->args; + + Tuple args; + + if (call->op == call_tir_op) { + args = Downcast(call->args[1]); + } else { + args = Tuple(call->args); + } + Array new_args; + + for (auto& arg : args->fields) { + if (redundant_map.find(arg) != redundant_map.end()) { + new_args.push_back(redundant_map[arg]); + } else { + new_args.push_back(arg); + } + } + if (call->op == call_tir_op) { + return Call(call_tir_op, {call->args[0], Tuple(new_args)}, call->attrs, {call->sinfo_args[0]}); + } else { + if (call->sinfo_args.size() > 0) { + return Call(call->op, new_args, call->attrs, {call->sinfo_args[0]}); + } else { + return Call(call->op, new_args, call->attrs); + } + } + } + +private: + Map redundant_map; + IRModule updates_; + IRModule mod_; +}; + class CollectProduserScopeInfo : public ExprVisitor { public: using ExprVisitor::VisitExpr_; Map Collect(const IRModule& mod, Function func, const Map>>& scope_info, - const Target& target) { + const Target& target, const BlockBuilder& builder) { mod_ = mod; scope_info_ = scope_info; target_ = target; + builder_ = builder; VisitExpr(func->body); return producer_sinfo; @@ -70,11 +198,20 @@ class CollectProduserScopeInfo : public ExprVisitor { if (call->op == call_tir_op) { out_sinfo = call->sinfo_args[0]; } else { - return; + tvm::OpAttrMap op_map_infer_struct_info_ = + Op::GetAttrMap("FInferStructInfo"); + + auto* op_ptr = call->op.as(); + Op op = GetRef(op_ptr); + ICHECK(op_map_infer_struct_info_.count(op)) + << " Cannot find the FInferStructInfo attribute registered to op: " << op->name; + out_sinfo = op_map_infer_struct_info_[op](GetRef(call), builder_); + LOG(WARNING) << "Got Struct Info for normal op:" << out_sinfo.as(); } std::unordered_map scope_count; + // Decide the final scope based on the max consumer demand. Rest will use to_device. auto arg_var = binding->var.as(); if (scope_info_.find(GetRef(arg_var)) != scope_info_.end()) { for (const auto& val : scope_info_[GetRef(arg_var)]) { @@ -95,6 +232,7 @@ class CollectProduserScopeInfo : public ExprVisitor { count = sval.second; } } + // Applying same scope for outputs StructInfo updated_ret_sinfo = UpdateStructInfo(out_sinfo, {final_scope}); producer_sinfo.Set(GetRef(call), updated_ret_sinfo); } @@ -132,6 +270,7 @@ class CollectProduserScopeInfo : public ExprVisitor { Map producer_sinfo; IRModule mod_; Target target_; + BlockBuilder builder_; }; class CollectConsumerScopeInfo : public ExprVisitor { @@ -143,6 +282,8 @@ class CollectConsumerScopeInfo : public ExprVisitor { mod_ = mod; target_ = target; VisitExpr(func->body); + LOG(WARNING) << "Visit completed"; + // Extend the scope for tuple items for (const auto& val : arg_to_binding) { if (scope_info.find(val.first) != scope_info.end()) { if (scope_info.find(val.second) == scope_info.end()) { @@ -175,28 +316,34 @@ class CollectConsumerScopeInfo : public ExprVisitor { Optional op_pattern = Integer(static_cast(relay::kOpaque)); Tuple func_args; - StructInfo out_sinfo; - if (call->op == call_tir_op) { gv = Downcast(call->args[0]); tir::PrimFunc pfunc = Downcast(mod_->Lookup(gv)); op_attrs = ExtractAttrs(pfunc); op_pattern = ExtractPattern(pfunc); - out_sinfo = call->sinfo_args[0]; func_args = Downcast(call->args[1]); } else { - return; + LOG(WARNING) << "About to access attrs"; + op_attrs = {call->attrs}; + op_pattern = Integer(static_cast(relay::kOpaque)); + func_args = Tuple(call->args); } + LOG(WARNING) << "About to call texture scope"; bool is_texture_supported = SupportsTexture(op_attrs, op_pattern.value()); + LOG(WARNING) << "returned texture scope"; Array arg_scope; for (auto arg : func_args->fields) { + LOG(WARNING) << "B1"; auto sinfo = GetStructInfo(arg); + LOG(WARNING) << "B2"; if (auto tensor_sinfo = sinfo.as()) { + LOG(WARNING) << "B3:" << tensor_sinfo; auto scope = is_texture_supported ? Scope(GetShapeFromTensorStructInfo(tensor_sinfo.value())) : "global"; + LOG(WARNING) << "B4"; Map> ent_call; const VarNode* arg_var = arg.as(); if (scope_info.find(GetRef(arg_var)) != scope_info.end()) { @@ -205,9 +352,11 @@ class CollectConsumerScopeInfo : public ExprVisitor { ent_call.Set(GetRef(call), {scope}); scope_info.Set(GetRef(arg_var), ent_call); arg_scope.push_back(scope); + LOG(WARNING) << "B5"; } } call_scope_info.Set(GetRef(call), arg_scope); + LOG(WARNING) << "B6"; } private: @@ -260,6 +409,12 @@ class CollectConsumerScopeInfo : public ExprVisitor { // 5d requirement is not limitation of textures in general, it is limitation how // we are representing memory scopes/layout and flattening of textures in tir if (shape.size() == 5 && shape[4].as()->value == 4) { + for (auto ind: shape) { + if (! ind.as()) { + // Dynamic tensors + return "global.texture-nchw"; + } + } std::map diffs; int spatial_limit = target_->GetAttr("texture_spatial_limit").value_or(Integer(16384))->value; @@ -285,7 +440,9 @@ class CollectConsumerScopeInfo : public ExprVisitor { return "global"; } + /* Map of each Var consumption by a call node and its scope */ Map>> scope_info; + /* A map of call node and cope infor for each argument it consunes */ Map> call_scope_info; Map arg_to_binding; IRModule mod_; @@ -305,11 +462,13 @@ class DefineVDevice : ExprMutator { if (base_func->HasNonzeroAttr(attr::kPrimitive)) { continue; } + LOG(WARNING) << "Ccall CollectConsumerScopeInfo"; auto info = CollectConsumerScopeInfo().Collect(mod_, Downcast(func), target_); call_scope_info_ = info.first; scope_info_ = info.second; + LOG(WARNING) << "Ccall CollectProducerScopeInfo"; producer_sinfo_ = CollectProduserScopeInfo().Collect(mod_, Downcast(func), - scope_info_, target_); + scope_info_, target_, builder_); relax::Function update_func = Downcast(VisitExpr(func)); updates_->Add(gv, update_func); } @@ -322,10 +481,13 @@ class DefineVDevice : ExprMutator { } mod_.CopyOnWrite()->global_infos.Set("vdevice", global_vdevices_); - mod_ = relax::transform::SpecializePrimFuncBasedOnCallSite()(mod_); + LOG(WARNING) << "MOd After scope:" << mod_; mod_ = relax::transform::DeadCodeElimination()(mod_); mod_ = relax::transform::RealizeVDevice()(mod_); - mod_ = relax::transform::SpecializePrimFuncBasedOnCallSite()(mod_); + LOG(WARNING) << "Realized:" << mod_; + mod_ = relax::transform::RemoveRedundantAssignments()(mod_); + LOG(WARNING) << "Redundant Assin:" << mod_; + //mod_ = relax::transform::SpecializePrimFuncBasedOnCallSite()(mod_); return mod_; } @@ -343,11 +505,12 @@ class DefineVDevice : ExprMutator { if (call->op == call_tir_op) { gv = Downcast(call->args[0]); - tir::PrimFunc pfunc = Downcast(mod_->Lookup(gv)); - out_sinfo = call->sinfo_args[0]; + //tir::PrimFunc pfunc = Downcast(mod_->Lookup(gv)); + //out_sinfo = call->sinfo_args[0]; func_args = Downcast(call->args[1]); } else { - return call; + func_args = Tuple(call->args); + //return call; } Array new_args; @@ -407,8 +570,11 @@ class DefineVDevice : ExprMutator { } } - auto updated_call = Call(call_tir_op, {gv, Tuple(new_args)}, call->attrs, {updated_ret_sinfo}); - return builder_->Normalize(updated_call); + if (call->op == call_tir_op) { + return builder_->Normalize(Call(call_tir_op, {gv, Tuple(new_args)}, call->attrs, {updated_ret_sinfo})); + } else { + return builder_->Normalize(Call(call->op, new_args, call->attrs, {updated_ret_sinfo})); + } } private: @@ -471,6 +637,31 @@ class DefineVDevice : ExprMutator { namespace transform { +Pass RemoveToDeviceForScopeChange() { + auto pass_func = [=](Function func, IRModule mod, PassContext pc) { + /* here Target doesn't matter as the scope_info we use only to find multiple consumers */ + auto info = CollectConsumerScopeInfo().Collect(mod, Downcast(func), Target("opencl")); + auto scope_info = info.second; + auto [pattern, rewriter] = CreatePatterns(scope_info); + return RewriteCall(pattern, rewriter, func); + }; + return CreateFunctionPass(pass_func, 1, "RemoveToDeviceForScopeChange", {}); +} +TVM_REGISTER_GLOBAL("relax.transform.RemoveToDeviceForScopeChange") + .set_body_typed(RemoveToDeviceForScopeChange); + +Pass RemoveRedundantAssignments() { + runtime::TypedPackedFunc pass_func = + [=](IRModule mod, PassContext pc) { return relax::RemoveRedundantAssignments().Run(mod); }; + return CreateModulePass(/*pass_function=*/pass_func, + /*opt_level=*/0, + /*pass_name=*/"RemoveRedundantAssignments", + /*required=*/{}); +} +TVM_REGISTER_GLOBAL("relax.transform.RemoveRedundantAssignments") + .set_body_typed(RemoveRedundantAssignments); + + Pass AnnotateCustomMemoryScope(Target target) { runtime::TypedPackedFunc pass_func = [=](IRModule mod, PassContext pc) { return relax::DefineVDevice(target).Run(mod); }; diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index 63671d077d60..cd960b13f7e9 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -405,6 +405,7 @@ class RelaxToTIRVarMapCollector : public ExprVisitor { private: void VisitBinding_(const VarBindingNode* binding) final { + LOG(WARNING) << "Var binding:" << binding->var; current_var_ = binding->var; ExprVisitor::VisitBinding_(binding); } @@ -420,6 +421,7 @@ class RelaxToTIRVarMapCollector : public ExprVisitor { } void CollectVarMapping(const CallNode* call, const Expr& lhs_var, bool in_place) { + LOG(WARNING) << "CollectVarMapping:" << call->args[0]; GlobalVar gv = Downcast(call->args[0]); tir::PrimFunc prim_func_ = Downcast(mod_->Lookup(gv)); const auto& buffer_map = prim_func_->buffer_map; @@ -454,8 +456,8 @@ class RelaxToTIRVarMapCollector : public ExprVisitor { auto ValidateBufferCompatibility = [this](tir::Buffer new_buf, Expr expr) { if (auto it = relax_to_tir_var_map_.find(expr); it != relax_to_tir_var_map_.end()) { ICHECK(StructuralEqual()((*it).second, new_buf)) - << "Inconsistent buffers " << (*it).second << " and " << new_buf - << " mapped to the same relax var: " << expr; + << "Inconsistent buffers " << (*it).second << " and " << new_buf.scope() + << " mapped to the same relax var: " << (*it).second.scope(); } }; for (size_t i = 0; i < tir_args.size(); ++i) { diff --git a/src/relax/transform/legalize_ops.cc b/src/relax/transform/legalize_ops.cc index 5986228ad06b..05265a9aaa72 100644 --- a/src/relax/transform/legalize_ops.cc +++ b/src/relax/transform/legalize_ops.cc @@ -60,13 +60,12 @@ bool KnowAllShapeValues(const StructInfo& sinfo) { class LegalizeMutator : public ExprMutator { public: - explicit LegalizeMutator(const IRModule& mod, - const ffi::Optional>& cmap, - bool enable_warning, bool add_attributes) + explicit LegalizeMutator(const IRModule& mod, const ffi::Optional>& cmap, + const ffi::Optional> skip_ops, bool enable_warning) : ExprMutator(mod), mod_(std::move(mod)), enable_warning_(enable_warning), - add_attributes_(add_attributes) { + skip_ops_(skip_ops) { if (cmap) { cmap_ = cmap.value(); } @@ -157,30 +156,32 @@ class LegalizeMutator : public ExprMutator { return std::nullopt; } - Expr AttributeOpAttrs(Expr expr, Attrs attrs) { - if (!expr->IsInstance()) { + Expr UpdateOutStructInfo(Expr expr, Call& visited_call) { + static const auto& infer_struct_info_map = Op::GetAttrMap("FInferStructInfo"); + static const Op& call_tir_op = Op::Get("relax.call_tir"); + auto* op_node = visited_call->op.as(); + + // Not an OpNode + if (op_node == nullptr) { return expr; } + auto op = GetRef(op_node); - auto call = Downcast(expr); - if (call->args.empty()) { + if (!infer_struct_info_map.count(op)) { return expr; } - auto gvar = call->args[0].as(); - if (!gvar.defined()) { + if (!expr->IsInstance()) { return expr; } - auto base_func = builder_->GetContextIRModule()->Lookup(gvar.value()); - auto opt_prim_func = base_func.as(); - if (!opt_prim_func) { + auto call = Downcast(expr); + if (call->op != call_tir_op) { return expr; } - auto prim_func = opt_prim_func.value(); - auto new_prim_func = WithAttr(prim_func, "op_attrs", attrs); - builder_->UpdateFunction(gvar.value(), new_prim_func); - return call; + + StructInfo updated_ret_sinfo = infer_struct_info_map[op](visited_call, builder_); + return Call(call_tir_op, call->args, call->attrs, {updated_ret_sinfo}); } Expr BindTarget(Expr expr) { @@ -268,6 +269,14 @@ class LegalizeMutator : public ExprMutator { } auto op = ffi::GetRef(op_node); + if (skip_ops_.defined()) { + for (const auto name: skip_ops_.value()) { + if (name == op->name) { + return visited_call; + } + } + } + bool shapes_are_known_if_required = [&]() -> bool { bool requires_arg_shapes = requires_arg_shapes_map.get(op, Bool(true))->value; if (!requires_arg_shapes) { @@ -373,9 +382,7 @@ class LegalizeMutator : public ExprMutator { } Expr legalized = legalization_func(builder_, visited_call); - if (call->attrs.as() && add_attributes_) { - legalized = AttributeOpAttrs(legalized, call->attrs); - } + legalized = UpdateOutStructInfo(legalized, visited_call); // Append the target attribute to any PrimFunc generated in // legalization. @@ -421,19 +428,20 @@ class LegalizeMutator : public ExprMutator { */ bool enable_warning_; /*! - * \brief Boolean indicating this pass to add operator attributes to prim function attr + * \brief List of ops to be skipped from legalization */ - bool add_attributes_; + Optional> skip_ops_; }; namespace transform { -Pass LegalizeOps(ffi::Optional> cmap, bool enable_warning, bool add_attributes) { +Pass LegalizeOps(ffi::Optional> cmap, ffi::Optional> skip_ops, + bool enable_warning) { auto pass_func = [=](IRModule mod, PassContext pc) { bool apply_legalize_ops = pc->GetConfig("relax.transform.apply_legalize_ops").value_or(Bool(true))->value; if (apply_legalize_ops) { - mod = LegalizeMutator(mod, cmap, enable_warning, add_attributes).Transform(); + mod = LegalizeMutator(mod, cmap, skip_ops, enable_warning).Transform(); } return mod; }; diff --git a/tests/python/relax/test_transform_annotate_custom_scope.py b/tests/python/relax/test_transform_annotate_custom_scope.py index 303966c88f7b..b06578e2a0c8 100644 --- a/tests/python/relax/test_transform_annotate_custom_scope.py +++ b/tests/python/relax/test_transform_annotate_custom_scope.py @@ -48,12 +48,12 @@ def visit_call_(self, call: relax.Call) -> None: # pylint: disable=arguments-re assert ( arg_sinfo.vdevice.memory_scope == self.scope_info[call.args[0].name_hint][0][idx] - ), f"Scope mispatched for argument {idx} in {call.args[0].name_hint}" + ), f"Scope mismatched for argument {idx} in {call.args[0].name_hint}" if isinstance(call.sinfo_args[0], relax.TensorStructInfo): assert ( call.sinfo_args[0].vdevice.memory_scope == self.scope_info[call.args[0].name_hint][1][0] - ), f"Scope mispatched for return scope: {call.args[0].name_hint}" + ), f"Scope mismatched for return scope: {call.args[0].name_hint}" else: assert isinstance( call.sinfo_args[0], relax.TupleStructInfo @@ -62,11 +62,17 @@ def visit_call_(self, call: relax.Call) -> None: # pylint: disable=arguments-re assert ( sinfo.vdevice.memory_scope == self.scope_info[call.args[0].name_hint][1][idx] - ), f"Scope mispatched for return scope for {idx} in {call.args[0].name_hint}" + ), f"Scope mismatched for return scope for {idx} in {call.args[0].name_hint}" def verify(mod, expected): tgt = tvm.target.Target("opencl --device=adreno", host="llvm") + skip_ops = [ + "relax.nn.conv2d", + "relax.nn.max_pool2d", + "relax.nn.adaptive_avg_pool2d", + #"relax.nn.layer_norm", + ] with tgt: mod = tvm.tir.transform.BindTarget(tvm.target.Target.current(allow_none=False))(mod) mod = tvm.relax.transform.DecomposeOpsForInference()(mod) @@ -75,17 +81,21 @@ def verify(mod, expected): mod = tvm.relax.transform.ConvertLayout(desired_layouts)(mod) mod = tvm.relax.transform.Normalize()(mod) mod = tvm.relax.transform.FoldConstant()(mod) - mod = tvm.relax.transform.LegalizeOps(add_attributes=True)(mod) + mod = tvm.relax.transform.LegalizeOps(skip_ops=skip_ops)(mod) + mod = tvm.relax.transform.AnnotateTIROpPattern()(mod) + mod = tvm.relax.transform.AnnotateCustomMemoryScope(tgt)(mod) + mod = tvm.relax.transform.LegalizeOps()(mod) # To handle any fallback ops mod = tvm.relax.transform.LegalizeOps( {"relax.nn.conv2d": legalize_adreno.conv2d_NCHWc_OIHWo}, - add_attributes=True, )(mod) mod = tvm.relax.transform.AnnotateTIROpPattern()(mod) mod = tvm.relax.transform.FoldConstant()(mod) mod = tvm.relax.transform.FuseOps()(mod) mod = tvm.relax.transform.FuseTIR()(mod) mod = tvm.relax.transform.DeadCodeElimination()(mod) - mod = tvm.relax.transform.AnnotateCustomMemoryScope(tgt)(mod) + mod = tvm.relax.transform.RemoveToDeviceForScopeChange()(mod) + mod = tvm.relax.transform.DeadCodeElimination()(mod) + mod = tvm.relax.transform.SpecializePrimFuncBasedOnCallSite()(mod) mod = tvm.relax.transform.Normalize()(mod) ValidateScope(expected).visit(mod) @@ -106,7 +116,7 @@ def main( Expected = { "te_layout_transform": (["global"], ["global.texture-nhwc"]), "te_layout_transform1": (["global"], ["global.texture-weight"]), - "conv2d_NCHWc_OIHWo": (["global.texture-nhwc", "global.texture-weight"], ["global"]), + "conv2d_NCHWc_OIHWo_opencl": (["global.texture-nhwc", "global.texture-weight"], ["global"]), "te_layout_transform2": (["global"], ["global"]), } @@ -134,7 +144,7 @@ def main( Expected = { "te_layout_transform": (["global"], ["global.texture-weight"]), "te_layout_transform1": (["global"], ["global.texture-weight"]), - "conv2d_NCHWc_OIHWo": (["global.texture-weight", "global.texture-weight"], ["global"]), + "conv2d_NCHWc_OIHWo_opencl": (["global.texture-weight", "global.texture-weight"], ["global"]), "te_layout_transform2": (["global"], ["global"]), } @@ -162,7 +172,7 @@ def main( Expected = { "te_layout_transform": (["global"], ["global.texture-weight"]), "te_layout_transform1": (["global"], ["global.texture-weight"]), - "conv2d_NCHWc_OIHWo": (["global.texture-weight", "global.texture-weight"], ["global"]), + "conv2d_NCHWc_OIHWo_opencl": (["global.texture-weight", "global.texture-weight"], ["global"]), "te_layout_transform2": (["global"], ["global"]), } @@ -190,7 +200,7 @@ def main( Expected = { "te_layout_transform": (["global"], ["global.texture-weight"]), "te_layout_transform1": (["global"], ["global.texture-weight"]), - "conv2d_NCHWc_OIHWo": (["global.texture-weight", "global.texture-weight"], ["global"]), + "conv2d_NCHWc_OIHWo_opencl": (["global.texture-weight", "global.texture-weight"], ["global"]), "te_layout_transform2": (["global"], ["global"]), } @@ -213,7 +223,7 @@ def main( Expected = { "te_layout_transform": (["global"], ["global.texture-weight"]), "te_layout_transform1": (["global"], ["global.texture-weight"]), - "fused_conv2d_NCHWc_OIHWo_relu": ( + "fused_conv2d_NCHWc_OIHWo_opencl_relu": ( ["global.texture-weight", "global.texture-weight"], ["global"], ), @@ -241,7 +251,7 @@ def main( "relu": (["global"], ["global"]), "te_layout_transform": (["global"], ["global.texture-weight"]), "te_layout_transform1": (["global"], ["global.texture-weight"]), - "fused_conv2d_NCHWc_OIHWo_relu1": ( + "fused_conv2d_NCHWc_OIHWo_opencl_relu1": ( ["global.texture-weight", "global.texture-weight"], ["global"], ), @@ -268,7 +278,7 @@ def main( Expected = { "te_layout_transform": (["global"], ["global.texture-weight"]), "te_layout_transform1": (["global"], ["global.texture-weight"]), - "fused_conv2d_NCHWc_OIHWo_relu_tir_tanh": ( + "fused_conv2d_NCHWc_OIHWo_opencl_relu_tir_tanh": ( ["global.texture-weight", "global.texture-weight"], ["global"], ), @@ -296,7 +306,7 @@ def main( "te_layout_transform": (["global"], ["global.texture-weight"]), "te_layout_transform1": (["global"], ["global.texture-weight"]), "te_layout_transform2": (["global"], ["global.texture-weight"]), - "fused_conv2d_NCHWc_OIHWo_add": ( + "fused_conv2d_NCHWc_OIHWo_opencl_add": ( ["global.texture-weight", "global.texture-weight", "global.texture-weight"], ["global"], ), @@ -326,13 +336,12 @@ def main( Expected = { "te_layout_transform": (["global"], ["global.texture-weight"]), "te_layout_transform1": (["global"], ["global.texture-weight"]), - "conv2d_NCHWc_OIHWo": (["global.texture-weight", "global.texture-weight"], ["global"]), - "relu": (["global"], ["global"]), + "conv2d_NCHWc_OIHWo_opencl": (["global.texture-weight", "global.texture-weight"], ["global"]), "te_layout_transform2": (["global"], ["global"]), + "relu": (["global"], ["global"]), "te_layout_transform3": (["global"], ["global.texture-weight"]), - "te_layout_transform4": (["global"], ["global.texture-weight"]), - "conv2d_NCHWc_OIHWo1": (["global.texture-weight", "global.texture-weight"], ["global"]), - "te_layout_transform5": (["global"], ["global"]), + "conv2d_NCHWc_OIHWo1_opencl": (["global.texture-weight", "global.texture-weight"], ["global"]), + "te_layout_transform4": (["global"], ["global"]), } verify(Input, Expected) @@ -353,7 +362,7 @@ def main( Expected = { "te_layout_transform": (["global"], ["global.texture-weight"]), "te_layout_transform1": (["global"], ["global.texture-weight"]), - "conv2d_NCHWc_OIHWo": (["global.texture-weight", "global.texture-weight"], ["global"]), + "conv2d_NCHWc_OIHWo_opencl": (["global.texture-weight", "global.texture-weight"], ["global"]), "sum": (["global"], ["global"]), "te_layout_transform2": (["global"], ["global"]), } @@ -376,7 +385,7 @@ def main( Expected = { "te_layout_transform": (["global"], ["global.texture-weight"]), "te_layout_transform1": (["global"], ["global.texture-weight"]), - "conv2d_NCHWc_OIHWo": (["global.texture-weight", "global.texture-weight"], ["global"]), + "conv2d_NCHWc_OIHWo_opencl": (["global.texture-weight", "global.texture-weight"], ["global"]), "sum": (["global"], ["global"]), "te_layout_transform2": (["global"], ["global"]), } @@ -399,7 +408,7 @@ def main( Expected = { "te_layout_transform": (["global"], ["global.texture-weight"]), "te_layout_transform1": (["global"], ["global.texture-weight"]), - "conv2d_NCHWc_OIHWo": (["global.texture-weight", "global.texture-weight"], ["global"]), + "conv2d_NCHWc_OIHWo_opencl": (["global.texture-weight", "global.texture-weight"], ["global"]), "sum": (["global"], ["global"]), "te_layout_transform2": (["global"], ["global"]), } @@ -422,7 +431,7 @@ def main( Expected = { "te_layout_transform": (["global"], ["global.texture-weight"]), "te_layout_transform1": (["global"], ["global.texture-weight"]), - "conv2d_NCHWc_OIHWo": (["global.texture-weight", "global.texture-weight"], ["global"]), + "conv2d_NCHWc_OIHWo_opencl": (["global.texture-weight", "global.texture-weight"], ["global"]), "te_layout_transform2": (["global"], ["global"]), "transpose": (["global"], ["global"]), } @@ -445,7 +454,7 @@ def main( Expected = { "te_layout_transform": (["global"], ["global.texture-weight"]), "te_layout_transform1": (["global"], ["global.texture-weight"]), - "conv2d_NCHWc_OIHWo": (["global.texture-weight", "global.texture-weight"], ["global"]), + "conv2d_NCHWc_OIHWo_opencl": (["global.texture-weight", "global.texture-weight"], ["global"]), "te_layout_transform2": (["global"], ["global"]), "expand_dims": (["global"], ["global"]), } @@ -468,7 +477,7 @@ def main( Expected = { "te_layout_transform": (["global"], ["global.texture-weight"]), "te_layout_transform1": (["global"], ["global.texture-weight"]), - "conv2d_NCHWc_OIHWo": (["global.texture-weight", "global.texture-weight"], ["global"]), + "conv2d_NCHWc_OIHWo_opencl": (["global.texture-weight", "global.texture-weight"], ["global"]), "te_layout_transform2": (["global"], ["global"]), "squeeze": (["global"], ["global"]), } @@ -493,7 +502,7 @@ def main( Expected = { "te_layout_transform": (["global"], ["global.texture-weight"]), "te_layout_transform1": (["global"], ["global.texture-weight"]), - "conv2d_NCHWc_OIHWo": (["global.texture-weight", "global.texture-weight"], ["global"]), + "conv2d_NCHWc_OIHWo_opencl": (["global.texture-weight", "global.texture-weight"], ["global"]), "te_layout_transform2": (["global"], ["global"]), "strided_slice": (["global"], ["global"]), } @@ -517,7 +526,7 @@ def main( Expected = { "te_layout_transform": (["global"], ["global.texture-weight"]), "te_layout_transform1": (["global"], ["global.texture-weight"]), - "conv2d_NCHWc_OIHWo": ( + "conv2d_NCHWc_OIHWo_opencl": ( ["global.texture-weight", "global.texture-weight"], ["global.texture-weight"], ), @@ -543,7 +552,7 @@ def main(x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "fl Expected = { "te_layout_transform": (["global"], ["global.texture-weight"]), "te_layout_transform1": (["global"], ["global.texture-weight"]), - "conv2d_NCHWc_OIHWo": ( + "conv2d_NCHWc_OIHWo_opencl": ( ["global.texture-weight", "global.texture-weight"], ["global.texture-weight"], ), @@ -573,14 +582,14 @@ def main(x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "fl Expected = { "te_layout_transform": (["global"], ["global.texture-weight"]), "te_layout_transform1": (["global"], ["global.texture-weight"]), - "conv2d_NCHWc_OIHWo": ( + "conv2d_NCHWc_OIHWo_opencl": ( ["global.texture-weight", "global.texture-weight"], ["global.texture-weight"], ), "fused_relu_concatenate_split": (["global.texture-weight"], ["global", "global"]), "te_layout_transform2": (["global"], ["global"]), "te_layout_transform3": (["global"], ["global"]), - "fused_transpose_transpose1_concatenate1": (["global", "global"], ["global"]), + "fused_transpose_transpose_concatenate1": (["global", "global"], ["global"]), } verify(Input, Expected) @@ -608,11 +617,11 @@ def main( Expected = { "te_layout_transform": (["global"], ["global.texture-weight"]), "te_layout_transform1": (["global"], ["global.texture-weight"]), - "conv2d_NCHWc_OIHWo": ( + "conv2d_NCHWc_OIHWo_opencl": ( ["global.texture-weight", "global.texture-weight"], ["global.texture-weight"], ), - "max_pool2d": (["global.texture-weight"], ["global"]), + "max_pool2d_opencl": (["global.texture-weight"], ["global"]), "te_layout_transform2": (["global"], ["global"]), } verify(Input, Expected) @@ -634,11 +643,11 @@ def main( Expected = { "te_layout_transform": (["global"], ["global.texture-weight"]), "te_layout_transform1": (["global"], ["global.texture-weight"]), - "conv2d_NCHWc_OIHWo": ( + "conv2d_NCHWc_OIHWo_opencl": ( ["global.texture-weight", "global.texture-weight"], ["global.texture-weight"], ), - "adaptive_avg_pool2d": (["global.texture-weight"], ["global"]), + "adaptive_avg_pool2d_opencl": (["global.texture-weight"], ["global"]), "te_layout_transform2": (["global"], ["global"]), } verify(Input, Expected) @@ -660,7 +669,7 @@ def main( Expected = { "te_layout_transform": (["global"], ["global.texture-weight"]), "te_layout_transform1": (["global"], ["global.texture-weight"]), - "conv2d_NCHWc_OIHWo": (["global.texture-weight", "global.texture-weight"], ["global"]), + "conv2d_NCHWc_OIHWo_opencl": (["global.texture-weight", "global.texture-weight"], ["global"]), "te_layout_transform2": (["global"], ["global"]), "softmax": (["global"], ["global"]), } @@ -688,11 +697,11 @@ def main( Expected = { "te_layout_transform": (["global"], ["global.texture-weight"]), "te_layout_transform1": (["global"], ["global.texture-weight"]), - "conv2d_NCHWc_OIHWo": ( + "conv2d_NCHWc_OIHWo_opencl": ( ["global.texture-weight", "global.texture-weight"], - ["global.texture-weight"], + ["global"], ), - "layer_norm": (["global.texture-weight", "global", "global"], ["global"]), + "layer_norm": (["global", "global", "global"], ["global"]), "te_layout_transform2": (["global"], ["global"]), } verify(Input, Expected) @@ -716,7 +725,7 @@ def main( Expected = { "te_layout_transform": (["global"], ["global.texture-weight"]), "te_layout_transform1": (["global"], ["global.texture-weight"]), - "conv2d_NCHWc_OIHWo": (["global.texture-weight", "global.texture-weight"], ["global"]), + "conv2d_NCHWc_OIHWo_opencl": (["global.texture-weight", "global.texture-weight"], ["global"]), "te_layout_transform2": (["global"], ["global"]), "add": (["global", "global"], ["global"]), } @@ -739,7 +748,7 @@ def main( Expected = { "te_layout_transform": (["global"], ["global.texture-weight"]), "te_layout_transform1": (["global"], ["global.texture-weight"]), - "fused_conv2d_NCHWc_OIHWo_add": ( + "fused_conv2d_NCHWc_OIHWo_opencl_add": ( ["global.texture-weight", "global.texture-weight"], ["global"], ), @@ -795,23 +804,21 @@ def main( "te_layout_transform": (["global"], ["global.texture-weight"]), "te_layout_transform1": (["global"], ["global.texture-weight"]), "te_layout_transform2": (["global"], ["global.texture-weight"]), - "fused_conv2d_NCHWc_OIHWo_add_relu": ( + "fused_conv2d_NCHWc_OIHWo_opencl_add_relu": ( ["global.texture-weight", "global.texture-weight", "global.texture-weight"], ["global.texture-weight"], ), "te_layout_transform3": (["global"], ["global.texture-weight"]), "multiply": (["global"], ["global"]), - "te_layout_transform4": (["global"], ["global.texture-weight"]), - "fused_conv2d_NCHWc_OIHWo1_add_relu": ( + "fused_conv2d_NCHWc_OIHWo1_opencl_add_relu": ( ["global.texture-weight", "global.texture-weight", "global.texture-weight"], ["global.texture-weight"], ), - "te_layout_transform5": (["global"], ["global.texture-weight"]), - "fused_conv2d_NCHWc_OIHWo2_relu1": ( + "fused_conv2d_NCHWc_OIHWo2_opencl_relu1": ( ["global.texture-weight", "global.texture-weight"], ["global"], ), - "te_layout_transform6": (["global"], ["global"]), + "te_layout_transform4": (["global"], ["global"]), } verify(Input, Expected) @@ -856,14 +863,14 @@ def main( "te_layout_transform": (["global"], ["global.texture-weight"]), "te_layout_transform1": (["global"], ["global.texture-weight"]), "te_layout_transform2": (["global"], ["global.texture-weight"]), - "fused_conv2d_NCHWc_OIHWo_add_relu": ( + "fused_conv2d_NCHWc_OIHWo_opencl_add_relu": ( ["global.texture-weight", "global.texture-weight", "global.texture-weight"], ["global"], ), "te_layout_transform3": (["global"], ["global.texture-weight"]), - "conv2d_NCHWc_OIHWo1": (["global.texture-weight", "global.texture-weight"], ["global"]), + "conv2d_NCHWc_OIHWo1_opencl": (["global.texture-weight", "global.texture-weight"], ["global"]), "te_layout_transform4": (["global"], ["global"]), - "conv2d": (["global", "global"], ["global"]), + "conv2d_opencl": (["global", "global"], ["global"]), "te_layout_transform5": (["global"], ["global"]), "concatenate": (["global", "global"], ["global"]), } @@ -910,17 +917,17 @@ def main( "te_layout_transform": (["global"], ["global.texture-weight"]), "te_layout_transform1": (["global"], ["global.texture-weight"]), "te_layout_transform2": (["global"], ["global.texture-weight"]), - "fused_conv2d_NCHWc_OIHWo_add_relu": ( + "fused_conv2d_NCHWc_OIHWo_opencl_add_relu": ( ["global.texture-weight", "global.texture-weight", "global.texture-weight"], ["global.texture-weight"], ), "te_layout_transform3": (["global"], ["global.texture-weight"]), - "conv2d_NCHWc_OIHWo1": ( + "conv2d_NCHWc_OIHWo1_opencl": ( ["global.texture-weight", "global.texture-weight"], ["global.texture-weight"], ), "te_layout_transform4": (["global"], ["global.texture-weight"]), - "conv2d_NCHWc_OIHWo2": ( + "conv2d_NCHWc_OIHWo2_opencl": ( ["global.texture-weight", "global.texture-weight"], ["global.texture-weight"], ), @@ -982,27 +989,25 @@ def main( Expected = { "te_layout_transform": (["global"], ["global.texture-weight"]), "te_layout_transform1": (["global"], ["global.texture-weight"]), - "conv2d_NCHWc_OIHWo": ( + "conv2d_NCHWc_OIHWo_opencl": ( ["global.texture-weight", "global.texture-weight"], ["global.texture-weight"], ), - "max_pool2d": (["global.texture-weight"], ["global.texture-weight"]), + "max_pool2d_opencl": (["global.texture-weight"], ["global.texture-weight"]), "te_layout_transform2": (["global"], ["global.texture-weight"]), - "te_layout_transform3": (["global"], ["global.texture-weight"]), - "conv2d_NCHWc_OIHWo2": ( + "conv2d_NCHWc_OIHWo2_opencl": ( ["global.texture-weight", "global.texture-weight"], ["global.texture-weight"], ), - "fused_conv2d_NCHWc_OIHWo1_add": ( + "fused_conv2d_NCHWc_OIHWo1_opencl_add": ( ["global.texture-weight", "global.texture-weight", "global.texture-weight"], ["global.texture-weight"], ), - "te_layout_transform4": (["global"], ["global.texture-weight"]), - "fused_conv2d_NCHWc_OIHWo3_add": ( + "fused_conv2d_NCHWc_OIHWo3_opencl_add": ( ["global.texture-weight", "global.texture-weight", "global.texture-weight"], ["global"], ), - "te_layout_transform5": (["global"], ["global"]), + "te_layout_transform3": (["global"], ["global"]), } verify(Input, Expected) @@ -1054,19 +1059,16 @@ def main( Expected = { "te_layout_transform": (["global"], ["global.texture-weight"]), "te_layout_transform1": (["global"], ["global.texture-weight"]), - "conv2d_NCHWc_OIHWo": ( + "conv2d_NCHWc_OIHWo_opencl": ( ["global.texture-weight", "global.texture-weight"], - ["global.texture-weight"], + ["global"], ), - "te_layout_transform3": (["global"], ["global"]), + "te_layout_transform2": (["global"], ["global"]), "fused_mean_add1": (["global", "global"], ["global"]), - "te_layout_transform2": (["global"], ["global.texture-weight"]), - "te_layout_transform4": (["global"], ["global.texture-weight"]), - "fused_conv2d_NCHWc_OIHWo1_add_multiply_add": ( - ["global.texture-weight", "global.texture-weight", "global.texture-weight"], + "fused_conv2d_NCHWc_OIHWo_opencl_add_multiply_add": ( + ["global.texture-weight", "global.texture-weight", "global.texture-weight", "global.texture-weight", "global.texture-weight"], ["global"], ), - "te_layout_transform5": (["global"], ["global"]), } verify(Input, Expected) @@ -1120,19 +1122,16 @@ def main( Expected = { "te_layout_transform": (["global"], ["global.texture-weight"]), "te_layout_transform1": (["global"], ["global.texture-weight"]), - "conv2d_NCHWc_OIHWo": ( + "conv2d_NCHWc_OIHWo_opencl": ( ["global.texture-weight", "global.texture-weight"], - ["global.texture-weight"], + ["global"], ), - "te_layout_transform3": (["global"], ["global"]), + "te_layout_transform2": (["global"], ["global"]), "fused_mean_add1": (["global", "global"], ["global"]), - "te_layout_transform2": (["global"], ["global.texture-weight"]), - "te_layout_transform4": (["global"], ["global.texture-weight"]), - "fused_conv2d_NCHWc_OIHWo1_add_multiply_add": ( - ["global.texture-weight", "global.texture-weight", "global.texture-weight"], + "fused_conv2d_NCHWc_OIHWo_opencl_add_multiply_add": ( + ["global.texture-weight", "global.texture-weight", "global.texture-weight", "global.texture-weight"], ["global"], ), - "te_layout_transform5": (["global"], ["global"]), } verify(Input, Expected) From 3528df783973aff13084c6b86806049f6f0029a2 Mon Sep 17 00:00:00 2001 From: Siva Date: Wed, 26 Feb 2025 14:28:19 +0530 Subject: [PATCH 09/31] lint --- python/tvm/dlight/adreno/convolution.py | 2 +- python/tvm/relax/transform/__init__.py | 3 +- python/tvm/relax/transform/transform.py | 6 +- python/tvm/tir/analysis/analysis.py | 1 + .../transform/annotate_custom_storage.cc | 67 +++++-------- src/relax/transform/fuse_tir.cc | 18 ---- .../test_transform_annotate_custom_scope.py | 96 +++++++++++++++---- tests/scripts/task_build_adreno_bins.sh | 4 +- 8 files changed, 109 insertions(+), 88 deletions(-) diff --git a/python/tvm/dlight/adreno/convolution.py b/python/tvm/dlight/adreno/convolution.py index f084885dad73..fc2cc449a1c6 100644 --- a/python/tvm/dlight/adreno/convolution.py +++ b/python/tvm/dlight/adreno/convolution.py @@ -24,7 +24,7 @@ from tvm.tir import IterVar from tvm.tir.schedule.schedule import BlockRV -from ..base import analysis, BlockInfo, IterInfo +from ..analysis import BlockInfo, IterInfo from .base import AdrenoScheduleRule diff --git a/python/tvm/relax/transform/__init__.py b/python/tvm/relax/transform/__init__.py index 559d4a9380e1..8c4e4acebdab 100644 --- a/python/tvm/relax/transform/__init__.py +++ b/python/tvm/relax/transform/__init__.py @@ -99,7 +99,8 @@ from .optimize_layout_transform import OptimizeLayoutTransform from .fold_batch_norm_to_conv2d_for_inference import FoldBatchnormToConv2D from .remove_redundant_reshape import RemoveRedundantReshape -#from .remove_to_device_for_scope_change import RemoveToDeviceForScopeChange + +# from .remove_to_device_for_scope_change import RemoveToDeviceForScopeChange # Import to register the legalization functions. from . import legalize_ops diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index aad6372b1658..339112714edb 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -1173,9 +1173,7 @@ def multiply( T_multiply[v_ax0, v_ax1] = A[v_ax0, v_ax1] * B[v_ax0, v_ax1] """ - return _ffi_api.LegalizeOps( - customize_legalize_map, skip_ops, enable_warning # type: ignore - ) + return _ffi_api.LegalizeOps(customize_legalize_map, skip_ops, enable_warning) # type: ignore def RealizeVDevice() -> tvm.ir.transform.Pass: @@ -1640,7 +1638,7 @@ def SpecializePrimFuncBasedOnCallSite() -> tvm.ir.transform.Pass: def RemoveToDeviceForScopeChange() -> tvm.ir.transform.Pass: - """This pass + """This pass Returns ------- diff --git a/python/tvm/tir/analysis/analysis.py b/python/tvm/tir/analysis/analysis.py index ef2da05c347c..44b7e4b8b18d 100644 --- a/python/tvm/tir/analysis/analysis.py +++ b/python/tvm/tir/analysis/analysis.py @@ -18,6 +18,7 @@ # pylint: disable=invalid-name from typing import Dict, List, Optional, Union +import tvm from tvm import Object, _ffi from tvm.ir import IRModule from tvm.tir.expr import Var diff --git a/src/relax/transform/annotate_custom_storage.cc b/src/relax/transform/annotate_custom_storage.cc index 80a8d36ae6c3..5f51697e83b3 100644 --- a/src/relax/transform/annotate_custom_storage.cc +++ b/src/relax/transform/annotate_custom_storage.cc @@ -23,11 +23,11 @@ #include #include +#include #include #include #include #include -#include #include #include @@ -48,7 +48,8 @@ static Array GetShapeFromTensorStructInfo(const TensorStructInfo& tens } namespace { -std::tuple)>> CreatePatterns(Map>> scope_info) { +std::tuple)>> CreatePatterns( + Map>> scope_info) { auto pat_gv = WildcardPattern(); auto pat_inp = WildcardPattern(); @@ -56,7 +57,7 @@ std::tuple)>> CreateP auto pattern_out = IsOp("relax.to_vdevice")(pat_call_tir); auto rewriter = [=](Expr expr, Map matches) -> Expr { - const auto* call_tir = matches[pat_call_tir].as(); + const auto* call_tir = matches[pat_call_tir].as(); ICHECK(call_tir) << "InternalError: " << "Match of relax.call_tir operator should produce Call, " << "but instead produces " << matches[pat_call_tir] << " with type " @@ -71,8 +72,7 @@ std::tuple)>> CreateP const auto* vdev_attrs = out->attrs.as(); ICHECK(vdev_attrs) << "InternalError: " << "Attributes for relax.to_vdevice operator should be ToVDeviceAttrs, " - << "but were instead " << out->attrs << " with type " - << out->GetTypeKey(); + << "but were instead " << out->attrs << " with type " << out->GetTypeKey(); const auto* tir_out_sinfo = call_tir->sinfo_args[0].as(); if (!tir_out_sinfo) return expr; @@ -87,12 +87,12 @@ std::tuple)>> CreateP } } - if ((std::string(tir_out_sinfo->vdevice.value()->memory_scope).find("texture") != std::string::npos) && + if ((std::string(tir_out_sinfo->vdevice.value()->memory_scope).find("texture") != + std::string::npos) && (vdev_attrs->dst_vdevice->memory_scope == "global")) { - LOG(WARNING) << "Can be optimized"; auto shape_arr = tir_out_sinfo->GetShape().value(); - auto new_sinfo = TensorStructInfo(ShapeExpr(shape_arr), tir_out_sinfo->dtype, - vdev_attrs->dst_vdevice); + auto new_sinfo = + TensorStructInfo(ShapeExpr(shape_arr), tir_out_sinfo->dtype, vdev_attrs->dst_vdevice); return Call(call_tir->op, call_tir->args, call_tir->attrs, {new_sinfo}); } @@ -126,19 +126,16 @@ class RemoveRedundantAssignments : public ExprMutator { } void VisitBinding_(const VarBindingNode* binding, const VarNode* var) final { - LOG(WARNING) << "VisitBinding_:" << binding->var.get()->name_hint() << " : " << var->name_hint(); redundant_map.Set(GetRef(binding->var.get()), GetRef(var)); } void VisitBinding_(const VarBindingNode* binding, const DataflowVarNode* val) override { - LOG(WARNING) << "VisitBinding_ - DataFlow:" << binding->var.get()->name_hint() << " : " << static_cast(val)->name_hint(); redundant_map.Set(GetRef(binding->var.get()), GetRef(val)); } Expr VisitExpr_(const CallNode* call_node) final { auto call = Downcast(ExprMutator::VisitExpr_(call_node)); static const Op& call_tir_op = Op::Get("relax.call_tir"); - LOG(WARNING) << "VisitExpr_:" << call->op << " Args:" << call->args; Tuple args; @@ -157,7 +154,8 @@ class RemoveRedundantAssignments : public ExprMutator { } } if (call->op == call_tir_op) { - return Call(call_tir_op, {call->args[0], Tuple(new_args)}, call->attrs, {call->sinfo_args[0]}); + return Call(call_tir_op, {call->args[0], Tuple(new_args)}, call->attrs, + {call->sinfo_args[0]}); } else { if (call->sinfo_args.size() > 0) { return Call(call->op, new_args, call->attrs, {call->sinfo_args[0]}); @@ -167,7 +165,7 @@ class RemoveRedundantAssignments : public ExprMutator { } } -private: + private: Map redundant_map; IRModule updates_; IRModule mod_; @@ -206,7 +204,6 @@ class CollectProduserScopeInfo : public ExprVisitor { ICHECK(op_map_infer_struct_info_.count(op)) << " Cannot find the FInferStructInfo attribute registered to op: " << op->name; out_sinfo = op_map_infer_struct_info_[op](GetRef(call), builder_); - LOG(WARNING) << "Got Struct Info for normal op:" << out_sinfo.as(); } std::unordered_map scope_count; @@ -282,7 +279,6 @@ class CollectConsumerScopeInfo : public ExprVisitor { mod_ = mod; target_ = target; VisitExpr(func->body); - LOG(WARNING) << "Visit completed"; // Extend the scope for tuple items for (const auto& val : arg_to_binding) { if (scope_info.find(val.first) != scope_info.end()) { @@ -313,7 +309,7 @@ class CollectConsumerScopeInfo : public ExprVisitor { static const Op& call_tir_op = Op::Get("relax.call_tir"); GlobalVar gv; Array op_attrs; - Optional op_pattern = Integer(static_cast(relay::kOpaque)); + Optional op_pattern = Integer(static_cast(OpPatternKind::kOpaque)); Tuple func_args; if (call->op == call_tir_op) { @@ -323,27 +319,20 @@ class CollectConsumerScopeInfo : public ExprVisitor { op_pattern = ExtractPattern(pfunc); func_args = Downcast(call->args[1]); } else { - LOG(WARNING) << "About to access attrs"; op_attrs = {call->attrs}; - op_pattern = Integer(static_cast(relay::kOpaque)); + op_pattern = Integer(static_cast(OpPatternKind::kOpaque)); func_args = Tuple(call->args); } - LOG(WARNING) << "About to call texture scope"; bool is_texture_supported = SupportsTexture(op_attrs, op_pattern.value()); - LOG(WARNING) << "returned texture scope"; Array arg_scope; for (auto arg : func_args->fields) { - LOG(WARNING) << "B1"; auto sinfo = GetStructInfo(arg); - LOG(WARNING) << "B2"; if (auto tensor_sinfo = sinfo.as()) { - LOG(WARNING) << "B3:" << tensor_sinfo; auto scope = is_texture_supported ? Scope(GetShapeFromTensorStructInfo(tensor_sinfo.value())) : "global"; - LOG(WARNING) << "B4"; Map> ent_call; const VarNode* arg_var = arg.as(); if (scope_info.find(GetRef(arg_var)) != scope_info.end()) { @@ -352,11 +341,9 @@ class CollectConsumerScopeInfo : public ExprVisitor { ent_call.Set(GetRef(call), {scope}); scope_info.Set(GetRef(arg_var), ent_call); arg_scope.push_back(scope); - LOG(WARNING) << "B5"; } } call_scope_info.Set(GetRef(call), arg_scope); - LOG(WARNING) << "B6"; } private: @@ -381,7 +368,7 @@ class CollectConsumerScopeInfo : public ExprVisitor { } bool SupportsTexture(const Array& op_attrs, Integer op_pattern) { - if (op_pattern.IntValue() < relay::kCommReduce) return true; + if (op_pattern.IntValue() < OpPatternKind::kCommReduce) return true; for (auto attr : op_attrs) { if (auto conv_attr = attr.as()) { @@ -409,8 +396,8 @@ class CollectConsumerScopeInfo : public ExprVisitor { // 5d requirement is not limitation of textures in general, it is limitation how // we are representing memory scopes/layout and flattening of textures in tir if (shape.size() == 5 && shape[4].as()->value == 4) { - for (auto ind: shape) { - if (! ind.as()) { + for (auto ind : shape) { + if (!ind.as()) { // Dynamic tensors return "global.texture-nchw"; } @@ -462,11 +449,9 @@ class DefineVDevice : ExprMutator { if (base_func->HasNonzeroAttr(attr::kPrimitive)) { continue; } - LOG(WARNING) << "Ccall CollectConsumerScopeInfo"; auto info = CollectConsumerScopeInfo().Collect(mod_, Downcast(func), target_); call_scope_info_ = info.first; scope_info_ = info.second; - LOG(WARNING) << "Ccall CollectProducerScopeInfo"; producer_sinfo_ = CollectProduserScopeInfo().Collect(mod_, Downcast(func), scope_info_, target_, builder_); relax::Function update_func = Downcast(VisitExpr(func)); @@ -481,13 +466,9 @@ class DefineVDevice : ExprMutator { } mod_.CopyOnWrite()->global_infos.Set("vdevice", global_vdevices_); - LOG(WARNING) << "MOd After scope:" << mod_; mod_ = relax::transform::DeadCodeElimination()(mod_); mod_ = relax::transform::RealizeVDevice()(mod_); - LOG(WARNING) << "Realized:" << mod_; mod_ = relax::transform::RemoveRedundantAssignments()(mod_); - LOG(WARNING) << "Redundant Assin:" << mod_; - //mod_ = relax::transform::SpecializePrimFuncBasedOnCallSite()(mod_); return mod_; } @@ -505,12 +486,12 @@ class DefineVDevice : ExprMutator { if (call->op == call_tir_op) { gv = Downcast(call->args[0]); - //tir::PrimFunc pfunc = Downcast(mod_->Lookup(gv)); - //out_sinfo = call->sinfo_args[0]; + // tir::PrimFunc pfunc = Downcast(mod_->Lookup(gv)); + // out_sinfo = call->sinfo_args[0]; func_args = Downcast(call->args[1]); } else { func_args = Tuple(call->args); - //return call; + // return call; } Array new_args; @@ -570,10 +551,11 @@ class DefineVDevice : ExprMutator { } } - if (call->op == call_tir_op) { - return builder_->Normalize(Call(call_tir_op, {gv, Tuple(new_args)}, call->attrs, {updated_ret_sinfo})); + if (call->op == call_tir_op) { + return builder_->Normalize( + Call(call_tir_op, {gv, Tuple(new_args)}, call->attrs, {updated_ret_sinfo})); } else { - return builder_->Normalize(Call(call->op, new_args, call->attrs, {updated_ret_sinfo})); + return builder_->Normalize(Call(call->op, new_args, call->attrs, {updated_ret_sinfo})); } } @@ -661,7 +643,6 @@ Pass RemoveRedundantAssignments() { TVM_REGISTER_GLOBAL("relax.transform.RemoveRedundantAssignments") .set_body_typed(RemoveRedundantAssignments); - Pass AnnotateCustomMemoryScope(Target target) { runtime::TypedPackedFunc pass_func = [=](IRModule mod, PassContext pc) { return relax::DefineVDevice(target).Run(mod); }; diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index cd960b13f7e9..92b667047271 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -644,14 +644,6 @@ class FusedTIRConstructor : public ExprVisitor { // Step 1. Get Global var and PrimFunc GlobalVar gv = Downcast(call->args[0]); tir::PrimFunc prim_func_ = Downcast(mod_->Lookup(gv)); - if (prim_func_->GetAttr("op_attrs")) { - func_info_.op_attrs.push_back(prim_func_->GetAttr("op_attrs").value()); - } - - if (prim_func_->GetAttr("op_pattern")) { - auto op_pattern = prim_func_->GetAttr("op_pattern").value(); - func_info_.op_pattern.push_back(static_cast(op_pattern.IntValue())); - } // Step 2. Renew all vars/buffer definitions and blocks to avoid duplication tir::PrimFunc prim_func = tir::RenewDefs(prim_func_); @@ -969,14 +961,6 @@ class FusedTIRConstructor : public ExprVisitor { tir::PrimFunc ConstructFunc() { ffi::Map attr_map; attr_map.Set(tir::attr::kNoAlias, true); - if (!func_info_.op_attrs.empty()) { - attr_map.Set("op_attrs", func_info_.op_attrs); - } - if (!func_info_.op_pattern.empty()) { - int op_pattern = relay::kOpaque; - op_pattern = *max_element(func_info_.op_pattern.begin(), func_info_.op_pattern.end()); - attr_map.Set("op_pattern", Integer(static_cast(op_pattern))); - } tir::FuseTIRBufferSubstitutor subst(func_info_.buffer_subst_map, func_info_.symbolic_var_remap); ICHECK(func_info_.global_name != "fused"); // Remove output buffers from func_info_.alloc_buffers @@ -1028,8 +1012,6 @@ class FusedTIRConstructor : public ExprVisitor { ffi::Array alloc_buffers; /*! \brief The bodies of the original funcs, which is also the body of the fused func. */ ffi::Array bodies; - ffi::Array op_attrs; - std::vector op_pattern; /*! \brief The params of the fused function*/ ffi::Array params; /*! diff --git a/tests/python/relax/test_transform_annotate_custom_scope.py b/tests/python/relax/test_transform_annotate_custom_scope.py index b06578e2a0c8..576d020db060 100644 --- a/tests/python/relax/test_transform_annotate_custom_scope.py +++ b/tests/python/relax/test_transform_annotate_custom_scope.py @@ -71,7 +71,7 @@ def verify(mod, expected): "relax.nn.conv2d", "relax.nn.max_pool2d", "relax.nn.adaptive_avg_pool2d", - #"relax.nn.layer_norm", + # "relax.nn.layer_norm", ] with tgt: mod = tvm.tir.transform.BindTarget(tvm.target.Target.current(allow_none=False))(mod) @@ -84,7 +84,7 @@ def verify(mod, expected): mod = tvm.relax.transform.LegalizeOps(skip_ops=skip_ops)(mod) mod = tvm.relax.transform.AnnotateTIROpPattern()(mod) mod = tvm.relax.transform.AnnotateCustomMemoryScope(tgt)(mod) - mod = tvm.relax.transform.LegalizeOps()(mod) # To handle any fallback ops + mod = tvm.relax.transform.LegalizeOps()(mod) # To handle any fallback ops mod = tvm.relax.transform.LegalizeOps( {"relax.nn.conv2d": legalize_adreno.conv2d_NCHWc_OIHWo}, )(mod) @@ -93,6 +93,7 @@ def verify(mod, expected): mod = tvm.relax.transform.FuseOps()(mod) mod = tvm.relax.transform.FuseTIR()(mod) mod = tvm.relax.transform.DeadCodeElimination()(mod) + mod = tvm.relax.transform.AnnotateTIROpPattern()(mod) mod = tvm.relax.transform.RemoveToDeviceForScopeChange()(mod) mod = tvm.relax.transform.DeadCodeElimination()(mod) mod = tvm.relax.transform.SpecializePrimFuncBasedOnCallSite()(mod) @@ -144,7 +145,10 @@ def main( Expected = { "te_layout_transform": (["global"], ["global.texture-weight"]), "te_layout_transform1": (["global"], ["global.texture-weight"]), - "conv2d_NCHWc_OIHWo_opencl": (["global.texture-weight", "global.texture-weight"], ["global"]), + "conv2d_NCHWc_OIHWo_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), "te_layout_transform2": (["global"], ["global"]), } @@ -172,7 +176,10 @@ def main( Expected = { "te_layout_transform": (["global"], ["global.texture-weight"]), "te_layout_transform1": (["global"], ["global.texture-weight"]), - "conv2d_NCHWc_OIHWo_opencl": (["global.texture-weight", "global.texture-weight"], ["global"]), + "conv2d_NCHWc_OIHWo_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), "te_layout_transform2": (["global"], ["global"]), } @@ -200,7 +207,10 @@ def main( Expected = { "te_layout_transform": (["global"], ["global.texture-weight"]), "te_layout_transform1": (["global"], ["global.texture-weight"]), - "conv2d_NCHWc_OIHWo_opencl": (["global.texture-weight", "global.texture-weight"], ["global"]), + "conv2d_NCHWc_OIHWo_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), "te_layout_transform2": (["global"], ["global"]), } @@ -336,11 +346,17 @@ def main( Expected = { "te_layout_transform": (["global"], ["global.texture-weight"]), "te_layout_transform1": (["global"], ["global.texture-weight"]), - "conv2d_NCHWc_OIHWo_opencl": (["global.texture-weight", "global.texture-weight"], ["global"]), + "conv2d_NCHWc_OIHWo_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), "te_layout_transform2": (["global"], ["global"]), "relu": (["global"], ["global"]), "te_layout_transform3": (["global"], ["global.texture-weight"]), - "conv2d_NCHWc_OIHWo1_opencl": (["global.texture-weight", "global.texture-weight"], ["global"]), + "conv2d_NCHWc_OIHWo1_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), "te_layout_transform4": (["global"], ["global"]), } verify(Input, Expected) @@ -362,7 +378,10 @@ def main( Expected = { "te_layout_transform": (["global"], ["global.texture-weight"]), "te_layout_transform1": (["global"], ["global.texture-weight"]), - "conv2d_NCHWc_OIHWo_opencl": (["global.texture-weight", "global.texture-weight"], ["global"]), + "conv2d_NCHWc_OIHWo_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), "sum": (["global"], ["global"]), "te_layout_transform2": (["global"], ["global"]), } @@ -385,7 +404,10 @@ def main( Expected = { "te_layout_transform": (["global"], ["global.texture-weight"]), "te_layout_transform1": (["global"], ["global.texture-weight"]), - "conv2d_NCHWc_OIHWo_opencl": (["global.texture-weight", "global.texture-weight"], ["global"]), + "conv2d_NCHWc_OIHWo_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), "sum": (["global"], ["global"]), "te_layout_transform2": (["global"], ["global"]), } @@ -408,7 +430,10 @@ def main( Expected = { "te_layout_transform": (["global"], ["global.texture-weight"]), "te_layout_transform1": (["global"], ["global.texture-weight"]), - "conv2d_NCHWc_OIHWo_opencl": (["global.texture-weight", "global.texture-weight"], ["global"]), + "conv2d_NCHWc_OIHWo_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), "sum": (["global"], ["global"]), "te_layout_transform2": (["global"], ["global"]), } @@ -431,7 +456,10 @@ def main( Expected = { "te_layout_transform": (["global"], ["global.texture-weight"]), "te_layout_transform1": (["global"], ["global.texture-weight"]), - "conv2d_NCHWc_OIHWo_opencl": (["global.texture-weight", "global.texture-weight"], ["global"]), + "conv2d_NCHWc_OIHWo_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), "te_layout_transform2": (["global"], ["global"]), "transpose": (["global"], ["global"]), } @@ -454,7 +482,10 @@ def main( Expected = { "te_layout_transform": (["global"], ["global.texture-weight"]), "te_layout_transform1": (["global"], ["global.texture-weight"]), - "conv2d_NCHWc_OIHWo_opencl": (["global.texture-weight", "global.texture-weight"], ["global"]), + "conv2d_NCHWc_OIHWo_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), "te_layout_transform2": (["global"], ["global"]), "expand_dims": (["global"], ["global"]), } @@ -477,7 +508,10 @@ def main( Expected = { "te_layout_transform": (["global"], ["global.texture-weight"]), "te_layout_transform1": (["global"], ["global.texture-weight"]), - "conv2d_NCHWc_OIHWo_opencl": (["global.texture-weight", "global.texture-weight"], ["global"]), + "conv2d_NCHWc_OIHWo_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), "te_layout_transform2": (["global"], ["global"]), "squeeze": (["global"], ["global"]), } @@ -502,7 +536,10 @@ def main( Expected = { "te_layout_transform": (["global"], ["global.texture-weight"]), "te_layout_transform1": (["global"], ["global.texture-weight"]), - "conv2d_NCHWc_OIHWo_opencl": (["global.texture-weight", "global.texture-weight"], ["global"]), + "conv2d_NCHWc_OIHWo_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), "te_layout_transform2": (["global"], ["global"]), "strided_slice": (["global"], ["global"]), } @@ -669,7 +706,10 @@ def main( Expected = { "te_layout_transform": (["global"], ["global.texture-weight"]), "te_layout_transform1": (["global"], ["global.texture-weight"]), - "conv2d_NCHWc_OIHWo_opencl": (["global.texture-weight", "global.texture-weight"], ["global"]), + "conv2d_NCHWc_OIHWo_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), "te_layout_transform2": (["global"], ["global"]), "softmax": (["global"], ["global"]), } @@ -725,7 +765,10 @@ def main( Expected = { "te_layout_transform": (["global"], ["global.texture-weight"]), "te_layout_transform1": (["global"], ["global.texture-weight"]), - "conv2d_NCHWc_OIHWo_opencl": (["global.texture-weight", "global.texture-weight"], ["global"]), + "conv2d_NCHWc_OIHWo_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), "te_layout_transform2": (["global"], ["global"]), "add": (["global", "global"], ["global"]), } @@ -868,7 +911,10 @@ def main( ["global"], ), "te_layout_transform3": (["global"], ["global.texture-weight"]), - "conv2d_NCHWc_OIHWo1_opencl": (["global.texture-weight", "global.texture-weight"], ["global"]), + "conv2d_NCHWc_OIHWo1_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), "te_layout_transform4": (["global"], ["global"]), "conv2d_opencl": (["global", "global"], ["global"]), "te_layout_transform5": (["global"], ["global"]), @@ -1066,7 +1112,13 @@ def main( "te_layout_transform2": (["global"], ["global"]), "fused_mean_add1": (["global", "global"], ["global"]), "fused_conv2d_NCHWc_OIHWo_opencl_add_multiply_add": ( - ["global.texture-weight", "global.texture-weight", "global.texture-weight", "global.texture-weight", "global.texture-weight"], + [ + "global.texture-weight", + "global.texture-weight", + "global.texture-weight", + "global.texture-weight", + "global.texture-weight", + ], ["global"], ), } @@ -1129,7 +1181,12 @@ def main( "te_layout_transform2": (["global"], ["global"]), "fused_mean_add1": (["global", "global"], ["global"]), "fused_conv2d_NCHWc_OIHWo_opencl_add_multiply_add": ( - ["global.texture-weight", "global.texture-weight", "global.texture-weight", "global.texture-weight"], + [ + "global.texture-weight", + "global.texture-weight", + "global.texture-weight", + "global.texture-weight", + ], ["global"], ), } @@ -1138,3 +1195,4 @@ def main( if __name__ == "__main__": tvm.testing.main() + # test_conv2d_relu_sub_indexed() diff --git a/tests/scripts/task_build_adreno_bins.sh b/tests/scripts/task_build_adreno_bins.sh index e5775c10ec34..91886281806d 100755 --- a/tests/scripts/task_build_adreno_bins.sh +++ b/tests/scripts/task_build_adreno_bins.sh @@ -39,7 +39,7 @@ echo set\(USE_OPENCL ON\) >> config.cmake fi echo set\(USE_RPC ON\) >> config.cmake echo set\(USE_CPP_RPC ON\) >> config.cmake -echo set\(USE_CPP_RTVM ON\) >> config.cmake +#echo set\(USE_CPP_RTVM ON\) >> config.cmake echo set\(USE_LIBBACKTRACE AUTO\) >> config.cmake echo set\(USE_KALLOC_ALIGNMENT 32\) >> config.cmake @@ -62,4 +62,4 @@ cmake -DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK_HOME}/build/cmake/android.toolchain. -DCMAKE_C_COMPILER="${ANDROID_NDK_HOME}/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android28-clang" \ -DMACHINE_NAME="aarch64-linux-gnu" .. -make -j$(nproc) tvm_rpc rtvm opencl-cpptest +make -j$(nproc) tvm_rpc opencl-cpptest From 007654a570752c33057bfce03c35e4c5cf871fca Mon Sep 17 00:00:00 2001 From: Siva Date: Wed, 26 Feb 2025 14:41:32 +0530 Subject: [PATCH 10/31] docs --- include/tvm/relax/transform.h | 13 ++++++++++-- python/tvm/relax/transform/__init__.py | 5 ++--- python/tvm/relax/transform/transform.py | 21 ++++++++++++++++--- .../transform/annotate_custom_storage.cc | 8 +++---- src/relax/transform/fuse_tir.cc | 6 ++---- .../test_transform_annotate_custom_scope.py | 2 +- 6 files changed, 38 insertions(+), 17 deletions(-) diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index 9c412245e223..9a7afef5b325 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -698,10 +698,19 @@ TVM_DLL Pass AnnotateCustomMemoryScope(Target target); */ TVM_DLL Pass SpecializePrimFuncBasedOnCallSite(); - +/*! + *\brief This pass removes redundant assignment statements. These stmts are result of other pass + * like hint_on_device processed by RealizeVDevice may leave them. The subsequent pass like + * fuse_ops fail to fuse in this case. + */ TVM_DLL Pass RemoveRedundantAssignments(); -TVM_DLL Pass RemoveToDeviceForScopeChange(); +/* + * \brief This is a texture specific pass that can optimize unnecessary to_device copies. + * Like texture_scope -> ToVDevice -> global scope. In this case the producer can directly + * store into global scope avoiding unnecessary device copy. + */ +TVM_DLL Pass OptimizeToDeviceForScopeChange(); } // namespace transform } // namespace relax diff --git a/python/tvm/relax/transform/__init__.py b/python/tvm/relax/transform/__init__.py index 8c4e4acebdab..248a0c426533 100644 --- a/python/tvm/relax/transform/__init__.py +++ b/python/tvm/relax/transform/__init__.py @@ -85,7 +85,8 @@ VMShapeLower, AnnotateCustomMemoryScope, SpecializePrimFuncBasedOnCallSite, - RemoveToDeviceForScopeChange, + OptimizeToDeviceForScopeChange, + RemoveRedundantAssignments, dataflowblock_pass, function_pass, ) @@ -100,7 +101,5 @@ from .fold_batch_norm_to_conv2d_for_inference import FoldBatchnormToConv2D from .remove_redundant_reshape import RemoveRedundantReshape -# from .remove_to_device_for_scope_change import RemoveToDeviceForScopeChange - # Import to register the legalization functions. from . import legalize_ops diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 339112714edb..d89544746ca9 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -1637,15 +1637,30 @@ def SpecializePrimFuncBasedOnCallSite() -> tvm.ir.transform.Pass: return _ffi_api.SpecializePrimFuncBasedOnCallSite() # type: ignore -def RemoveToDeviceForScopeChange() -> tvm.ir.transform.Pass: - """This pass +def OptimizeToDeviceForScopeChange() -> tvm.ir.transform.Pass: + """This pass is a texture specific pass that can optimize unnecessary to_device copies. + Like texture_scope -> ToVDevice -> global scope. In this case the producer can directly + store into global scope avoiding unnecessary device copy. Returns ------- ret: tvm.ir.transform.Pass The registered pass for allocating workspace. """ - return _ffi_api.RemoveToDeviceForScopeChange() # type: ignore + return _ffi_api.OptimizeToDeviceForScopeChange() # type: ignore + + +def RemoveRedundantAssignments() -> tvm.ir.transform.Pass: + """ This pass removes redundant assignment statements. These stmts are result of other pass + like hint_on_device processed by RealizeVDevice may leave them. The subsequent pass like + fuse_ops fail to fuse in this case + + Returns + ------- + ret: tvm.ir.transform.Pass + The registered pass for allocating workspace. + """ + return _ffi_api.RemoveRedundantAssignments() # type: ignore def _wrap_class_function_pass(pass_cls, pass_info): diff --git a/src/relax/transform/annotate_custom_storage.cc b/src/relax/transform/annotate_custom_storage.cc index 5f51697e83b3..40ddc42a271e 100644 --- a/src/relax/transform/annotate_custom_storage.cc +++ b/src/relax/transform/annotate_custom_storage.cc @@ -619,7 +619,7 @@ class DefineVDevice : ExprMutator { namespace transform { -Pass RemoveToDeviceForScopeChange() { +Pass OptimizeToDeviceForScopeChange() { auto pass_func = [=](Function func, IRModule mod, PassContext pc) { /* here Target doesn't matter as the scope_info we use only to find multiple consumers */ auto info = CollectConsumerScopeInfo().Collect(mod, Downcast(func), Target("opencl")); @@ -627,10 +627,10 @@ Pass RemoveToDeviceForScopeChange() { auto [pattern, rewriter] = CreatePatterns(scope_info); return RewriteCall(pattern, rewriter, func); }; - return CreateFunctionPass(pass_func, 1, "RemoveToDeviceForScopeChange", {}); + return CreateFunctionPass(pass_func, 1, "OptimizeToDeviceForScopeChange", {}); } -TVM_REGISTER_GLOBAL("relax.transform.RemoveToDeviceForScopeChange") - .set_body_typed(RemoveToDeviceForScopeChange); +TVM_REGISTER_GLOBAL("relax.transform.OptimizeToDeviceForScopeChange") + .set_body_typed(OptimizeToDeviceForScopeChange); Pass RemoveRedundantAssignments() { runtime::TypedPackedFunc pass_func = diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index 92b667047271..ba4515faf390 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -405,7 +405,6 @@ class RelaxToTIRVarMapCollector : public ExprVisitor { private: void VisitBinding_(const VarBindingNode* binding) final { - LOG(WARNING) << "Var binding:" << binding->var; current_var_ = binding->var; ExprVisitor::VisitBinding_(binding); } @@ -421,7 +420,6 @@ class RelaxToTIRVarMapCollector : public ExprVisitor { } void CollectVarMapping(const CallNode* call, const Expr& lhs_var, bool in_place) { - LOG(WARNING) << "CollectVarMapping:" << call->args[0]; GlobalVar gv = Downcast(call->args[0]); tir::PrimFunc prim_func_ = Downcast(mod_->Lookup(gv)); const auto& buffer_map = prim_func_->buffer_map; @@ -456,8 +454,8 @@ class RelaxToTIRVarMapCollector : public ExprVisitor { auto ValidateBufferCompatibility = [this](tir::Buffer new_buf, Expr expr) { if (auto it = relax_to_tir_var_map_.find(expr); it != relax_to_tir_var_map_.end()) { ICHECK(StructuralEqual()((*it).second, new_buf)) - << "Inconsistent buffers " << (*it).second << " and " << new_buf.scope() - << " mapped to the same relax var: " << (*it).second.scope(); + << "Inconsistent buffers " << (*it).second << " and " << new_buf + << " mapped to the same relax var: " << expr; } }; for (size_t i = 0; i < tir_args.size(); ++i) { diff --git a/tests/python/relax/test_transform_annotate_custom_scope.py b/tests/python/relax/test_transform_annotate_custom_scope.py index 576d020db060..d353ba1655f4 100644 --- a/tests/python/relax/test_transform_annotate_custom_scope.py +++ b/tests/python/relax/test_transform_annotate_custom_scope.py @@ -94,7 +94,7 @@ def verify(mod, expected): mod = tvm.relax.transform.FuseTIR()(mod) mod = tvm.relax.transform.DeadCodeElimination()(mod) mod = tvm.relax.transform.AnnotateTIROpPattern()(mod) - mod = tvm.relax.transform.RemoveToDeviceForScopeChange()(mod) + mod = tvm.relax.transform.OptimizeToDeviceForScopeChange()(mod) mod = tvm.relax.transform.DeadCodeElimination()(mod) mod = tvm.relax.transform.SpecializePrimFuncBasedOnCallSite()(mod) mod = tvm.relax.transform.Normalize()(mod) From 15adfcde8aa726ee784ad354cf48451580c44fd2 Mon Sep 17 00:00:00 2001 From: Siva Date: Wed, 26 Feb 2025 15:40:51 +0530 Subject: [PATCH 11/31] Organize passes --- include/tvm/relax/transform.h | 4 +- python/tvm/relax/transform/__init__.py | 2 +- python/tvm/relax/transform/transform.py | 4 +- .../transform/annotate_custom_storage.cc | 148 ------------------ .../test_transform_annotate_custom_scope.py | 2 +- 5 files changed, 6 insertions(+), 154 deletions(-) diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index 9a7afef5b325..d6fe526ea528 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -699,7 +699,7 @@ TVM_DLL Pass AnnotateCustomMemoryScope(Target target); TVM_DLL Pass SpecializePrimFuncBasedOnCallSite(); /*! - *\brief This pass removes redundant assignment statements. These stmts are result of other pass + * \brief This pass removes redundant assignment statements. These stmts are result of other pass * like hint_on_device processed by RealizeVDevice may leave them. The subsequent pass like * fuse_ops fail to fuse in this case. */ @@ -710,7 +710,7 @@ TVM_DLL Pass RemoveRedundantAssignments(); * Like texture_scope -> ToVDevice -> global scope. In this case the producer can directly * store into global scope avoiding unnecessary device copy. */ -TVM_DLL Pass OptimizeToDeviceForScopeChange(); +TVM_DLL Pass OptimizeToVDeviceForScopeChange(); } // namespace transform } // namespace relax diff --git a/python/tvm/relax/transform/__init__.py b/python/tvm/relax/transform/__init__.py index 248a0c426533..f256dda4b176 100644 --- a/python/tvm/relax/transform/__init__.py +++ b/python/tvm/relax/transform/__init__.py @@ -85,7 +85,7 @@ VMShapeLower, AnnotateCustomMemoryScope, SpecializePrimFuncBasedOnCallSite, - OptimizeToDeviceForScopeChange, + OptimizeToVDeviceForScopeChange, RemoveRedundantAssignments, dataflowblock_pass, function_pass, diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index d89544746ca9..71718e862211 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -1637,7 +1637,7 @@ def SpecializePrimFuncBasedOnCallSite() -> tvm.ir.transform.Pass: return _ffi_api.SpecializePrimFuncBasedOnCallSite() # type: ignore -def OptimizeToDeviceForScopeChange() -> tvm.ir.transform.Pass: +def OptimizeToVDeviceForScopeChange() -> tvm.ir.transform.Pass: """This pass is a texture specific pass that can optimize unnecessary to_device copies. Like texture_scope -> ToVDevice -> global scope. In this case the producer can directly store into global scope avoiding unnecessary device copy. @@ -1647,7 +1647,7 @@ def OptimizeToDeviceForScopeChange() -> tvm.ir.transform.Pass: ret: tvm.ir.transform.Pass The registered pass for allocating workspace. """ - return _ffi_api.OptimizeToDeviceForScopeChange() # type: ignore + return _ffi_api.OptimizeToVDeviceForScopeChange() # type: ignore def RemoveRedundantAssignments() -> tvm.ir.transform.Pass: diff --git a/src/relax/transform/annotate_custom_storage.cc b/src/relax/transform/annotate_custom_storage.cc index 40ddc42a271e..2d6d5777daea 100644 --- a/src/relax/transform/annotate_custom_storage.cc +++ b/src/relax/transform/annotate_custom_storage.cc @@ -47,130 +47,6 @@ static Array GetShapeFromTensorStructInfo(const TensorStructInfo& tens return shape.value(); } -namespace { -std::tuple)>> CreatePatterns( - Map>> scope_info) { - auto pat_gv = WildcardPattern(); - - auto pat_inp = WildcardPattern(); - auto pat_call_tir = IsOp("relax.call_tir")(pat_gv, pat_inp); - auto pattern_out = IsOp("relax.to_vdevice")(pat_call_tir); - - auto rewriter = [=](Expr expr, Map matches) -> Expr { - const auto* call_tir = matches[pat_call_tir].as(); - ICHECK(call_tir) << "InternalError: " - << "Match of relax.call_tir operator should produce Call, " - << "but instead produces " << matches[pat_call_tir] << " with type " - << matches[pat_call_tir]->GetTypeKey(); - - const auto* out = matches[pattern_out].as(); - ICHECK(out) << "InternalError: " - << "Match of relax.to_vdevice operator should produce Call, " - << "but instead produces " << matches[pattern_out] << " with type " - << matches[pattern_out]->GetTypeKey(); - - const auto* vdev_attrs = out->attrs.as(); - ICHECK(vdev_attrs) << "InternalError: " - << "Attributes for relax.to_vdevice operator should be ToVDeviceAttrs, " - << "but were instead " << out->attrs << " with type " << out->GetTypeKey(); - - const auto* tir_out_sinfo = call_tir->sinfo_args[0].as(); - if (!tir_out_sinfo) return expr; - - if (!tir_out_sinfo->vdevice.defined()) return expr; - - const VarNode* arg_var = out->args[0].as(); - if (scope_info.find(GetRef(arg_var)) != scope_info.end()) { - if (scope_info[GetRef(arg_var)].size() > 1) { - /* Don't do to_device optimization as we are not the only consumer */ - return expr; - } - } - - if ((std::string(tir_out_sinfo->vdevice.value()->memory_scope).find("texture") != - std::string::npos) && - (vdev_attrs->dst_vdevice->memory_scope == "global")) { - auto shape_arr = tir_out_sinfo->GetShape().value(); - auto new_sinfo = - TensorStructInfo(ShapeExpr(shape_arr), tir_out_sinfo->dtype, vdev_attrs->dst_vdevice); - - return Call(call_tir->op, call_tir->args, call_tir->attrs, {new_sinfo}); - } - return expr; - }; - - return {pattern_out, rewriter}; -} - -} // namespace - -class RemoveRedundantAssignments : public ExprMutator { - public: - using ExprMutator::VisitExpr_; - - IRModule Run(IRModule& mod) { - mod_ = mod; - for (const auto& [gv, func] : mod_->functions) { - if (func->IsInstance()) { - const auto& base_func = mod_->Lookup(gv); - // Only non primitive relax functions - if (base_func->HasNonzeroAttr(attr::kPrimitive)) { - continue; - } - relax::Function update_func = Downcast(VisitExpr(func)); - updates_->Add(gv, update_func); - } - } - mod_.CopyOnWrite()->Update(updates_); - return mod_; - } - - void VisitBinding_(const VarBindingNode* binding, const VarNode* var) final { - redundant_map.Set(GetRef(binding->var.get()), GetRef(var)); - } - - void VisitBinding_(const VarBindingNode* binding, const DataflowVarNode* val) override { - redundant_map.Set(GetRef(binding->var.get()), GetRef(val)); - } - - Expr VisitExpr_(const CallNode* call_node) final { - auto call = Downcast(ExprMutator::VisitExpr_(call_node)); - static const Op& call_tir_op = Op::Get("relax.call_tir"); - - Tuple args; - - if (call->op == call_tir_op) { - args = Downcast(call->args[1]); - } else { - args = Tuple(call->args); - } - Array new_args; - - for (auto& arg : args->fields) { - if (redundant_map.find(arg) != redundant_map.end()) { - new_args.push_back(redundant_map[arg]); - } else { - new_args.push_back(arg); - } - } - if (call->op == call_tir_op) { - return Call(call_tir_op, {call->args[0], Tuple(new_args)}, call->attrs, - {call->sinfo_args[0]}); - } else { - if (call->sinfo_args.size() > 0) { - return Call(call->op, new_args, call->attrs, {call->sinfo_args[0]}); - } else { - return Call(call->op, new_args, call->attrs); - } - } - } - - private: - Map redundant_map; - IRModule updates_; - IRModule mod_; -}; - class CollectProduserScopeInfo : public ExprVisitor { public: using ExprVisitor::VisitExpr_; @@ -619,30 +495,6 @@ class DefineVDevice : ExprMutator { namespace transform { -Pass OptimizeToDeviceForScopeChange() { - auto pass_func = [=](Function func, IRModule mod, PassContext pc) { - /* here Target doesn't matter as the scope_info we use only to find multiple consumers */ - auto info = CollectConsumerScopeInfo().Collect(mod, Downcast(func), Target("opencl")); - auto scope_info = info.second; - auto [pattern, rewriter] = CreatePatterns(scope_info); - return RewriteCall(pattern, rewriter, func); - }; - return CreateFunctionPass(pass_func, 1, "OptimizeToDeviceForScopeChange", {}); -} -TVM_REGISTER_GLOBAL("relax.transform.OptimizeToDeviceForScopeChange") - .set_body_typed(OptimizeToDeviceForScopeChange); - -Pass RemoveRedundantAssignments() { - runtime::TypedPackedFunc pass_func = - [=](IRModule mod, PassContext pc) { return relax::RemoveRedundantAssignments().Run(mod); }; - return CreateModulePass(/*pass_function=*/pass_func, - /*opt_level=*/0, - /*pass_name=*/"RemoveRedundantAssignments", - /*required=*/{}); -} -TVM_REGISTER_GLOBAL("relax.transform.RemoveRedundantAssignments") - .set_body_typed(RemoveRedundantAssignments); - Pass AnnotateCustomMemoryScope(Target target) { runtime::TypedPackedFunc pass_func = [=](IRModule mod, PassContext pc) { return relax::DefineVDevice(target).Run(mod); }; diff --git a/tests/python/relax/test_transform_annotate_custom_scope.py b/tests/python/relax/test_transform_annotate_custom_scope.py index d353ba1655f4..62e4b7c979f2 100644 --- a/tests/python/relax/test_transform_annotate_custom_scope.py +++ b/tests/python/relax/test_transform_annotate_custom_scope.py @@ -94,7 +94,7 @@ def verify(mod, expected): mod = tvm.relax.transform.FuseTIR()(mod) mod = tvm.relax.transform.DeadCodeElimination()(mod) mod = tvm.relax.transform.AnnotateTIROpPattern()(mod) - mod = tvm.relax.transform.OptimizeToDeviceForScopeChange()(mod) + mod = tvm.relax.transform.OptimizeToVDeviceForScopeChange()(mod) mod = tvm.relax.transform.DeadCodeElimination()(mod) mod = tvm.relax.transform.SpecializePrimFuncBasedOnCallSite()(mod) mod = tvm.relax.transform.Normalize()(mod) From 6634b80269c5821720429533c13ec19e9fca38e9 Mon Sep 17 00:00:00 2001 From: Siva Date: Wed, 26 Feb 2025 15:44:52 +0530 Subject: [PATCH 12/31] tests --- tests/python/relax/test_transform_annotate_custom_scope.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/python/relax/test_transform_annotate_custom_scope.py b/tests/python/relax/test_transform_annotate_custom_scope.py index 62e4b7c979f2..9b9b3f6926d2 100644 --- a/tests/python/relax/test_transform_annotate_custom_scope.py +++ b/tests/python/relax/test_transform_annotate_custom_scope.py @@ -84,16 +84,17 @@ def verify(mod, expected): mod = tvm.relax.transform.LegalizeOps(skip_ops=skip_ops)(mod) mod = tvm.relax.transform.AnnotateTIROpPattern()(mod) mod = tvm.relax.transform.AnnotateCustomMemoryScope(tgt)(mod) - mod = tvm.relax.transform.LegalizeOps()(mod) # To handle any fallback ops + # There is apossibility of some skipped ops above might not use 5D layouts. + mod = tvm.relax.transform.LegalizeOps()(mod) mod = tvm.relax.transform.LegalizeOps( {"relax.nn.conv2d": legalize_adreno.conv2d_NCHWc_OIHWo}, )(mod) + # Lets get pattern info for newly legalized ops mod = tvm.relax.transform.AnnotateTIROpPattern()(mod) mod = tvm.relax.transform.FoldConstant()(mod) mod = tvm.relax.transform.FuseOps()(mod) mod = tvm.relax.transform.FuseTIR()(mod) mod = tvm.relax.transform.DeadCodeElimination()(mod) - mod = tvm.relax.transform.AnnotateTIROpPattern()(mod) mod = tvm.relax.transform.OptimizeToVDeviceForScopeChange()(mod) mod = tvm.relax.transform.DeadCodeElimination()(mod) mod = tvm.relax.transform.SpecializePrimFuncBasedOnCallSite()(mod) From e6c0275d06159c854c9b291642e7b838171734cf Mon Sep 17 00:00:00 2001 From: Siva Date: Wed, 26 Feb 2025 16:00:08 +0530 Subject: [PATCH 13/31] new passes --- .../optimize_to_vdevice_for_scope_change.cc | 185 ++++++++++++++++++ .../transform/remove_redundant_assignments.cc | 124 ++++++++++++ 2 files changed, 309 insertions(+) create mode 100644 src/relax/transform/optimize_to_vdevice_for_scope_change.cc create mode 100644 src/relax/transform/remove_redundant_assignments.cc diff --git a/src/relax/transform/optimize_to_vdevice_for_scope_change.cc b/src/relax/transform/optimize_to_vdevice_for_scope_change.cc new file mode 100644 index 000000000000..f2b4b1fd371d --- /dev/null +++ b/src/relax/transform/optimize_to_vdevice_for_scope_change.cc @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/transform/optimize_to_vdevice_for_scope_change.cc + * \brief This is a texture specific pass that can optimize unnecessary to_device copies. + * Like texture_scope -> ToVDevice -> global scope. In this case the producer can directly + * store into global scope avoiding unnecessary device copy. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "../op/tensor/manipulate.h" +#include "infer_layout_utils.h" +#include "utils.h" + +namespace tvm { +namespace relax { + +namespace { +std::tuple)>> CreatePatterns( + Map> consumers) { + auto pat_gv = WildcardPattern(); + + auto pat_inp = WildcardPattern(); + auto pat_call_tir = IsOp("relax.call_tir")(pat_gv, pat_inp); + auto pattern_out = IsOp("relax.to_vdevice")(pat_call_tir); + + auto rewriter = [=](Expr expr, Map matches) -> Expr { + const auto* call_tir = matches[pat_call_tir].as(); + ICHECK(call_tir) << "InternalError: " + << "Match of relax.call_tir operator should produce Call, " + << "but instead produces " << matches[pat_call_tir] << " with type " + << matches[pat_call_tir]->GetTypeKey(); + + const auto* out = matches[pattern_out].as(); + ICHECK(out) << "InternalError: " + << "Match of relax.to_vdevice operator should produce Call, " + << "but instead produces " << matches[pattern_out] << " with type " + << matches[pattern_out]->GetTypeKey(); + + const auto* vdev_attrs = out->attrs.as(); + ICHECK(vdev_attrs) << "InternalError: " + << "Attributes for relax.to_vdevice operator should be ToVDeviceAttrs, " + << "but were instead " << out->attrs << " with type " << out->GetTypeKey(); + + const auto* tir_out_sinfo = call_tir->sinfo_args[0].as(); + if (!tir_out_sinfo) return expr; + + if (!tir_out_sinfo->vdevice.defined()) return expr; + + const VarNode* arg_var = out->args[0].as(); + if (consumers.find(GetRef(arg_var)) != consumers.end()) { + if (consumers[GetRef(arg_var)].size() > 1) { + /* Don't do to_device optimization as we are not the only consumer */ + return expr; + } + } + + if ((std::string(tir_out_sinfo->vdevice.value()->memory_scope).find("texture") != + std::string::npos) && + (vdev_attrs->dst_vdevice->memory_scope == "global")) { + auto shape_arr = tir_out_sinfo->GetShape().value(); + auto new_sinfo = + TensorStructInfo(ShapeExpr(shape_arr), tir_out_sinfo->dtype, vdev_attrs->dst_vdevice); + + return Call(call_tir->op, call_tir->args, call_tir->attrs, {new_sinfo}); + } + return expr; + }; + + return {pattern_out, rewriter}; +} + +} // namespace + +class CollectConsumerDetails : public ExprVisitor { + public: + using ExprVisitor::VisitExpr_; + + Map> Collect( + const IRModule& mod, Function func, const Target& target) { + mod_ = mod; + target_ = target; + VisitExpr(func->body); + // Extend the consumer details for tuple items + for (const auto& val : arg_to_binding) { + if (consumers.find(val.first) != consumers.end()) { + if (consumers.find(val.second) == consumers.end()) { + consumers.Set(val.second, consumers[val.first]); + } else { + auto ent = consumers[val.second]; + for (auto ent_val : consumers[val.first]) { + ent.push_back(ent_val); + } + consumers.Set(val.second, ent); + } + } + } + return consumers; + } + + void VisitBinding_(const VarBindingNode* binding, + const TupleGetItemNode* tuple_get_item_node) final { + if (arg_to_binding.find(GetRef(binding->var.get())) == arg_to_binding.end()) { + arg_to_binding.Set(GetRef(binding->var.get()), + GetRef(tuple_get_item_node->tuple.get())); + } + } + + void VisitExpr_(const CallNode* call) final { + static const Op& call_tir_op = Op::Get("relax.call_tir"); + Tuple func_args; + + if (call->op == call_tir_op) { + func_args = Downcast(call->args[1]); + } else { + func_args = Tuple(call->args); + } + + for (auto arg : func_args->fields) { + auto sinfo = GetStructInfo(arg); + if (auto tensor_sinfo = sinfo.as()) { + Array call_list; + + const VarNode* arg_var = arg.as(); + + if (consumers.find(GetRef(arg_var)) != consumers.end()) { + call_list = consumers[GetRef(arg_var)]; + } + call_list.push_back(GetRef(call)); + consumers.Set(GetRef(arg_var), call_list); + } + } + } + + private: + + /* Map of each Var consumption by a call node */ + Map> consumers; + Map arg_to_binding; + IRModule mod_; + Target target_; +}; + +namespace transform { + +Pass OptimizeToVDeviceForScopeChange() { + auto pass_func = [=](Function func, IRModule mod, PassContext pc) { + /* here Target doesn't matter as the consumers we use only to find multiple consumers */ + auto consumers = CollectConsumerDetails().Collect(mod, Downcast(func), Target("opencl")); + auto [pattern, rewriter] = CreatePatterns(consumers); + return RewriteCall(pattern, rewriter, func); + }; + return CreateFunctionPass(pass_func, 1, "OptimizeToVDeviceForScopeChange", {}); +} +TVM_REGISTER_GLOBAL("relax.transform.OptimizeToVDeviceForScopeChange") + .set_body_typed(OptimizeToVDeviceForScopeChange); +} // namespace transform +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/remove_redundant_assignments.cc b/src/relax/transform/remove_redundant_assignments.cc new file mode 100644 index 000000000000..a33feba668e6 --- /dev/null +++ b/src/relax/transform/remove_redundant_assignments.cc @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/transform/remove_redundant_assignments.cc + * \brief This pass removes redundant assignment statements. These stmts are result of other pass + * like hint_on_device processed by RealizeVDevice may leave them. The subsequent pass like + * fuse_ops fail to fuse in this case. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "../op/tensor/manipulate.h" +#include "infer_layout_utils.h" +#include "utils.h" + +namespace tvm { +namespace relax { + +class RemoveRedundantAssignments : public ExprMutator { + public: + using ExprMutator::VisitExpr_; + + IRModule Run(IRModule& mod) { + mod_ = mod; + for (const auto& [gv, func] : mod_->functions) { + if (func->IsInstance()) { + const auto& base_func = mod_->Lookup(gv); + // Only non primitive relax functions + if (base_func->HasNonzeroAttr(attr::kPrimitive)) { + continue; + } + relax::Function update_func = Downcast(VisitExpr(func)); + updates_->Add(gv, update_func); + } + } + mod_.CopyOnWrite()->Update(updates_); + return mod_; + } + + void VisitBinding_(const VarBindingNode* binding, const VarNode* var) final { + redundant_map.Set(GetRef(binding->var.get()), GetRef(var)); + } + + void VisitBinding_(const VarBindingNode* binding, const DataflowVarNode* val) override { + redundant_map.Set(GetRef(binding->var.get()), GetRef(val)); + } + + Expr VisitExpr_(const CallNode* call_node) final { + auto call = Downcast(ExprMutator::VisitExpr_(call_node)); + static const Op& call_tir_op = Op::Get("relax.call_tir"); + + Tuple args; + + if (call->op == call_tir_op) { + args = Downcast(call->args[1]); + } else { + args = Tuple(call->args); + } + Array new_args; + + for (auto& arg : args->fields) { + if (redundant_map.find(arg) != redundant_map.end()) { + new_args.push_back(redundant_map[arg]); + } else { + new_args.push_back(arg); + } + } + if (call->op == call_tir_op) { + return Call(call_tir_op, {call->args[0], Tuple(new_args)}, call->attrs, + {call->sinfo_args[0]}); + } else { + if (call->sinfo_args.size() > 0) { + return Call(call->op, new_args, call->attrs, {call->sinfo_args[0]}); + } else { + return Call(call->op, new_args, call->attrs); + } + } + } + + private: + Map redundant_map; + IRModule updates_; + IRModule mod_; +}; + +namespace transform { + +Pass RemoveRedundantAssignments() { + runtime::TypedPackedFunc pass_func = + [=](IRModule mod, PassContext pc) { return relax::RemoveRedundantAssignments().Run(mod); }; + return CreateModulePass(/*pass_function=*/pass_func, + /*opt_level=*/0, + /*pass_name=*/"RemoveRedundantAssignments", + /*required=*/{}); +} +TVM_REGISTER_GLOBAL("relax.transform.RemoveRedundantAssignments") + .set_body_typed(RemoveRedundantAssignments); +} // namespace transform +} // namespace relax +} // namespace tvm From 44eea5eeb388d52340c841cc51b9689c2d7269e4 Mon Sep 17 00:00:00 2001 From: Siva Date: Wed, 26 Feb 2025 16:10:27 +0530 Subject: [PATCH 14/31] Lint --- python/tvm/relax/transform/transform.py | 2 +- .../transform/annotate_custom_storage.cc | 223 ++++++++++-------- .../test_transform_annotate_custom_scope.py | 2 +- 3 files changed, 124 insertions(+), 103 deletions(-) diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 71718e862211..a07501f567ec 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -1651,7 +1651,7 @@ def OptimizeToVDeviceForScopeChange() -> tvm.ir.transform.Pass: def RemoveRedundantAssignments() -> tvm.ir.transform.Pass: - """ This pass removes redundant assignment statements. These stmts are result of other pass + """This pass removes redundant assignment statements. These stmts are result of other pass like hint_on_device processed by RealizeVDevice may leave them. The subsequent pass like fuse_ops fail to fuse in this case diff --git a/src/relax/transform/annotate_custom_storage.cc b/src/relax/transform/annotate_custom_storage.cc index 2d6d5777daea..3a36962f91e0 100644 --- a/src/relax/transform/annotate_custom_storage.cc +++ b/src/relax/transform/annotate_custom_storage.cc @@ -18,7 +18,9 @@ */ /*! * \file src/relax/transform/annotate_texture_storage.cc - * \brief Texture Storage Annotation Pass. + * \brief Texture Storage Annotation Pass for Adreno GPU targets. + * + * Texture scope annotation and realization for Adreno GPU targets goes by */ #include @@ -47,105 +49,12 @@ static Array GetShapeFromTensorStructInfo(const TensorStructInfo& tens return shape.value(); } -class CollectProduserScopeInfo : public ExprVisitor { - public: - using ExprVisitor::VisitExpr_; - - Map Collect(const IRModule& mod, Function func, - const Map>>& scope_info, - const Target& target, const BlockBuilder& builder) { - mod_ = mod; - scope_info_ = scope_info; - target_ = target; - builder_ = builder; - VisitExpr(func->body); - - return producer_sinfo; - } - - void VisitBinding_(const VarBindingNode* binding, const CallNode* call) final { - ExprVisitor::VisitBinding_(binding, call); - - static const Op& call_tir_op = Op::Get("relax.call_tir"); - StructInfo out_sinfo; - - if (call->op == call_tir_op) { - out_sinfo = call->sinfo_args[0]; - } else { - tvm::OpAttrMap op_map_infer_struct_info_ = - Op::GetAttrMap("FInferStructInfo"); - - auto* op_ptr = call->op.as(); - Op op = GetRef(op_ptr); - ICHECK(op_map_infer_struct_info_.count(op)) - << " Cannot find the FInferStructInfo attribute registered to op: " << op->name; - out_sinfo = op_map_infer_struct_info_[op](GetRef(call), builder_); - } - - std::unordered_map scope_count; - - // Decide the final scope based on the max consumer demand. Rest will use to_device. - auto arg_var = binding->var.as(); - if (scope_info_.find(GetRef(arg_var)) != scope_info_.end()) { - for (const auto& val : scope_info_[GetRef(arg_var)]) { - auto call_node = Downcast(val.first); - if (scope_count.find(val.second[0]) == scope_count.end()) { - scope_count.insert({val.second[0], 1}); - } else { - auto curr_count = scope_count[val.second[0]]; - scope_count.emplace(val.second[0], curr_count + 1); - } - } - } - String final_scope = "global"; - int count = 0; - for (const auto& sval : scope_count) { - if (sval.second > count) { - final_scope = sval.first; - count = sval.second; - } - } - // Applying same scope for outputs - StructInfo updated_ret_sinfo = UpdateStructInfo(out_sinfo, {final_scope}); - producer_sinfo.Set(GetRef(call), updated_ret_sinfo); - } - - private: - StructInfo UpdateStructInfo(const StructInfo& out_sinfo, Array scope) { - if (out_sinfo->IsInstance()) { - auto tensor_sinfo = Downcast(out_sinfo); - auto shape_arr = GetShapeFromTensorStructInfo(tensor_sinfo); - return TensorStructInfo(ShapeExpr(shape_arr), tensor_sinfo->dtype, - VDevice(target_, 0, scope[0])); - } - - ICHECK(out_sinfo->IsInstance()) - << "Expect output struct info of call_tir to be either TupleStructInfo or " - "TensorStructInfo, but got " - << out_sinfo; - - const auto& tuple_sinfo = Downcast(out_sinfo); - Array sinfo_fields; - for (const auto& si : tuple_sinfo->fields) { - ICHECK(si->IsInstance()) - << "Fields of TupleStructInfo must be TensorStructInfo for call_tir " - "output structinfo, but got " - << si; - auto sinfo = Downcast(si); - auto shape_arr = GetShapeFromTensorStructInfo(sinfo); - sinfo_fields.push_back( - TensorStructInfo(ShapeExpr(shape_arr), sinfo->dtype, VDevice(target_, 0, scope[0]))); - } - return TupleStructInfo(sinfo_fields); - } - - Map>> scope_info_; - Map producer_sinfo; - IRModule mod_; - Target target_; - BlockBuilder builder_; -}; - +/* + * \brief generates consumer information for each var + * \return scope_info is a map which contain for each var the corresponding call nodes that + * consume it and corresponding scope it expects this input to be. + * \return call_scope_info is a map of each call_node and array holding scope infor for each input. + */ class CollectConsumerScopeInfo : public ExprVisitor { public: using ExprVisitor::VisitExpr_; @@ -305,13 +214,125 @@ class CollectConsumerScopeInfo : public ExprVisitor { /* Map of each Var consumption by a call node and its scope */ Map>> scope_info; - /* A map of call node and cope infor for each argument it consunes */ + /* A map of call node and scope info for each argument it consunes */ Map> call_scope_info; Map arg_to_binding; IRModule mod_; Target target_; }; +/* + * \brief producer scope information consolidated based on consumer demands. + * \return producer_info which is a map of each call node and corresponding out StructInfo + * This pass considers all consumers and their scope demand. + * Any mismatches here introduces copies as needed. + */ +class CollectProduserScopeInfo : public ExprVisitor { + public: + using ExprVisitor::VisitExpr_; + + Map Collect(const IRModule& mod, Function func, + const Map>>& scope_info, + const Target& target, const BlockBuilder& builder) { + mod_ = mod; + scope_info_ = scope_info; + target_ = target; + builder_ = builder; + VisitExpr(func->body); + + return producer_sinfo; + } + + void VisitBinding_(const VarBindingNode* binding, const CallNode* call) final { + ExprVisitor::VisitBinding_(binding, call); + + static const Op& call_tir_op = Op::Get("relax.call_tir"); + StructInfo out_sinfo; + + if (call->op == call_tir_op) { + out_sinfo = call->sinfo_args[0]; + } else { + tvm::OpAttrMap op_map_infer_struct_info_ = + Op::GetAttrMap("FInferStructInfo"); + + auto* op_ptr = call->op.as(); + Op op = GetRef(op_ptr); + ICHECK(op_map_infer_struct_info_.count(op)) + << " Cannot find the FInferStructInfo attribute registered to op: " << op->name; + out_sinfo = op_map_infer_struct_info_[op](GetRef(call), builder_); + } + + std::unordered_map scope_count; + + // Decide the final scope based on the max consumer demand. Rest will use to_device. + auto arg_var = binding->var.as(); + if (scope_info_.find(GetRef(arg_var)) != scope_info_.end()) { + for (const auto& val : scope_info_[GetRef(arg_var)]) { + auto call_node = Downcast(val.first); + if (scope_count.find(val.second[0]) == scope_count.end()) { + scope_count.insert({val.second[0], 1}); + } else { + auto curr_count = scope_count[val.second[0]]; + scope_count.emplace(val.second[0], curr_count + 1); + } + } + } + String final_scope = "global"; + int count = 0; + for (const auto& sval : scope_count) { + if (sval.second > count) { + final_scope = sval.first; + count = sval.second; + } + } + // Applying same scope for outputs + StructInfo updated_ret_sinfo = UpdateStructInfo(out_sinfo, {final_scope}); + producer_sinfo.Set(GetRef(call), updated_ret_sinfo); + } + + private: + StructInfo UpdateStructInfo(const StructInfo& out_sinfo, Array scope) { + if (out_sinfo->IsInstance()) { + auto tensor_sinfo = Downcast(out_sinfo); + auto shape_arr = GetShapeFromTensorStructInfo(tensor_sinfo); + return TensorStructInfo(ShapeExpr(shape_arr), tensor_sinfo->dtype, + VDevice(target_, 0, scope[0])); + } + + ICHECK(out_sinfo->IsInstance()) + << "Expect output struct info of call_tir to be either TupleStructInfo or " + "TensorStructInfo, but got " + << out_sinfo; + + const auto& tuple_sinfo = Downcast(out_sinfo); + Array sinfo_fields; + for (const auto& si : tuple_sinfo->fields) { + ICHECK(si->IsInstance()) + << "Fields of TupleStructInfo must be TensorStructInfo for call_tir " + "output structinfo, but got " + << si; + auto sinfo = Downcast(si); + auto shape_arr = GetShapeFromTensorStructInfo(sinfo); + sinfo_fields.push_back( + TensorStructInfo(ShapeExpr(shape_arr), sinfo->dtype, VDevice(target_, 0, scope[0]))); + } + return TupleStructInfo(sinfo_fields); + } + + Map>> scope_info_; + Map producer_sinfo; + IRModule mod_; + Target target_; + BlockBuilder builder_; +}; + +/* + * \brief main pass that injects hint_on_device for each argument based on producer, + * consumer indormations. This also attributes ret StructInfo for each call node. + * This pass also calls the ReliaseVdevice that formalizes the hints by appropriately injecting + * Vdevice copies as needed. + */ + class DefineVDevice : ExprMutator { public: explicit DefineVDevice(const Target& target) : target_(target) {} diff --git a/tests/python/relax/test_transform_annotate_custom_scope.py b/tests/python/relax/test_transform_annotate_custom_scope.py index 9b9b3f6926d2..c1856cbb4fd3 100644 --- a/tests/python/relax/test_transform_annotate_custom_scope.py +++ b/tests/python/relax/test_transform_annotate_custom_scope.py @@ -84,7 +84,7 @@ def verify(mod, expected): mod = tvm.relax.transform.LegalizeOps(skip_ops=skip_ops)(mod) mod = tvm.relax.transform.AnnotateTIROpPattern()(mod) mod = tvm.relax.transform.AnnotateCustomMemoryScope(tgt)(mod) - # There is apossibility of some skipped ops above might not use 5D layouts. + # There is a possibility of some skipped ops above might not use 5D layouts. mod = tvm.relax.transform.LegalizeOps()(mod) mod = tvm.relax.transform.LegalizeOps( {"relax.nn.conv2d": legalize_adreno.conv2d_NCHWc_OIHWo}, From b9a228ac8bcd0ea0a9f09f949b963da173714290 Mon Sep 17 00:00:00 2001 From: Siva Date: Wed, 26 Feb 2025 19:39:28 +0530 Subject: [PATCH 15/31] Lint --- python/tvm/tir/analysis/analysis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/tir/analysis/analysis.py b/python/tvm/tir/analysis/analysis.py index 44b7e4b8b18d..cc23bb939588 100644 --- a/python/tvm/tir/analysis/analysis.py +++ b/python/tvm/tir/analysis/analysis.py @@ -19,7 +19,7 @@ from typing import Dict, List, Optional, Union import tvm -from tvm import Object, _ffi +from tvm import _ffi from tvm.ir import IRModule from tvm.tir.expr import Var from tvm.tir.stmt import Block, BufferRegion, PrimExpr From a518972814a6be628f22b4f6ae5d726ec5ba4265 Mon Sep 17 00:00:00 2001 From: Siva Date: Wed, 26 Feb 2025 21:02:01 +0530 Subject: [PATCH 16/31] Update only VDevice. --- src/relax/transform/legalize_ops.cc | 53 ++++++++++++++++++++++++++--- 1 file changed, 48 insertions(+), 5 deletions(-) diff --git a/src/relax/transform/legalize_ops.cc b/src/relax/transform/legalize_ops.cc index 05265a9aaa72..9083e74a5f7b 100644 --- a/src/relax/transform/legalize_ops.cc +++ b/src/relax/transform/legalize_ops.cc @@ -36,6 +36,12 @@ namespace relax { TVM_REGISTER_PASS_CONFIG_OPTION("relax.transform.apply_legalize_ops", Bool); +static Array GetShapeFromTensorStructInfo(const TensorStructInfo& tensor_sinfo) { + auto shape = tensor_sinfo->GetShape(); + ICHECK(shape.defined()); + return shape.value(); +} + /*! * \brief Check if a given Tensor/Shape/TupleStructInfo contains shapes whose * values are all known. @@ -156,7 +162,7 @@ class LegalizeMutator : public ExprMutator { return std::nullopt; } - Expr UpdateOutStructInfo(Expr expr, Call& visited_call) { + Expr UpdateVDeviceOutStructInfo(Expr expr, Call& visited_call) { static const auto& infer_struct_info_map = Op::GetAttrMap("FInferStructInfo"); static const Op& call_tir_op = Op::Get("relax.call_tir"); auto* op_node = visited_call->op.as(); @@ -180,8 +186,45 @@ class LegalizeMutator : public ExprMutator { return expr; } - StructInfo updated_ret_sinfo = infer_struct_info_map[op](visited_call, builder_); - return Call(call_tir_op, call->args, call->attrs, {updated_ret_sinfo}); + StructInfo out_sinfo = call->sinfo_args[0]; + StructInfo infered_sinfo = infer_struct_info_map[op](visited_call, builder_); + + if (out_sinfo->IsInstance()) { + auto out_tsinfo = Downcast(out_sinfo); + auto infered_tsinfo = Downcast(infered_sinfo); + auto shape_arr = GetShapeFromTensorStructInfo(out_tsinfo); + if (infered_tsinfo->vdevice.defined()) { + out_sinfo = TensorStructInfo(ShapeExpr(shape_arr), out_tsinfo->dtype, + infered_tsinfo->vdevice.value()); + } + } else if (out_sinfo->IsInstance()) { + const auto& tuple_sinfo = Downcast(out_sinfo); + const auto& infered_tuple_sinfo = Downcast(infered_sinfo); + Array sinfo_fields; + int index = 0; + for (const auto& si : tuple_sinfo->fields) { + ICHECK(si->IsInstance()) + << "Fields of TupleStructInfo must be TensorStructInfo for call_tir " + "output structinfo, but got " + << si; + auto tsinfo = Downcast(si); + auto shape_arr = GetShapeFromTensorStructInfo(tsinfo); + auto infered_tsinfo = Downcast(infered_tuple_sinfo->fields[index]); + if (infered_tsinfo->vdevice.defined()) { + sinfo_fields.push_back(TensorStructInfo(ShapeExpr(shape_arr), tsinfo->dtype, + infered_tsinfo->vdevice.value())); + } else { + sinfo_fields.push_back(tsinfo); + } + ++index; + } + out_sinfo = TupleStructInfo(sinfo_fields); + } + + if (out_sinfo->IsInstance()) { + LOG(WARNING) << "New Struct Info:" << Downcast(out_sinfo); + } + return Call(call_tir_op, call->args, call->attrs, {out_sinfo}); } Expr BindTarget(Expr expr) { @@ -270,7 +313,7 @@ class LegalizeMutator : public ExprMutator { auto op = ffi::GetRef(op_node); if (skip_ops_.defined()) { - for (const auto name: skip_ops_.value()) { + for (const auto name : skip_ops_.value()) { if (name == op->name) { return visited_call; } @@ -382,7 +425,7 @@ class LegalizeMutator : public ExprMutator { } Expr legalized = legalization_func(builder_, visited_call); - legalized = UpdateOutStructInfo(legalized, visited_call); + legalized = UpdateVDeviceOutStructInfo(legalized, visited_call); // Append the target attribute to any PrimFunc generated in // legalization. From 709b50817bf84f71e44d4c62af921747138d1e91 Mon Sep 17 00:00:00 2001 From: Siva Date: Wed, 26 Feb 2025 22:28:38 +0530 Subject: [PATCH 17/31] Docs. --- include/tvm/relax/transform.h | 2 +- .../transform/annotate_custom_storage.cc | 221 +++++++++++++++++- 2 files changed, 219 insertions(+), 4 deletions(-) diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index d6fe526ea528..a2ea717553aa 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -244,9 +244,9 @@ TVM_DLL Pass FoldConstant(); * * \param cmap The customized operator legalization function map. The customized function * will override the default one. + * \param skip_ops The list operator names which need to be skipped from legalization * \param enable_warning A boolean value indicating if to print warnings for TIR functions not * showing up in the database. - * \param add_attributes A boolean value indicating adding of call attributes to TIR functions * \return The Pass. */ TVM_DLL Pass LegalizeOps(ffi::Optional> cmap, ffi::Optional> skip_ops, diff --git a/src/relax/transform/annotate_custom_storage.cc b/src/relax/transform/annotate_custom_storage.cc index 3a36962f91e0..8f086a1f8904 100644 --- a/src/relax/transform/annotate_custom_storage.cc +++ b/src/relax/transform/annotate_custom_storage.cc @@ -20,7 +20,220 @@ * \file src/relax/transform/annotate_texture_storage.cc * \brief Texture Storage Annotation Pass for Adreno GPU targets. * - * Texture scope annotation and realization for Adreno GPU targets goes by + * Texture realization for Adreno GPU targets requires fundamentally follows + * Stage 1: Transforming the shapes with inner most dimension being 4 + * Stage 2: Annotate appropriate memory_scope hint in VDevice of StructInfo + * Stage 3: TIR lowering does injects texture load/store builtins looking at this scope + * Stage 4: Finally codegen handles appropriate code looking at buffer types and load/store + * builtins. + * + * Stage 1 is generic and straight forward by using convert_layout pass that transforms the + * shapes as well as injecting layout_transform ops as needed. + * + * Stage 2 This pass is responsible for injeting appropriate VDevice into StructInfo and + * adding any copies if there is a conflict between producer and consuner scopes. + * + * After convert_layout the mod looks like below + * @I.ir_module + * class Module: + * @R.function + * def main( + * x: R.Tensor((2, 64, 56, 56), dtype="float32"), + * w: R.Tensor((32, 64, 3, 3), dtype="float32") + * ) -> R.Tensor((2, 32, 54, 54), dtype="float32"): + * with R.dataflow(): + * lv: R.Tensor((2, 16, 56, 56, 4), dtype="float32") = R.layout_transform( + * x, + * index_map=T.index_map( + * lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4))) + * lv1: R.Tensor((8, 64, 3, 3, 4), dtype="float32") = R.layout_transform( + * w, + * index_map=T.index_map( + * lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4))) + * lv2: R.Tensor((2, 8, 54, 54, 4), dtype="float32") = R.nn.conv2d( + * lv, + * lv1, + * data_layout="NCHW4c", + * kernel_layout="OIHW4o", + * out_layout="NCHW4c", + * out_dtype="float32" + * ) + * gv: R.Tensor((2, 32, 54, 54), dtype="float32") = R.layout_transform( + * lv2, + * index_map=T.index_map( + * lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3))) + * R.output(gv) + * return gv + * + * Here, the param layout transforms are injected properly and the conv2d op is operating + * in 5D shapes. + * + * Now, the scope annotation decisions are done by + * - For op_pattern < kCommReduce we just look for shape being 5D and inner dimsion = 4 + * - For op_pattern > kCommReduce we make decisions selectively. Currently we do enable texture + * scope for Conv2D, PoolOps. + * The trick here is whiel this pass is in action we need op_pattern information for ops that are + * below kCommReduce as well op attrbuted for seletive ops like Conv2D and PoolOps. + * op_pattern is available after legalization and TIROpPattern pass does an analysis. However, + * op specific attributes doesn't exist after legalization. + * + * To solve this issue, we go legalization in parts. + * At first, we call legalization by skipping the list of ops we wanted not to legalize. + * LigalizeOps is enhanced to accept skip_ops for this purpose. + * After legalization and AnnotateTIROpPattern this way the mod liiks like + * + * class Module: + * @R.function + * def main( + * x: R.Tensor((2, 64, 56, 56), dtype="float32"), + * w: R.Tensor((32, 64, 3, 3), dtype="float32") + * ) -> R.Tensor((2, 32, 54, 54), dtype="float32"): + * with R.dataflow(): + * lv = R.call_tir(cls.te_layout_transform, (x,), + * out_sinfo=R.Tensor((2, 16, 56, 56, 4), dtype="float32") + * ) + * lv1 = R.call_tir(cls.te_layout_transform1, (w,), + * out_sinfo=R.Tensor((8, 64, 3, 3, 4), dtype="float32") + * ) + * lv2: R.Tensor((2, 8, 54, 54, 4), dtype="float32") = R.nn.conv2d( + * lv, + * lv1, + * data_layout="NCHW4c", + * kernel_layout="OIHW4o", + * out_layout="NCHW4c", + * out_dtype="float32" + * ) + * gv = R.call_tir(cls.te_layout_transform2, (lv2,), + * out_sinfo=R.Tensor((2, 32, 54, 54), dtype="float32") + * ) + * R.output(gv) + * return gv + * + * Here, the legalized prim functions does have op_pattern attribute. + * We now have what we wanted to run this pass. + * + * This pass in principle does scope annotation based on sonsumer priotiry. i.e. + * For any tensor object we tries to assign scope based on the sonsuner requirement. + * The conflicts and multiple consumers for same tensor are handled by injecting + * appropriate copies. + * 1: CollectConsumerScopeInfo: Visitor collects all consumer demand for each input + * 2: CollectProducerScopeInfo: Visitor does finalizes the scope for each input and output based + * on consumer scope information. It does evaluating mutiple consumer cases and conflicts. + * 3: DefineVDevice: Pass does injects hint_on_device for each argument. It also tries to update + * out StructInfo containing VDevice information. This update for tir calls is straight forward + * as sinfo_args in CallNode is meant for this purpose. This sinfo_args for other calls by + * design is invalid as we do this by "FInferStructInfo". + * Another issue we have with "FInferStructInfo" per op is they can't decide this + * memory scope information which is done by this pass based on consumer demand. + * Hence, we are going to use the sinfo_args to indicate this information. + * So, this pass attributes sinfo_args for regumar calls too and FInferStructInfo implmentation + * do take VDevice information fro this hint. This also solves the issue of mixed VDevice + * for arguments of an op. + * After these steps the mod looks like + * + * class Module: + * @R.function + * def main( + * x: R.Tensor((2, 64, 56, 56), dtype="float32"), + * w: R.Tensor((32, 64, 3, 3), dtype="float32") + * ) -> R.Tensor((2, 32, 54, 54), dtype="float32"): + * with R.dataflow(): + * lv: R.Tensor((2, 64, 56, 56), dtype="float32") = R.hint_on_device( + * x, R.device(dev_type=4, dev_id=0), "global" + * ) + * lv_1 = R.call_tir(cls.te_layout_transform, (lv,), + * out_sinfo=R.Tensor((2, 16, 56, 56, 4), dtype="float32", + * vdevice="opencl:0:global.texture-nhwc" + * ) + * ) + * lv1: R.Tensor((32, 64, 3, 3), dtype="float32") = R.hint_on_device( + * w, R.device(dev_type=4, dev_id=0), "global" + * ) + * lv1_1 = R.call_tir(cls.te_layout_transform1, (lv1,), + * out_sinfo=R.Tensor((8, 64, 3, 3, 4), dtype="float32", + * vdevice="opencl:2:global.texture-weight" + * ) + * ) + * lv2: R.Tensor((2, 16, 56, 56, 4), dtype="float32", + * vdevice="opencl:0:global.texture-nhwc" + * ) = R.hint_on_device(lv_1, R.device(dev_type=4, dev_id=0), "global.texture-nhwc") + * lv3: R.Tensor((8, 64, 3, 3, 4), dtype="float32", + * vdevice="opencl:2:global.texture-weight" + * ) = R.hint_on_device(lv1_1, R.device(dev_type=4, dev_id=0), "global.texture-weight") + * lv2_1: R.Tensor((2, 8, 54, 54, 4), dtype="float32", + * vdevice="opencl:1:global" + & ) = R.nn.conv2d( + * lv2, lv3, + * data_layout="NCHW4c", kernel_layout="OIHW4o", + * out_layout="NCHW4c", out_dtype="float32", + * sinfo_args=(R.Tensor((2, 8, 54, 54, 4), dtype="float32", + * vdevice="opencl:1:global"), + * ) + * ) + * lv4: R.Tensor((2, 8, 54, 54, 4), dtype="float32", + * vdevice="opencl:1:global" + * ) = R.hint_on_device(lv2_1, R.device(dev_type=4, dev_id=0), "global") + * gv = R.call_tir(cls.te_layout_transform2, (lv4,), + * out_sinfo=R.Tensor((2, 32, 54, 54), dtype="float32", vdevice="opencl:1:global") + * ) + * R.output(gv) + * return gv + * + * What we have above is hint_on_device injections and out_sinfo for all calls. + * Now, we apply RealizeVDevice to formalize the hints. Follwed by we also call + * RemoveRedundantAssignments that removes redundant assignments like + * + * lv: R.Tensor((2, 64, 56, 56), dtype="float32", vdevice="opencl:1:global") = x + * lv1: R.Tensor((32, 64, 3, 3), dtype="float32", vdevice="opencl:1:global") = w + * + * These assignments are result of hint_on_device not realizing any copy while consumer and + * producer has same memory scope or vdevice. These assignments do impact operator fusion. + * + * Now the mod looks like, + * + * class Module: + * @R.function + * def main( + * x: R.Tensor((2, 64, 56, 56), dtype="float32"), + * w: R.Tensor((32, 64, 3, 3), dtype="float32") + * ) -> R.Tensor((2, 32, 54, 54), dtype="float32"): + * with R.dataflow(): + * lv = R.call_tir(cls.te_layout_transform, (x,), + * out_sinfo=R.Tensor((2, 16, 56, 56, 4), dtype="float32", + * vdevice="opencl:0:global.texture-nhwc" + * ) + * ) + * lv1 = R.call_tir(cls.te_layout_transform1, (w,), + * out_sinfo=R.Tensor((8, 64, 3, 3, 4), dtype="float32", + * vdevice="opencl:2:global.texture-weight" + * ) + * ) + * lv2: R.Tensor((2, 8, 54, 54, 4), dtype="float32", + * vdevice="opencl:1:global" + * ) = R.nn.conv2d( + * lv2, lv3, + * data_layout="NCHW4c", kernel_layout="OIHW4o", + * out_layout="NCHW4c", out_dtype="float32", + * sinfo_args=(R.Tensor((2, 8, 54, 54, 4), dtype="float32", + * vdevice="opencl:1:global"), + * ) + * ) + * gv = R.call_tir(cls.te_layout_transform2, (lv4,), + * out_sinfo=R.Tensor((2, 32, 54, 54), dtype="float32", vdevice="opencl:1:global") + * ) + * R.output(gv) + * return gv + * + * Followed by, the compilation pipeline calls + * - legalization of the remainng ops: This legalization do forwards the annotated out_sinfo + * VDevice information to tir_calls + * - AnnotateTIROpPattern : TIROp Patterns for newly legalizes ops + * - Fusion + * - OptimizeToVDeviceForScopeChange: There existed some ToVDevice copies from texture to buffer + * This pass removes the copes and updates producer scope to global. + * - SpecializePrimFuncBasedOnCallSite: Finally we updates the Buffer Var maps according to + * VDevice scopes. + * */ #include @@ -227,7 +440,7 @@ class CollectConsumerScopeInfo : public ExprVisitor { * This pass considers all consumers and their scope demand. * Any mismatches here introduces copies as needed. */ -class CollectProduserScopeInfo : public ExprVisitor { +class CollectProducerScopeInfo : public ExprVisitor { public: using ExprVisitor::VisitExpr_; @@ -349,7 +562,7 @@ class DefineVDevice : ExprMutator { auto info = CollectConsumerScopeInfo().Collect(mod_, Downcast(func), target_); call_scope_info_ = info.first; scope_info_ = info.second; - producer_sinfo_ = CollectProduserScopeInfo().Collect(mod_, Downcast(func), + producer_sinfo_ = CollectProducerScopeInfo().Collect(mod_, Downcast(func), scope_info_, target_, builder_); relax::Function update_func = Downcast(VisitExpr(func)); updates_->Add(gv, update_func); @@ -364,7 +577,9 @@ class DefineVDevice : ExprMutator { mod_.CopyOnWrite()->global_infos.Set("vdevice", global_vdevices_); mod_ = relax::transform::DeadCodeElimination()(mod_); + LOG(WARNING) << "RealizeVDevice:" << mod_; mod_ = relax::transform::RealizeVDevice()(mod_); + LOG(WARNING) << "RealizeVDevice after:" << mod_; mod_ = relax::transform::RemoveRedundantAssignments()(mod_); return mod_; From 7161141fa9565475b8b7a990bb90654bf6b296dc Mon Sep 17 00:00:00 2001 From: Siva Date: Wed, 26 Feb 2025 23:07:08 +0530 Subject: [PATCH 18/31] Test case rollback and fixes. --- src/relax/transform/legalize_ops.cc | 2 +- .../transform/optimize_to_vdevice_for_scope_change.cc | 7 +++---- src/relax/transform/remove_redundant_assignments.cc | 2 +- tests/python/relax/test_transform_fuse_tir.py | 4 ++-- 4 files changed, 7 insertions(+), 8 deletions(-) diff --git a/src/relax/transform/legalize_ops.cc b/src/relax/transform/legalize_ops.cc index 9083e74a5f7b..af5f9c2ca8fc 100644 --- a/src/relax/transform/legalize_ops.cc +++ b/src/relax/transform/legalize_ops.cc @@ -162,7 +162,7 @@ class LegalizeMutator : public ExprMutator { return std::nullopt; } - Expr UpdateVDeviceOutStructInfo(Expr expr, Call& visited_call) { + Expr UpdateVDeviceOutStructInfo(Expr expr, const Call& visited_call) { static const auto& infer_struct_info_map = Op::GetAttrMap("FInferStructInfo"); static const Op& call_tir_op = Op::Get("relax.call_tir"); auto* op_node = visited_call->op.as(); diff --git a/src/relax/transform/optimize_to_vdevice_for_scope_change.cc b/src/relax/transform/optimize_to_vdevice_for_scope_change.cc index f2b4b1fd371d..68da901aea52 100644 --- a/src/relax/transform/optimize_to_vdevice_for_scope_change.cc +++ b/src/relax/transform/optimize_to_vdevice_for_scope_change.cc @@ -102,8 +102,7 @@ class CollectConsumerDetails : public ExprVisitor { public: using ExprVisitor::VisitExpr_; - Map> Collect( - const IRModule& mod, Function func, const Target& target) { + Map> Collect(const IRModule& mod, Function func, const Target& target) { mod_ = mod; target_ = target; VisitExpr(func->body); @@ -159,7 +158,6 @@ class CollectConsumerDetails : public ExprVisitor { } private: - /* Map of each Var consumption by a call node */ Map> consumers; Map arg_to_binding; @@ -172,7 +170,8 @@ namespace transform { Pass OptimizeToVDeviceForScopeChange() { auto pass_func = [=](Function func, IRModule mod, PassContext pc) { /* here Target doesn't matter as the consumers we use only to find multiple consumers */ - auto consumers = CollectConsumerDetails().Collect(mod, Downcast(func), Target("opencl")); + auto consumers = + CollectConsumerDetails().Collect(mod, Downcast(func), Target("opencl")); auto [pattern, rewriter] = CreatePatterns(consumers); return RewriteCall(pattern, rewriter, func); }; diff --git a/src/relax/transform/remove_redundant_assignments.cc b/src/relax/transform/remove_redundant_assignments.cc index a33feba668e6..ebf03db4bc6d 100644 --- a/src/relax/transform/remove_redundant_assignments.cc +++ b/src/relax/transform/remove_redundant_assignments.cc @@ -44,7 +44,7 @@ class RemoveRedundantAssignments : public ExprMutator { public: using ExprMutator::VisitExpr_; - IRModule Run(IRModule& mod) { + IRModule Run(const IRModule& mod) { mod_ = mod; for (const auto& [gv, func] : mod_->functions) { if (func->IsInstance()) { diff --git a/tests/python/relax/test_transform_fuse_tir.py b/tests/python/relax/test_transform_fuse_tir.py index 8b93e22e0752..8e583b3dd4cc 100644 --- a/tests/python/relax/test_transform_fuse_tir.py +++ b/tests/python/relax/test_transform_fuse_tir.py @@ -1119,7 +1119,7 @@ def fused_concatenate_transpose2( (T.int64(2), T.int64(64), T.int64(64), T.int64(4)), "float32" ), ): - T.func_attr({"op_pattern": 2, "tir.noalias": True}) + T.func_attr({"tir.noalias": True}) T_concat_handle_intermediate = T.alloc_buffer( (T.int64(2), T.int64(4), T.int64(64), T.int64(64)) ) @@ -1309,7 +1309,7 @@ def fused_reshape( (T.int64(4), T.int64(8), T.int64(32), T.int64(64)), "float32" ), ): - T.func_attr({"op_pattern": 2, "tir.noalias": True}) + T.func_attr({"tir.noalias": True}) # with T.block("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(8), T.int64(32), T.int64(64)): with T.block("T_reshape"): From 55c1ae13a17988e1ebb88595cdb372d437213349 Mon Sep 17 00:00:00 2001 From: Siva Date: Thu, 27 Feb 2025 08:20:10 +0530 Subject: [PATCH 19/31] remove log prints --- src/relax/transform/annotate_custom_storage.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/relax/transform/annotate_custom_storage.cc b/src/relax/transform/annotate_custom_storage.cc index 8f086a1f8904..750c80db9fe9 100644 --- a/src/relax/transform/annotate_custom_storage.cc +++ b/src/relax/transform/annotate_custom_storage.cc @@ -577,9 +577,7 @@ class DefineVDevice : ExprMutator { mod_.CopyOnWrite()->global_infos.Set("vdevice", global_vdevices_); mod_ = relax::transform::DeadCodeElimination()(mod_); - LOG(WARNING) << "RealizeVDevice:" << mod_; mod_ = relax::transform::RealizeVDevice()(mod_); - LOG(WARNING) << "RealizeVDevice after:" << mod_; mod_ = relax::transform::RemoveRedundantAssignments()(mod_); return mod_; From 45bf9033d36fb1fde4c6d5acc3fad9f63624d9ac Mon Sep 17 00:00:00 2001 From: Siva Date: Tue, 25 Mar 2025 22:34:51 +0530 Subject: [PATCH 20/31] review comments --- CMakeLists.txt | 1 + include/tvm/relax/expr.h | 6 +- include/tvm/relax/transform.h | 9 +- python/tvm/relax/transform/__init__.py | 3 +- python/tvm/relax/transform/transform.py | 17 +- .../adreno}/annotate_custom_storage.cc | 14 +- .../adreno/fold_vdevice_scope_change.cc} | 16 +- src/relax/op/op_common.h | 5 + src/relax/transform/legalize_ops.cc | 25 +- .../transform/remove_redundant_assignments.cc | 124 ------- .../test_transform_annotate_custom_scope.py | 3 +- ...est_transform_fold_vdevice_scope_change.py | 282 ++++++++++++++ ...m_specialize_primfunc_based_on_callsite.py | 344 ++++++++++++++++++ 13 files changed, 668 insertions(+), 181 deletions(-) rename src/relax/{transform => backend/adreno}/annotate_custom_storage.cc (98%) rename src/relax/{transform/optimize_to_vdevice_for_scope_change.cc => backend/adreno/fold_vdevice_scope_change.cc} (93%) delete mode 100644 src/relax/transform/remove_redundant_assignments.cc create mode 100644 tests/python/relax/test_transform_fold_vdevice_scope_change.py create mode 100644 tests/python/relax/test_transform_specialize_primfunc_based_on_callsite.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 6713a7cbb5c7..4b9112e265f2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -307,6 +307,7 @@ tvm_file_glob(GLOB_RECURSE COMPILER_SRCS src/relax/analysis/*.cc src/relax/transform/*.cc src/relax/backend/vm/*.cc + src/relax/backend/adreno/*.cc src/relax/backend/task_extraction.cc src/relax/backend/pattern_registry.cc src/relax/utils.cc diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index d746de9c1672..9b5a3176f413 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -156,9 +156,13 @@ class CallNode : public ExprNode { /*! * \brief The structure info arguments of a CallNode. - * sinfo_args is designed to be non-empty only for intrinsic op (e.g., + * sinfo_args is by default designed to be non-empty only for intrinsic op (e.g., * call_tir, call_builtin_with_ctx, etc.) and calls to ExternFuncs, with the main * usage of structure info inference. + * + * Regular ops also at times may have sinfo_args defined to specialize partial + * or complete structure info. Like VDevice customization with mixed input memory_scopes. + * The customized pass can set this info and operator specific inference will respect it. */ ffi::Array sinfo_args; diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index a2ea717553aa..e4b1118928e2 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -698,19 +698,12 @@ TVM_DLL Pass AnnotateCustomMemoryScope(Target target); */ TVM_DLL Pass SpecializePrimFuncBasedOnCallSite(); -/*! - * \brief This pass removes redundant assignment statements. These stmts are result of other pass - * like hint_on_device processed by RealizeVDevice may leave them. The subsequent pass like - * fuse_ops fail to fuse in this case. - */ -TVM_DLL Pass RemoveRedundantAssignments(); - /* * \brief This is a texture specific pass that can optimize unnecessary to_device copies. * Like texture_scope -> ToVDevice -> global scope. In this case the producer can directly * store into global scope avoiding unnecessary device copy. */ -TVM_DLL Pass OptimizeToVDeviceForScopeChange(); +TVM_DLL Pass FoldVDeviceScopeChange(); } // namespace transform } // namespace relax diff --git a/python/tvm/relax/transform/__init__.py b/python/tvm/relax/transform/__init__.py index f256dda4b176..30a3590dcb53 100644 --- a/python/tvm/relax/transform/__init__.py +++ b/python/tvm/relax/transform/__init__.py @@ -85,8 +85,7 @@ VMShapeLower, AnnotateCustomMemoryScope, SpecializePrimFuncBasedOnCallSite, - OptimizeToVDeviceForScopeChange, - RemoveRedundantAssignments, + FoldVDeviceScopeChange, dataflowblock_pass, function_pass, ) diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index a07501f567ec..bebe0da93f31 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -1637,7 +1637,7 @@ def SpecializePrimFuncBasedOnCallSite() -> tvm.ir.transform.Pass: return _ffi_api.SpecializePrimFuncBasedOnCallSite() # type: ignore -def OptimizeToVDeviceForScopeChange() -> tvm.ir.transform.Pass: +def FoldVDeviceScopeChange() -> tvm.ir.transform.Pass: """This pass is a texture specific pass that can optimize unnecessary to_device copies. Like texture_scope -> ToVDevice -> global scope. In this case the producer can directly store into global scope avoiding unnecessary device copy. @@ -1647,20 +1647,7 @@ def OptimizeToVDeviceForScopeChange() -> tvm.ir.transform.Pass: ret: tvm.ir.transform.Pass The registered pass for allocating workspace. """ - return _ffi_api.OptimizeToVDeviceForScopeChange() # type: ignore - - -def RemoveRedundantAssignments() -> tvm.ir.transform.Pass: - """This pass removes redundant assignment statements. These stmts are result of other pass - like hint_on_device processed by RealizeVDevice may leave them. The subsequent pass like - fuse_ops fail to fuse in this case - - Returns - ------- - ret: tvm.ir.transform.Pass - The registered pass for allocating workspace. - """ - return _ffi_api.RemoveRedundantAssignments() # type: ignore + return _ffi_api.FoldVDeviceScopeChange() # type: ignore def _wrap_class_function_pass(pass_cls, pass_info): diff --git a/src/relax/transform/annotate_custom_storage.cc b/src/relax/backend/adreno/annotate_custom_storage.cc similarity index 98% rename from src/relax/transform/annotate_custom_storage.cc rename to src/relax/backend/adreno/annotate_custom_storage.cc index 750c80db9fe9..3b2a1fac4285 100644 --- a/src/relax/transform/annotate_custom_storage.cc +++ b/src/relax/backend/adreno/annotate_custom_storage.cc @@ -17,7 +17,7 @@ * under the License. */ /*! - * \file src/relax/transform/annotate_texture_storage.cc + * \file src/relax/backend/adreno/annotate_texture_storage.cc * \brief Texture Storage Annotation Pass for Adreno GPU targets. * * Texture realization for Adreno GPU targets requires fundamentally follows @@ -181,7 +181,7 @@ * * What we have above is hint_on_device injections and out_sinfo for all calls. * Now, we apply RealizeVDevice to formalize the hints. Follwed by we also call - * RemoveRedundantAssignments that removes redundant assignments like + * CanonicalizeBindings that removes redundant assignments like * * lv: R.Tensor((2, 64, 56, 56), dtype="float32", vdevice="opencl:1:global") = x * lv1: R.Tensor((32, 64, 3, 3), dtype="float32", vdevice="opencl:1:global") = w @@ -229,7 +229,7 @@ * VDevice information to tir_calls * - AnnotateTIROpPattern : TIROp Patterns for newly legalizes ops * - Fusion - * - OptimizeToVDeviceForScopeChange: There existed some ToVDevice copies from texture to buffer + * - FoldVDeviceScopeChange: There existed some ToVDevice copies from texture to buffer * This pass removes the copes and updates producer scope to global. * - SpecializePrimFuncBasedOnCallSite: Finally we updates the Buffer Var maps according to * VDevice scopes. @@ -247,9 +247,9 @@ #include -#include "../op/tensor/manipulate.h" -#include "infer_layout_utils.h" -#include "utils.h" +#include "../../op/tensor/manipulate.h" +#include "../../transform/infer_layout_utils.h" +#include "../../transform/utils.h" namespace tvm { namespace relax { @@ -578,7 +578,7 @@ class DefineVDevice : ExprMutator { mod_ = relax::transform::DeadCodeElimination()(mod_); mod_ = relax::transform::RealizeVDevice()(mod_); - mod_ = relax::transform::RemoveRedundantAssignments()(mod_); + mod_ = relax::transform::CanonicalizeBindings()(mod_); return mod_; } diff --git a/src/relax/transform/optimize_to_vdevice_for_scope_change.cc b/src/relax/backend/adreno/fold_vdevice_scope_change.cc similarity index 93% rename from src/relax/transform/optimize_to_vdevice_for_scope_change.cc rename to src/relax/backend/adreno/fold_vdevice_scope_change.cc index 68da901aea52..1ad3aafb2df9 100644 --- a/src/relax/transform/optimize_to_vdevice_for_scope_change.cc +++ b/src/relax/backend/adreno/fold_vdevice_scope_change.cc @@ -17,7 +17,7 @@ * under the License. */ /*! - * \file src/relax/transform/optimize_to_vdevice_for_scope_change.cc + * \file src/relax/backend/adreno/optimize_to_vdevice_for_scope_change.cc * \brief This is a texture specific pass that can optimize unnecessary to_device copies. * Like texture_scope -> ToVDevice -> global scope. In this case the producer can directly * store into global scope avoiding unnecessary device copy. @@ -34,9 +34,9 @@ #include -#include "../op/tensor/manipulate.h" -#include "infer_layout_utils.h" -#include "utils.h" +#include "../../op/tensor/manipulate.h" +#include "../../transform/infer_layout_utils.h" +#include "../../transform/utils.h" namespace tvm { namespace relax { @@ -167,7 +167,7 @@ class CollectConsumerDetails : public ExprVisitor { namespace transform { -Pass OptimizeToVDeviceForScopeChange() { +Pass FoldVDeviceScopeChange() { auto pass_func = [=](Function func, IRModule mod, PassContext pc) { /* here Target doesn't matter as the consumers we use only to find multiple consumers */ auto consumers = @@ -175,10 +175,10 @@ Pass OptimizeToVDeviceForScopeChange() { auto [pattern, rewriter] = CreatePatterns(consumers); return RewriteCall(pattern, rewriter, func); }; - return CreateFunctionPass(pass_func, 1, "OptimizeToVDeviceForScopeChange", {}); + return CreateFunctionPass(pass_func, 1, "FoldVDeviceScopeChange", {}); } -TVM_REGISTER_GLOBAL("relax.transform.OptimizeToVDeviceForScopeChange") - .set_body_typed(OptimizeToVDeviceForScopeChange); +TVM_REGISTER_GLOBAL("relax.transform.FoldVDeviceScopeChange") + .set_body_typed(FoldVDeviceScopeChange); } // namespace transform } // namespace relax } // namespace tvm diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h index 0e0c4254e34b..5a556cbd7413 100644 --- a/src/relax/op/op_common.h +++ b/src/relax/op/op_common.h @@ -351,6 +351,11 @@ inline ffi::Optional InferBinaryArithOpOutVDevice(const Call& call, } }; + /* + * This is the case where the output VDevice defined by a customization pass. + * Like targets that supports mixed VDevices (like differed by memory_scope for Adreno) + * and have specialized derivation for output VDevice. + */ if (call->sinfo_args.size() > 0) { return get_vdevice(call->sinfo_args[0]); } diff --git a/src/relax/transform/legalize_ops.cc b/src/relax/transform/legalize_ops.cc index af5f9c2ca8fc..f8ad0a7dcbb0 100644 --- a/src/relax/transform/legalize_ops.cc +++ b/src/relax/transform/legalize_ops.cc @@ -31,6 +31,8 @@ #include #include +#include + namespace tvm { namespace relax { @@ -68,13 +70,15 @@ class LegalizeMutator : public ExprMutator { public: explicit LegalizeMutator(const IRModule& mod, const ffi::Optional>& cmap, const ffi::Optional> skip_ops, bool enable_warning) - : ExprMutator(mod), - mod_(std::move(mod)), - enable_warning_(enable_warning), - skip_ops_(skip_ops) { + : ExprMutator(mod), mod_(std::move(mod)), enable_warning_(enable_warning) { if (cmap) { cmap_ = cmap.value(); } + if (skip_ops.defined()) { + for (const auto name : skip_ops.value()) { + skip_ops_.insert(Op::Get(name)); + } + } } IRModule Transform() { @@ -221,9 +225,6 @@ class LegalizeMutator : public ExprMutator { out_sinfo = TupleStructInfo(sinfo_fields); } - if (out_sinfo->IsInstance()) { - LOG(WARNING) << "New Struct Info:" << Downcast(out_sinfo); - } return Call(call_tir_op, call->args, call->attrs, {out_sinfo}); } @@ -312,12 +313,8 @@ class LegalizeMutator : public ExprMutator { } auto op = ffi::GetRef(op_node); - if (skip_ops_.defined()) { - for (const auto name : skip_ops_.value()) { - if (name == op->name) { - return visited_call; - } - } + if (skip_ops_.find(op) != skip_ops_.end()) { + return visited_call; } bool shapes_are_known_if_required = [&]() -> bool { @@ -473,7 +470,7 @@ class LegalizeMutator : public ExprMutator { /*! * \brief List of ops to be skipped from legalization */ - Optional> skip_ops_; + std::set skip_ops_; }; namespace transform { diff --git a/src/relax/transform/remove_redundant_assignments.cc b/src/relax/transform/remove_redundant_assignments.cc deleted file mode 100644 index ebf03db4bc6d..000000000000 --- a/src/relax/transform/remove_redundant_assignments.cc +++ /dev/null @@ -1,124 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file src/relax/transform/remove_redundant_assignments.cc - * \brief This pass removes redundant assignment statements. These stmts are result of other pass - * like hint_on_device processed by RealizeVDevice may leave them. The subsequent pass like - * fuse_ops fail to fuse in this case. - */ - -#include -#include -#include -#include -#include -#include -#include - -#include - -#include "../op/tensor/manipulate.h" -#include "infer_layout_utils.h" -#include "utils.h" - -namespace tvm { -namespace relax { - -class RemoveRedundantAssignments : public ExprMutator { - public: - using ExprMutator::VisitExpr_; - - IRModule Run(const IRModule& mod) { - mod_ = mod; - for (const auto& [gv, func] : mod_->functions) { - if (func->IsInstance()) { - const auto& base_func = mod_->Lookup(gv); - // Only non primitive relax functions - if (base_func->HasNonzeroAttr(attr::kPrimitive)) { - continue; - } - relax::Function update_func = Downcast(VisitExpr(func)); - updates_->Add(gv, update_func); - } - } - mod_.CopyOnWrite()->Update(updates_); - return mod_; - } - - void VisitBinding_(const VarBindingNode* binding, const VarNode* var) final { - redundant_map.Set(GetRef(binding->var.get()), GetRef(var)); - } - - void VisitBinding_(const VarBindingNode* binding, const DataflowVarNode* val) override { - redundant_map.Set(GetRef(binding->var.get()), GetRef(val)); - } - - Expr VisitExpr_(const CallNode* call_node) final { - auto call = Downcast(ExprMutator::VisitExpr_(call_node)); - static const Op& call_tir_op = Op::Get("relax.call_tir"); - - Tuple args; - - if (call->op == call_tir_op) { - args = Downcast(call->args[1]); - } else { - args = Tuple(call->args); - } - Array new_args; - - for (auto& arg : args->fields) { - if (redundant_map.find(arg) != redundant_map.end()) { - new_args.push_back(redundant_map[arg]); - } else { - new_args.push_back(arg); - } - } - if (call->op == call_tir_op) { - return Call(call_tir_op, {call->args[0], Tuple(new_args)}, call->attrs, - {call->sinfo_args[0]}); - } else { - if (call->sinfo_args.size() > 0) { - return Call(call->op, new_args, call->attrs, {call->sinfo_args[0]}); - } else { - return Call(call->op, new_args, call->attrs); - } - } - } - - private: - Map redundant_map; - IRModule updates_; - IRModule mod_; -}; - -namespace transform { - -Pass RemoveRedundantAssignments() { - runtime::TypedPackedFunc pass_func = - [=](IRModule mod, PassContext pc) { return relax::RemoveRedundantAssignments().Run(mod); }; - return CreateModulePass(/*pass_function=*/pass_func, - /*opt_level=*/0, - /*pass_name=*/"RemoveRedundantAssignments", - /*required=*/{}); -} -TVM_REGISTER_GLOBAL("relax.transform.RemoveRedundantAssignments") - .set_body_typed(RemoveRedundantAssignments); -} // namespace transform -} // namespace relax -} // namespace tvm diff --git a/tests/python/relax/test_transform_annotate_custom_scope.py b/tests/python/relax/test_transform_annotate_custom_scope.py index c1856cbb4fd3..785bf1371707 100644 --- a/tests/python/relax/test_transform_annotate_custom_scope.py +++ b/tests/python/relax/test_transform_annotate_custom_scope.py @@ -95,7 +95,7 @@ def verify(mod, expected): mod = tvm.relax.transform.FuseOps()(mod) mod = tvm.relax.transform.FuseTIR()(mod) mod = tvm.relax.transform.DeadCodeElimination()(mod) - mod = tvm.relax.transform.OptimizeToVDeviceForScopeChange()(mod) + mod = tvm.relax.transform.FoldVDeviceScopeChange()(mod) mod = tvm.relax.transform.DeadCodeElimination()(mod) mod = tvm.relax.transform.SpecializePrimFuncBasedOnCallSite()(mod) mod = tvm.relax.transform.Normalize()(mod) @@ -1196,4 +1196,3 @@ def main( if __name__ == "__main__": tvm.testing.main() - # test_conv2d_relu_sub_indexed() diff --git a/tests/python/relax/test_transform_fold_vdevice_scope_change.py b/tests/python/relax/test_transform_fold_vdevice_scope_change.py new file mode 100644 index 000000000000..274dd10ccb33 --- /dev/null +++ b/tests/python/relax/test_transform_fold_vdevice_scope_change.py @@ -0,0 +1,282 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm import relax +import tvm.testing +from tvm.script.parser import ir as I, relax as R, tir as T +from tvm.ir.module import IRModule + + +def verify(input, expected): + mod = tvm.relax.transform.FoldVDeviceScopeChange()(input) + tvm.ir.assert_structural_equal(mod, expected) + + +def test_maxpool2d_scope_folding(): + @I.ir_module + class Input: + I.module_global_infos( + { + "vdevice": [ + I.vdevice({"device": "adreno", "kind": "opencl"}, 0, "global.texture-weight"), + I.vdevice({"device": "adreno", "kind": "opencl"}, 0, "global"), + ] + } + ) + + @T.prim_func(private=True) + def max_pool2d_opencl( + gv: T.Buffer((T.int64(2), T.int64(1), T.int64(26), T.int64(26), T.int64(4)), "float32"), + pool_max: T.Buffer( + (T.int64(2), T.int64(1), T.int64(13), T.int64(13), T.int64(4)), "float32" + ), + ): + # with T.block("root"): + for ax0, ax1, ax2, ax3, ax4, rv0, rv1 in T.grid( + T.int64(2), T.int64(1), T.int64(13), T.int64(13), T.int64(4), T.int64(2), T.int64(2) + ): + with T.block("pool_max"): + v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_rv0, v_rv1 = T.axis.remap( + "SSSSSRR", [ax0, ax1, ax2, ax3, ax4, rv0, rv1] + ) + T.reads( + gv[ + v_ax0, + v_ax1, + v_ax2 * T.int64(2) + v_rv0, + v_ax3 * T.int64(2) + v_rv1, + v_ax4, + ] + ) + T.writes(pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]) + T.block_attr({"schedule_rule": "meta_schedule.pool_max"}) + with T.init(): + pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.float32( + -340282346638528859811704183484516925440.0 + ) + pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.max( + pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4], + gv[ + v_ax0, + v_ax1, + v_ax2 * T.int64(2) + v_rv0, + v_ax3 * T.int64(2) + v_rv1, + v_ax4, + ], + ) + + @T.prim_func(private=True) + def te_layout_transform( + x: T.Buffer((T.int64(2), T.int64(4), T.int64(26), T.int64(26)), "float32"), + te_layout_transform: T.Buffer( + (T.int64(2), T.int64(1), T.int64(26), T.int64(26), T.int64(4)), "float32" + ), + ): + # with T.block("root"): + for self, i0, i1, i2 in T.grid(T.int64(2), T.int64(4), T.int64(26), T.int64(26)): + with T.block("te_layout_transform"): + v_self, v_i0, v_i1, v_i2 = T.axis.remap("SSSS", [self, i0, i1, i2]) + T.reads(x[v_self, v_i0, v_i1, v_i2]) + T.writes( + te_layout_transform[ + v_self, v_i0 // T.int64(4), v_i1, v_i2, v_i0 % T.int64(4) + ] + ) + te_layout_transform[ + v_self, v_i0 // T.int64(4), v_i1, v_i2, v_i0 % T.int64(4) + ] = x[v_self, v_i0, v_i1, v_i2] + + @T.prim_func(private=True) + def te_layout_transform2( + lv2: T.Buffer( + (T.int64(2), T.int64(1), T.int64(13), T.int64(13), T.int64(4)), "float32" + ), + te_layout_transform: T.Buffer( + (T.int64(2), T.int64(4), T.int64(13), T.int64(13)), "float32" + ), + ): + # with T.block("root"): + for self, i0, i1, i2, i3 in T.grid( + T.int64(2), T.int64(1), T.int64(13), T.int64(13), T.int64(4) + ): + with T.block("te_layout_transform"): + v_self, v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSSS", [self, i0, i1, i2, i3]) + T.reads(lv2[v_self, v_i0, v_i1, v_i2, v_i3]) + T.writes(te_layout_transform[v_self, v_i3, v_i1, v_i2]) + te_layout_transform[v_self, v_i3, v_i1, v_i2] = lv2[ + v_self, v_i0, v_i1, v_i2, v_i3 + ] + + @R.function + def main( + x: R.Tensor((2, 4, 26, 26), dtype="float32", vdevice="opencl:1:global"), # noqa: F722 + ) -> R.Tensor((2, 4, 13, 13), dtype="float32", vdevice="opencl:1:global"): # noqa: F722 + cls = Input + with R.dataflow(): + lv = R.call_tir( + cls.te_layout_transform, + (x,), + out_sinfo=R.Tensor( + (2, 1, 26, 26, 4), dtype="float32", vdevice="opencl:0:global.texture-weight" + ), + ) + lv2 = R.call_tir( + cls.max_pool2d_opencl, + (lv,), + out_sinfo=R.Tensor( + (2, 1, 13, 13, 4), dtype="float32", vdevice="opencl:0:global.texture-weight" + ), + ) + lv5: R.Tensor( + (2, 1, 13, 13, 4), dtype="float32", vdevice="opencl:1:global" # noqa: F722 + ) = R.to_vdevice(lv2, dst_vdevice="opencl:1:global") + gv2 = R.call_tir( + cls.te_layout_transform2, + (lv5,), + out_sinfo=R.Tensor((2, 4, 13, 13), dtype="float32", vdevice="opencl:1:global"), + ) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected: + I.module_global_infos( + { + "vdevice": [ + I.vdevice({"device": "adreno", "kind": "opencl"}, 0, "global.texture-weight"), + I.vdevice({"device": "adreno", "kind": "opencl"}, 0, "global"), + ] + } + ) + + @T.prim_func(private=True) + def max_pool2d_opencl( + gv: T.Buffer((T.int64(2), T.int64(1), T.int64(26), T.int64(26), T.int64(4)), "float32"), + pool_max: T.Buffer( + (T.int64(2), T.int64(1), T.int64(13), T.int64(13), T.int64(4)), "float32" + ), + ): + # with T.block("root"): + for ax0, ax1, ax2, ax3, ax4, rv0, rv1 in T.grid( + T.int64(2), T.int64(1), T.int64(13), T.int64(13), T.int64(4), T.int64(2), T.int64(2) + ): + with T.block("pool_max"): + v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_rv0, v_rv1 = T.axis.remap( + "SSSSSRR", [ax0, ax1, ax2, ax3, ax4, rv0, rv1] + ) + T.reads( + gv[ + v_ax0, + v_ax1, + v_ax2 * T.int64(2) + v_rv0, + v_ax3 * T.int64(2) + v_rv1, + v_ax4, + ] + ) + T.writes(pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]) + T.block_attr({"schedule_rule": "meta_schedule.pool_max"}) + with T.init(): + pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.float32( + -340282346638528859811704183484516925440.0 + ) + pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.max( + pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4], + gv[ + v_ax0, + v_ax1, + v_ax2 * T.int64(2) + v_rv0, + v_ax3 * T.int64(2) + v_rv1, + v_ax4, + ], + ) + + @T.prim_func(private=True) + def te_layout_transform( + x: T.Buffer((T.int64(2), T.int64(4), T.int64(26), T.int64(26)), "float32"), + te_layout_transform: T.Buffer( + (T.int64(2), T.int64(1), T.int64(26), T.int64(26), T.int64(4)), "float32" + ), + ): + # with T.block("root"): + for self, i0, i1, i2 in T.grid(T.int64(2), T.int64(4), T.int64(26), T.int64(26)): + with T.block("te_layout_transform"): + v_self, v_i0, v_i1, v_i2 = T.axis.remap("SSSS", [self, i0, i1, i2]) + T.reads(x[v_self, v_i0, v_i1, v_i2]) + T.writes( + te_layout_transform[ + v_self, v_i0 // T.int64(4), v_i1, v_i2, v_i0 % T.int64(4) + ] + ) + te_layout_transform[ + v_self, v_i0 // T.int64(4), v_i1, v_i2, v_i0 % T.int64(4) + ] = x[v_self, v_i0, v_i1, v_i2] + + @T.prim_func(private=True) + def te_layout_transform2( + lv2: T.Buffer( + (T.int64(2), T.int64(1), T.int64(13), T.int64(13), T.int64(4)), "float32" + ), + te_layout_transform: T.Buffer( + (T.int64(2), T.int64(4), T.int64(13), T.int64(13)), "float32" + ), + ): + # with T.block("root"): + for self, i0, i1, i2, i3 in T.grid( + T.int64(2), T.int64(1), T.int64(13), T.int64(13), T.int64(4) + ): + with T.block("te_layout_transform"): + v_self, v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSSS", [self, i0, i1, i2, i3]) + T.reads(lv2[v_self, v_i0, v_i1, v_i2, v_i3]) + T.writes(te_layout_transform[v_self, v_i3, v_i1, v_i2]) + te_layout_transform[v_self, v_i3, v_i1, v_i2] = lv2[ + v_self, v_i0, v_i1, v_i2, v_i3 + ] + + @R.function + def main( + x: R.Tensor((2, 4, 26, 26), dtype="float32", vdevice="opencl:1:global"), # noqa: F722 + ) -> R.Tensor((2, 4, 13, 13), dtype="float32", vdevice="opencl:1:global"): # noqa: F722 + cls = Expected + with R.dataflow(): + lv = R.call_tir( + cls.te_layout_transform, + (x,), + out_sinfo=R.Tensor( + (2, 1, 26, 26, 4), dtype="float32", vdevice="opencl:0:global.texture-weight" + ), + ) + lv5 = R.call_tir( + cls.max_pool2d_opencl, + (lv,), + out_sinfo=R.Tensor( + (2, 1, 13, 13, 4), dtype="float32", vdevice="opencl:1:global" + ), + ) + gv2 = R.call_tir( + cls.te_layout_transform2, + (lv5,), + out_sinfo=R.Tensor((2, 4, 13, 13), dtype="float32", vdevice="opencl:1:global"), + ) + R.output(gv2) + return gv2 + + verify(Input, Expected) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform_specialize_primfunc_based_on_callsite.py b/tests/python/relax/test_transform_specialize_primfunc_based_on_callsite.py new file mode 100644 index 000000000000..d92570025fce --- /dev/null +++ b/tests/python/relax/test_transform_specialize_primfunc_based_on_callsite.py @@ -0,0 +1,344 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm import relax +import tvm.testing +from tvm.script.parser import ir as I, relax as R, tir as T +from tvm.relax.transform.legalize_ops import adreno as legalize_adreno +from tvm.ir.module import IRModule +from tvm.relax.expr_functor import PyExprMutator, PyExprVisitor, mutator, visitor + + +@visitor +class ValidateBufferScopes(PyExprVisitor): # pylint: disable=abstract-method + def __init__(self, is_matched: bool) -> None: + self.is_matched = is_matched + + def visit(self, mod: IRModule) -> None: + """Entry point""" + self.mod = mod + for key, func in mod.functions_items(): + if isinstance(func, relax.Function): + self.visit_expr(func) + + def visit_call_(self, call: relax.Call) -> None: # pylint: disable=arguments-renamed + if call.op.name == "relax.call_tir": + pfunc = self.mod[call.args[0]] + if not self.is_matched: + # All scopes should be global in before pass + for _, buf in pfunc.buffer_map.items(): + assert ( + "global" == buf.data.type_annotation.storage_scope + ), f"expected to be global scoped, but got {val.data.type_annotation.storage_scope}" + else: + for idx, arg in enumerate(call.args[1]): + arg_sinfo = arg.struct_info + assert isinstance( + arg_sinfo, relax.TensorStructInfo + ), f"Expected TensorStructInfo but git {type(arg_sinfo)}" + buf = pfunc.buffer_map[pfunc.params[idx]] + assert ( + arg_sinfo.vdevice.memory_scope == buf.data.type_annotation.storage_scope + ), f"scope mismatched after specialization {arg_sinfo.vdevice.memory_scope} vs {buf.data.type_annotation.storage_scope}" + if isinstance(call.sinfo_args[0], relax.TensorStructInfo): + buf = pfunc.buffer_map[pfunc.params[-1]] + assert ( + call.sinfo_args[0].vdevice.memory_scope + == buf.data.type_annotation.storage_scope + ), f"scope mismatched after specialization {call.sinfo_args[0].vdevice.memory_scope} vs {buf.data.type_annotation.storage_scope}" + else: + assert isinstance( + call.sinfo_args[0], relax.TupleStructInfo + ), f"Expected TupleStructInfo but git {type(call.sinfo_args[0])}" + for idx, sinfo in enumerate(call.sinfo_args[0].fields): + buf = pfunc.buffer_map[pfunc.params[len(call.args[1]) + idx]] + assert ( + sinfo.vdevice.memory_scope == buf.data.type_annotation.storage_scope + ), f"scope mismatched after specialization {sinfo.vdevice.memory_scope} vs {buf.data.type_annotation.storage_scope}" + + +def verify(input): + ValidateBufferScopes(False).visit(input) + mod = tvm.relax.transform.SpecializePrimFuncBasedOnCallSite()(input) + ValidateBufferScopes(True).visit(mod) + + +def test_single_arg_return(): + @I.ir_module + class Input: + I.module_global_infos( + { + "vdevice": [ + I.vdevice({"device": "adreno", "kind": "opencl"}, 0, "global.texture-weight"), + I.vdevice({"device": "adreno", "kind": "opencl"}, 0, "global"), + ] + } + ) + + @T.prim_func(private=True) + def max_pool2d_opencl( + gv: T.Buffer((T.int64(2), T.int64(1), T.int64(26), T.int64(26), T.int64(4)), "float32"), + pool_max: T.Buffer( + (T.int64(2), T.int64(1), T.int64(13), T.int64(13), T.int64(4)), "float32" + ), + ): + # with T.block("root"): + for ax0, ax1, ax2, ax3, ax4, rv0, rv1 in T.grid( + T.int64(2), T.int64(1), T.int64(13), T.int64(13), T.int64(4), T.int64(2), T.int64(2) + ): + with T.block("pool_max"): + v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_rv0, v_rv1 = T.axis.remap( + "SSSSSRR", [ax0, ax1, ax2, ax3, ax4, rv0, rv1] + ) + T.reads( + gv[ + v_ax0, + v_ax1, + v_ax2 * T.int64(2) + v_rv0, + v_ax3 * T.int64(2) + v_rv1, + v_ax4, + ] + ) + T.writes(pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]) + T.block_attr({"schedule_rule": "meta_schedule.pool_max"}) + with T.init(): + pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.float32( + -340282346638528859811704183484516925440.0 + ) + pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.max( + pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4], + gv[ + v_ax0, + v_ax1, + v_ax2 * T.int64(2) + v_rv0, + v_ax3 * T.int64(2) + v_rv1, + v_ax4, + ], + ) + + @T.prim_func(private=True) + def te_layout_transform( + x: T.Buffer((T.int64(2), T.int64(4), T.int64(26), T.int64(26)), "float32"), + te_layout_transform: T.Buffer( + (T.int64(2), T.int64(1), T.int64(26), T.int64(26), T.int64(4)), "float32" + ), + ): + # with T.block("root"): + for self, i0, i1, i2 in T.grid(T.int64(2), T.int64(4), T.int64(26), T.int64(26)): + with T.block("te_layout_transform"): + v_self, v_i0, v_i1, v_i2 = T.axis.remap("SSSS", [self, i0, i1, i2]) + T.reads(x[v_self, v_i0, v_i1, v_i2]) + T.writes( + te_layout_transform[ + v_self, v_i0 // T.int64(4), v_i1, v_i2, v_i0 % T.int64(4) + ] + ) + te_layout_transform[ + v_self, v_i0 // T.int64(4), v_i1, v_i2, v_i0 % T.int64(4) + ] = x[v_self, v_i0, v_i1, v_i2] + + @T.prim_func(private=True) + def te_layout_transform2( + lv2: T.Buffer( + (T.int64(2), T.int64(1), T.int64(13), T.int64(13), T.int64(4)), "float32" + ), + te_layout_transform: T.Buffer( + (T.int64(2), T.int64(4), T.int64(13), T.int64(13)), "float32" + ), + ): + # with T.block("root"): + for self, i0, i1, i2, i3 in T.grid( + T.int64(2), T.int64(1), T.int64(13), T.int64(13), T.int64(4) + ): + with T.block("te_layout_transform"): + v_self, v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSSS", [self, i0, i1, i2, i3]) + T.reads(lv2[v_self, v_i0, v_i1, v_i2, v_i3]) + T.writes(te_layout_transform[v_self, v_i3, v_i1, v_i2]) + te_layout_transform[v_self, v_i3, v_i1, v_i2] = lv2[ + v_self, v_i0, v_i1, v_i2, v_i3 + ] + + @R.function + def main( + x: R.Tensor((2, 4, 26, 26), dtype="float32", vdevice="opencl:1:global"), # noqa: F722 + ) -> R.Tensor((2, 4, 13, 13), dtype="float32", vdevice="opencl:1:global"): # noqa: F722 + cls = Input + with R.dataflow(): + lv = R.call_tir( + cls.te_layout_transform, + (x,), + out_sinfo=R.Tensor( + (2, 1, 26, 26, 4), dtype="float32", vdevice="opencl:0:global.texture-weight" + ), + ) + lv2 = R.call_tir( + cls.max_pool2d_opencl, + (lv,), + out_sinfo=R.Tensor( + (2, 1, 13, 13, 4), dtype="float32", vdevice="opencl:0:global.texture-weight" + ), + ) + lv5: R.Tensor( + (2, 1, 13, 13, 4), dtype="float32", vdevice="opencl:1:global" # noqa: F722 + ) = R.to_vdevice(lv2, dst_vdevice="opencl:1:global") + gv2 = R.call_tir( + cls.te_layout_transform2, + (lv5,), + out_sinfo=R.Tensor((2, 4, 13, 13), dtype="float32", vdevice="opencl:1:global"), + ) + R.output(gv2) + return gv2 + + verify(Input) + + +def test_multi_arg_return(): + @I.ir_module + class Input: + I.module_global_infos( + { + "vdevice": [ + I.vdevice({"device": "adreno", "kind": "opencl"}, 0, "global.texture-weight"), + I.vdevice({"device": "adreno", "kind": "opencl"}, 0, "global"), + ] + } + ) + + @T.prim_func(private=True) + def conv2d_NCHWc_OIHWo_opencl( + lv: T.Buffer((T.int64(2), T.int64(4), T.int64(28), T.int64(28), T.int64(4)), "float32"), + lv1: T.Buffer((T.int64(1), T.int64(16), T.int64(3), T.int64(3), T.int64(4)), "float32"), + conv2d_NCHWc_OIHWo: T.Buffer( + (T.int64(2), T.int64(1), T.int64(26), T.int64(26), T.int64(4)), "float32" + ), + ): + conv2d_NCHWc_OIHWo[0, 0, 0, 0, 0] = T.float32(0.0) + + @T.prim_func(private=True) + def fused_relu_concatenate_split( + gv: T.Buffer((T.int64(2), T.int64(1), T.int64(26), T.int64(26), T.int64(4)), "float32"), + T_split_sections_intermediate: T.Buffer( + (T.int64(2), T.int64(1), T.int64(26), T.int64(26), T.int64(4)), "float32" + ), + T_split_sections_intermediate_1: T.Buffer( + (T.int64(2), T.int64(1), T.int64(26), T.int64(26), T.int64(4)), "float32" + ), + ): + T_split_sections_intermediate[0, 0, 0, 0, 0] = T.float32(0.0) + T_split_sections_intermediate_1[0, 0, 0, 0, 0] = T.float32(0.0) + + @T.prim_func(private=True) + def te_layout_transform( + x: T.Buffer((T.int64(2), T.int64(16), T.int64(28), T.int64(28)), "float32"), + te_layout_transform: T.Buffer( + (T.int64(2), T.int64(4), T.int64(28), T.int64(28), T.int64(4)), "float32" + ), + ): + te_layout_transform[0, 0, 0, 0, 0] = T.float32(0.0) + + @T.prim_func(private=True) + def te_layout_transform1( + w: T.Buffer((T.int64(4), T.int64(16), T.int64(3), T.int64(3)), "float32"), + te_layout_transform: T.Buffer( + (T.int64(1), T.int64(16), T.int64(3), T.int64(3), T.int64(4)), "float32" + ), + ): + te_layout_transform[0, 0, 0, 0, 0] = T.float32(0.0) + + @T.prim_func(private=True) + def te_layout_transform2( + lv3: T.Buffer( + (T.int64(2), T.int64(1), T.int64(26), T.int64(26), T.int64(4)), "float32" + ), + te_layout_transform: T.Buffer( + (T.int64(2), T.int64(4), T.int64(26), T.int64(26)), "float32" + ), + ): + te_layout_transform[0, 0, 0, 0] = T.float32(0.0) + + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32", vdevice="opencl:1:global"), # noqa: F722 + w: R.Tensor((4, 16, 3, 3), dtype="float32", vdevice="opencl:1:global"), # noqa: F722 + ) -> R.Tuple( + R.Tensor((2, 4, 26, 26), dtype="float32", vdevice="opencl:1:global"), # noqa: F722 + R.Tensor((2, 4, 26, 26), dtype="float32", vdevice="opencl:1:global"), # noqa: F722 + ): + cls = Input + with R.dataflow(): + lv = R.call_tir( + cls.te_layout_transform, + (x,), + out_sinfo=R.Tensor( + (2, 4, 28, 28, 4), dtype="float32", vdevice="opencl:0:global.texture-weight" + ), + ) + lv1 = R.call_tir( + cls.te_layout_transform1, + (w,), + out_sinfo=R.Tensor( + (1, 16, 3, 3, 4), dtype="float32", vdevice="opencl:0:global.texture-weight" + ), + ) + gv = R.call_tir( + cls.conv2d_NCHWc_OIHWo_opencl, + (lv, lv1), + out_sinfo=R.Tensor( + (2, 1, 26, 26, 4), dtype="float32", vdevice="opencl:0:global.texture-weight" + ), + ) + lv_1 = R.call_tir( + cls.fused_relu_concatenate_split, + (gv,), + out_sinfo=[ + R.Tensor((2, 1, 26, 26, 4), dtype="float32", vdevice="opencl:1:global"), + R.Tensor((2, 1, 26, 26, 4), dtype="float32", vdevice="opencl:1:global"), + ], + ) + lv3: R.Tensor( + (2, 1, 26, 26, 4), dtype="float32", vdevice="opencl:1:global" # noqa: F722 + ) = lv_1[0] + lv4 = R.call_tir( + cls.te_layout_transform2, + (lv3,), + out_sinfo=R.Tensor((2, 4, 26, 26), dtype="float32", vdevice="opencl:1:global"), + ) + lv5: R.Tensor( + (2, 1, 26, 26, 4), dtype="float32", vdevice="opencl:1:global" # noqa: F722 + ) = lv_1[1] + lv6 = R.call_tir( + cls.te_layout_transform2, + (lv5,), + out_sinfo=R.Tensor((2, 4, 26, 26), dtype="float32", vdevice="opencl:1:global"), + ) + gv4: R.Tuple( + R.Tensor( + (2, 4, 26, 26), dtype="float32", vdevice="opencl:1:global" # noqa: F722 + ), + R.Tensor( + (2, 4, 26, 26), dtype="float32", vdevice="opencl:1:global" # noqa: F722 + ), + ) = (lv4, lv6) + R.output(gv4) + return gv4 + + verify(Input) + + +if __name__ == "__main__": + tvm.testing.main() From e60ac3341a736ec1daa1f39b2cbaeef3664f921e Mon Sep 17 00:00:00 2001 From: Siva Date: Mon, 31 Mar 2025 22:55:29 +0530 Subject: [PATCH 21/31] Review comments. Adreno transforms under relax.backend.adreno.tansforms across cpp and python. --- include/tvm/relax/backend/adreno/transform.h | 65 +++++++++++++++++++ include/tvm/relax/transform.h | 18 ----- python/tvm/relax/backend/adreno/__init__.py | 3 + .../backend/adreno/transform/__init__.py | 22 +++++++ .../backend/adreno/transform/_ffi_api.py | 19 ++++++ .../backend/adreno/transform/transform.py | 50 ++++++++++++++ python/tvm/relax/transform/__init__.py | 2 - python/tvm/relax/transform/transform.py | 27 -------- .../backend/adreno/annotate_custom_storage.cc | 14 ++-- .../adreno/fold_vdevice_scope_change.cc | 12 ++-- src/relax/transform/legalize_ops.cc | 13 ++-- .../test_transform_annotate_custom_scope.py | 4 +- 12 files changed, 184 insertions(+), 65 deletions(-) create mode 100644 include/tvm/relax/backend/adreno/transform.h create mode 100644 python/tvm/relax/backend/adreno/transform/__init__.py create mode 100644 python/tvm/relax/backend/adreno/transform/_ffi_api.py create mode 100644 python/tvm/relax/backend/adreno/transform/transform.py diff --git a/include/tvm/relax/backend/adreno/transform.h b/include/tvm/relax/backend/adreno/transform.h new file mode 100644 index 000000000000..531391181c5a --- /dev/null +++ b/include/tvm/relax/backend/adreno/transform.h @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/backend/adreno/transform.h + * \brief Adreno GPU specific transformation passes. + */ +#ifndef TVM_RELAX_BACKEND_ADRENO_TRANSFORM_H_ +#define TVM_RELAX_BACKEND_ADRENO_TRANSFORM_H_ + +#include +#include +namespace tvm { +namespace relax { +namespace backend { +namespace adreno { +namespace transform { + +using Pass = tvm::transform::Pass; +using PassInfo = tvm::transform::PassInfo; +using PassContext = tvm::transform::PassContext; +using Function = tvm::relax::Function; +using DataflowBlock = tvm::relax::DataflowBlock; + +/*! + * \brief This pass is designed to annotate the memory scope information via VDevice attribute. + * This pass need operator attrbutes which in general vanish aftre legalization. + * FuseOps and FuseTIR are modified to pass on the operator specific attributes and also + * op_pattern details as part of the PrimFunc. This pass is Adreno specific and annotates each + * BindingVar with appropriate HintInDevice. RealizeVDevice pass followed by handles these hints. + * Followed by this pass we also invoke SpecializePrimFuncBasedOnCallSite which updates the + * var_buffer_map based on this new VDevice information. + */ +TVM_DLL Pass AnnotateCustomMemoryScope(Target target); + +/* + * \brief This is a texture specific pass that can optimize unnecessary to_device copies. + * Like texture_scope -> ToVDevice -> global scope. In this case the producer can directly + * store into global scope avoiding unnecessary device copy. + */ +TVM_DLL Pass FoldVDeviceScopeChange(); + +} // namespace transform +} // namespace adreno +} // namespace backend +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_BACKEND_ADRENO_TRANSFORM_H_ diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index e4b1118928e2..b627dc35482b 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -680,17 +680,6 @@ TVM_DLL Pass RewriteCUDAGraph(); */ TVM_DLL Pass FewShotTuning(int valid_count, bool benchmark); -/*! - * \brief This pass is designed to annotate the memory scope information via VDevice attribute. - * This pass need operator attrbutes which in general vanish aftre legalization. - * FuseOps and FuseTIR are modified to pass on the operator specific attributes and also - * op_pattern details as part of the PrimFunc. This pass is Adreno specific and annotates each - * BindingVar with appropriate HintInDevice. RealizeVDevice pass followed by handles these hints. - * Followed by this pass we also invoke SpecializePrimFuncBasedOnCallSite which updates the - * var_buffer_map based on this new VDevice information. - */ -TVM_DLL Pass AnnotateCustomMemoryScope(Target target); - /*! * \brief This pass updates the var_buffer mapping of PrimFunctions from the call_tir info. * Primarily used to update the VDevice information if any changes occured from the caller. @@ -698,13 +687,6 @@ TVM_DLL Pass AnnotateCustomMemoryScope(Target target); */ TVM_DLL Pass SpecializePrimFuncBasedOnCallSite(); -/* - * \brief This is a texture specific pass that can optimize unnecessary to_device copies. - * Like texture_scope -> ToVDevice -> global scope. In this case the producer can directly - * store into global scope avoiding unnecessary device copy. - */ -TVM_DLL Pass FoldVDeviceScopeChange(); - } // namespace transform } // namespace relax } // namespace tvm diff --git a/python/tvm/relax/backend/adreno/__init__.py b/python/tvm/relax/backend/adreno/__init__.py index b3364f2f4b4a..b97ea399ab19 100644 --- a/python/tvm/relax/backend/adreno/__init__.py +++ b/python/tvm/relax/backend/adreno/__init__.py @@ -15,6 +15,9 @@ # specific language governing permissions and limitations # under the License. """The Relax Adreno backend compilation pipeline and other passes.""" + +from . import transform + from .pipeline import ( finalize_passes, get_default_pipeline, diff --git a/python/tvm/relax/backend/adreno/transform/__init__.py b/python/tvm/relax/backend/adreno/transform/__init__.py new file mode 100644 index 000000000000..abeb56ac488c --- /dev/null +++ b/python/tvm/relax/backend/adreno/transform/__init__.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Adreno Relax transformations. """ + +from .transform import ( + AnnotateCustomMemoryScope, + FoldVDeviceScopeChange, +) diff --git a/python/tvm/relax/backend/adreno/transform/_ffi_api.py b/python/tvm/relax/backend/adreno/transform/_ffi_api.py new file mode 100644 index 000000000000..7ed23bd57b19 --- /dev/null +++ b/python/tvm/relax/backend/adreno/transform/_ffi_api.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +"""FFI APIs for Adreno transform""" +import tvm._ffi + +tvm._ffi._init_api("relax.backend.adreno.transform", __name__) diff --git a/python/tvm/relax/backend/adreno/transform/transform.py b/python/tvm/relax/backend/adreno/transform/transform.py new file mode 100644 index 000000000000..9a01d7be97dd --- /dev/null +++ b/python/tvm/relax/backend/adreno/transform/transform.py @@ -0,0 +1,50 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Adreno Relax transformation passes.""" +from typing import Optional + +import tvm.ir +from tvm.target import Target + +from . import _ffi_api + + +def AnnotateCustomMemoryScope(target: Optional[Target] = None) -> tvm.ir.transform.Pass: + """Allocate the memory scope information. This is Adreno specific pass to annotate + The memory scope information and realize the same with RealizeVDevice pass followed by + updating the Prim Function var_buffer mapping using SpecializePrimFuncBasedOnCallSite. + + Returns + ------- + ret: tvm.ir.transform.Pass + The registered pass for allocating workspace. + """ + return _ffi_api.AnnotateCustomMemoryScope(target) # type: ignore + + +def FoldVDeviceScopeChange() -> tvm.ir.transform.Pass: + """This pass is a texture specific pass that can optimize unnecessary to_device copies. + Like texture_scope -> ToVDevice -> global scope. In this case the producer can directly + store into global scope avoiding unnecessary device copy. + + Returns + ------- + ret: tvm.ir.transform.Pass + The registered pass for allocating workspace. + """ + return _ffi_api.FoldVDeviceScopeChange() # type: ignore diff --git a/python/tvm/relax/transform/__init__.py b/python/tvm/relax/transform/__init__.py index 30a3590dcb53..dacbc667be2b 100644 --- a/python/tvm/relax/transform/__init__.py +++ b/python/tvm/relax/transform/__init__.py @@ -83,9 +83,7 @@ UpdateVDevice, VMBuiltinLower, VMShapeLower, - AnnotateCustomMemoryScope, SpecializePrimFuncBasedOnCallSite, - FoldVDeviceScopeChange, dataflowblock_pass, function_pass, ) diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index bebe0da93f31..46efc17e3d4f 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -30,7 +30,6 @@ from tvm.relax.dpl import DFPattern from tvm.runtime import Tensor, Object from tvm.tir import IndexMap, PrimFunc -from tvm.target import Target from . import _ffi_api from .legalize_ops.common import LegalizeFunc @@ -1611,19 +1610,6 @@ def AllocateWorkspace() -> tvm.ir.transform.Pass: return _ffi_api.AllocateWorkspace() # type: ignore -def AnnotateCustomMemoryScope(target: Optional[Target] = None) -> tvm.ir.transform.Pass: - """Allocate the memory scope information. This is Adreno specific pass to annotate - The memory scope information and realize the same with RealizeVDevice pass followed by - updating the Prim Function var_buffer mapping using SpecializePrimFuncBasedOnCallSite. - - Returns - ------- - ret: tvm.ir.transform.Pass - The registered pass for allocating workspace. - """ - return _ffi_api.AnnotateCustomMemoryScope(target) # type: ignore - - def SpecializePrimFuncBasedOnCallSite() -> tvm.ir.transform.Pass: """This pass updates the var_buffer mapping of PrimFunctions from the call_tir info. Primarily used to update the VDevice information if any changes occured from the caller. @@ -1637,19 +1623,6 @@ def SpecializePrimFuncBasedOnCallSite() -> tvm.ir.transform.Pass: return _ffi_api.SpecializePrimFuncBasedOnCallSite() # type: ignore -def FoldVDeviceScopeChange() -> tvm.ir.transform.Pass: - """This pass is a texture specific pass that can optimize unnecessary to_device copies. - Like texture_scope -> ToVDevice -> global scope. In this case the producer can directly - store into global scope avoiding unnecessary device copy. - - Returns - ------- - ret: tvm.ir.transform.Pass - The registered pass for allocating workspace. - """ - return _ffi_api.FoldVDeviceScopeChange() # type: ignore - - def _wrap_class_function_pass(pass_cls, pass_info): """Wrap a python class as function pass.""" diff --git a/src/relax/backend/adreno/annotate_custom_storage.cc b/src/relax/backend/adreno/annotate_custom_storage.cc index 3b2a1fac4285..63a98ecbedad 100644 --- a/src/relax/backend/adreno/annotate_custom_storage.cc +++ b/src/relax/backend/adreno/annotate_custom_storage.cc @@ -238,11 +238,11 @@ #include #include +#include #include #include #include #include -#include #include #include @@ -253,6 +253,8 @@ namespace tvm { namespace relax { +namespace backend { +namespace adreno { using tvm::tir::Buffer; @@ -730,17 +732,21 @@ class DefineVDevice : ExprMutator { namespace transform { Pass AnnotateCustomMemoryScope(Target target) { - runtime::TypedPackedFunc pass_func = - [=](IRModule mod, PassContext pc) { return relax::DefineVDevice(target).Run(mod); }; + runtime::TypedPackedFunc pass_func = [=](IRModule mod, + PassContext pc) { + return tvm::relax::backend::adreno::DefineVDevice(target).Run(mod); + }; return CreateModulePass(/*pass_function=*/pass_func, /*opt_level=*/0, /*pass_name=*/"AnnotateCustomMemoryScope", /*required=*/{}); } -TVM_REGISTER_GLOBAL("relax.transform.AnnotateCustomMemoryScope") +TVM_REGISTER_GLOBAL("relax.backend.adreno.transform.AnnotateCustomMemoryScope") .set_body_typed(AnnotateCustomMemoryScope); } // namespace transform +} // namespace adreno +} // namespace backend } // namespace relax } // namespace tvm diff --git a/src/relax/backend/adreno/fold_vdevice_scope_change.cc b/src/relax/backend/adreno/fold_vdevice_scope_change.cc index 1ad3aafb2df9..bb3a863743e9 100644 --- a/src/relax/backend/adreno/fold_vdevice_scope_change.cc +++ b/src/relax/backend/adreno/fold_vdevice_scope_change.cc @@ -17,7 +17,7 @@ * under the License. */ /*! - * \file src/relax/backend/adreno/optimize_to_vdevice_for_scope_change.cc + * \file src/relax/backend/adreno/fold_vdevice_scope_change.cc * \brief This is a texture specific pass that can optimize unnecessary to_device copies. * Like texture_scope -> ToVDevice -> global scope. In this case the producer can directly * store into global scope avoiding unnecessary device copy. @@ -25,11 +25,11 @@ #include #include +#include #include #include #include #include -#include #include #include @@ -40,6 +40,8 @@ namespace tvm { namespace relax { +namespace backend { +namespace adreno { namespace { std::tuple)>> CreatePatterns( @@ -175,10 +177,12 @@ Pass FoldVDeviceScopeChange() { auto [pattern, rewriter] = CreatePatterns(consumers); return RewriteCall(pattern, rewriter, func); }; - return CreateFunctionPass(pass_func, 1, "FoldVDeviceScopeChange", {}); + return tvm::relax::transform::CreateFunctionPass(pass_func, 1, "FoldVDeviceScopeChange", {}); } -TVM_REGISTER_GLOBAL("relax.transform.FoldVDeviceScopeChange") +TVM_REGISTER_GLOBAL("relax.backend.adreno.transform.FoldVDeviceScopeChange") .set_body_typed(FoldVDeviceScopeChange); } // namespace transform +} // namespace adreno +} // namespace backend } // namespace relax } // namespace tvm diff --git a/src/relax/transform/legalize_ops.cc b/src/relax/transform/legalize_ops.cc index f8ad0a7dcbb0..e10d770a14fb 100644 --- a/src/relax/transform/legalize_ops.cc +++ b/src/relax/transform/legalize_ops.cc @@ -166,8 +166,8 @@ class LegalizeMutator : public ExprMutator { return std::nullopt; } - Expr UpdateVDeviceOutStructInfo(Expr expr, const Call& visited_call) { - static const auto& infer_struct_info_map = Op::GetAttrMap("FInferStructInfo"); + Expr UpdateVDeviceOutStructInfo(Expr expr, const Call& visited_call, + const StructInfo& infered_sinfo) { static const Op& call_tir_op = Op::Get("relax.call_tir"); auto* op_node = visited_call->op.as(); @@ -177,10 +177,6 @@ class LegalizeMutator : public ExprMutator { } auto op = GetRef(op_node); - if (!infer_struct_info_map.count(op)) { - return expr; - } - if (!expr->IsInstance()) { return expr; } @@ -191,7 +187,6 @@ class LegalizeMutator : public ExprMutator { } StructInfo out_sinfo = call->sinfo_args[0]; - StructInfo infered_sinfo = infer_struct_info_map[op](visited_call, builder_); if (out_sinfo->IsInstance()) { auto out_tsinfo = Downcast(out_sinfo); @@ -299,6 +294,7 @@ class LegalizeMutator : public ExprMutator { Expr VisitExpr_(const CallNode* call) final { Call visited_call = Downcast(this->VisitExprPostOrder_(call)); + static const auto& infer_struct_info_map = Op::GetAttrMap("FInferStructInfo"); static const auto& legalize_map = Op::GetAttrMap("FLegalize"); static const auto& call_packed_map = Op::GetAttrMap("FCallPacked"); static const auto& requires_arg_shapes_map = Op::GetAttrMap("RequiresArgumentShapes"); @@ -422,7 +418,8 @@ class LegalizeMutator : public ExprMutator { } Expr legalized = legalization_func(builder_, visited_call); - legalized = UpdateVDeviceOutStructInfo(legalized, visited_call); + StructInfo infered_sinfo = infer_struct_info_map[op](GetRef(call), builder_); + legalized = UpdateVDeviceOutStructInfo(legalized, visited_call, infered_sinfo); // Append the target attribute to any PrimFunc generated in // legalization. diff --git a/tests/python/relax/test_transform_annotate_custom_scope.py b/tests/python/relax/test_transform_annotate_custom_scope.py index 785bf1371707..02f264cca30e 100644 --- a/tests/python/relax/test_transform_annotate_custom_scope.py +++ b/tests/python/relax/test_transform_annotate_custom_scope.py @@ -83,7 +83,7 @@ def verify(mod, expected): mod = tvm.relax.transform.FoldConstant()(mod) mod = tvm.relax.transform.LegalizeOps(skip_ops=skip_ops)(mod) mod = tvm.relax.transform.AnnotateTIROpPattern()(mod) - mod = tvm.relax.transform.AnnotateCustomMemoryScope(tgt)(mod) + mod = tvm.relax.backend.adreno.transform.AnnotateCustomMemoryScope(tgt)(mod) # There is a possibility of some skipped ops above might not use 5D layouts. mod = tvm.relax.transform.LegalizeOps()(mod) mod = tvm.relax.transform.LegalizeOps( @@ -95,7 +95,7 @@ def verify(mod, expected): mod = tvm.relax.transform.FuseOps()(mod) mod = tvm.relax.transform.FuseTIR()(mod) mod = tvm.relax.transform.DeadCodeElimination()(mod) - mod = tvm.relax.transform.FoldVDeviceScopeChange()(mod) + mod = tvm.relax.backend.adreno.transform.FoldVDeviceScopeChange()(mod) mod = tvm.relax.transform.DeadCodeElimination()(mod) mod = tvm.relax.transform.SpecializePrimFuncBasedOnCallSite()(mod) mod = tvm.relax.transform.Normalize()(mod) From b95193785cbecff120bb57f4fddac566dc820559 Mon Sep 17 00:00:00 2001 From: Siva Date: Tue, 1 Apr 2025 06:56:29 +0530 Subject: [PATCH 22/31] Test case fix. --- tests/python/relax/test_transform_fold_vdevice_scope_change.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relax/test_transform_fold_vdevice_scope_change.py b/tests/python/relax/test_transform_fold_vdevice_scope_change.py index 274dd10ccb33..b461f39dd744 100644 --- a/tests/python/relax/test_transform_fold_vdevice_scope_change.py +++ b/tests/python/relax/test_transform_fold_vdevice_scope_change.py @@ -23,7 +23,7 @@ def verify(input, expected): - mod = tvm.relax.transform.FoldVDeviceScopeChange()(input) + mod = tvm.relax.backend.adreno.transform.FoldVDeviceScopeChange()(input) tvm.ir.assert_structural_equal(mod, expected) From a9ea6ba2ab244d62c23e7100cd0ff958c92899bc Mon Sep 17 00:00:00 2001 From: Siva Date: Fri, 2 May 2025 09:27:14 +0530 Subject: [PATCH 23/31] Move adreno tests to specific folder --- .../relax/{ => adreno}/test_transform_annotate_custom_scope.py | 0 .../{ => adreno}/test_transform_fold_vdevice_scope_change.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename tests/python/relax/{ => adreno}/test_transform_annotate_custom_scope.py (100%) rename tests/python/relax/{ => adreno}/test_transform_fold_vdevice_scope_change.py (100%) diff --git a/tests/python/relax/test_transform_annotate_custom_scope.py b/tests/python/relax/adreno/test_transform_annotate_custom_scope.py similarity index 100% rename from tests/python/relax/test_transform_annotate_custom_scope.py rename to tests/python/relax/adreno/test_transform_annotate_custom_scope.py diff --git a/tests/python/relax/test_transform_fold_vdevice_scope_change.py b/tests/python/relax/adreno/test_transform_fold_vdevice_scope_change.py similarity index 100% rename from tests/python/relax/test_transform_fold_vdevice_scope_change.py rename to tests/python/relax/adreno/test_transform_fold_vdevice_scope_change.py From a021f588e9116522bbddb2b21a8f956109058101 Mon Sep 17 00:00:00 2001 From: Siva Date: Wed, 4 Jun 2025 23:05:15 +0530 Subject: [PATCH 24/31] Changes due to recent mainline ffi --- include/tvm/relax/backend/adreno/transform.h | 2 ++ include/tvm/relax/transform.h | 4 ++-- include/tvm/runtime/tensor.h | 10 ++++++++ .../backend/adreno/transform/_ffi_api.py | 4 ++-- python/tvm/tir/analysis/analysis.py | 3 +-- .../backend/adreno/annotate_custom_storage.cc | 7 +++--- .../adreno/fold_vdevice_scope_change.cc | 6 ++--- src/relax/op/op.cc | 10 ++++---- src/relax/transform/legalize_ops.cc | 3 ++- .../specialize_primfunc_based_on_callsite.cc | 7 +++--- src/runtime/contrib/clml/clml_runtime.cc | 10 ++++---- src/runtime/tensor.cc | 24 +++++++++++++++++-- src/tir/schedule/analysis/analysis.cc | 2 ++ 13 files changed, 62 insertions(+), 30 deletions(-) diff --git a/include/tvm/relax/backend/adreno/transform.h b/include/tvm/relax/backend/adreno/transform.h index 531391181c5a..891a19187739 100644 --- a/include/tvm/relax/backend/adreno/transform.h +++ b/include/tvm/relax/backend/adreno/transform.h @@ -37,6 +37,8 @@ using PassInfo = tvm::transform::PassInfo; using PassContext = tvm::transform::PassContext; using Function = tvm::relax::Function; using DataflowBlock = tvm::relax::DataflowBlock; +using tvm::relax::transform::CreateFunctionPass; +using tvm::transform::CreateModulePass; /*! * \brief This pass is designed to annotate the memory scope information via VDevice attribute. diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index b627dc35482b..b50aa280a476 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -249,8 +249,8 @@ TVM_DLL Pass FoldConstant(); * showing up in the database. * \return The Pass. */ -TVM_DLL Pass LegalizeOps(ffi::Optional> cmap, ffi::Optional> skip_ops, - bool enable_warning = false, bool add_attributes = false); +TVM_DLL Pass LegalizeOps(ffi::Optional> cmap, + ffi::Optional> skip_ops, bool enable_warning = false); /*! * \brief Propagate virtual device information. diff --git a/include/tvm/runtime/tensor.h b/include/tvm/runtime/tensor.h index e32101aac2dd..615cfd8cccfe 100644 --- a/include/tvm/runtime/tensor.h +++ b/include/tvm/runtime/tensor.h @@ -178,6 +178,16 @@ class Tensor : public tvm::ffi::Tensor { */ TVM_DLL static void CopyToBytes(const DLTensor* from, void* to, size_t nbytes, TVMStreamHandle stream = nullptr); + + /*! + * \brief Function to copy data from one array to a byte buffer. + * \param from The source array. + * \param to The target byte buffer. + * \param nbytes The size of the data buffer. + * \param stream The stream used in copy. + */ + TVM_DLL static void CopyFromBytes(const DLTensor* to, void* from, size_t nbytes, + TVMStreamHandle stream = nullptr); }; /*! diff --git a/python/tvm/relax/backend/adreno/transform/_ffi_api.py b/python/tvm/relax/backend/adreno/transform/_ffi_api.py index 7ed23bd57b19..7a19e3380feb 100644 --- a/python/tvm/relax/backend/adreno/transform/_ffi_api.py +++ b/python/tvm/relax/backend/adreno/transform/_ffi_api.py @@ -14,6 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations """FFI APIs for Adreno transform""" -import tvm._ffi +import tvm.ffi -tvm._ffi._init_api("relax.backend.adreno.transform", __name__) +tvm.ffi._init_api("relax.backend.adreno.transform", __name__) diff --git a/python/tvm/tir/analysis/analysis.py b/python/tvm/tir/analysis/analysis.py index cc23bb939588..8a84d3ee51fa 100644 --- a/python/tvm/tir/analysis/analysis.py +++ b/python/tvm/tir/analysis/analysis.py @@ -19,7 +19,6 @@ from typing import Dict, List, Optional, Union import tvm -from tvm import _ffi from tvm.ir import IRModule from tvm.tir.expr import Var from tvm.tir.stmt import Block, BufferRegion, PrimExpr @@ -303,7 +302,7 @@ def find_anchor_block(mod: IRModule) -> Block: def has_if_then_else(stmt: Stmt) -> bool: - return _ffi.get_global_func("tir.schedule.HasIfThenElse")(stmt) + return tvm.ffi.get_global_func("tir.schedule.HasIfThenElse")(stmt) def get_vtcm_compaction_passes() -> List[tvm.transform.Pass]: diff --git a/src/relax/backend/adreno/annotate_custom_storage.cc b/src/relax/backend/adreno/annotate_custom_storage.cc index 63a98ecbedad..3896acca908d 100644 --- a/src/relax/backend/adreno/annotate_custom_storage.cc +++ b/src/relax/backend/adreno/annotate_custom_storage.cc @@ -716,7 +716,7 @@ class DefineVDevice : ExprMutator { return vdevice->target; } } - return NullOpt; + return std::nullopt; } const Op& hint_on_device_op_ = Op::Get("relax.hint_on_device"); @@ -732,8 +732,7 @@ class DefineVDevice : ExprMutator { namespace transform { Pass AnnotateCustomMemoryScope(Target target) { - runtime::TypedPackedFunc pass_func = [=](IRModule mod, - PassContext pc) { + auto pass_func = [=](IRModule mod, PassContext pc) { return tvm::relax::backend::adreno::DefineVDevice(target).Run(mod); }; return CreateModulePass(/*pass_function=*/pass_func, @@ -742,7 +741,7 @@ Pass AnnotateCustomMemoryScope(Target target) { /*required=*/{}); } -TVM_REGISTER_GLOBAL("relax.backend.adreno.transform.AnnotateCustomMemoryScope") +TVM_FFI_REGISTER_GLOBAL("relax.backend.adreno.transform.AnnotateCustomMemoryScope") .set_body_typed(AnnotateCustomMemoryScope); } // namespace transform diff --git a/src/relax/backend/adreno/fold_vdevice_scope_change.cc b/src/relax/backend/adreno/fold_vdevice_scope_change.cc index bb3a863743e9..73c1e51acb90 100644 --- a/src/relax/backend/adreno/fold_vdevice_scope_change.cc +++ b/src/relax/backend/adreno/fold_vdevice_scope_change.cc @@ -44,7 +44,7 @@ namespace backend { namespace adreno { namespace { -std::tuple)>> CreatePatterns( +std::tuple)>> CreatePatterns( Map> consumers) { auto pat_gv = WildcardPattern(); @@ -177,9 +177,9 @@ Pass FoldVDeviceScopeChange() { auto [pattern, rewriter] = CreatePatterns(consumers); return RewriteCall(pattern, rewriter, func); }; - return tvm::relax::transform::CreateFunctionPass(pass_func, 1, "FoldVDeviceScopeChange", {}); + return CreateFunctionPass(pass_func, 1, "FoldVDeviceScopeChange", {}); } -TVM_REGISTER_GLOBAL("relax.backend.adreno.transform.FoldVDeviceScopeChange") +TVM_FFI_REGISTER_GLOBAL("relax.backend.adreno.transform.FoldVDeviceScopeChange") .set_body_typed(FoldVDeviceScopeChange); } // namespace transform } // namespace adreno diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index 345832da8b61..c1750cd545fc 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -1518,14 +1518,12 @@ Expr MakeHintOnDevice(Expr data, Device device, String memory_scope = "global") TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def_packed("relax.op.hint_on_device", [](ffi::PackedArgs args, ffi::Any* rv), { - Expr data = args[0].cast(); - Device device = args[1].cast(); + refl::GlobalDef().def_packed("relax.op.hint_on_device", [](ffi::PackedArgs args, ffi::Any* ret) -> Expr , { if (args.size() == 3) { - String scope = args[2].cast(); - *rv = MakeHintOnDevice(data, device, scope); + *ret = + MakeHintOnDevice(args[0].cast(), args[1].cast(), args[2].cast()); } else { - *rv = MakeHintOnDevice(data, device); + *ret = MakeHintOnDevice(args[0].cast(), args[1].cast()); } }); } diff --git a/src/relax/transform/legalize_ops.cc b/src/relax/transform/legalize_ops.cc index e10d770a14fb..c159fa934779 100644 --- a/src/relax/transform/legalize_ops.cc +++ b/src/relax/transform/legalize_ops.cc @@ -472,7 +472,8 @@ class LegalizeMutator : public ExprMutator { namespace transform { -Pass LegalizeOps(ffi::Optional> cmap, ffi::Optional> skip_ops, +Pass LegalizeOps(ffi::Optional> cmap, + ffi::Optional> skip_ops, bool enable_warning) { auto pass_func = [=](IRModule mod, PassContext pc) { bool apply_legalize_ops = diff --git a/src/relax/transform/specialize_primfunc_based_on_callsite.cc b/src/relax/transform/specialize_primfunc_based_on_callsite.cc index fe2cc9329860..0e4ce22b537f 100644 --- a/src/relax/transform/specialize_primfunc_based_on_callsite.cc +++ b/src/relax/transform/specialize_primfunc_based_on_callsite.cc @@ -155,15 +155,16 @@ class SpecializeTIRCallArgs : ExprMutator { namespace transform { Pass SpecializePrimFuncBasedOnCallSite() { - runtime::TypedPackedFunc pass_func = - [=](IRModule mod, PassContext pc) { return relax::SpecializeTIRCallArgs().Run(mod); }; + auto pass_func = [=](IRModule mod, PassContext pc) { + return relax::SpecializeTIRCallArgs().Run(mod); + }; return CreateModulePass(/*pass_function=*/pass_func, /*opt_level=*/0, /*pass_name=*/"SpecializePrimFuncBasedOnCallSite", /*required=*/{}); } -TVM_REGISTER_GLOBAL("relax.transform.SpecializePrimFuncBasedOnCallSite") +TVM_FFI_REGISTER_GLOBAL("relax.transform.SpecializePrimFuncBasedOnCallSite") .set_body_typed(SpecializePrimFuncBasedOnCallSite); } // namespace transform diff --git a/src/runtime/contrib/clml/clml_runtime.cc b/src/runtime/contrib/clml/clml_runtime.cc index c166d0fb4bed..c4475618107e 100644 --- a/src/runtime/contrib/clml/clml_runtime.cc +++ b/src/runtime/contrib/clml/clml_runtime.cc @@ -315,7 +315,7 @@ class CLMLRuntime : public JSONRuntimeBase { const auto f = tvm::ffi::Function::GetGlobal("runtime.SaveParams"); if (f.has_value()) { - std::string dump_bytes = (*f)(dump_tensors); + std::string dump_bytes = (*f)(dump_tensors).cast(); std::ostringstream oss; /*TODO(Siva) HEX encoding doubles the size, look for better encode that can cross the RPC. */ for (size_t i = 0; i < dump_bytes.size(); ++i) { @@ -466,7 +466,7 @@ class CLMLRuntime : public JSONRuntimeBase { cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype); int dtype_size = cl_dtype == CL_FLOAT ? 4 : 2; void* tmpptr = reinterpret_cast(malloc(isize * dtype_size)); - TVMTensorCopyToBytes(const_cast(data_entry_[eid]), const_cast(tmpptr), + NDArray::CopyToBytes(const_cast(data_entry_[eid]), const_cast(tmpptr), isize * dtype_size); CopyDataToCLMLTensor(layer_.inputs[nid], tmpptr); free(tmpptr); @@ -481,7 +481,7 @@ class CLMLRuntime : public JSONRuntimeBase { if (cws->workspace->IsProfiling(cws->tentry->device)) { Timer t; auto f = tvm::ffi::Function::GetGlobal(std::string("profiling.timer.opencl")); - t = f->operator()(cws->tentry->device); + t = f->operator()(cws->tentry->device).cast(); t->Start(); queue = CLML_QUEUE; evts.resize(evts.size() + 1); @@ -502,7 +502,7 @@ class CLMLRuntime : public JSONRuntimeBase { if (cws->workspace->IsProfiling(cws->tentry->device)) { Timer t; auto f = tvm::ffi::Function::GetGlobal(std::string("profiling.timer.opencl")); - t = f->operator()(cws->tentry->device); + t = f->operator()(cws->tentry->device).cast(); t->Start(); queue = CLML_QUEUE; evts.resize(evts.size() + 1); @@ -553,7 +553,7 @@ class CLMLRuntime : public JSONRuntimeBase { void* tmpptr = reinterpret_cast(malloc(osize * dtype_size)); CopyDataFromCLMLTensor(layer_.outputs[0], tmpptr); - TVMTensorCopyFromBytes(const_cast(data_entry_[eid]), const_cast(tmpptr), + NDArray::CopyFromBytes(const_cast(data_entry_[eid]), const_cast(tmpptr), osize * dtype_size); free(tmpptr); } diff --git a/src/runtime/tensor.cc b/src/runtime/tensor.cc index f44e7a882a11..3895db979894 100644 --- a/src/runtime/tensor.cc +++ b/src/runtime/tensor.cc @@ -97,8 +97,28 @@ void Tensor::CopyToBytes(const DLTensor* handle, void* data, size_t nbytes, DeviceAPI::Get(handle->device)->StreamSync(handle->device, stream); } -Tensor Tensor::Empty(ffi::Shape shape, DLDataType dtype, Device dev, - ffi::Optional mem_scope) { +void NDArray::CopyFromBytes(const DLTensor* handle, void* data, size_t nbytes, + TVMStreamHandle stream) { + size_t arr_size = GetDataSize(*handle); + ICHECK_EQ(arr_size, nbytes) << "ArrayCopyToBytes: size mismatch"; + ICHECK(ffi::IsContiguous(*handle)) << "ArrayCopyToBytes only support contiguous array for now"; + + DLTensor from; + from.data = const_cast(data); + from.device = Device{kDLCPU, 0}; + from.ndim = handle->ndim; + from.dtype = handle->dtype; + from.shape = handle->shape; + from.strides = nullptr; + from.byte_offset = 0; + + DeviceAPI::Get(handle->device)->CopyDataFromTo(&from, const_cast(handle), stream); + // Synchronize in case data become unavailable later. + DeviceAPI::Get(handle->device)->StreamSync(handle->device, stream); +} + +NDArray NDArray::Empty(ffi::Shape shape, DLDataType dtype, Device dev, + ffi::Optional mem_scope) { struct DeviceAPIAlloc { void AllocData(DLTensor* tensor, ffi::Optional mem_scope) { tensor->data = DeviceAPI::Get(tensor->device) diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 7cc15a407ead..4fbdef1eafea 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -2164,6 +2164,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { } else { return "O"; } + }) + .def("tir.schedule.HasIfThenElse", [](const Stmt& stmt) -> bool { return HasIfThenElse(stmt); }); } From 5c3dd5f3500f0c6d2034be1ce77f1f6d55542ea1 Mon Sep 17 00:00:00 2001 From: Siva Date: Wed, 25 Jun 2025 14:04:57 +0530 Subject: [PATCH 25/31] Remove VDevice from legalization. Let tir emit handle it. --- .../legalize_ops/adreno/convolution.py | 1 + python/tvm/relax/utils.py | 20 +++--- src/relax/transform/legalize_ops.cc | 67 ------------------- .../test_transform_annotate_custom_scope.py | 20 ++++-- 4 files changed, 26 insertions(+), 82 deletions(-) diff --git a/python/tvm/relax/transform/legalize_ops/adreno/convolution.py b/python/tvm/relax/transform/legalize_ops/adreno/convolution.py index eb0bf30cfbf2..959e43778024 100644 --- a/python/tvm/relax/transform/legalize_ops/adreno/convolution.py +++ b/python/tvm/relax/transform/legalize_ops/adreno/convolution.py @@ -32,5 +32,6 @@ def conv2d_NCHWc_OIHWo(bb: relax.BlockBuilder, call: relax.Call) -> relax.Expr: layout=call.attrs.data_layout, out_layout=call.attrs.out_layout, # out_dtype=call.attrs.out_dtype, + sinfo_args=call.sinfo_args, primfunc_name_hint="conv2d_NCHWc_OIHWo", ) diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py index 7ce188f780c3..76897eefd707 100644 --- a/python/tvm/relax/utils.py +++ b/python/tvm/relax/utils.py @@ -347,6 +347,7 @@ def _shape_with_old_tir_var( ) primfunc_attrs = kwargs.pop("primfunc_attrs", None) + custom_out_sinfo = kwargs.pop("sinfo_args", []) te_args = _convert_te_arg(args) te_kwargs = _convert_te_arg(kwargs) @@ -371,14 +372,17 @@ def _shape_with_old_tir_var( # with old set of variables. tir_var_inverse_map = {v: k for k, v in tir_var_map.items()} - output_sinfo = [ - TensorStructInfo( - _shape_with_old_tir_var(out.shape, tir_var_inverse_map), - out.dtype, - _get_vdevice(args), - ) - for out in outs - ] + if len(custom_out_sinfo) == 1: + output_sinfo = custom_out_sinfo[0] + else: + output_sinfo = [ + TensorStructInfo( + _shape_with_old_tir_var(out.shape, tir_var_inverse_map), + out.dtype, + _get_vdevice(args), + ) + for out in outs + ] tir_vars = None if len(unbound_tir_vars) > 0: diff --git a/src/relax/transform/legalize_ops.cc b/src/relax/transform/legalize_ops.cc index c159fa934779..a49d448e9369 100644 --- a/src/relax/transform/legalize_ops.cc +++ b/src/relax/transform/legalize_ops.cc @@ -38,12 +38,6 @@ namespace relax { TVM_REGISTER_PASS_CONFIG_OPTION("relax.transform.apply_legalize_ops", Bool); -static Array GetShapeFromTensorStructInfo(const TensorStructInfo& tensor_sinfo) { - auto shape = tensor_sinfo->GetShape(); - ICHECK(shape.defined()); - return shape.value(); -} - /*! * \brief Check if a given Tensor/Shape/TupleStructInfo contains shapes whose * values are all known. @@ -166,63 +160,6 @@ class LegalizeMutator : public ExprMutator { return std::nullopt; } - Expr UpdateVDeviceOutStructInfo(Expr expr, const Call& visited_call, - const StructInfo& infered_sinfo) { - static const Op& call_tir_op = Op::Get("relax.call_tir"); - auto* op_node = visited_call->op.as(); - - // Not an OpNode - if (op_node == nullptr) { - return expr; - } - auto op = GetRef(op_node); - - if (!expr->IsInstance()) { - return expr; - } - - auto call = Downcast(expr); - if (call->op != call_tir_op) { - return expr; - } - - StructInfo out_sinfo = call->sinfo_args[0]; - - if (out_sinfo->IsInstance()) { - auto out_tsinfo = Downcast(out_sinfo); - auto infered_tsinfo = Downcast(infered_sinfo); - auto shape_arr = GetShapeFromTensorStructInfo(out_tsinfo); - if (infered_tsinfo->vdevice.defined()) { - out_sinfo = TensorStructInfo(ShapeExpr(shape_arr), out_tsinfo->dtype, - infered_tsinfo->vdevice.value()); - } - } else if (out_sinfo->IsInstance()) { - const auto& tuple_sinfo = Downcast(out_sinfo); - const auto& infered_tuple_sinfo = Downcast(infered_sinfo); - Array sinfo_fields; - int index = 0; - for (const auto& si : tuple_sinfo->fields) { - ICHECK(si->IsInstance()) - << "Fields of TupleStructInfo must be TensorStructInfo for call_tir " - "output structinfo, but got " - << si; - auto tsinfo = Downcast(si); - auto shape_arr = GetShapeFromTensorStructInfo(tsinfo); - auto infered_tsinfo = Downcast(infered_tuple_sinfo->fields[index]); - if (infered_tsinfo->vdevice.defined()) { - sinfo_fields.push_back(TensorStructInfo(ShapeExpr(shape_arr), tsinfo->dtype, - infered_tsinfo->vdevice.value())); - } else { - sinfo_fields.push_back(tsinfo); - } - ++index; - } - out_sinfo = TupleStructInfo(sinfo_fields); - } - - return Call(call_tir_op, call->args, call->attrs, {out_sinfo}); - } - Expr BindTarget(Expr expr) { if (!expr->IsInstance()) { // FLegalize returned something other than a relax::Call. This @@ -294,7 +231,6 @@ class LegalizeMutator : public ExprMutator { Expr VisitExpr_(const CallNode* call) final { Call visited_call = Downcast(this->VisitExprPostOrder_(call)); - static const auto& infer_struct_info_map = Op::GetAttrMap("FInferStructInfo"); static const auto& legalize_map = Op::GetAttrMap("FLegalize"); static const auto& call_packed_map = Op::GetAttrMap("FCallPacked"); static const auto& requires_arg_shapes_map = Op::GetAttrMap("RequiresArgumentShapes"); @@ -418,9 +354,6 @@ class LegalizeMutator : public ExprMutator { } Expr legalized = legalization_func(builder_, visited_call); - StructInfo infered_sinfo = infer_struct_info_map[op](GetRef(call), builder_); - legalized = UpdateVDeviceOutStructInfo(legalized, visited_call, infered_sinfo); - // Append the target attribute to any PrimFunc generated in // legalization. legalized = BindTarget(legalized); diff --git a/tests/python/relax/adreno/test_transform_annotate_custom_scope.py b/tests/python/relax/adreno/test_transform_annotate_custom_scope.py index 02f264cca30e..24b4cf66b888 100644 --- a/tests/python/relax/adreno/test_transform_annotate_custom_scope.py +++ b/tests/python/relax/adreno/test_transform_annotate_custom_scope.py @@ -45,23 +45,29 @@ def visit_call_(self, call: relax.Call) -> None: # pylint: disable=arguments-re assert isinstance( arg_sinfo, relax.TensorStructInfo ), f"Expected TensorStructInfo but git {type(arg_sinfo)}" + call_mem_scope = ( + "global" if not arg_sinfo.vdevice else arg_sinfo.vdevice.memory_scope + ) assert ( - arg_sinfo.vdevice.memory_scope - == self.scope_info[call.args[0].name_hint][0][idx] + call_mem_scope == self.scope_info[call.args[0].name_hint][0][idx] ), f"Scope mismatched for argument {idx} in {call.args[0].name_hint}" if isinstance(call.sinfo_args[0], relax.TensorStructInfo): + call_mem_scope = ( + "global" + if not call.sinfo_args[0].vdevice + else call.sinfo_args[0].vdevice.memory_scope + ) assert ( - call.sinfo_args[0].vdevice.memory_scope - == self.scope_info[call.args[0].name_hint][1][0] + call_mem_scope == self.scope_info[call.args[0].name_hint][1][0] ), f"Scope mismatched for return scope: {call.args[0].name_hint}" else: assert isinstance( call.sinfo_args[0], relax.TupleStructInfo ), f"Expected TupleStructInfo but git {type(call.sinfo_args[0])}" for idx, sinfo in enumerate(call.sinfo_args[0].fields): + call_mem_scope = "global" if not sinfo.vdevice else sinfo.vdevice.memory_scope assert ( - sinfo.vdevice.memory_scope - == self.scope_info[call.args[0].name_hint][1][idx] + call_mem_scope == self.scope_info[call.args[0].name_hint][1][idx] ), f"Scope mismatched for return scope for {idx} in {call.args[0].name_hint}" @@ -917,7 +923,7 @@ def main( ["global"], ), "te_layout_transform4": (["global"], ["global"]), - "conv2d_opencl": (["global", "global"], ["global"]), + "conv2d": (["global", "global"], ["global"]), "te_layout_transform5": (["global"], ["global"]), "concatenate": (["global", "global"], ["global"]), } From f7dfd2effc9eba40391f122856dcefd442fd4e63 Mon Sep 17 00:00:00 2001 From: Siva Date: Mon, 21 Jul 2025 12:43:18 +0530 Subject: [PATCH 26/31] Rebase after reflection changes. --- .../backend/adreno/annotate_custom_storage.cc | 14 ++++++++------ .../backend/adreno/fold_vdevice_scope_change.cc | 7 +++++-- src/relax/op/op.cc | 5 ++--- .../specialize_primfunc_based_on_callsite.cc | 8 +++++--- src/tir/schedule/analysis/analysis.cc | 1 + 5 files changed, 21 insertions(+), 14 deletions(-) diff --git a/src/relax/backend/adreno/annotate_custom_storage.cc b/src/relax/backend/adreno/annotate_custom_storage.cc index 3896acca908d..bd6daf9d3228 100644 --- a/src/relax/backend/adreno/annotate_custom_storage.cc +++ b/src/relax/backend/adreno/annotate_custom_storage.cc @@ -358,13 +358,13 @@ class CollectConsumerScopeInfo : public ExprVisitor { op_attrs = val.value(); } } - return std::move(op_attrs); + return op_attrs; } template Optional ExtractPattern(const T& func) { Optional op_pat = func->template GetAttr("op_pattern"); - return std::move(op_pat); + return op_pat; } bool SupportsTexture(const Array& op_attrs, Integer op_pattern) { @@ -705,7 +705,7 @@ class DefineVDevice : ExprMutator { Expr new_arg = Call(hint_on_device_op_, {arg}, Attrs{std::move(attrs)}, {}); - return std::move(new_arg); + return new_arg; } Optional GetTarget(const StructInfo& sinfo) { @@ -741,9 +741,11 @@ Pass AnnotateCustomMemoryScope(Target target) { /*required=*/{}); } -TVM_FFI_REGISTER_GLOBAL("relax.backend.adreno.transform.AnnotateCustomMemoryScope") - .set_body_typed(AnnotateCustomMemoryScope); - +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.backend.adreno.transform.AnnotateCustomMemoryScope", + AnnotateCustomMemoryScope); +}); } // namespace transform } // namespace adreno } // namespace backend diff --git a/src/relax/backend/adreno/fold_vdevice_scope_change.cc b/src/relax/backend/adreno/fold_vdevice_scope_change.cc index 73c1e51acb90..2dc72d305b5f 100644 --- a/src/relax/backend/adreno/fold_vdevice_scope_change.cc +++ b/src/relax/backend/adreno/fold_vdevice_scope_change.cc @@ -179,8 +179,11 @@ Pass FoldVDeviceScopeChange() { }; return CreateFunctionPass(pass_func, 1, "FoldVDeviceScopeChange", {}); } -TVM_FFI_REGISTER_GLOBAL("relax.backend.adreno.transform.FoldVDeviceScopeChange") - .set_body_typed(FoldVDeviceScopeChange); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.backend.adreno.transform.FoldVDeviceScopeChange", + FoldVDeviceScopeChange); +}); } // namespace transform } // namespace adreno } // namespace backend diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index c1750cd545fc..c423b90f6964 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -1518,10 +1518,9 @@ Expr MakeHintOnDevice(Expr data, Device device, String memory_scope = "global") TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def_packed("relax.op.hint_on_device", [](ffi::PackedArgs args, ffi::Any* ret) -> Expr , { + refl::GlobalDef().def_packed("relax.op.hint_on_device", [](ffi::PackedArgs args, ffi::Any* ret) { if (args.size() == 3) { - *ret = - MakeHintOnDevice(args[0].cast(), args[1].cast(), args[2].cast()); + *ret = MakeHintOnDevice(args[0].cast(), args[1].cast(), args[2].cast()); } else { *ret = MakeHintOnDevice(args[0].cast(), args[1].cast()); } diff --git a/src/relax/transform/specialize_primfunc_based_on_callsite.cc b/src/relax/transform/specialize_primfunc_based_on_callsite.cc index 0e4ce22b537f..a6c49c2f4f84 100644 --- a/src/relax/transform/specialize_primfunc_based_on_callsite.cc +++ b/src/relax/transform/specialize_primfunc_based_on_callsite.cc @@ -164,9 +164,11 @@ Pass SpecializePrimFuncBasedOnCallSite() { /*required=*/{}); } -TVM_FFI_REGISTER_GLOBAL("relax.transform.SpecializePrimFuncBasedOnCallSite") - .set_body_typed(SpecializePrimFuncBasedOnCallSite); - +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.SpecializePrimFuncBasedOnCallSite", + SpecializePrimFuncBasedOnCallSite); +}); } // namespace transform } // namespace relax } // namespace tvm diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 4fbdef1eafea..7a8c6a1a8e85 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -2169,6 +2169,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { return HasIfThenElse(stmt); }); } +>>>>>>> f1d68472e (Rebase after reflection changes.) } // namespace tir } // namespace tvm From 1f173a64055090cc6d82fd3648a5b7ba562f072c Mon Sep 17 00:00:00 2001 From: Siva Date: Sat, 26 Jul 2025 09:33:58 +0530 Subject: [PATCH 27/31] Cross compiler options to work with configure --- tests/scripts/task_build_adreno_bins.sh | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/scripts/task_build_adreno_bins.sh b/tests/scripts/task_build_adreno_bins.sh index 91886281806d..8b85a27277e0 100755 --- a/tests/scripts/task_build_adreno_bins.sh +++ b/tests/scripts/task_build_adreno_bins.sh @@ -51,8 +51,7 @@ echo set\(USE_OPENCL_GTEST ON\) >> config.cmake echo set\(USE_OPENCL_EXTN_QCOM ON\) >> config.cmake -cmake -DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK_HOME}/build/cmake/android.toolchain.cmake" \ - -DANDROID_ABI=arm64-v8a \ +cmake -DANDROID_ABI=arm64-v8a \ -DANDROID_PLATFORM=android-28 \ -DCMAKE_SYSTEM_VERSION=1 \ -DCMAKE_FIND_ROOT_PATH="${ADRENO_OPENCL}" \ From 49730ac87f27c7883e25f4ecf9eeeec08898db20 Mon Sep 17 00:00:00 2001 From: Siva Date: Mon, 3 Nov 2025 14:12:47 +0530 Subject: [PATCH 28/31] Rebase --- include/tvm/relax/attrs/op.h | 2 +- .../backend/adreno/annotate_custom_storage.cc | 112 +++++++++--------- .../adreno/fold_vdevice_scope_change.cc | 37 +++--- src/relax/op/op.cc | 5 +- src/relax/op/tensor/manipulate.cc | 4 +- src/relax/transform/realize_vdevice.cc | 2 +- .../specialize_primfunc_based_on_callsite.cc | 16 +-- src/runtime/contrib/clml/clml_runtime.cc | 16 +-- src/runtime/tensor.cc | 8 +- src/tir/schedule/analysis/analysis.cc | 1 - .../test_transform_annotate_custom_scope1.py | 2 + 11 files changed, 103 insertions(+), 102 deletions(-) diff --git a/include/tvm/relax/attrs/op.h b/include/tvm/relax/attrs/op.h index e504a36bf3c9..54640901ff53 100644 --- a/include/tvm/relax/attrs/op.h +++ b/include/tvm/relax/attrs/op.h @@ -111,7 +111,7 @@ struct HintOnDeviceAttrs : public AttrsNodeReflAdapter { refl::ObjectDef() .def_ro("device_type", &HintOnDeviceAttrs::device_type, "The device type where the data is supposed to be executed.") - .def_ro("index", &HintOnDeviceAttrs::index, "The device id."); + .def_ro("index", &HintOnDeviceAttrs::index, "The device id.") .def_ro("memory_scope", &HintOnDeviceAttrs::memory_scope, "The device memory scope."); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.HintOnDeviceAttrs", HintOnDeviceAttrs, diff --git a/src/relax/backend/adreno/annotate_custom_storage.cc b/src/relax/backend/adreno/annotate_custom_storage.cc index bd6daf9d3228..447d6babf143 100644 --- a/src/relax/backend/adreno/annotate_custom_storage.cc +++ b/src/relax/backend/adreno/annotate_custom_storage.cc @@ -258,7 +258,7 @@ namespace adreno { using tvm::tir::Buffer; -static Array GetShapeFromTensorStructInfo(const TensorStructInfo& tensor_sinfo) { +static ffi::Array GetShapeFromTensorStructInfo(const TensorStructInfo& tensor_sinfo) { auto shape = tensor_sinfo->GetShape(); ICHECK(shape.defined()); return shape.value(); @@ -274,7 +274,7 @@ class CollectConsumerScopeInfo : public ExprVisitor { public: using ExprVisitor::VisitExpr_; - std::pair>, Map>>> Collect( + std::pair>, ffi::Map>>> Collect( const IRModule& mod, Function func, const Target& target) { mod_ = mod; target_ = target; @@ -299,17 +299,17 @@ class CollectConsumerScopeInfo : public ExprVisitor { void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* tuple_get_item_node) final { - if (arg_to_binding.find(GetRef(binding->var.get())) == arg_to_binding.end()) { - arg_to_binding.Set(GetRef(binding->var.get()), - GetRef(tuple_get_item_node->tuple.get())); + if (arg_to_binding.find(ffi::GetRef(binding->var.get())) == arg_to_binding.end()) { + arg_to_binding.Set(ffi::GetRef(binding->var.get()), + ffi::GetRef(tuple_get_item_node->tuple.get())); } } void VisitExpr_(const CallNode* call) final { static const Op& call_tir_op = Op::Get("relax.call_tir"); GlobalVar gv; - Array op_attrs; - Optional op_pattern = Integer(static_cast(OpPatternKind::kOpaque)); + ffi::Array op_attrs; + ffi::Optional op_pattern = Integer(static_cast(OpPatternKind::kOpaque)); Tuple func_args; if (call->op == call_tir_op) { @@ -326,35 +326,35 @@ class CollectConsumerScopeInfo : public ExprVisitor { bool is_texture_supported = SupportsTexture(op_attrs, op_pattern.value()); - Array arg_scope; + ffi::Array arg_scope; for (auto arg : func_args->fields) { auto sinfo = GetStructInfo(arg); if (auto tensor_sinfo = sinfo.as()) { auto scope = is_texture_supported ? Scope(GetShapeFromTensorStructInfo(tensor_sinfo.value())) : "global"; - Map> ent_call; + ffi::Map> ent_call; const VarNode* arg_var = arg.as(); - if (scope_info.find(GetRef(arg_var)) != scope_info.end()) { - ent_call = scope_info[GetRef(arg_var)]; + if (scope_info.find(ffi::GetRef(arg_var)) != scope_info.end()) { + ent_call = scope_info[ffi::GetRef(arg_var)]; } - ent_call.Set(GetRef(call), {scope}); - scope_info.Set(GetRef(arg_var), ent_call); + ent_call.Set(ffi::GetRef(call), {scope}); + scope_info.Set(ffi::GetRef(arg_var), ent_call); arg_scope.push_back(scope); } } - call_scope_info.Set(GetRef(call), arg_scope); + call_scope_info.Set(ffi::GetRef(call), arg_scope); } private: template - Array ExtractAttrs(const T& func) { - Array op_attrs; - Optional attrs = func->template GetAttr("op_attrs"); + ffi::Array ExtractAttrs(const T& func) { + ffi::Array op_attrs; + ffi::Optional attrs = func->template GetAttr("op_attrs"); if (attrs) { if (auto val = attrs.value().as()) { op_attrs.push_back(val.value()); - } else if (auto val = attrs.value().as>()) { + } else if (auto val = attrs.value().as>()) { op_attrs = val.value(); } } @@ -362,12 +362,12 @@ class CollectConsumerScopeInfo : public ExprVisitor { } template - Optional ExtractPattern(const T& func) { - Optional op_pat = func->template GetAttr("op_pattern"); + ffi::Optional ExtractPattern(const T& func) { + ffi::Optional op_pat = func->template GetAttr("op_pattern"); return op_pat; } - bool SupportsTexture(const Array& op_attrs, Integer op_pattern) { + bool SupportsTexture(const ffi::Array& op_attrs, Integer op_pattern) { if (op_pattern.IntValue() < OpPatternKind::kCommReduce) return true; for (auto attr : op_attrs) { @@ -391,7 +391,7 @@ class CollectConsumerScopeInfo : public ExprVisitor { return false; } - std::string Scope(Array shape) { + std::string Scope(ffi::Array shape) { // currently we support only textures been made from 5d tensors // 5d requirement is not limitation of textures in general, it is limitation how // we are representing memory scopes/layout and flattening of textures in tir @@ -428,10 +428,10 @@ class CollectConsumerScopeInfo : public ExprVisitor { } /* Map of each Var consumption by a call node and its scope */ - Map>> scope_info; + ffi::Map>> scope_info; /* A map of call node and scope info for each argument it consunes */ - Map> call_scope_info; - Map arg_to_binding; + ffi::Map> call_scope_info; + ffi::Map arg_to_binding; IRModule mod_; Target target_; }; @@ -446,8 +446,8 @@ class CollectProducerScopeInfo : public ExprVisitor { public: using ExprVisitor::VisitExpr_; - Map Collect(const IRModule& mod, Function func, - const Map>>& scope_info, + ffi::Map Collect(const IRModule& mod, Function func, + const ffi::Map>>& scope_info, const Target& target, const BlockBuilder& builder) { mod_ = mod; scope_info_ = scope_info; @@ -471,18 +471,18 @@ class CollectProducerScopeInfo : public ExprVisitor { Op::GetAttrMap("FInferStructInfo"); auto* op_ptr = call->op.as(); - Op op = GetRef(op_ptr); + Op op = ffi::GetRef(op_ptr); ICHECK(op_map_infer_struct_info_.count(op)) << " Cannot find the FInferStructInfo attribute registered to op: " << op->name; - out_sinfo = op_map_infer_struct_info_[op](GetRef(call), builder_); + out_sinfo = op_map_infer_struct_info_[op](ffi::GetRef(call), builder_); } - std::unordered_map scope_count; + std::unordered_map scope_count; // Decide the final scope based on the max consumer demand. Rest will use to_device. auto arg_var = binding->var.as(); - if (scope_info_.find(GetRef(arg_var)) != scope_info_.end()) { - for (const auto& val : scope_info_[GetRef(arg_var)]) { + if (scope_info_.find(ffi::GetRef(arg_var)) != scope_info_.end()) { + for (const auto& val : scope_info_[ffi::GetRef(arg_var)]) { auto call_node = Downcast(val.first); if (scope_count.find(val.second[0]) == scope_count.end()) { scope_count.insert({val.second[0], 1}); @@ -492,7 +492,7 @@ class CollectProducerScopeInfo : public ExprVisitor { } } } - String final_scope = "global"; + ffi::String final_scope = "global"; int count = 0; for (const auto& sval : scope_count) { if (sval.second > count) { @@ -502,11 +502,11 @@ class CollectProducerScopeInfo : public ExprVisitor { } // Applying same scope for outputs StructInfo updated_ret_sinfo = UpdateStructInfo(out_sinfo, {final_scope}); - producer_sinfo.Set(GetRef(call), updated_ret_sinfo); + producer_sinfo.Set(ffi::GetRef(call), updated_ret_sinfo); } private: - StructInfo UpdateStructInfo(const StructInfo& out_sinfo, Array scope) { + StructInfo UpdateStructInfo(const StructInfo& out_sinfo, ffi::Array scope) { if (out_sinfo->IsInstance()) { auto tensor_sinfo = Downcast(out_sinfo); auto shape_arr = GetShapeFromTensorStructInfo(tensor_sinfo); @@ -520,7 +520,7 @@ class CollectProducerScopeInfo : public ExprVisitor { << out_sinfo; const auto& tuple_sinfo = Downcast(out_sinfo); - Array sinfo_fields; + ffi::Array sinfo_fields; for (const auto& si : tuple_sinfo->fields) { ICHECK(si->IsInstance()) << "Fields of TupleStructInfo must be TensorStructInfo for call_tir " @@ -534,8 +534,8 @@ class CollectProducerScopeInfo : public ExprVisitor { return TupleStructInfo(sinfo_fields); } - Map>> scope_info_; - Map producer_sinfo; + ffi::Map>> scope_info_; + ffi::Map producer_sinfo; IRModule mod_; Target target_; BlockBuilder builder_; @@ -572,7 +572,7 @@ class DefineVDevice : ExprMutator { } mod_.CopyOnWrite()->Update(updates_); - Array global_vdevices_; + ffi::Array global_vdevices_; for (auto vdev : vdevices_) { global_vdevices_.push_back(vdev.as().value()); } @@ -606,8 +606,8 @@ class DefineVDevice : ExprMutator { // return call; } - Array new_args; - StructInfo updated_ret_sinfo = producer_sinfo_[GetRef(call_node)]; + ffi::Array new_args; + StructInfo updated_ret_sinfo = producer_sinfo_[ffi::GetRef(call_node)]; if (updated_ret_sinfo->IsInstance()) { auto tensor_sinfo = Downcast(updated_ret_sinfo); @@ -625,7 +625,7 @@ class DefineVDevice : ExprMutator { << updated_ret_sinfo; const auto& tuple_sinfo = Downcast(updated_ret_sinfo); - Array sinfo_fields; + ffi::Array sinfo_fields; for (const auto& si : tuple_sinfo->fields) { ICHECK(si->IsInstance()) << "Fields of TupleStructInfo must be TensorStructInfo for call_tir " @@ -652,9 +652,9 @@ class DefineVDevice : ExprMutator { for (auto arg : func_args->fields) { auto sinfo = GetStructInfo(arg); if (auto tensor_sinfo = sinfo.as()) { - String scope = "global"; - if (call_scope_info_.find(GetRef(call_node)) != call_scope_info_.end()) { - scope = call_scope_info_[GetRef(call_node)][arg_idx]; + ffi::String scope = "global"; + if (call_scope_info_.find(ffi::GetRef(call_node)) != call_scope_info_.end()) { + scope = call_scope_info_[ffi::GetRef(call_node)][arg_idx]; } new_args.push_back(HintArg(arg, scope)); arg_idx++; @@ -685,7 +685,7 @@ class DefineVDevice : ExprMutator { return (vdevices_.back()); } - Expr HintArg(const Expr& arg, String scope) { + Expr HintArg(const Expr& arg, ffi::String scope) { if (arg->IsInstance()) { if (auto tsinfo = arg->struct_info_.as()) { if (!tsinfo->vdevice.defined()) { @@ -697,10 +697,10 @@ class DefineVDevice : ExprMutator { } } } - ObjectPtr attrs = make_object(); + ObjectPtr attrs = ffi::make_object(); const VDevice& vdev = MakeGlobalVDevice(VDevice(target_, 0, scope)); - attrs->dev_type = vdev->target->GetTargetDeviceType(); - attrs->dev_id = vdev->vdevice_id; + attrs->device_type = vdev->target->GetTargetDeviceType(); + attrs->index = vdev->vdevice_id; attrs->memory_scope = vdev->memory_scope; Expr new_arg = Call(hint_on_device_op_, {arg}, Attrs{std::move(attrs)}, {}); @@ -708,7 +708,7 @@ class DefineVDevice : ExprMutator { return new_arg; } - Optional GetTarget(const StructInfo& sinfo) { + ffi::Optional GetTarget(const StructInfo& sinfo) { auto tinfo = sinfo.as(); if (tinfo->vdevice.defined()) { auto vdevice = tinfo->vdevice.value(); @@ -723,10 +723,10 @@ class DefineVDevice : ExprMutator { IRModule mod_; IRModule updates_; Target target_; - Array vdevices_; - Map>> scope_info_; - Map producer_sinfo_; - Map> call_scope_info_; + ffi::Array vdevices_; + ffi::Map>> scope_info_; + ffi::Map producer_sinfo_; + ffi::Map> call_scope_info_; }; namespace transform { @@ -741,11 +741,11 @@ Pass AnnotateCustomMemoryScope(Target target) { /*required=*/{}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.backend.adreno.transform.AnnotateCustomMemoryScope", AnnotateCustomMemoryScope); -}); +} } // namespace transform } // namespace adreno } // namespace backend diff --git a/src/relax/backend/adreno/fold_vdevice_scope_change.cc b/src/relax/backend/adreno/fold_vdevice_scope_change.cc index 2dc72d305b5f..5d95336fca15 100644 --- a/src/relax/backend/adreno/fold_vdevice_scope_change.cc +++ b/src/relax/backend/adreno/fold_vdevice_scope_change.cc @@ -44,15 +44,15 @@ namespace backend { namespace adreno { namespace { -std::tuple)>> CreatePatterns( - Map> consumers) { +std::tuple)>> CreatePatterns( + ffi::Map> consumers) { auto pat_gv = WildcardPattern(); auto pat_inp = WildcardPattern(); auto pat_call_tir = IsOp("relax.call_tir")(pat_gv, pat_inp); auto pattern_out = IsOp("relax.to_vdevice")(pat_call_tir); - auto rewriter = [=](Expr expr, Map matches) -> Expr { + auto rewriter = [=](Expr expr, ffi::Map matches) -> Expr { const auto* call_tir = matches[pat_call_tir].as(); ICHECK(call_tir) << "InternalError: " << "Match of relax.call_tir operator should produce Call, " @@ -76,8 +76,8 @@ std::tuple)>> Crea if (!tir_out_sinfo->vdevice.defined()) return expr; const VarNode* arg_var = out->args[0].as(); - if (consumers.find(GetRef(arg_var)) != consumers.end()) { - if (consumers[GetRef(arg_var)].size() > 1) { + if (consumers.find(ffi::GetRef(arg_var)) != consumers.end()) { + if (consumers[ffi::GetRef(arg_var)].size() > 1) { /* Don't do to_device optimization as we are not the only consumer */ return expr; } @@ -104,7 +104,7 @@ class CollectConsumerDetails : public ExprVisitor { public: using ExprVisitor::VisitExpr_; - Map> Collect(const IRModule& mod, Function func, const Target& target) { + ffi::Map> Collect(const IRModule& mod, Function func, const Target& target) { mod_ = mod; target_ = target; VisitExpr(func->body); @@ -127,9 +127,9 @@ class CollectConsumerDetails : public ExprVisitor { void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* tuple_get_item_node) final { - if (arg_to_binding.find(GetRef(binding->var.get())) == arg_to_binding.end()) { - arg_to_binding.Set(GetRef(binding->var.get()), - GetRef(tuple_get_item_node->tuple.get())); + if (arg_to_binding.find(ffi::GetRef(binding->var.get())) == arg_to_binding.end()) { + arg_to_binding.Set(ffi::GetRef(binding->var.get()), + ffi::GetRef(tuple_get_item_node->tuple.get())); } } @@ -146,23 +146,23 @@ class CollectConsumerDetails : public ExprVisitor { for (auto arg : func_args->fields) { auto sinfo = GetStructInfo(arg); if (auto tensor_sinfo = sinfo.as()) { - Array call_list; + ffi::Array call_list; const VarNode* arg_var = arg.as(); - if (consumers.find(GetRef(arg_var)) != consumers.end()) { - call_list = consumers[GetRef(arg_var)]; + if (consumers.find(ffi::GetRef(arg_var)) != consumers.end()) { + call_list = consumers[ffi::GetRef(arg_var)]; } - call_list.push_back(GetRef(call)); - consumers.Set(GetRef(arg_var), call_list); + call_list.push_back(ffi::GetRef(call)); + consumers.Set(ffi::GetRef(arg_var), call_list); } } } private: /* Map of each Var consumption by a call node */ - Map> consumers; - Map arg_to_binding; + ffi::Map> consumers; + ffi::Map arg_to_binding; IRModule mod_; Target target_; }; @@ -179,11 +179,12 @@ Pass FoldVDeviceScopeChange() { }; return CreateFunctionPass(pass_func, 1, "FoldVDeviceScopeChange", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ + +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.backend.adreno.transform.FoldVDeviceScopeChange", FoldVDeviceScopeChange); -}); +} } // namespace transform } // namespace adreno } // namespace backend diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index c423b90f6964..ee6ee1619f9f 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -1506,12 +1506,11 @@ TVM_REGISTER_OP("relax.hint_on_device") .set_attr("FInferStructInfo", InferHintOnDeviceStructInfo) .set_attr("FPurity", Bool(true)); -Expr MakeHintOnDevice(Expr data, Device device, String memory_scope = "global") { +Expr MakeHintOnDevice(Expr data, Device device, ffi::String memory_scope = "global") { static const Op& op = Op::Get("relax.hint_on_device"); ObjectPtr attrs = ffi::make_object(); attrs->device_type = static_cast(device.device_type); attrs->index = device.device_id; - attrs->memory_scope = vdevice->memory_scope; attrs->memory_scope = memory_scope; return Call(op, {data}, Attrs(attrs), {}); } @@ -1520,7 +1519,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("relax.op.hint_on_device", [](ffi::PackedArgs args, ffi::Any* ret) { if (args.size() == 3) { - *ret = MakeHintOnDevice(args[0].cast(), args[1].cast(), args[2].cast()); + *ret = MakeHintOnDevice(args[0].cast(), args[1].cast(), args[2].cast()); } else { *ret = MakeHintOnDevice(args[0].cast(), args[1].cast()); } diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 7c2c220dcf44..1d612c35e44e 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -361,8 +361,8 @@ InferLayoutOutput InferLayoutConcat( << call->op << " expects the input to be a Tuple of Tensors. However, the given input is " << call->args[0]->struct_info_; - auto t_sinfo = GetRef(field_tensor_sinfo); - Optional t_shape = GetRef(t_sinfo->shape.as()); + auto t_sinfo = ffi::GetRef(field_tensor_sinfo); + ffi::Optional t_shape = ffi::GetRef(t_sinfo->shape.as()); LayoutDecision curr_layout = nlayout_array[i].LeafValue(); if (!CanProveLayoutTransform(curr_layout->layout, in_layout->layout, t_shape.value()->values)) { diff --git a/src/relax/transform/realize_vdevice.cc b/src/relax/transform/realize_vdevice.cc index b9456052264a..7f1042d57ecc 100644 --- a/src/relax/transform/realize_vdevice.cc +++ b/src/relax/transform/realize_vdevice.cc @@ -56,7 +56,7 @@ class VDeviceLookup { ICHECK(attrs); int32_t device_type = attrs->device_type; int32_t device_id = attrs->index; - String memory_scope = attrs->memory_scope; + ffi::String memory_scope = attrs->memory_scope; CHECK(opt_vdevices_.defined()) << "ValueError: The target VDevice in the GlobalInfos was not found."; diff --git a/src/relax/transform/specialize_primfunc_based_on_callsite.cc b/src/relax/transform/specialize_primfunc_based_on_callsite.cc index a6c49c2f4f84..6258e14b666d 100644 --- a/src/relax/transform/specialize_primfunc_based_on_callsite.cc +++ b/src/relax/transform/specialize_primfunc_based_on_callsite.cc @@ -40,7 +40,7 @@ namespace relax { using tvm::tir::Buffer; -static Array GetShapeFromTensorStructInfo(const TensorStructInfo& tensor_sinfo) { +static ffi::Array GetShapeFromTensorStructInfo(const TensorStructInfo& tensor_sinfo) { auto shape = tensor_sinfo->GetShape(); ICHECK(shape.defined()); return shape.value(); @@ -81,7 +81,7 @@ class SpecializeTIRCallArgs : ExprMutator { auto gv = Downcast(call->args[0]); auto pfunc = Downcast(mod_->Lookup(gv)); auto args = Downcast(call->args[1])->fields; - Map> param_map; + ffi::Map> param_map; for (size_t i = 0; i < args.size(); ++i) { auto sinfo = GetStructInfo(args[i]); @@ -89,11 +89,11 @@ class SpecializeTIRCallArgs : ExprMutator { << "Expected Tensor struct Info for call :" << call->op; auto tensor_sinfo = Downcast(sinfo); CHECK(tensor_sinfo->shape.defined()) << "Shape undefined for call:" << call->args[0]; - String scope = "global"; + ffi::String scope = "global"; if (tensor_sinfo->vdevice.defined()) { scope = tensor_sinfo->vdevice.value()->memory_scope; } - String name; + ffi::String name; if (args[i]->IsInstance()) { name = Downcast(args[i])->name_hint(); } else { @@ -104,7 +104,7 @@ class SpecializeTIRCallArgs : ExprMutator { tensor_sinfo->dtype, name, scope); param_map.Set(pfunc->params[i], buffer); } - String scope = "global"; + ffi::String scope = "global"; auto out_sinfo = call->sinfo_args[0]; if (out_sinfo->IsInstance()) { auto sinfo = Downcast(out_sinfo); @@ -121,7 +121,7 @@ class SpecializeTIRCallArgs : ExprMutator { << out_sinfo; const auto& tuple_sinfo = Downcast(out_sinfo); - Array sinfo_fields; + ffi::Array sinfo_fields; int index = 0; for (const auto& si : tuple_sinfo->fields) { ICHECK(si->IsInstance()) @@ -164,11 +164,11 @@ Pass SpecializePrimFuncBasedOnCallSite() { /*required=*/{}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.transform.SpecializePrimFuncBasedOnCallSite", SpecializePrimFuncBasedOnCallSite); -}); +} } // namespace transform } // namespace relax } // namespace tvm diff --git a/src/runtime/contrib/clml/clml_runtime.cc b/src/runtime/contrib/clml/clml_runtime.cc index c4475618107e..d1cf6b2808b0 100644 --- a/src/runtime/contrib/clml/clml_runtime.cc +++ b/src/runtime/contrib/clml/clml_runtime.cc @@ -315,7 +315,7 @@ class CLMLRuntime : public JSONRuntimeBase { const auto f = tvm::ffi::Function::GetGlobal("runtime.SaveParams"); if (f.has_value()) { - std::string dump_bytes = (*f)(dump_tensors).cast(); + std::string dump_bytes = (*f)(dump_tensors).cast(); std::ostringstream oss; /*TODO(Siva) HEX encoding doubles the size, look for better encode that can cross the RPC. */ for (size_t i = 0; i < dump_bytes.size(); ++i) { @@ -349,7 +349,7 @@ class CLMLRuntime : public JSONRuntimeBase { evts.resize(evts.size() + 1); evt = &(evts.back()); } - std::unordered_map metrics; + std::unordered_map metrics; std::string shape_str; std::vector shape = nodes_[nid].GetOpShape()[0]; DLDataType tvm_dtype = nodes_[nid].GetOpDataType()[0]; @@ -366,7 +366,7 @@ class CLMLRuntime : public JSONRuntimeBase { } for (size_t i = 0; i < this->layer_.function.size(); ++i) { - std::unordered_map metrics; + std::unordered_map metrics; auto node = this->layer_.op_node_map[this->layer_.function[i]].second; std::string shape_str; for (uint32_t j = 0; j < node.GetInputs().size(); ++j) { @@ -407,7 +407,7 @@ class CLMLRuntime : public JSONRuntimeBase { evt = &(evts.back()); } - std::unordered_map metrics; + std::unordered_map metrics; std::string shape_str; std::vector shape = nodes_[eid].GetOpShape()[0]; DLDataType tvm_dtype = nodes_[eid].GetOpDataType()[0]; @@ -466,8 +466,8 @@ class CLMLRuntime : public JSONRuntimeBase { cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype); int dtype_size = cl_dtype == CL_FLOAT ? 4 : 2; void* tmpptr = reinterpret_cast(malloc(isize * dtype_size)); - NDArray::CopyToBytes(const_cast(data_entry_[eid]), const_cast(tmpptr), - isize * dtype_size); + Tensor::CopyToBytes(const_cast(data_entry_[eid]), const_cast(tmpptr), + isize * dtype_size); CopyDataToCLMLTensor(layer_.inputs[nid], tmpptr); free(tmpptr); } @@ -553,8 +553,8 @@ class CLMLRuntime : public JSONRuntimeBase { void* tmpptr = reinterpret_cast(malloc(osize * dtype_size)); CopyDataFromCLMLTensor(layer_.outputs[0], tmpptr); - NDArray::CopyFromBytes(const_cast(data_entry_[eid]), const_cast(tmpptr), - osize * dtype_size); + Tensor::CopyFromBytes(const_cast(data_entry_[eid]), const_cast(tmpptr), + osize * dtype_size); free(tmpptr); } } diff --git a/src/runtime/tensor.cc b/src/runtime/tensor.cc index 3895db979894..4ef744452c3c 100644 --- a/src/runtime/tensor.cc +++ b/src/runtime/tensor.cc @@ -97,8 +97,8 @@ void Tensor::CopyToBytes(const DLTensor* handle, void* data, size_t nbytes, DeviceAPI::Get(handle->device)->StreamSync(handle->device, stream); } -void NDArray::CopyFromBytes(const DLTensor* handle, void* data, size_t nbytes, - TVMStreamHandle stream) { +void Tensor::CopyFromBytes(const DLTensor* handle, void* data, size_t nbytes, + TVMStreamHandle stream) { size_t arr_size = GetDataSize(*handle); ICHECK_EQ(arr_size, nbytes) << "ArrayCopyToBytes: size mismatch"; ICHECK(ffi::IsContiguous(*handle)) << "ArrayCopyToBytes only support contiguous array for now"; @@ -117,8 +117,8 @@ void NDArray::CopyFromBytes(const DLTensor* handle, void* data, size_t nbytes, DeviceAPI::Get(handle->device)->StreamSync(handle->device, stream); } -NDArray NDArray::Empty(ffi::Shape shape, DLDataType dtype, Device dev, - ffi::Optional mem_scope) { +Tensor Tensor::Empty(ffi::Shape shape, DLDataType dtype, Device dev, + ffi::Optional mem_scope) { struct DeviceAPIAlloc { void AllocData(DLTensor* tensor, ffi::Optional mem_scope) { tensor->data = DeviceAPI::Get(tensor->device) diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 7a8c6a1a8e85..4fbdef1eafea 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -2169,7 +2169,6 @@ TVM_FFI_STATIC_INIT_BLOCK() { return HasIfThenElse(stmt); }); } ->>>>>>> f1d68472e (Rebase after reflection changes.) } // namespace tir } // namespace tvm diff --git a/tests/python/relax/adreno/test_transform_annotate_custom_scope1.py b/tests/python/relax/adreno/test_transform_annotate_custom_scope1.py index 8c629eecc8fc..7cf4ebd717f6 100644 --- a/tests/python/relax/adreno/test_transform_annotate_custom_scope1.py +++ b/tests/python/relax/adreno/test_transform_annotate_custom_scope1.py @@ -102,6 +102,7 @@ def verify(mod, expected): print(mod) ValidateScope(expected).visit(mod) + def test_conv2d_conv2d_fallback_to_buffer_conv2d(): """ layout_transform (NCHW->NCHW4c) @@ -158,5 +159,6 @@ def main( } verify(Input, Expected) + if __name__ == "__main__": tvm.testing.main() From b57491b067e7d0b1112d561b356d0ccc070b8ddf Mon Sep 17 00:00:00 2001 From: Siva Date: Mon, 3 Nov 2025 14:59:49 +0530 Subject: [PATCH 29/31] Lint --- include/tvm/relax/transform.h | 3 ++- .../backend/adreno/annotate_custom_storage.cc | 12 +++++---- .../adreno/fold_vdevice_scope_change.cc | 3 ++- src/relax/op/op.cc | 3 ++- src/relax/op/tensor/manipulate.cc | 3 ++- src/relax/transform/legalize_ops.cc | 9 ++++--- src/tir/schedule/analysis/analysis.cc | 26 +++++++++---------- 7 files changed, 33 insertions(+), 26 deletions(-) diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index b50aa280a476..646784afd52f 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -250,7 +250,8 @@ TVM_DLL Pass FoldConstant(); * \return The Pass. */ TVM_DLL Pass LegalizeOps(ffi::Optional> cmap, - ffi::Optional> skip_ops, bool enable_warning = false); + ffi::Optional> skip_ops, + bool enable_warning = false); /*! * \brief Propagate virtual device information. diff --git a/src/relax/backend/adreno/annotate_custom_storage.cc b/src/relax/backend/adreno/annotate_custom_storage.cc index 447d6babf143..887b81872940 100644 --- a/src/relax/backend/adreno/annotate_custom_storage.cc +++ b/src/relax/backend/adreno/annotate_custom_storage.cc @@ -274,8 +274,9 @@ class CollectConsumerScopeInfo : public ExprVisitor { public: using ExprVisitor::VisitExpr_; - std::pair>, ffi::Map>>> Collect( - const IRModule& mod, Function func, const Target& target) { + std::pair>, + ffi::Map>>> + Collect(const IRModule& mod, Function func, const Target& target) { mod_ = mod; target_ = target; VisitExpr(func->body); @@ -446,9 +447,10 @@ class CollectProducerScopeInfo : public ExprVisitor { public: using ExprVisitor::VisitExpr_; - ffi::Map Collect(const IRModule& mod, Function func, - const ffi::Map>>& scope_info, - const Target& target, const BlockBuilder& builder) { + ffi::Map Collect( + const IRModule& mod, Function func, + const ffi::Map>>& scope_info, + const Target& target, const BlockBuilder& builder) { mod_ = mod; scope_info_ = scope_info; target_ = target; diff --git a/src/relax/backend/adreno/fold_vdevice_scope_change.cc b/src/relax/backend/adreno/fold_vdevice_scope_change.cc index 5d95336fca15..c59beae78e96 100644 --- a/src/relax/backend/adreno/fold_vdevice_scope_change.cc +++ b/src/relax/backend/adreno/fold_vdevice_scope_change.cc @@ -104,7 +104,8 @@ class CollectConsumerDetails : public ExprVisitor { public: using ExprVisitor::VisitExpr_; - ffi::Map> Collect(const IRModule& mod, Function func, const Target& target) { + ffi::Map> Collect(const IRModule& mod, Function func, + const Target& target) { mod_ = mod; target_ = target; VisitExpr(func->body); diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index ee6ee1619f9f..54f9da4c786f 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -1519,7 +1519,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("relax.op.hint_on_device", [](ffi::PackedArgs args, ffi::Any* ret) { if (args.size() == 3) { - *ret = MakeHintOnDevice(args[0].cast(), args[1].cast(), args[2].cast()); + *ret = MakeHintOnDevice(args[0].cast(), args[1].cast(), + args[2].cast()); } else { *ret = MakeHintOnDevice(args[0].cast(), args[1].cast()); } diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 1d612c35e44e..fb3455f2dd58 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -362,7 +362,8 @@ InferLayoutOutput InferLayoutConcat( << " expects the input to be a Tuple of Tensors. However, the given input is " << call->args[0]->struct_info_; auto t_sinfo = ffi::GetRef(field_tensor_sinfo); - ffi::Optional t_shape = ffi::GetRef(t_sinfo->shape.as()); + ffi::Optional t_shape = + ffi::GetRef(t_sinfo->shape.as()); LayoutDecision curr_layout = nlayout_array[i].LeafValue(); if (!CanProveLayoutTransform(curr_layout->layout, in_layout->layout, t_shape.value()->values)) { diff --git a/src/relax/transform/legalize_ops.cc b/src/relax/transform/legalize_ops.cc index a49d448e9369..75e0776418ed 100644 --- a/src/relax/transform/legalize_ops.cc +++ b/src/relax/transform/legalize_ops.cc @@ -62,8 +62,10 @@ bool KnowAllShapeValues(const StructInfo& sinfo) { class LegalizeMutator : public ExprMutator { public: - explicit LegalizeMutator(const IRModule& mod, const ffi::Optional>& cmap, - const ffi::Optional> skip_ops, bool enable_warning) + explicit LegalizeMutator(const IRModule& mod, + const ffi::Optional>& cmap, + const ffi::Optional> skip_ops, + bool enable_warning) : ExprMutator(mod), mod_(std::move(mod)), enable_warning_(enable_warning) { if (cmap) { cmap_ = cmap.value(); @@ -406,8 +408,7 @@ class LegalizeMutator : public ExprMutator { namespace transform { Pass LegalizeOps(ffi::Optional> cmap, - ffi::Optional> skip_ops, - bool enable_warning) { + ffi::Optional> skip_ops, bool enable_warning) { auto pass_func = [=](IRModule mod, PassContext pc) { bool apply_legalize_ops = pc->GetConfig("relax.transform.apply_legalize_ops").value_or(Bool(true))->value; diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 4fbdef1eafea..75cbd5f3e4c1 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -2155,19 +2155,19 @@ TVM_FFI_STATIC_INIT_BLOCK() { auto block_sref = sch->GetSRef(block); return IsOutputBlock(state, block_sref, GetScopeRoot(state, block_sref, false)); }) - .def("tir.schedule.GetLoopIterType", [](Schedule sch, LoopRV loop) -> ffi::String { - IterVarType kind = GetLoopIterType(sch->GetSRef(loop)); - if (kind == kDataPar) { - return "S"; - } else if (kind == kCommReduce) { - return "R"; - } else { - return "O"; - } - }) - .def("tir.schedule.HasIfThenElse", [](const Stmt& stmt) -> bool { - return HasIfThenElse(stmt); - }); + .def("tir.schedule.GetLoopIterType", + [](Schedule sch, LoopRV loop) -> ffi::String { + IterVarType kind = GetLoopIterType(sch->GetSRef(loop)); + if (kind == kDataPar) { + return "S"; + } else if (kind == kCommReduce) { + return "R"; + } else { + return "O"; + } + }) + .def("tir.schedule.HasIfThenElse", + [](const Stmt& stmt) -> bool { return HasIfThenElse(stmt); }); } } // namespace tir From ddab6abfc66b2a27430b169053de14bb9cc30da6 Mon Sep 17 00:00:00 2001 From: Siva Date: Mon, 3 Nov 2025 23:11:13 +0530 Subject: [PATCH 30/31] ffi fix --- python/tvm/relax/backend/adreno/transform/_ffi_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relax/backend/adreno/transform/_ffi_api.py b/python/tvm/relax/backend/adreno/transform/_ffi_api.py index 7a19e3380feb..d665ba02a70e 100644 --- a/python/tvm/relax/backend/adreno/transform/_ffi_api.py +++ b/python/tvm/relax/backend/adreno/transform/_ffi_api.py @@ -16,4 +16,4 @@ """FFI APIs for Adreno transform""" import tvm.ffi -tvm.ffi._init_api("relax.backend.adreno.transform", __name__) +tvm.ffi.init_ffi_api("relax.backend.adreno.transform", __name__) From ce3104c3d2060a60fe54f740bac85059b9aed08c Mon Sep 17 00:00:00 2001 From: Siva Date: Tue, 4 Nov 2025 08:35:48 +0530 Subject: [PATCH 31/31] remove unused test --- .../test_transform_annotate_custom_scope1.py | 164 ------------------ 1 file changed, 164 deletions(-) delete mode 100644 tests/python/relax/adreno/test_transform_annotate_custom_scope1.py diff --git a/tests/python/relax/adreno/test_transform_annotate_custom_scope1.py b/tests/python/relax/adreno/test_transform_annotate_custom_scope1.py deleted file mode 100644 index 7cf4ebd717f6..000000000000 --- a/tests/python/relax/adreno/test_transform_annotate_custom_scope1.py +++ /dev/null @@ -1,164 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import tvm -from tvm import relax -import tvm.testing -from tvm.script.parser import ir as I, relax as R, tir as T -from tvm.relax.transform.legalize_ops import adreno as legalize_adreno -from tvm.ir.module import IRModule -from tvm.relax.expr_functor import PyExprMutator, PyExprVisitor, mutator, visitor - - -@visitor -class ValidateScope(PyExprVisitor): # pylint: disable=abstract-method - def __init__(self, scope_info: dict) -> None: - self.scope_info = scope_info - self.matched = True - - def visit(self, mod: IRModule) -> None: - """Entry point""" - for _, func in mod.functions_items(): - if isinstance(func, relax.Function): - self.visit_expr(func) - return self.matched - - def visit_call_(self, call: relax.Call) -> None: # pylint: disable=arguments-renamed - if call.op.name == "relax.call_tir": - # if call.args[0].name_hint in self.scope_info: - for idx, arg in enumerate(call.args[1]): - arg_sinfo = arg.struct_info - assert isinstance( - arg_sinfo, relax.TensorStructInfo - ), f"Expected TensorStructInfo but git {type(arg_sinfo)}" - assert ( - arg_sinfo.vdevice.memory_scope - == self.scope_info[call.args[0].name_hint][0][idx] - ), f"Scope mismatched for argument {idx} in {call.args[0].name_hint}" - if isinstance(call.sinfo_args[0], relax.TensorStructInfo): - assert ( - call.sinfo_args[0].vdevice.memory_scope - == self.scope_info[call.args[0].name_hint][1][0] - ), f"Scope mismatched for return scope: {call.args[0].name_hint}" - else: - assert isinstance( - call.sinfo_args[0], relax.TupleStructInfo - ), f"Expected TupleStructInfo but git {type(call.sinfo_args[0])}" - for idx, sinfo in enumerate(call.sinfo_args[0].fields): - assert ( - sinfo.vdevice.memory_scope - == self.scope_info[call.args[0].name_hint][1][idx] - ), f"Scope mismatched for return scope for {idx} in {call.args[0].name_hint}" - - -def verify(mod, expected): - tgt = tvm.target.Target("opencl --device=adreno", host="llvm") - skip_ops = [ - "relax.nn.conv2d", - "relax.nn.max_pool2d", - "relax.nn.adaptive_avg_pool2d", - # "relax.nn.layer_norm", - ] - with tgt: - mod = tvm.tir.transform.BindTarget(tvm.target.Target.current(allow_none=False))(mod) - mod = tvm.relax.transform.DecomposeOpsForInference()(mod) - mod = tvm.relax.transform.FoldConstant()(mod) - desired_layouts = {"relax.nn.conv2d": ["NCHW4c", "OIHW4o", "NCHW4c"]} - mod = tvm.relax.transform.ConvertLayout(desired_layouts)(mod) - mod = tvm.relax.transform.Normalize()(mod) - mod = tvm.relax.transform.FoldConstant()(mod) - mod = tvm.relax.transform.LegalizeOps(skip_ops=skip_ops)(mod) - mod = tvm.relax.transform.AnnotateTIROpPattern()(mod) - mod = tvm.relax.backend.adreno.transform.AnnotateCustomMemoryScope(tgt)(mod) - # There is a possibility of some skipped ops above might not use 5D layouts. - mod = tvm.relax.transform.LegalizeOps()(mod) - mod = tvm.relax.transform.LegalizeOps( - {"relax.nn.conv2d": legalize_adreno.conv2d_NCHWc_OIHWo}, - )(mod) - # Lets get pattern info for newly legalized ops - mod = tvm.relax.transform.AnnotateTIROpPattern()(mod) - mod = tvm.relax.transform.FoldConstant()(mod) - mod = tvm.relax.transform.FuseOps()(mod) - mod = tvm.relax.transform.FuseTIR()(mod) - mod = tvm.relax.transform.DeadCodeElimination()(mod) - mod = tvm.relax.backend.adreno.transform.FoldVDeviceScopeChange()(mod) - mod = tvm.relax.transform.DeadCodeElimination()(mod) - mod = tvm.relax.transform.SpecializePrimFuncBasedOnCallSite()(mod) - mod = tvm.relax.transform.Normalize()(mod) - print(mod) - ValidateScope(expected).visit(mod) - - -def test_conv2d_conv2d_fallback_to_buffer_conv2d(): - """ - layout_transform (NCHW->NCHW4c) - | <- texture - conv2d (1) <- textures as output - / \ - conv2d (2) conv2d (3) <- conv2d (2) emits texture, conv2d (3) emits buffer - \ / <- concat shouldn't support textures here - concatenation - | <- buffer - layout_transform (NCHW4c->NCHW) - """ - - @I.ir_module - class Input: - @R.function - def main( - x: R.Tensor((2, 32, 40, 40), "float32"), - w1: R.Tensor((96, 32, 2, 2), "float32"), - w2: R.Tensor((32, 96, 2, 2), "float32"), - w3: R.Tensor((5, 96, 2, 2), "float32"), - bias1: R.Tensor((1, 96, 1, 1), "float32"), - bias2: R.Tensor((1, 32, 1, 1), "float32"), - ) -> R.Tensor(None, "float32", ndim=4): - with R.dataflow(): - gv = R.nn.conv2d(x, w1, strides=[2, 2], out_dtype="float32") - gv1 = R.add(gv, bias1) - gv2 = R.nn.relu(gv1) - gv3 = R.nn.conv2d(gv2, w2, strides=[2, 2], out_dtype="float32") - gv4 = R.add(gv3, bias2) - gv5 = R.nn.relu(gv4) - gv6 = R.nn.conv2d(gv2, w3, strides=[2, 2], out_dtype="float32") - gv7 = R.concat((gv3, gv6), axis=1) - R.output(gv7) - return gv7 - - Expected = { - "te_layout_transform": (["global"], ["global.texture-weight"]), - "te_layout_transform1": (["global"], ["global.texture-weight"]), - "te_layout_transform2": (["global"], ["global.texture-weight"]), - "fused_conv2d_NCHWc_OIHWo_opencl_add_relu": ( - ["global.texture-weight", "global.texture-weight", "global.texture-weight"], - ["global"], - ), - "te_layout_transform3": (["global"], ["global.texture-weight"]), - "conv2d_NCHWc_OIHWo1_opencl": ( - ["global.texture-weight", "global.texture-weight"], - ["global"], - ), - "te_layout_transform4": (["global"], ["global"]), - "conv2d_opencl": (["global", "global"], ["global"]), - "te_layout_transform5": (["global"], ["global"]), - "concatenate": (["global", "global"], ["global"]), - } - verify(Input, Expected) - - -if __name__ == "__main__": - tvm.testing.main()