From d127c81637c000f9059c17b69cb10e9653d4b1c7 Mon Sep 17 00:00:00 2001 From: cyita Date: Thu, 31 Oct 2024 11:25:23 +0800 Subject: [PATCH 1/6] prefill use sdp --- .../transformers/npu_models/mp_models_base.py | 34 +++++++++++++------ 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py index 1550d6837f6..d6e5e8af63b 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py @@ -201,7 +201,11 @@ def attention(self, query_states = self.transpose(query_states, [0, 2, 1, 3]) key_states = self.transpose(key_states, [0, 2, 1, 3]) if self.transpose_value: - value_states = self.transpose(value_states, [0, 2, 3, 1]) + new_value_states = self.transpose(value_states, [0, 2, 3, 1]) + if mode == "prefill": + value_states = self.transpose(value_states, [0, 2, 1, 3]) + else: + value_states = new_value_states else: value_states = self.transpose(value_states, [0, 2, 1, 3]) @@ -216,7 +220,6 @@ def attention(self, head_dim=head_dim, ) new_key_states = key_states - new_value_states = value_states if mode == "decode": key_states = self.concat(past_key, key_states, axis=-2) @@ -239,15 +242,24 @@ def attention(self, kv_seq_len=kv_seq_len, head_dim=head_dim, transpose=self.transpose_value) - attn_weight = self.matmul(query_states, key_states, False, True) / ( - math.sqrt(head_dim) - ) - attention_mask = self.convert_to_fp16(attention_mask) - attn_weight = self.eltwise_add(attn_weight, attention_mask) - attn_weight = self.convert_to_fp32(attn_weight) - attn_weight = self.softmax(attn_weight, -1) - attn_weight = self.convert_to_fp16(attn_weight) - attn_output = self.matmul(attn_weight, value_states, False, self.transpose_value) + if mode == "prefill": + value_states = self.convert_to_fp32(value_states) + key_states = self.convert_to_fp32(key_states) + query_states = self.convert_to_fp32(query_states) + attention_mask = self.convert_to_fp32(attention_mask) + attn_output = self.scaled_dot_product_attention( + query_states, key_states,value_states, attention_mask, False) + attn_output = self.convert_to_fp16(attn_output) + else: + attn_weight = self.matmul(query_states, key_states, False, True) / ( + math.sqrt(head_dim) + ) + attention_mask = self.convert_to_fp16(attention_mask) + attn_weight = self.eltwise_add(attn_weight, attention_mask) + attn_weight = self.convert_to_fp32(attn_weight) + attn_weight = self.softmax(attn_weight, -1) + attn_weight = self.convert_to_fp16(attn_weight) + attn_output = self.matmul(attn_weight, value_states, False, self.transpose_value) attn_output = self.transpose(attn_output, [0, 2, 1, 3]) attn_output = self.reshape(attn_output, [1, seq_len, hidden_size]) From b7435c7765afb79cf0ae6ad36dacf47a1afd52b5 Mon Sep 17 00:00:00 2001 From: cyita Date: Thu, 31 Oct 2024 14:50:06 +0800 Subject: [PATCH 2/6] add param --- .../ipex_llm/transformers/npu_models/mp_models_base.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py index d6e5e8af63b..cacee391bc5 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py @@ -138,7 +138,6 @@ def attention(self, v_bias=None): hidden_size = num_heads * head_dim num_key_value_groups = num_heads // num_key_value_heads - groupsize = hidden_size // self.n_splits_linear if self.n_splits_linear == 1: query_states = self.linear( hidden_states, @@ -200,9 +199,12 @@ def attention(self, query_states = self.transpose(query_states, [0, 2, 1, 3]) key_states = self.transpose(key_states, [0, 2, 1, 3]) + use_ov_sdp = (mode == "prefill") and (self.group_size != 0) + # use_ov_sdp = (mode == "prefill") + print(f"-------------------- use_ov_sdp: {use_ov_sdp}, groupsize: {self.group_size}") if self.transpose_value: new_value_states = self.transpose(value_states, [0, 2, 3, 1]) - if mode == "prefill": + if use_ov_sdp: value_states = self.transpose(value_states, [0, 2, 1, 3]) else: value_states = new_value_states @@ -241,8 +243,8 @@ def attention(self, num_key_value_heads=num_key_value_heads, kv_seq_len=kv_seq_len, head_dim=head_dim, - transpose=self.transpose_value) - if mode == "prefill": + transpose=(self.transpose_value and (not use_ov_sdp))) + if use_ov_sdp: value_states = self.convert_to_fp32(value_states) key_states = self.convert_to_fp32(key_states) query_states = self.convert_to_fp32(query_states) From a81ccc0b58a2aff70acf8be4f3467a0735db9d59 Mon Sep 17 00:00:00 2001 From: cyita Date: Thu, 31 Oct 2024 18:14:28 +0800 Subject: [PATCH 3/6] update --- .../transformers/npu_models/llama_mp.py | 25 +++++++++++++++---- .../transformers/npu_models/mp_models_base.py | 9 +++---- 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py index 93f1ff36448..5b5210e6a4c 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py @@ -110,13 +110,20 @@ def __init__( # define input, the order self.parameter matters input = self.create_input_op((self.batch_size, self.seq_len, self.hidden_size)) + # open llama2 other models need to test + use_prefill_sdp = self.intermediate_size == 11008 + # Self Attention if mode == "decode": attention_mask = self.create_input_op((self.batch_size, 1, 1, self.max_seq_len + 1), dtype=np.int64) else: - attention_mask = self.create_input_op((self.batch_size, 1, self.seq_len, self.seq_len), - dtype=np.int64) + if use_prefill_sdp: + attention_mask = None + else: + attention_mask = self.create_input_op((self.batch_size, 1, self.seq_len, + self.seq_len), + dtype=np.int64) position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64) @@ -177,6 +184,7 @@ def __init__( post_attention_layernorm_weight=post_attn_layernorm_weights[i], past_key=past_keys[i], past_value=past_values[i], + use_prefill_sdp=use_prefill_sdp, ) curr_key_values.append((new_key_states, new_value_states)) @@ -202,6 +210,7 @@ def build_decoder( post_attention_layernorm_weight, past_key=None, past_value=None, + use_prefill_sdp=False, ): residual = hidden_states @@ -220,6 +229,7 @@ def build_decoder( num_key_value_heads=self.num_key_value_heads, head_dim=self.head_dim, seq_len=self.seq_len, + use_prefill_sdp=use_prefill_sdp, ) hidden_states = self.eltwise_add(residual, attn_output) residual = hidden_states @@ -427,6 +437,7 @@ def __init__( ) self.layer_norm_0 = layer_norm_0 self.layer_norm_1 = layer_norm_1 + self.use_prefill_sdp = intermediate_size == 11008 def forward( self, @@ -451,9 +462,13 @@ def forward( seq_len = hidden_states.shape[1] backend_cls = self.backend_cls_prefill - inputs = (hidden_states.to(torch.float16), - attention_mask.to(torch.int64), - position_ids.to(torch.int64)) + if self.use_prefill_sdp: + inputs = (hidden_states.to(torch.float16), + position_ids.to(torch.int64)) + else: + inputs = (hidden_states.to(torch.float16), + attention_mask.to(torch.int64), + position_ids.to(torch.int64)) inputs += (self.layer_norm_0, self.layer_norm_1) hidden_states, past_key, past_value = run_model( inputs, self.op_parameters, backend_cls, self.op_id, replica=2 diff --git a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py index cacee391bc5..501164d2172 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py @@ -135,7 +135,8 @@ def attention(self, seq_len, q_bias=None, k_bias=None, - v_bias=None): + v_bias=None, + use_prefill_sdp=False): hidden_size = num_heads * head_dim num_key_value_groups = num_heads // num_key_value_heads if self.n_splits_linear == 1: @@ -199,8 +200,7 @@ def attention(self, query_states = self.transpose(query_states, [0, 2, 1, 3]) key_states = self.transpose(key_states, [0, 2, 1, 3]) - use_ov_sdp = (mode == "prefill") and (self.group_size != 0) - # use_ov_sdp = (mode == "prefill") + use_ov_sdp = (mode == "prefill") and use_prefill_sdp print(f"-------------------- use_ov_sdp: {use_ov_sdp}, groupsize: {self.group_size}") if self.transpose_value: new_value_states = self.transpose(value_states, [0, 2, 3, 1]) @@ -248,9 +248,8 @@ def attention(self, value_states = self.convert_to_fp32(value_states) key_states = self.convert_to_fp32(key_states) query_states = self.convert_to_fp32(query_states) - attention_mask = self.convert_to_fp32(attention_mask) attn_output = self.scaled_dot_product_attention( - query_states, key_states,value_states, attention_mask, False) + query_states, key_states,value_states, None, True) attn_output = self.convert_to_fp16(attn_output) else: attn_weight = self.matmul(query_states, key_states, False, True) / ( From 9b9e09da980eb632bcd9c990505f3c9826dee4ef Mon Sep 17 00:00:00 2001 From: cyita Date: Thu, 31 Oct 2024 18:29:51 +0800 Subject: [PATCH 4/6] fix style --- python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py | 5 ++--- .../src/ipex_llm/transformers/npu_models/mp_models_base.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py index 5b5210e6a4c..8991b52f1d3 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py @@ -121,9 +121,8 @@ def __init__( if use_prefill_sdp: attention_mask = None else: - attention_mask = self.create_input_op((self.batch_size, 1, self.seq_len, - self.seq_len), - dtype=np.int64) + attention_mask = self.create_input_op((self.batch_size, 1, self.seq_len, self.seq_len), + dtype=np.int64) position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py index 501164d2172..47d0fa5df22 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py @@ -249,7 +249,7 @@ def attention(self, key_states = self.convert_to_fp32(key_states) query_states = self.convert_to_fp32(query_states) attn_output = self.scaled_dot_product_attention( - query_states, key_states,value_states, None, True) + query_states, key_states, value_states, None, True) attn_output = self.convert_to_fp16(attn_output) else: attn_weight = self.matmul(query_states, key_states, False, True) / ( From be529bbc8b42aebf36daa72beec4e9d94d42f1db Mon Sep 17 00:00:00 2001 From: cyita Date: Thu, 31 Oct 2024 18:32:11 +0800 Subject: [PATCH 5/6] fix style --- python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py index 8991b52f1d3..30e976015ea 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py @@ -121,7 +121,8 @@ def __init__( if use_prefill_sdp: attention_mask = None else: - attention_mask = self.create_input_op((self.batch_size, 1, self.seq_len, self.seq_len), + attention_mask = self.create_input_op((self.batch_size, 1, self.seq_len, + self.seq_len), dtype=np.int64) position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64) From 46dd9b756e2969307e8065a87b0e8c711815f4fc Mon Sep 17 00:00:00 2001 From: cyita Date: Fri, 1 Nov 2024 10:54:39 +0800 Subject: [PATCH 6/6] meet comments --- python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py | 2 +- .../llm/src/ipex_llm/transformers/npu_models/mp_models_base.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py index 30e976015ea..76187872b38 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py @@ -110,7 +110,7 @@ def __init__( # define input, the order self.parameter matters input = self.create_input_op((self.batch_size, self.seq_len, self.hidden_size)) - # open llama2 other models need to test + # llama2 use ov sdp, other models need to test use_prefill_sdp = self.intermediate_size == 11008 # Self Attention diff --git a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py index 47d0fa5df22..3ac026aa687 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py @@ -201,7 +201,6 @@ def attention(self, query_states = self.transpose(query_states, [0, 2, 1, 3]) key_states = self.transpose(key_states, [0, 2, 1, 3]) use_ov_sdp = (mode == "prefill") and use_prefill_sdp - print(f"-------------------- use_ov_sdp: {use_ov_sdp}, groupsize: {self.group_size}") if self.transpose_value: new_value_states = self.transpose(value_states, [0, 2, 3, 1]) if use_ov_sdp: