Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions backends/vulkan/partitioner/vulkan_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# pyre-strict

import logging
from typing import Any, Callable, Dict, final, List, Mapping, Optional, Tuple
from typing import Any, Callable, Dict, final, List, Mapping, Optional, Set, Tuple

import executorch.backends.vulkan.utils as utils

Expand All @@ -17,6 +17,7 @@
get_op_features,
has_impl,
OpFeatures,
OpKey,
vulkan_supported_ops,
)

Expand Down Expand Up @@ -55,11 +56,17 @@ def __init__(
texture_limits: utils.ImageExtents,
buffer_limit: int,
require_dynamic_shape: bool = False,
operator_blocklist: Optional[Set[OpKey]] = None,
operator_allowlist: Optional[Set[OpKey]] = None,
) -> None:
super().__init__()
self.texture_limits: utils.ImageExtents = texture_limits
self.buffer_limit = buffer_limit
self.require_dynamic_shapes = require_dynamic_shape
self.operator_blocklist: Set[OpKey] = (
operator_blocklist if operator_blocklist is not None else set()
)
self.operator_allowlist = operator_allowlist

def op_node_is_compatible( # noqa: C901: Function is too complex
self, node: torch.fx.Node, features: Optional[OpFeatures] = None
Expand All @@ -77,6 +84,17 @@ def op_node_is_compatible( # noqa: C901: Function is too complex
assert isinstance(first_arg, torch._ops.OpOverload)
target = first_arg.name()

# Operator allow list is only used for torch ops
if (
utils.is_torch_op_node(node)
and (self.operator_allowlist is not None)
and (target not in self.operator_allowlist)
):
return False, "op is not in allowlist"

if target in self.operator_blocklist:
return False, "op is in blocklist"

# Extract the features for the node's operator, if no override was provided
if features is None:
if not has_impl(target):
Expand All @@ -93,7 +111,7 @@ def op_node_is_compatible( # noqa: C901: Function is too complex
if op_repsets.any_is_empty():
return (
False,
"No valid representations for a tensor in the operation",
f"no valid representations for op {utils.node_io_str(node)}",
)

return True, "Op is compatible"
Expand Down Expand Up @@ -277,6 +295,8 @@ class VulkanPartitioner(Partitioner):
def __init__(
self,
compile_options: Optional[Dict[str, Any]] = None,
operator_blocklist: Optional[List[OpKey]] = None,
operator_allowlist: Optional[List[OpKey]] = None,
) -> None:
self.options: Dict[str, Any] = {}
if compile_options is not None:
Expand All @@ -285,6 +305,18 @@ def __init__(
compile_spec = parse_compile_options(self.options)
self.delegation_spec = DelegationSpec(VulkanBackend.__name__, compile_spec)

self.operator_blocklist: Set[OpKey] = set()
if operator_blocklist is not None:
for entry in operator_blocklist or []:
self.operator_blocklist.add(entry)

self.operator_allowlist: Optional[Set[OpKey]] = None
if operator_allowlist is not None:
self.operator_allowlist = set()
for entry in operator_allowlist:
assert self.operator_allowlist is not None
self.operator_allowlist.add(entry)

def ops_to_not_decompose(
self, ep: ExportedProgram
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
Expand All @@ -308,6 +340,8 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
texture_limits,
buffer_limit,
require_dynamic_shape=self.options.get("require_dynamic_shapes", False),
operator_blocklist=self.operator_blocklist,
operator_allowlist=self.operator_allowlist,
),
allows_single_node_partition=True,
)
Expand Down
57 changes: 57 additions & 0 deletions backends/vulkan/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
format_target_name,
)

from executorch.exir.dialects.edge._ops import EdgeOpOverload

from executorch.exir.tensor import TensorSpec

from torch._export.utils import is_buffer, is_param
Expand Down Expand Up @@ -54,6 +56,18 @@
MaybeNodeList = Union[torch.fx.Node, List[torch.fx.Node], Tuple[torch.fx.Node]]


def is_torch_op_node(node: torch.fx.Node) -> bool:
if node.op != "call_function":
return False

if isinstance(node.target, EdgeOpOverload):
return True
if isinstance(node.target, torch._ops.OpOverload):
return True

return False


def is_dequant_node(node: torch.fx.Node) -> bool:
if node.op != "call_function":
return False
Expand Down Expand Up @@ -1033,6 +1047,49 @@ def get_node_repr(node) -> Union[TensorRepr, TensorReprList]:
##


def get_tensor_val_str(tensor_val: FakeTensor) -> str:
return f"{tensor_val.dtype}: {tensor_val.shape}"


def get_node_val_str(node: torch.fx.Node) -> str:
if is_single_tensor_node(node):
assert isinstance(node.meta["val"], FakeTensor)
return get_tensor_val_str(node.meta["val"])
elif is_tensor_collection_node(node):
assert isinstance(node.meta["val"], (list, tuple))
return f"[{', '.join(get_tensor_val_str(t) for t in node.meta['val'])}]"
else:
return str(node.meta["val"])


def get_arg_node_val_str(arg_node: Any) -> str:
if isinstance(arg_node, torch.fx.Node):
return get_node_val_str(arg_node)
elif isinstance(arg_node, (list, tuple)):
return f"[{', '.join(get_arg_node_val_str(n) for n in arg_node)}]"
else:
return str(arg_node)


def node_io_str(node: torch.fx.Node) -> str:
target = node.target
if isinstance(target, EdgeOpOverload):
assert isinstance(target, EdgeOpOverload)
target_name = target.__name__
elif isinstance(target, torch._ops.OpOverload):
assert isinstance(target, torch._ops.OpOverload)
target_name = target.name()
else:
target_name = str(target)

out_str = f"{get_node_val_str(node)} = {target_name}("
for arg in node.args:
out_str += get_arg_node_val_str(arg) + ", "

out_str += " ...)"
return out_str


def update_program_state_dict(
program: ExportedProgram,
buffer_name: str,
Expand Down
Loading