Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion QEfficient/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
# Placeholder for all non-transformer models registered in QEfficient


# custom warning for the better logging experience
warnings.formatwarning = custom_format_warning

Expand Down
78 changes: 63 additions & 15 deletions QEfficient/base/pytorch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,61 +120,109 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:

class SplitGateUpWeightsTransform(PytorchTransform):
"""
split fused Gate+Up weights and copy into the model
Split fused Gate+Up weights and copy into the model.
Handles both standard MoE models and GptOss models.
For every transformer layer inside `model`:
• expects <PREFIX>.experts.gate_up_proj in the *source* `sd`
• copies halves into
<PREFIX>.experts.gate_proj <-- Gate [E,H,I]
<PREFIX>.experts.up_proj <-- Up [E,H,I]
• expects <PREFIX>.experts.gate_up_proj in the *source* `sd`
• copies halves into
<PREFIX>.experts.gate_proj <-- Gate [E,H,I]
<PREFIX>.experts.up_proj <-- Up [E,H,I]
Handles both interleaved weights (GptOss) and concatenated weights (standard MoE).
Also handles bias terms when present.
"""

@classmethod
def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
transformed = False
model_class = model.__class__.__name__ if hasattr(model, "model") else model.__class__.__name__

if model_class not in VLM_SPLIT_GATE_UP_WEIGHTS:
return model, transformed

model_tmp = model.language_model if hasattr(model, "language_model") else model

num_layers = len(model_tmp.model.layers)
delete_fused_key = True
sd = model_tmp.state_dict()

for layer_idx in range(num_layers):
# Determine if this is a GptOss model or standard MoE model
is_gpt_oss = hasattr(model_tmp.model.layers[layer_idx], "mlp")

# ---- build the textual prefix once per layer ----------
prefix = f"model.layers.{layer_idx}.feed_forward.experts."
if is_gpt_oss:
prefix = f"model.layers.{layer_idx}.mlp.experts."
experts = model_tmp.model.layers[layer_idx].mlp.experts
else:
prefix = f"model.layers.{layer_idx}.feed_forward.experts."
experts = model_tmp.model.layers[layer_idx].feed_forward.experts

fused_key = prefix + "gate_up_proj"
gate_key = prefix + "gate_proj"
up_key = prefix + "up_proj"

# ---- split [E,H,2I] → two [E,H,I] tensors ----------------------
fused = sd[fused_key] # [E, H, 2I] (no .weight here)
# Check if we have bias terms (GptOss case)
has_bias = fused_key + "_bias" in sd
if has_bias:
fused_bias_key = fused_key + "_bias"
gate_bias_key = gate_key + "_bias"
up_bias_key = up_key + "_bias"

# ---- split weights based on model type ----------------------
fused = sd[fused_key] # [E, H, 2I]
E, H, two_I = fused.shape
ffn_dim = two_I // 2
gate, up = fused.split(ffn_dim, dim=-1) # views – no copy

experts = model_tmp.model.layers[layer_idx].feed_forward.experts
if is_gpt_oss:
# For GptOss, gate/up are interleaved: [gate0, up0, gate1, up1, ...]
gate = fused[..., ::2] # [E, H, I] - even indices
up = fused[..., 1::2] # [E, H, I] - odd indices
else:
# For standard MoE, gate/up are concatenated: [gate, up]
ffn_dim = two_I // 2
gate, up = fused.split(ffn_dim, dim=-1) # views – no copy

# Copy weights to model
experts.gate_proj.data.copy_(gate)
experts.up_proj.data.copy_(up)

# Handle bias if present
if has_bias:
fused_bias = sd[fused_bias_key] # [E, 2I]

if is_gpt_oss:
gate_bias = fused_bias[..., ::2] # [E, I] - even indices
up_bias = fused_bias[..., 1::2] # [E, I] - odd indices
else:
ffn_dim = fused_bias.shape[-1] // 2
gate_bias, up_bias = fused_bias.split(ffn_dim, dim=-1)

experts.gate_proj_bias.data.copy_(gate_bias)
experts.up_proj_bias.data.copy_(up_bias)

# ---- update the state-dict so load_state_dict sees the right keys
sd[gate_key] = gate
sd[up_key] = up

if has_bias:
sd[gate_bias_key] = gate_bias
sd[up_bias_key] = up_bias

# Delete fused keys
if delete_fused_key:
del sd[fused_key]
if has_bias:
del sd[fused_bias_key]

logger.info(f"[layer {layer_idx:02d}] loaded gate_proj & up_proj from fused tensor (shape {fused.shape})")
logger.info(f"[layer {layer_idx:02d}] loaded gate_proj & up_proj from fused tensor (shape {fused.shape})")
transformed = True

if hasattr(model, "language_model"):
model.language_model = model_tmp
else:
model = model_tmp

return model, transformed


VLM_SPLIT_GATE_UP_WEIGHTS = {"QEffLlama4ForConditionalGeneration", "QEffLlama4ForCausalLM"}
# Keep the existing list of supported models
VLM_SPLIT_GATE_UP_WEIGHTS = {"QEffLlama4ForConditionalGeneration", "QEffLlama4ForCausalLM", "QEffGptOssForCausalLM"}
99 changes: 99 additions & 0 deletions QEfficient/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,3 +537,102 @@ def update(
ctx_v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)
v_out = torch.where((is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), v_out, ctx_v_out)
return k_out, v_out


# This is a hack for now, until we get to merging this code with HybridCache class,
# We don't really need to inherit transformers classes as their cache classes are made to work with pytorch and
# ours are made to work with AIC
class QEffHybridCacheForGPTOSS:
def __init__(self, config, batch_size, max_cache_len, sliding_window_len):
self.max_cache_len = max_cache_len
self.batch_size = batch_size
self.sliding_window_len = sliding_window_len
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []

@classmethod
def from_legacy_cache(
cls, config, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
) -> "HybridCache":
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for
backward compatibility."""
cache = cls(
config,
batch_size=past_key_values[0][0].shape[0],
max_cache_len=past_key_values[1][0].shape[2],
sliding_window_len=past_key_values[0][0].shape[2],
)
if past_key_values is not None:
for layer_idx in range(len(past_key_values)):
key_states, value_states = past_key_values[layer_idx]
cache.update(key_states, value_states, layer_idx)
return cache

def __len__(self):
"""
Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
to the number of layers in the model.
"""
return len(self.key_cache)

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
# TODO: deprecate this function in favor of `cache_position`
is_empty_layer = (
len(self.key_cache) == 0 # no cache in any layer
or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
or len(self.key_cache[layer_idx]) == 0 # the layer has no cache
)
layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
return layer_seq_length

def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
"""Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for
backward compatibility."""
legacy_cache = ()
for layer_idx in range(len(self)):
legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
return legacy_cache

def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if len(self.key_cache) <= layer_idx:
self.key_cache.append(key_states)
self.value_cache.append(value_states)
k_out, v_out = key_states, value_states
else:
position_ids = cache_kwargs.get("position_ids")
is_sliding_layer = cache_kwargs.get("is_sliding")
sliding_window = cache_kwargs.get("sliding_window")

if is_sliding_layer:
kv_position_ids = torch.where(position_ids == -1, position_ids, position_ids % sliding_window)
else:
kv_position_ids = position_ids

self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states)
self.value_cache[layer_idx] = CtxScatterFunc.apply(
self.value_cache[layer_idx], kv_position_ids, value_states
)
k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx]

# Original Gather
ctx_len = self.key_cache[layer_idx].shape[2]
ctx_indices = torch.arange(ctx_len)[None, None, ...]
gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1)
invalid_mask = ctx_indices > gather_limit
if torch.onnx.is_in_onnx_export():
invalid_idx_value = torch.iinfo(torch.int32).max
else:
invalid_idx_value = 0
ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)

k_out = CtxGatherFunc.apply(k_out, ctx_indices)
v_out = CtxGatherFunc.apply(v_out, ctx_indices)
v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)
return k_out, v_out
1 change: 1 addition & 0 deletions QEfficient/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@
]
)

# This is for supporting different seq_len for different layers for Sliding window attn, chunked attn etc.
DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH = {"gemma3", "llama4", "gemma3_text", "llama4_text"}

# Define a transformers layers to QEff layers dictionary
Expand Down
6 changes: 6 additions & 0 deletions QEfficient/transformers/models/gpt_oss/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
#
# ----------------------------------------------------------------------------
Loading
Loading