Skip to content

Commit a4f93c0

Browse files
authored
int8 lstm fix: remove redundant quant in size pattern (#1414)
1 parent 8b97d29 commit a4f93c0

File tree

2 files changed

+21
-4
lines changed

2 files changed

+21
-4
lines changed

csrc/jit/passes/graph_rewrite.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -783,8 +783,7 @@ void preprocessSizeForQLstm(std::shared_ptr<Graph>& graph) {
783783
op_list_construct_same_states, op_list_construct_diff_states};
784784

785785
auto pattern = at::jit::CodeTemplate(R"(
786-
graph(%x, %scale, %zero_point, %quantize_dtype, %size_dim, %ld, %hidden_size, %scalar_type, %layout, %device, %pin_memory, %weight, %has_biases, %num_layers, %dropout, %train, %bidirectional, %batch_first):
787-
%quantized_input = aten::quantize_per_tensor(%x, %scale, %zero_point, %quantize_dtype)
786+
graph(%quantized_input, %size_dim, %ld, %hidden_size, %scalar_type, %layout, %device, %pin_memory, %weight, %has_biases, %num_layers, %dropout, %train, %bidirectional, %batch_first):
788787
%ret.3 = aten::dequantize(%quantized_input)
789788
%max_batch_size : int = aten::size(%ret.3, %size_dim)
790789
%ret.tensor : Tensor = prim::NumToTensor(%max_batch_size)
@@ -795,8 +794,7 @@ void preprocessSizeForQLstm(std::shared_ptr<Graph>& graph) {
795794
return (%res.1, %res.2, %res.3) )");
796795

797796
auto replacement = at::jit::CodeTemplate(R"(
798-
graph(%x, %scale, %zero_point, %quantize_dtype, %size_dim, %ld, %hidden_size, %scalar_type, %layout, %device, %pin_memory, %weight, %has_biases, %num_layers, %dropout, %train, %bidirectional, %batch_first):
799-
%quantized_input = aten::quantize_per_tensor(%x, %scale, %zero_point, %quantize_dtype)
797+
graph(%quantized_input, %size_dim, %ld, %hidden_size, %scalar_type, %layout, %device, %pin_memory, %weight, %has_biases, %num_layers, %dropout, %train, %bidirectional, %batch_first):
800798
%max_batch_size : int = aten::size(%quantized_input, %size_dim)
801799
%ret.3 = aten::dequantize(%quantized_input)
802800
%ret.tensor : Tensor = prim::NumToTensor(%max_batch_size)

tests/cpu/test_ao_jit_ipex_quantization.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,25 @@ def forward(self, input, hid, mask=None):
289289
graph = self.checkQuantizeTrace(model, [seq, hid, mask])
290290
self.assertGraphContainsExactly(graph, 'aten::lstm', 1)
291291

292+
def test_linear_lstm(self):
293+
class M(nn.Module):
294+
def __init__(self):
295+
super(M, self).__init__()
296+
self.linear = nn.Linear(512, 64)
297+
self.lstm = nn.LSTM(input_size=64, hidden_size=256, num_layers=2)
298+
299+
def forward(self, input, hid=None):
300+
x = self.linear(input)
301+
x = self.lstm(x, hid)
302+
return x
303+
304+
model = M().eval()
305+
seq = torch.randn(24, 1, 512)
306+
307+
graph = self.checkQuantizeTrace(model, [seq], atol=3e-2, rtol=1e-1)
308+
self.assertGraphContainsExactly(graph, 'ipex::quantized_lstm', 1)
309+
self.assertGraphContainsExactly(graph, 'aten::lstm', 0)
310+
292311
class TestIpexQuantizationConvertAPI(JitLlgaTestCase):
293312
def test_inplace_preapre(self):
294313
class M(nn.Module):

0 commit comments

Comments
 (0)