From 06e44f5ef535f7ac4aa7bb831dbe3e8c26a6ec50 Mon Sep 17 00:00:00 2001 From: TatWai Chong Date: Tue, 3 Oct 2023 17:22:04 -0700 Subject: [PATCH] Add initial lowering of aten.convolution to tosa.conv2d support --- backends/arm/arm_backend.py | 96 ++++++++++++--- backends/arm/test/test_models.py | 205 ++++++++++++++++++++++++++----- backends/arm/tosa_quant_utils.py | 19 +++ examples/arm/arm_tosa_e2e.py | 18 ++- 4 files changed, 288 insertions(+), 50 deletions(-) diff --git a/backends/arm/arm_backend.py b/backends/arm/arm_backend.py index f0f285418c6..748e60e2138 100644 --- a/backends/arm/arm_backend.py +++ b/backends/arm/arm_backend.py @@ -553,23 +553,39 @@ def preprocess( # noqa: C901 elif exir_ops.edge.aten.convolution.default == node.target: input, weight, bias, stride, pad, dilation, _, _, group = inputs + # Currently only int8 is supported in quantized types. + actual_out_type = ts.DType.INT8 if is_quant_node else outp.dtype + ## Transpose input tensor to NHWC_Order for TOSA NHWC_Order = [0, 2, 3, 1] input_transposed = transpose_helper( - tosa_fb, input, NHWC_Order, outp.dtype + tosa_fb, input, NHWC_Order, actual_out_type ) - ## CONV2DOp + # Get the attributes of convolution. attr = ts.TosaSerializerAttribute() - # PAD pad_attr = [val for val in pad.special for _ in (0, 1)] - # Stride stride_attr = stride.special - # Dilation dilation_attr = dilation.special attr.ConvAttribute(pad_attr, stride_attr, dilation_attr, 0, 0) + # Non-bias case. + if len(node.all_input_nodes) == 2: + # Create a zero bias tensor if not presented + out_channels = weight.shape[0] + bias_name = "bias" + node.name.split("default", 1)[1] + bias = tosa_fb.addConst( + [out_channels], + ts.DType.INT32 if is_quant_node else outp.dtype, + [0] * out_channels, + name=bias_name, + ) + if group.number > 1: + assert ( + is_quant_node is False + ), "quantized depthwise convolution is not supported yet in BI mode" + # Transpose weight to [KH, KW, C, M] weight_HWCM_Order = [2, 3, 0, 1] weight_transposed = transpose_helper( @@ -600,14 +616,17 @@ def preprocess( # noqa: C901 # Transpose weight to [OC, H, W, IC] weight_CHWC_Order = [0, 2, 3, 1] weight_transposed = transpose_helper( - tosa_fb, weight, weight_CHWC_Order, outp.dtype + tosa_fb, weight, weight_CHWC_Order, actual_out_type ) ## TOSA output shape is [NHWO] NHWO_Order = [0, 2, 3, 1] out_shape_TOSA_CONV2D = [outp.shape[i] for i in NHWO_Order] + + # The output type is int32 when input type is int8. conv2d_res = tosa_fb.addIntermediate( - out_shape_TOSA_CONV2D, outp.dtype + out_shape_TOSA_CONV2D, + ts.DType.INT32 if is_quant_node else outp.dtype, ) tosa_fb.addOperator( TosaOp.Op().CONV2D, @@ -624,6 +643,24 @@ def preprocess( # noqa: C901 NOHW_Order = [0, 3, 1, 2] attr_output_transpose = ts.TosaSerializerAttribute() attr_output_transpose.TransposeAttribute(NOHW_Order) + + # For quantized convolution, rescale the output value back to the same + # integer value domain of the next op. Otherwise return float32 output. + if is_quant_node: + # Get scale_factor from input, weight, and output. + _, input_scale, _, _, _, _ = getNodeArgs(node.args[0]) + _, weight_scale, _, _, _, _ = getNodeArgs(node.args[1]) + _, output_scale, _, _, _, _ = getNodeArgs(list(node.users)[0]) + + conv2d_res = tosa_quant_utils.buildRescaleOpConvOutput( + tosa_fb, + conv2d_res, + actual_out_type, + input_scale, + weight_scale, + output_scale, + ) + tosa_fb.addOperator( TosaOp.Op().TRANSPOSE, [conv2d_res.name], @@ -879,7 +916,7 @@ def preprocess( # noqa: C901 p_data = edge_program.state_dict[parameter_name] assert isinstance(p_data, torch.Tensor), "Expect Attr to be tensor" - weight_values = p_data.detach().numpy() + parameter_values = p_data.detach().numpy() # Check if they're for quantized nodes consumer_node = list(node.users)[0] @@ -888,14 +925,14 @@ def preprocess( # noqa: C901 consumer_node ) - weight_values_quantized = ( - (weight_values / weight_node_scale.number) + parameter_values_quantized = ( + (parameter_values / weight_node_scale.number) + weight_node_zp.number ).astype(np.int8) tosa_fb.addConst( inputs[0].shape, ts.DType.INT8, - weight_values_quantized, + parameter_values_quantized, name=out, ) elif ( @@ -914,20 +951,45 @@ def preprocess( # noqa: C901 weight_node ) - weight_values_quantized = ( - weight_values / (input_node_scale * weight_node_scale) + parameter_values_quantized = ( + parameter_values / (input_node_scale * weight_node_scale) + ).astype(np.int32) + + tosa_fb.addConst( + inputs[0].shape, + ts.DType.INT32, + parameter_values_quantized, + name=out, + ) + elif ( + consumer_node.target == exir_ops.edge.aten.convolution.default + and list(consumer_node.users)[0].target == tosa_quant_utils.q_op + ): + ( + input_node, + weight_node, + bias_node, + ) = consumer_node.all_input_nodes + + input_node_scale, _ = getQuantNodeArgs(input_node) + weight_node_scale, _ = getQuantNodeArgs(weight_node) + + bias_scales = input_node_scale * weight_node_scale + parameter_values_quantized = ( + parameter_values / bias_scales ).astype(np.int32) tosa_fb.addConst( inputs[0].shape, ts.DType.INT32, - weight_values_quantized, + parameter_values_quantized, name=out, ) else: tosa_fb.addConst( - inputs[0].shape, inputs[0].dtype, weight_values, name=out + inputs[0].shape, inputs[0].dtype, parameter_values, name=out ) + elif out in edge_program.graph_signature.inputs_to_buffers: parameter_name = edge_program.graph_signature.inputs_to_buffers[ node.name @@ -935,9 +997,9 @@ def preprocess( # noqa: C901 p_data = edge_program.state_dict[parameter_name] assert isinstance(p_data, torch.Tensor), "Expect Attr to be tensor" - weight_values = p_data.detach().numpy() + buffer_values = p_data.detach().numpy() tosa_fb.addConst( - inputs[0].shape, inputs[0].dtype, weight_values, name=out + inputs[0].shape, inputs[0].dtype, buffer_values, name=out ) else: tensor = ts.TosaSerializerTensor( diff --git a/backends/arm/test/test_models.py b/backends/arm/test/test_models.py index 46a57a601b8..7a510660283 100644 --- a/backends/arm/test/test_models.py +++ b/backends/arm/test/test_models.py @@ -9,10 +9,16 @@ from enum import Enum +import numpy as np + import torch TestList = {} +# Seed the RNG a convenient number so that we get the same random tests for each test each time +seed = 42 +rng = np.random.default_rng(seed) + def register_test(cls): TestList[cls.__name__] = cls() @@ -124,25 +130,49 @@ class simple_linear(torch.nn.Module): def __init__(self): super().__init__() - torch.manual_seed(42) + torch.manual_seed(seed) self.fc = torch.nn.Linear(20, 30) def forward(self, x): x = self.fc(x) return x - # @register_test - class simple_conv2d(torch.nn.Module): + """Currenly we compare the quantized result directly with the floating point result, to avoid a noticable + precision difference due to wide random numerical distribution, generate small random value range for + convolution testing instead for now""" + + @register_test + class simple_conv2d_2x2_3x1x40x40_non_bias(torch.nn.Module): + data = torch.from_numpy( + np.float32(rng.integers(low=10, high=20, size=(3, 1, 40, 40))) + ) inputs = { - TosaProfile.BI: ( - torch.ones( - 1, - 3, - 256, - 256, - ), - ), - TosaProfile.MI: (torch.ones(1, 3, 256, 256),), + TosaProfile.BI: (data,), + TosaProfile.MI: (data,), + } + + def __init__(self): + super().__init__() + self.conv2d = torch.nn.Conv2d( + in_channels=1, out_channels=3, kernel_size=2, stride=1, bias=False + ) + with torch.no_grad(): + self.conv2d.weight.copy_( + torch.from_numpy( + np.float32(rng.integers(low=1, high=10, size=(1, 1, 2, 2))) + ) + ) + + def forward(self, x): + x = self.conv2d(x) + return x + + @register_test + class simple_conv2d_3x3_1x3x256x256_st1(torch.nn.Module): + data = torch.ones(1, 3, 256, 256) + inputs = { + TosaProfile.BI: (data,), + TosaProfile.MI: (data,), } def __init__(self): @@ -150,16 +180,140 @@ def __init__(self): self.conv2d = torch.nn.Conv2d( in_channels=3, out_channels=10, kernel_size=3, stride=1 ) + with torch.no_grad(): + self.conv2d.weight.copy_( + torch.from_numpy( + np.float32(rng.integers(low=1, high=4, size=(10, 3, 3, 3))) + ) + ) + self.conv2d.bias.copy_( + torch.from_numpy(np.float32(rng.integers(low=1, high=4, size=(10)))) + ) def forward(self, x): x = self.conv2d(x) return x - # @register_test + @register_test + class simple_conv2d_1x1_1x2x128x128_st1(torch.nn.Module): + data = torch.from_numpy( + np.float32(rng.integers(low=10, high=20, size=(1, 2, 128, 128))) + ) + inputs = { + TosaProfile.BI: (data,), + TosaProfile.MI: (data,), + } + + def __init__(self): + super().__init__() + self.conv2d = torch.nn.Conv2d( + in_channels=2, out_channels=1, kernel_size=1, stride=1 + ) + with torch.no_grad(): + self.conv2d.weight.copy_( + torch.from_numpy( + np.float32(rng.integers(low=1, high=4, size=(1, 2, 1, 1))) + ) + ) + self.conv2d.bias.copy_( + torch.from_numpy(np.float32(rng.integers(low=1, high=4, size=(1)))) + ) + + def forward(self, x): + x = self.conv2d(x) + return x + + @register_test + class simple_conv2d_2x2_1x1x14x14_st2(torch.nn.Module): + data = torch.from_numpy( + np.float32(rng.integers(low=10, high=20, size=(1, 1, 14, 14))) + ) + inputs = { + TosaProfile.BI: (data,), + TosaProfile.MI: (data,), + } + + def __init__(self): + super().__init__() + self.conv2d = torch.nn.Conv2d( + in_channels=1, out_channels=1, kernel_size=2, stride=2 + ) + with torch.no_grad(): + self.conv2d.weight.copy_( + torch.from_numpy( + np.float32(rng.integers(low=1, high=4, size=(1, 1, 2, 2))) + ) + ) + self.conv2d.bias.copy_( + torch.from_numpy(np.float32(rng.integers(low=1, high=4, size=(1)))) + ) + + def forward(self, x): + x = self.conv2d(x) + return x + + @register_test + class simple_conv2d_5x5_3x2x128x128_st1(torch.nn.Module): + data = torch.from_numpy( + np.float32(rng.integers(low=10, high=20, size=(3, 2, 128, 128))) + ) + inputs = { + TosaProfile.BI: (data,), + TosaProfile.MI: (data,), + } + + def __init__(self): + super().__init__() + self.conv2d = torch.nn.Conv2d( + in_channels=2, out_channels=3, kernel_size=5, stride=1 + ) + with torch.no_grad(): + self.conv2d.weight.copy_( + torch.from_numpy( + np.float32(rng.integers(low=1, high=10, size=(1, 1, 5, 5))) + ) + ) + self.conv2d.bias.copy_(torch.ones(3, dtype=torch.float)) + + def forward(self, x): + x = self.conv2d(x) + return x + + @register_test + class block_two_conv2d_non_bias(torch.nn.Module): + data = torch.from_numpy( + np.float32(rng.integers(low=10, high=20, size=(1, 3, 256, 256))) + ) + inputs = { + TosaProfile.BI: (data,), + TosaProfile.MI: (data,), + } + + def __init__(self): + super().__init__() + self.conv2d = torch.nn.Conv2d( + in_channels=3, out_channels=10, kernel_size=5, stride=1, bias=False + ) + self.conv2d_2 = torch.nn.Conv2d( + in_channels=10, out_channels=15, kernel_size=5, stride=1, bias=False + ) + with torch.no_grad(): + self.conv2d.weight.copy_(torch.ones(10, 3, 5, 5, dtype=torch.float)) + self.conv2d_2.weight.copy_(torch.ones(15, 10, 5, 5, dtype=torch.float)) + + def forward(self, x): + x = self.conv2d(x) + x = self.conv2d_2(x) + return x + + @register_test class block_two_conv2d(torch.nn.Module): + data = torch.from_numpy( + np.float32(rng.integers(low=10, high=20, size=(1, 3, 256, 256))) + ) inputs = { - TosaProfile.BI: (torch.ones(1, 3, 256, 256),), - TosaProfile.MI: (torch.ones(1, 3, 256, 256),), + TosaProfile.BI: (data,), + TosaProfile.MI: (data,), } def __init__(self): @@ -170,6 +324,11 @@ def __init__(self): self.conv2d_2 = torch.nn.Conv2d( in_channels=10, out_channels=15, kernel_size=5, stride=1 ) + with torch.no_grad(): + self.conv2d.weight.copy_(torch.ones(10, 3, 5, 5, dtype=torch.float)) + self.conv2d.bias.copy_(torch.ones(10)) + self.conv2d_2.weight.copy_(torch.ones(15, 10, 5, 5, dtype=torch.float)) + self.conv2d_2.bias.copy_(torch.ones(15)) def forward(self, x): x = self.conv2d(x) @@ -179,14 +338,6 @@ def forward(self, x): # @register_test class simple_depthwise_conv2d(torch.nn.Module): inputs = { - TosaProfile.BI: ( - torch.ones( - 1, - 3, - 256, - 256, - ), - ), TosaProfile.MI: (torch.ones(1, 3, 256, 256),), } @@ -308,14 +459,6 @@ class block_bottleneck_residual(torch.nn.Module): # Ref: https://arxiv.org/abs/1801.04381 inputs = { - TosaProfile.BI: ( - torch.ones( - 1, - 64, - 81, - 81, - ), - ), TosaProfile.MI: (torch.ones(1, 64, 81, 81),), } diff --git a/backends/arm/tosa_quant_utils.py b/backends/arm/tosa_quant_utils.py index d9a6ec9425c..e4d04d41293 100644 --- a/backends/arm/tosa_quant_utils.py +++ b/backends/arm/tosa_quant_utils.py @@ -161,3 +161,22 @@ def buildRescaleFromInt32( ) return + + +""" Creates a TOSA rescale op based on conv2d parameters. """ + + +def buildRescaleOpConvOutput( + tosa_fb, op, output_type, input_scale, weight_scale, output_scale +): + # Only use double round if we are doing 32 bit scaling + double_round = isScale32(output_type) + + # TODO add check to verify if this is a Per-channel quantization. + post_conv2d_scale = (input_scale.number * weight_scale.number) / output_scale.number + + # Since we assume the input tensor that is being rescaled is int32 date type, zero point must be 0. + rescale_op = buildRescale( + tosa_fb, post_conv2d_scale, op, output_type, op.shape, 0, 0, double_round + ) + return rescale_op diff --git a/examples/arm/arm_tosa_e2e.py b/examples/arm/arm_tosa_e2e.py index 80f1e19a357..d92595da629 100644 --- a/examples/arm/arm_tosa_e2e.py +++ b/examples/arm/arm_tosa_e2e.py @@ -36,7 +36,18 @@ _check_ir_validity=False, ) -SUPPORTED_BI_TEST_LIST = ["simple_add", "simple_add_broadcast", "simple_linear"] +SUPPORTED_BI_TEST_LIST = [ + "simple_add", + "simple_add_broadcast", + "simple_linear", + "simple_conv2d_3x3_1x3x256x256_stride1", + "simple_conv2d_1x1_1x2x128x128_stride1", + "simple_conv2d_2x2_1x1x14x14_stride2", + "simple_conv2d_5x5_3x2x128x128_stride1", + "simple_conv2d_2x2_3x1x40x40_non_bias", + "block_two_conv2d", + "block_two_conv2d_non_bias", +] def get_input_quantization_params(captured_model): @@ -242,7 +253,10 @@ def tosa_run_test(op, profile=TosaProfile.MI): # noqa: C901 ## TODO: Torch is doing [Q, DQ, Operation (FP32), Q, DQ] for quantization ## While TOSA is doing everything in INT8 which is causing a large diff ## Between two final results. Need to fix this to have a smaller error margin. - if np.allclose(tosa_output, torch_output, rtol=1e-1, atol=1e-1, equal_nan=True): + ## Set tolerance values to 1.5e-1 for conv2d testing as that operation can + ## generate larger difference with ground-truth floating point output on random + ## input data. + if np.allclose(tosa_output, torch_output, rtol=1.5e-1, atol=1.5e-1, equal_nan=True): print( "\033[92m" + "Torch and Tosa Reference results are matching for operator: "