Skip to content
Closed
30 changes: 30 additions & 0 deletions QEfficient/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,36 @@ class QEffDynamicCache(DynamicCache):

"""

def write_only(self, key_states, value_states, layer_idx, cache_kwargs):
# Update the cache
if len(self.key_cache) <= layer_idx:
self.key_cache.append(key_states)
self.value_cache.append(value_states)
else:
position_ids = cache_kwargs.get("position_ids")
self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], position_ids, key_states)
self.value_cache[layer_idx] = CtxScatterFunc.apply(self.value_cache[layer_idx], position_ids, value_states)

def read_only(self, layer_idx, cache_kwargs):
k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx]
position_ids = cache_kwargs.get("position_ids")
ctx_len = k_out.shape[2]
ctx_indices = torch.arange(ctx_len)[None, None, ...]
gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1)
invalid_mask = ctx_indices > gather_limit

if torch.onnx.is_in_onnx_export():
invalid_idx_value = torch.iinfo(torch.int32).max
else:
invalid_idx_value = 0

ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)

k_out = CtxGatherFunc.apply(k_out, ctx_indices)
v_out = CtxGatherFunc.apply(v_out, ctx_indices)
v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)
return k_out, v_out

def update(
self,
key_states: torch.Tensor,
Expand Down
18 changes: 18 additions & 0 deletions QEfficient/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,154 +5,157 @@
#
# -----------------------------------------------------------------------------

from collections import namedtuple
from typing import Dict, Optional, Tuple, Type

import torch
import torch.nn as nn
from transformers.models.codegen.modeling_codegen import (
CodeGenAttention,
CodeGenBlock,
CodeGenForCausalLM,
CodeGenModel,
)
from transformers.models.falcon.modeling_falcon import (
FalconAttention,
FalconForCausalLM,
FalconModel,
)
from transformers.models.gemma.modeling_gemma import (
GemmaAttention,
GemmaDecoderLayer,
GemmaForCausalLM,
GemmaModel,
GemmaRMSNorm,
)
from transformers.models.gemma2.modeling_gemma2 import (
Gemma2Attention,
Gemma2DecoderLayer,
Gemma2ForCausalLM,
Gemma2Model,
Gemma2RMSNorm,
)
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2LMHeadModel, GPT2Model
from transformers.models.gpt_bigcode.modeling_gpt_bigcode import (
GPTBigCodeAttention,
GPTBigCodeBlock,
GPTBigCodeForCausalLM,
GPTBigCodeModel,
)
from transformers.models.gptj.modeling_gptj import GPTJAttention, GPTJForCausalLM, GPTJModel
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaForCausalLM,
LlamaModel,
LlamaRMSNorm,
)
from transformers.models.mistral.modeling_mistral import (
MistralAttention,
MistralDecoderLayer,
MistralForCausalLM,
MistralModel,
MistralRMSNorm,
)
from transformers.models.mixtral.modeling_mixtral import (
MixtralAttention,
MixtralDecoderLayer,
MixtralForCausalLM,
MixtralModel,
MixtralRMSNorm,
MixtralSparseMoeBlock,
)
from transformers.models.mllama.modeling_mllama import MllamaForCausalLM
from transformers.models.mpt.modeling_mpt import MptAttention, MptBlock, MptForCausalLM, MptModel
from transformers.models.phi.modeling_phi import PhiAttention, PhiForCausalLM, PhiModel
from transformers.models.phi3.modeling_phi3 import Phi3Attention, Phi3ForCausalLM, Phi3Model, Phi3RMSNorm
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, Qwen2ForCausalLM, Qwen2Model, Qwen2RMSNorm
from transformers.models.starcoder2.modeling_starcoder2 import (
Starcoder2Attention,
Starcoder2DecoderLayer,
Starcoder2ForCausalLM,
Starcoder2Model,
)
from transformers.models.whisper.modeling_whisper import (
WhisperAttention,
WhisperDecoder,
WhisperDecoderLayer,
WhisperEncoder,
WhisperForConditionalGeneration,
WhisperModel,
WhisperPositionalEmbedding,
)

from QEfficient.customop import CustomRMSNormAIC

from .models.codegen.modeling_codegen import (
QEffCodeGenAttention,
QeffCodeGenBlock,
QEffCodeGenForCausalLM,
QEffCodeGenModel,
)
from .models.falcon.modeling_falcon import (
QEffFalconAttention,
QEffFalconForCausalLM,
QEffFalconModel,
)
from .models.gemma.modeling_gemma import QEffGemmaAttention, QEffGemmaDecoderLayer, QEffGemmaForCausalLM, QEffGemmaModel
from .models.gemma2.modeling_gemma2 import (
QEffGemma2Attention,
QEffGemma2DecoderLayer,
QEffGemma2ForCausalLM,
QEffGemma2Model,
)
from .models.gpt2.modeling_gpt2 import QEffGPT2Attention, QEffGPT2Block, QEffGPT2LMHeadModel, QEffGPT2Model
from .models.gpt_bigcode.modeling_gpt_bigcode import (
QEffGPTBigCodeAttention,
QEffGPTBigCodeBlock,
QEffGPTBigCodeForCausalLM,
QEffGPTBigCodeModel,
)
from .models.gptj.modeling_gptj import QEffGPTJAttention, QEffGPTJForCausalLM, QEffGPTJModel
from .models.llama.modeling_llama import (
QEffLlamaAttention,
QEffLlamaDecoderLayer,
QEffLlamaForCausalLM,
QEffLlamaModel,
)
from .models.mistral.modeling_mistral import (
QEffMistralAttention,
QEffMistralDecoderLayer,
QEffMistralForCausalLM,
QEffMistralModel,
)
from .models.mixtral_moe.modeling_mixtral import (
QEffMixtralAttention,
QeffMixtralDecoderLayer,
QEffMixtralForCausalLM,
QEffMixtralModel,
QEffMixtralSparseMoeBlock,
)
from .models.mpt.modeling_mpt import QEffMptAttention, QEffMptBlock, QEffMptForCausalLM, QEFfMptModel
from .models.phi.modeling_phi import QEffPhiAttention, QEffPhiForCausalLM, QEffPhiModel
from .models.phi3.modeling_phi3 import QEffPhi3Attention, QEffPhi3ForCausalLM, QEffPhi3Model
from .models.qwen2.modeling_qwen2 import QEffQwen2Attention, QEffQwen2ForCausalLM, QEffQwen2Model
from .models.starcoder2.modeling_starcoder2 import (
QEffStarcoder2Attention,
QEFFStarcoder2DecoderLayer,
QEffStarcoder2ForCausalLM,
QEffStarcoder2Model,
)
from .models.whisper.modeling_whisper import (
QEffWhisperAttention,
QEffWhisperDecoder,
QEffWhisperDecoderLayer,
QEffWhisperEncoder,
QEffWhisperForConditionalGeneration,
QEffWhisperModel,
QEffWhisperPositionalEmbedding,
)

from QEfficient.transformers.models.llama_swiftkv.config_llama_swiftkv import LlamaSwiftKVConfig
from QEfficient.transformers.models.llama_swiftkv.modeling_llama_swiftkv import LlamaSwiftKVForCausalLM

Check failure on line 157 in QEfficient/transformers/modeling_utils.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (I001)

QEfficient/transformers/modeling_utils.py:8:1: I001 Import block is un-sorted or un-formatted

# Define a named tuple for ModelArchitectures
# Required for the Automation tool
ModelArchitectures = namedtuple("ModelArchitectures", ["architectures"])
Expand Down Expand Up @@ -362,3 +365,18 @@
attention_mask = attention_mask.unsqueeze(1)

return attention_mask


# Define a SwiftKV Model card name to Model type dictionary
# While onboarding new models make sure to add the new SwiftKV model card names to this dictionary.
SwiftKVModelCardNameToSwiftKVModelTypeDict: Dict[Type[str], Type[str]] = {
# LlamaSwiftKV Model
"Snowflake/Llama-3.1-SwiftKV-8B-Instruct": "llama_swiftkv"
}

# Define a SwiftKV Model type to ConfigClass and ModelArchitecture class dictionary
# While onboarding new models make sure to add the new SwiftKV model card names to this dictionary.
SwiftKVModelTypeToConfigClassAndModelArchClassDict = {
# LlamaSwiftKV Model
"llama_swiftkv": [LlamaSwiftKVConfig, LlamaSwiftKVForCausalLM]
}
6 changes: 6 additions & 0 deletions QEfficient/transformers/models/llama_swiftkv/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# -----------------------------------------------------------------------------
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this peice of code in the modelling file

#
# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------
# The Modules are updated as required by Cloud AI 100 HW requirements.


"""Inference-only LLaMA model compatible with HuggingFace weights."""

from typing import Optional
from transformers import LlamaConfig

Check failure on line 13 in QEfficient/transformers/models/llama_swiftkv/config_llama_swiftkv.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (I001)

QEfficient/transformers/models/llama_swiftkv/config_llama_swiftkv.py:12:1: I001 Import block is un-sorted or un-formatted


class LlamaSwiftKVConfig(LlamaConfig):
"""
Args:
num_key_value_layers (int, optional):
The number of layers, from the first layer, that have keys and
values. If None, all layers have keys and values.
last_key_value_heads (int, optional):
The number of heads in the last layer that have keys and values.
If None, the number of heads in the last key-value layer is equal
to the number of heads in all the other key-value layers.
"""

model_type = "llama_swiftkv"

def __init__(
self,
swiftkv: bool = False,
num_key_value_layers: Optional[int] = None,
key_value_group_size: Optional[int] = None,
**kwargs,
):
super().__init__(**kwargs)
self.swiftkv = swiftkv
self.num_key_value_layers = num_key_value_layers or self.num_hidden_layers
self.key_value_group_size = key_value_group_size or 1
assert (self.num_hidden_layers - self.num_key_value_layers) % self.key_value_group_size == 0
Loading
Loading