Skip to content

Qualcomm AI Engine Direct - Delegated mutable buffer #6727

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/qualcomm/_passes/layout_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class LayoutTransform(ExportPass):
exir_ops.edge.aten.prelu.default,
exir_ops.edge.aten.relu.default,
exir_ops.edge.aten._softmax.default, # TODO: Need to find a new solution to do "axis_order" to transform axis.
exir_ops.edge.aten.sigmoid.default,
exir_ops.edge.aten.sqrt.default,
exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten.sum.dim_IntList,
Expand Down
2 changes: 2 additions & 0 deletions backends/qualcomm/builders/__init__.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
op_ceil,
op_clamp,
op_conv2d,
op_copy,
op_depth_to_space,
op_dequantize,
op_div,
Expand Down Expand Up @@ -71,6 +72,7 @@
op_ceil,
op_clamp,
op_conv2d,
op_copy,
op_depth_to_space,
op_dequantize,
op_div,
Expand Down
43 changes: 36 additions & 7 deletions backends/qualcomm/builders/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
get_parameter,
is_graph_input,
is_graph_output,
is_mutable_buffer_input,
is_mutable_buffer_output,
is_parameter,
)

Expand Down Expand Up @@ -214,7 +216,9 @@ def get_tensor_type(
node: torch.fx.Node,
tensor_type: PyQnnWrapper.Qnn_TensorType_t,
) -> PyQnnWrapper.Qnn_TensorType_t:
is_input = is_graph_input(node, self.edge_program)
is_input = is_graph_input(node, self.edge_program) or is_mutable_buffer_input(
node, self.edge_program
)
is_output = is_graph_output(node)
# handle logic for input/output tensors
if is_input or is_output:
Expand Down Expand Up @@ -247,6 +251,33 @@ def get_data_type(

return QNN_TENSOR_TYPE_MAP[tensor.dtype]

def get_tensor_name(
self,
node: torch.fx.Node,
wrapper_idx: int = 0,
):
tensor_name = f"{node.name}_{wrapper_idx}"
# The `input_{id}` is utilized for sorting at runtime. Due to multiple passes in qnn_preprocess,
# the input order between QNN and the original graph’s forward function may differ.
# The `mutbuf_{id}` is utilized for mapping I/O of mutable buffer at runtime.
# The `output_` is identified as the graph’s output at runtime to prevent confusion with per_tensor_dump.
if is_mutable_buffer_input(node, self.edge_program):
fqn = self.edge_program.graph_signature.inputs_to_buffers[node.target]
position_index = list(
self.edge_program.graph_signature.buffers_to_mutate.values()
).index(fqn)
tensor_name = f"input_{str(self.external_ids[node])}_mutbuf_{str(position_index)}_{tensor_name}"
elif is_graph_input(node, self.edge_program):
tensor_name = f"input_{str(self.external_ids[node])}_{tensor_name}"
elif is_mutable_buffer_output(node, self.edge_program):
position_index = list(
self.edge_program.graph_signature.buffers_to_mutate.keys()
).index(node.name)
tensor_name = f"output_mutbuf_{position_index}_{tensor_name}"
elif is_graph_output(node):
tensor_name = f"output_{tensor_name}"
return tensor_name

def define_custom_tensor_wrapper(
self,
node_name: str,
Expand Down Expand Up @@ -307,11 +338,7 @@ def define_tensor(
if cached := nodes_to_wrappers[node_name].get(wrapper_idx, None):
return cached

tensor_name = f"{node.name}_{wrapper_idx}"
if is_graph_input(node, self.edge_program):
tensor_name = "input_" + str(self.external_ids[node]) + "_" + tensor_name
if is_graph_output(node):
tensor_name = "output_" + tensor_name
tensor_name = self.get_tensor_name(node, wrapper_idx)
dims = [1] if len(tensor.size()) == 0 else tensor.size()
tensor_type = self.get_tensor_type(node, tensor_type)
quant_encoding, quant_configs = self.get_quant_encoding_conf(
Expand Down Expand Up @@ -383,7 +410,9 @@ def generate_node_to_external_map(
# The order in which we visit the placeholder node is same as the *args
# order for the forward(*args) signature for this gm. Using the order of
# the nodes as external_id to extract the right arg from *args at runtime
if is_graph_input(node, edge_program):
if is_graph_input(node, edge_program) or is_mutable_buffer_input(
node, edge_program
):
node_to_external_map[node] = len(node_to_external_map)
for node in edge_program.graph_module.graph.nodes:
if is_graph_output(node):
Expand Down
63 changes: 63 additions & 0 deletions backends/qualcomm/builders/op_copy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict

import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper

import torch
from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS

from .node_visitor import NodeVisitor, register_node_visitor
from .qnn_constants import OpReshape, QNN_OP_PACKAGE_NAME_QTI_AISW


@register_node_visitor
class Copy(NodeVisitor):
target = ["aten.copy.default"]

def __init__(self, *args) -> None:
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
input_node = node.args[1]
input_tensor = self.get_tensor(input_node, node)
copy_inp_tensor_wrapper = self.define_tensor(
input_node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
is_input_tensor=True,
)

copy_input_tensors = [copy_inp_tensor_wrapper]

if quant_attrs := input_node.meta.get(QCOM_QUANT_ATTRS):
quant_attrs = quant_attrs.copy()
# Because there is no output after convert_pt2e, the QCOM_QUANT_ATTRS of node is none
node.meta[QCOM_QUANT_ATTRS] = quant_attrs
output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
is_input_tensor=False,
)
copy_output_tensors = [output_tensor_wrapper]

copy_op = PyQnnWrapper.PyQnnOpWrapper(
node.name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpReshape.op_name,
)
copy_op.AddInputTensors(copy_input_tensors)
copy_op.AddOutputTensors(copy_output_tensors)

return copy_op
37 changes: 37 additions & 0 deletions backends/qualcomm/builders/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,23 @@ def is_graph_input(
return tensor.op == "placeholder" and not is_parameter(tensor, edge_program)


def is_mutable_buffer_input(
tensor: torch.fx.Node, edge_program: torch.export.ExportedProgram
) -> bool:
"""
Check if the given tensor is a mutable buffer input

Args:
tensor: EdgeIR Tensor that is being checked for mutable buffer input
"""
if tensor.op == "placeholder" and is_buffer(edge_program, tensor):
fqn = edge_program.graph_signature.inputs_to_buffers[tensor.target]
# if the buffer is mutated then record that
if fqn in edge_program.graph_signature.buffers_to_mutate.values():
return True
return False


def is_graph_output(tensor: torch.fx.Node) -> bool:
"""
Check if the given tensor is used as a graph output
Expand All @@ -91,6 +108,26 @@ def is_graph_output(tensor: torch.fx.Node) -> bool:
return False


def is_mutable_buffer_output(
tensor: torch.fx.Node, edge_program: torch.export.ExportedProgram
) -> bool:
"""
Check if the given tensor is a mutable buffer output

Args:
tensor: EdgeIR Tensor that is being checked for mutable buffer output
"""
for user in tensor.users.keys():
# getitem node is skiped, check the op_skip_ops.py
if user.op == "output" or (
user.target.__name__ == "getitem" and is_graph_output(user)
):
# if the buffer is mutated then record that
if tensor.name in edge_program.graph_signature.buffers_to_mutate.keys():
return True
return False


def is_constant(
tensor: torch.fx.Node, edge_program: torch.export.ExportedProgram
) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/partition/common_defs.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
exir_ops.edge.aten.clone.default,
exir_ops.edge.aten.full.default,
exir_ops.edge.aten.slice_scatter.default,
exir_ops.edge.aten.copy.default,
exir_ops.edge.quantized_decomposed.embedding_4bit.dtype,
]

to_be_implemented_operator = [
Expand Down
6 changes: 5 additions & 1 deletion backends/qualcomm/partition/qnn_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
Partitioner,
PartitionResult,
)
from executorch.exir.backend.utils import tag_constant_data
from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer
from torch.fx.passes.infra.partitioner import Partition
from torch.fx.passes.operator_support import OperatorSupportBase

Expand Down Expand Up @@ -108,6 +108,7 @@ def __init__(
compiler_specs: List[CompileSpec],
skip_node_id_set: set = None,
skip_node_op_set: set = None,
skip_mutable_buffer: bool = True,
):
self.compiler_specs_snapshot = copy.deepcopy(compiler_specs)

Expand All @@ -117,6 +118,7 @@ def __init__(
self.partition_tags: Dict[str, DelegationSpec] = {}
self.skip_node_id_set = set() if skip_node_id_set is None else skip_node_id_set
self.skip_node_op_set = set() if skip_node_op_set is None else skip_node_op_set
self.skip_mutable_buffer = skip_mutable_buffer

def generate_partitions(
self, edge_program: torch.export.ExportedProgram
Expand Down Expand Up @@ -162,6 +164,8 @@ def partition(self, edge_program: torch.export.ExportedProgram) -> PartitionResu
if len(partitions) != 0:
self.tag_nodes(partitions, edge_program)
tag_constant_data(edge_program)
if not self.skip_mutable_buffer:
tag_mutated_buffer(edge_program)
for node in edge_program.graph_module.graph.nodes:
if hasattr(node, "meta"):
# pop certain keys in meta for not affecting the passes in compilation
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/quantizer/custom_annotation.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def get_custom_quant_ios_dtype(

# Tag index put node before copy node, because copy is a skipped node in qnn
if (
exir_ops.edge.aten.index_put.default == node.target
exir_ops.edge.aten.copy.default == node.target
and node.meta["val"].shape == cache_shape
):
return kv_dtype
Expand Down
31 changes: 18 additions & 13 deletions backends/qualcomm/runtime/QnnExecuTorchBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,29 +111,34 @@ Error QnnExecuTorchBackend::execute(
std::vector<Qnn_Tensor_t> input_tensor_structs;
std::vector<Qnn_Tensor_t> output_tensor_structs;

int args_index = 0;
input_tensor_structs.reserve(input_tensors.size());
for (int i = 0; i < input_tensors.size(); ++i) {
if (qnn_manager->RegisterMem(
args[i]->toTensor().mutable_data_ptr(), input_tensors[i]) !=
Error::Ok) {
// update data ptr only should be fine
input_tensors[i]->FillDataBuffer(
args[i]->toTensor().const_data_ptr(), false /* copy_data */);
for (const auto& input_tensor : input_tensors) {
if (input_tensor->GetName().find("mutbuf_") == std::string::npos) {
if (qnn_manager->RegisterMem(
args[args_index]->toTensor().mutable_data_ptr(), input_tensor) !=
Error::Ok) {
// update data ptr only should be fine
input_tensor->FillDataBuffer(
args[args_index]->toTensor().const_data_ptr(),
false /* copy_data */);
}
args_index++;
}
input_tensor_structs.push_back(input_tensors[i]->CloneTensorStruct());

input_tensor_structs.push_back(input_tensor->CloneTensorStruct());
}

int output_index = input_tensors.size();
for (const auto& output_tensor : output_tensors) {
// pos=0 limits the search to the prefix
if (output_tensor->GetName().rfind("output_", 0) == 0) {
void* mutable_data_ptr =
args[output_index]->toTensor().mutable_data_ptr();
if (output_tensor->GetName().rfind("output_", 0) == 0 &&
output_tensor->GetName().find("mutbuf_") == std::string::npos) {
void* mutable_data_ptr = args[args_index]->toTensor().mutable_data_ptr();
if (qnn_manager->RegisterMem(mutable_data_ptr, output_tensor) !=
Error::Ok) {
output_tensor->FillDataBuffer(mutable_data_ptr, false /* copy_data */);
}
output_index++;
args_index++;
}
output_tensor_structs.push_back(output_tensor->CloneTensorStruct());
}
Expand Down
36 changes: 36 additions & 0 deletions backends/qualcomm/runtime/QnnManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <cstring>
#include <fstream>
#include <string>
#include <unordered_map>

namespace executorch {
namespace backends {
Expand All @@ -36,6 +37,16 @@ bool CompareExportedInput(
return numA < numB;
}

int ExtractMutableBufferNumber(const std::string& name) {
std::string prefix = "mutbuf_";
size_t startPos = name.find(prefix);
if (startPos != std::string::npos) {
startPos += prefix.length();
return std::stoi(name.substr(startPos));
}
return -1;
}

QnnManager::~QnnManager() {
backend_params_ptr_.reset(new BackendConfigParameters());
logger_.reset();
Expand Down Expand Up @@ -331,9 +342,22 @@ Error QnnManager::AllocateTensor(const std::string& graph_name) {
std::vector<Qnn_Tensor_t> output_tensors =
backend_params_ptr_->qnn_context_ptr_->GetGraphOutputs(graph_name);

// Mapping memory address for the input and output of mutable buffer
std::unordered_map<int, const void*> mutable_buffer_id_to_memory_map;

for (auto& tensor : input_tensors) {
std::shared_ptr<TensorWrapper> tensor_wrapper = CreateTensorWrapper(tensor);
tensor_wrapper->UpdateQnnTensorMeta(tensor);

int mutable_buffer_id =
ExtractMutableBufferNumber(tensor_wrapper->GetName());
if (mutable_buffer_id != -1) {
// Delegate maintain the memory for mutable buffer
tensor_wrapper->AllocateDataBuffer();
mutable_buffer_id_to_memory_map[mutable_buffer_id] =
tensor_wrapper->GetStaticTensorData();
}

input_tensors_[graph_name].emplace_back(std::move(tensor_wrapper));
}
if (!options_->is_from_context_binary()) {
Expand All @@ -356,6 +380,18 @@ Error QnnManager::AllocateTensor(const std::string& graph_name) {
if (IsTensorDump()) {
tensor_wrapper->AllocateDataBuffer();
}

int mutable_buffer_id =
ExtractMutableBufferNumber(tensor_wrapper->GetName());
if (mutable_buffer_id != -1 &&
mutable_buffer_id_to_memory_map.find(mutable_buffer_id) !=
mutable_buffer_id_to_memory_map.end()) {
// Fill the same memory for I/O of mutable buffer
tensor_wrapper->FillDataBuffer(
mutable_buffer_id_to_memory_map[mutable_buffer_id],
false /* copy_data */);
}

output_tensors_[graph_name].emplace_back(std::move(tensor_wrapper));
}
return Error::Ok;
Expand Down
5 changes: 5 additions & 0 deletions extension/llm/export/partitioner_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,4 +191,9 @@ def get_qnn_partitioner(
),
skip_node_id_set={},
skip_node_op_set=skip_node_op_set,
# TODO: For fp flow, the behavior of mutable buffer is unexpected
# It seems that delegated mutable buffer is not removed from the output.
# When I trace back, I found the mutable buffer doesn't exist in original_program.state_dict.
# So, it doesn't be added into output_specs_to_delete.
skip_mutable_buffer=use_fp16,
)
Loading