Skip to content
2 changes: 1 addition & 1 deletion backends/arm/operators/op_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def define_node(
validate_valid_dtype(
self.target,
[inputs[0], output],
[ts.DType.INT8, ts.DType.INT32, ts.DType.FP32],
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32],
output.tosa_spec,
)

Expand Down
111 changes: 110 additions & 1 deletion backends/arm/test/ops/test_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,14 @@

from typing import Tuple

import pytest
import torch
from executorch.backends.arm.quantizer.arm_quantizer import (
get_symmetric_a16w8_quantization_config,
TOSAQuantizer,
)

from executorch.backends.arm.test import common
from executorch.backends.arm.test import common, conftest

from executorch.backends.arm.test.tester.test_pipeline import (
EthosU55PipelineINT,
Expand All @@ -18,6 +23,8 @@
TosaPipelineINT,
VgfPipeline,
)
from executorch.backends.arm.tosa.specification import TosaSpecification
from executorch.backends.xnnpack.test.tester import Quantize

aten_op = "torch.ops.aten.slice.Tensor"
exir_op = "executorch_exir_dialects_edge__ops_aten_slice_copy"
Expand Down Expand Up @@ -119,3 +126,105 @@ def test_slice_tensor_vgf_INT(test_data: torch.Tensor):
tosa_version="TOSA-1.0+INT",
)
pipeline.run()


def get_symmetric_a16w8_slice_quantizer(per_channel_quantization=False):
tosa_version = conftest.get_option("tosa_version")
tosa_profiles = {
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
}

quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
quantizer.set_global(
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
)

return Quantize(
quantizer,
get_symmetric_a16w8_quantization_config(
is_per_channel=per_channel_quantization
),
)


@common.parametrize("test_data", test_data_suite)
@pytest.mark.xfail(
reason="missing int16 slice ops support; fails at TOSA reference model with Unsupported operation type or rank. See: https://github.com/pytorch/executorch/issues/13976"
)
def test_slice_tensor_16a8w_tosa_INT(test_data: torch.Tensor):
"""Test slice operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
per_channel_quantization = False

pipeline = TosaPipelineINT[input_t1](
Slice(),
test_data(),
aten_op,
exir_op=[],
per_channel_quantization=per_channel_quantization,
use_to_edge_transform_and_lower=True,
tosa_extensions=["int16"],
)

pipeline.change_args(
"quantize",
get_symmetric_a16w8_slice_quantizer(
per_channel_quantization=per_channel_quantization
),
)
pipeline.run()


@common.parametrize("test_data", test_data_suite)
@common.XfailIfNoCorstone300
@pytest.mark.xfail(
reason="Vela compilation fails with 'Invalid arguments' for int16 slice operations"
)
def test_slice_tensor_16a8w_u55_INT16(test_data: torch.Tensor):
"""Test slice operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
per_channel_quantization = False

pipeline = EthosU55PipelineINT[input_t1](
Slice(),
test_data(),
aten_ops=[],
exir_ops=[],
per_channel_quantization=per_channel_quantization,
use_to_edge_transform_and_lower=True,
run_on_fvp=True,
)

pipeline.change_args(
"quantize",
get_symmetric_a16w8_slice_quantizer(
per_channel_quantization=per_channel_quantization
),
)
pipeline.run()


@common.parametrize("test_data", test_data_suite)
@common.XfailIfNoCorstone320
@pytest.mark.xfail(
reason="Vela compilation fails with 'Invalid arguments' for int16 slice operations"
)
def test_slice_tensor_16a8w_u85_INT16(test_data: torch.Tensor):
"""Test slice operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
per_channel_quantization = False

pipeline = EthosU85PipelineINT[input_t1](
Slice(),
test_data(),
aten_ops=[],
exir_ops=[],
per_channel_quantization=per_channel_quantization,
use_to_edge_transform_and_lower=True,
run_on_fvp=True,
)

pipeline.change_args(
"quantize",
get_symmetric_a16w8_slice_quantizer(
per_channel_quantization=per_channel_quantization
),
)
pipeline.run()
Loading