diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index 776d1d6e168..302b9af83e2 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -7,7 +7,7 @@ # pyre-strict import logging -from typing import Any, Callable, Dict, final, List, Mapping, Optional, Tuple +from typing import Any, Callable, Dict, final, List, Mapping, Optional, Set, Tuple import executorch.backends.vulkan.utils as utils @@ -17,6 +17,7 @@ get_op_features, has_impl, OpFeatures, + OpKey, vulkan_supported_ops, ) @@ -55,11 +56,17 @@ def __init__( texture_limits: utils.ImageExtents, buffer_limit: int, require_dynamic_shape: bool = False, + operator_blocklist: Optional[Set[OpKey]] = None, + operator_allowlist: Optional[Set[OpKey]] = None, ) -> None: super().__init__() self.texture_limits: utils.ImageExtents = texture_limits self.buffer_limit = buffer_limit self.require_dynamic_shapes = require_dynamic_shape + self.operator_blocklist: Set[OpKey] = ( + operator_blocklist if operator_blocklist is not None else set() + ) + self.operator_allowlist = operator_allowlist def op_node_is_compatible( # noqa: C901: Function is too complex self, node: torch.fx.Node, features: Optional[OpFeatures] = None @@ -77,6 +84,17 @@ def op_node_is_compatible( # noqa: C901: Function is too complex assert isinstance(first_arg, torch._ops.OpOverload) target = first_arg.name() + # Operator allow list is only used for torch ops + if ( + utils.is_torch_op_node(node) + and (self.operator_allowlist is not None) + and (target not in self.operator_allowlist) + ): + return False, "op is not in allowlist" + + if target in self.operator_blocklist: + return False, "op is in blocklist" + # Extract the features for the node's operator, if no override was provided if features is None: if not has_impl(target): @@ -93,7 +111,7 @@ def op_node_is_compatible( # noqa: C901: Function is too complex if op_repsets.any_is_empty(): return ( False, - "No valid representations for a tensor in the operation", + f"no valid representations for op {utils.node_io_str(node)}", ) return True, "Op is compatible" @@ -277,6 +295,8 @@ class VulkanPartitioner(Partitioner): def __init__( self, compile_options: Optional[Dict[str, Any]] = None, + operator_blocklist: Optional[List[OpKey]] = None, + operator_allowlist: Optional[List[OpKey]] = None, ) -> None: self.options: Dict[str, Any] = {} if compile_options is not None: @@ -285,6 +305,18 @@ def __init__( compile_spec = parse_compile_options(self.options) self.delegation_spec = DelegationSpec(VulkanBackend.__name__, compile_spec) + self.operator_blocklist: Set[OpKey] = set() + if operator_blocklist is not None: + for entry in operator_blocklist or []: + self.operator_blocklist.add(entry) + + self.operator_allowlist: Optional[Set[OpKey]] = None + if operator_allowlist is not None: + self.operator_allowlist = set() + for entry in operator_allowlist: + assert self.operator_allowlist is not None + self.operator_allowlist.add(entry) + def ops_to_not_decompose( self, ep: ExportedProgram ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: @@ -308,6 +340,8 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: texture_limits, buffer_limit, require_dynamic_shape=self.options.get("require_dynamic_shapes", False), + operator_blocklist=self.operator_blocklist, + operator_allowlist=self.operator_allowlist, ), allows_single_node_partition=True, ) diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index fa45063a4d3..1765f0b5e1c 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -18,6 +18,8 @@ format_target_name, ) +from executorch.exir.dialects.edge._ops import EdgeOpOverload + from executorch.exir.tensor import TensorSpec from torch._export.utils import is_buffer, is_param @@ -54,6 +56,18 @@ MaybeNodeList = Union[torch.fx.Node, List[torch.fx.Node], Tuple[torch.fx.Node]] +def is_torch_op_node(node: torch.fx.Node) -> bool: + if node.op != "call_function": + return False + + if isinstance(node.target, EdgeOpOverload): + return True + if isinstance(node.target, torch._ops.OpOverload): + return True + + return False + + def is_dequant_node(node: torch.fx.Node) -> bool: if node.op != "call_function": return False @@ -1033,6 +1047,49 @@ def get_node_repr(node) -> Union[TensorRepr, TensorReprList]: ## +def get_tensor_val_str(tensor_val: FakeTensor) -> str: + return f"{tensor_val.dtype}: {tensor_val.shape}" + + +def get_node_val_str(node: torch.fx.Node) -> str: + if is_single_tensor_node(node): + assert isinstance(node.meta["val"], FakeTensor) + return get_tensor_val_str(node.meta["val"]) + elif is_tensor_collection_node(node): + assert isinstance(node.meta["val"], (list, tuple)) + return f"[{', '.join(get_tensor_val_str(t) for t in node.meta['val'])}]" + else: + return str(node.meta["val"]) + + +def get_arg_node_val_str(arg_node: Any) -> str: + if isinstance(arg_node, torch.fx.Node): + return get_node_val_str(arg_node) + elif isinstance(arg_node, (list, tuple)): + return f"[{', '.join(get_arg_node_val_str(n) for n in arg_node)}]" + else: + return str(arg_node) + + +def node_io_str(node: torch.fx.Node) -> str: + target = node.target + if isinstance(target, EdgeOpOverload): + assert isinstance(target, EdgeOpOverload) + target_name = target.__name__ + elif isinstance(target, torch._ops.OpOverload): + assert isinstance(target, torch._ops.OpOverload) + target_name = target.name() + else: + target_name = str(target) + + out_str = f"{get_node_val_str(node)} = {target_name}(" + for arg in node.args: + out_str += get_arg_node_val_str(arg) + ", " + + out_str += " ...)" + return out_str + + def update_program_state_dict( program: ExportedProgram, buffer_name: str,