Skip to content

Commit df4be6f

Browse files
committed
fix code style
1 parent 7121f3c commit df4be6f

File tree

4 files changed

+30
-34
lines changed

4 files changed

+30
-34
lines changed

python/llm/src/ipex_llm/transformers/npu_model.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -162,18 +162,16 @@ def from_pretrained(cls,
162162
ggml_tensor_qtype, FP4Params
163163

164164
if isinstance(model.lm_head, torch.nn.Linear):
165-
new_linear = LowBitLinear(
166-
model.lm_head.in_features,
167-
model.lm_head.out_features,
168-
ggml_tensor_qtype["sym_int4"],
169-
False
170-
)
165+
new_linear = LowBitLinear(model.lm_head.in_features,
166+
model.lm_head.out_features,
167+
ggml_tensor_qtype["sym_int4"],
168+
False)
171169
paramsLowBit = FP4Params(data=model.lm_head.weight.data,
172-
requires_grad=False,
173-
quantized=False,
174-
_shape=None,
175-
qtype=ggml_tensor_qtype["sym_int4"],
176-
in_features=model.lm_head.in_features).to("cpu")
170+
requires_grad=False,
171+
quantized=False,
172+
_shape=None,
173+
qtype=ggml_tensor_qtype["sym_int4"],
174+
in_features=model.lm_head.in_features).to("cpu")
177175
new_linear._parameters['weight'] = paramsLowBit
178176
model.lm_head = new_linear
179177

python/llm/src/ipex_llm/transformers/npu_models/kv.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ def init_fused_kv_cache(batch_size, num_heads, head_dim, current_length, max_len
2525
max_length, head_dim,
2626
dtype=dtype, device=device)
2727
value_cache_storage = torch.zeros(batch_size, num_heads,
28-
max_length, head_dim,
29-
dtype=dtype, device=device)
28+
max_length, head_dim,
29+
dtype=dtype, device=device)
3030

3131
key_cache = key_cache_storage.as_strided((batch_size, num_heads,
3232
current_length, head_dim),
@@ -57,9 +57,9 @@ class DynamicFusedNormalCache(DynamicCache):
5757
KV_ALLOC_BLOCK_LENGTH = 256
5858

5959
def __init__(self) -> None:
60-
self.key_cache: Dict[int, torch.Tensor] = {}
60+
self.key_cache: Dict[int, torch.Tensor] = {}
6161
self.value_cache: Dict[int, torch.Tensor] = {}
62-
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
62+
self._seen_tokens = 0 # Used in `generate` to keep how many tokens the cache has seen
6363

6464
def update(
6565
self,
@@ -85,7 +85,8 @@ def update(
8585
# Update the cache
8686
# if len(self.key_cache) <= layer_idx:
8787
if layer_idx not in self.key_cache:
88-
max_len = max_seq_length if max_seq_length is not None else key_states.size(2) + self.KV_ALLOC_BLOCK_LENGTH
88+
max_len = max_seq_length if max_seq_length is not None else key_states.size(2) + \
89+
self.KV_ALLOC_BLOCK_LENGTH
8990
k_cache, v_cache = init_fused_kv_cache(
9091
batch_size, num_heads, head_dim,
9192
0, max_len,
@@ -107,7 +108,8 @@ def update(
107108
return self.key_cache[layer_idx], self.value_cache[layer_idx]
108109

109110
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
110-
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
111+
"""Returns the sequence length of the cached states.
112+
A layer index can be optionally passed."""
111113

112114
for idx, layer in self.key_cache.items():
113115
return layer.shape[-2]

python/llm/src/ipex_llm/transformers/npu_models/llama.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def llama_fused_model_forward(
232232

233233
if position_ids is None:
234234
position_ids = cache_position.unsqueeze(0)
235-
235+
236236
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds,
237237
cache_position, past_seen_tokens)
238238

@@ -247,21 +247,17 @@ def llama_fused_model_forward(
247247
seq_len = hidden_states.size(1)
248248

249249
if seq_len == 1:
250-
# assert hasattr(self, "multi_decoder")
251250
# multi_decoder = self.layers[(self.layer_end + 1) % num_layers]
252251
layer_outputs = self.multi_decoder(hidden_states,
253-
attention_mask=causal_mask,
254-
position_ids=position_ids,
255-
past_key_value=past_key_values,
256-
output_attentions=output_attentions,
257-
use_cache=use_cache,
258-
cache_position=cache_position,)
252+
attention_mask=causal_mask,
253+
position_ids=position_ids,
254+
past_key_value=past_key_values,
255+
output_attentions=output_attentions,
256+
use_cache=use_cache,
257+
cache_position=cache_position,)
259258
hidden_states = layer_outputs[0]
260259

261-
assert use_cache
262260
next_decoder_cache = layer_outputs[1]
263-
264-
assert not output_attentions
265261
else:
266262
for decoder_layer in self.layers:
267263
if output_hidden_states:

python/llm/src/ipex_llm/transformers/npu_models/pipeline_parallel.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def pipeline_parallel_generate(self,
276276
bs = inputs_tensor.shape[0]
277277
if model_kwargs.get("attention_mask", None) is None:
278278
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
279-
inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id)
279+
inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id)
280280
if self.config.is_encoder_decoder:
281281
input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
282282
batch_size=bs,
@@ -289,7 +289,7 @@ def pipeline_parallel_generate(self,
289289
else:
290290
input_ids = inputs_tensor if model_input_name == "input_ids" \
291291
else model_kwargs.pop("input_ids")
292-
292+
293293
local_rank = dist.get_rank()
294294
pre_rank = (local_rank - 1) % self.pipeline_parallel_stages
295295
next_rank = (local_rank + 1) % self.pipeline_parallel_stages
@@ -325,7 +325,7 @@ def pipeline_parallel_generate(self,
325325

326326
if _input_ids is None:
327327
_input_ids = input_ids
328-
328+
329329
model_inputs = self.prepare_inputs_for_generation(output_ids, **model_kwargs)
330330

331331
tic = time.time()
@@ -360,8 +360,8 @@ def pipeline_parallel_generate(self,
360360
output_ids = torch.cat([output_ids, next_ids], dim=-1)
361361

362362
model_kwargs = self._update_model_kwargs_for_generation(
363-
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
364-
)
363+
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
364+
)
365365

366366
# finished sentences should have their next token be a padding token
367367
next_ids = next_ids.squeeze()
@@ -602,7 +602,7 @@ def glm4_conditional_generation_forward_lowmem(
602602
hidden_states = transformer_outputs[0]
603603
if return_last_logit:
604604
hidden_states = hidden_states[:, -1:]
605-
605+
606606
device = hidden_states.device
607607
# ipex-llm change starts
608608
if device.type == "xpu":

0 commit comments

Comments
 (0)