Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ Text Generation
-
* - :code:`MiniCPMForCausalLM`
- MiniCPM
- :code:`openbmb/MiniCPM-2B-sft-bf16`, :code:`openbmb/MiniCPM-2B-dpo-bf16`, etc.
- :code:`openbmb/MiniCPM-2B-sft-bf16`, :code:`openbmb/MiniCPM-2B-dpo-bf16`, :code:`openbmb/MiniCPM-S-1B-sft`, etc.
- ✅︎
- ✅︎
* - :code:`MiniCPM3ForCausalLM`
Expand Down
24 changes: 24 additions & 0 deletions vllm/model_executor/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,30 @@
from vllm.model_executor.utils import set_weight_attrs


class FatreluAndMul(CustomOp):
"""An activation function for FATReLU.
The function computes x -> FATReLU(x[:d]) * x[d:] where
d = x.shape[-1] // 2.
This is used in openbmb/MiniCPM-S-1B-sft.
Shapes:
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
return: (num_tokens, d) or (batch_size, seq_len, d)
"""

def __init__(self, threshold: float = 0.):
super().__init__()
self.threshold = threshold

def forward(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
x1 = x[..., :d]
x2 = x[..., d:]
x1 = F.threshold(x1, self.threshold, 0.0)
return x1 * x2


class SiluAndMul(CustomOp):
"""An activation function for SwiGLU.
Expand Down
13 changes: 9 additions & 4 deletions vllm/model_executor/models/minicpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.activation import FatreluAndMul, SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
Expand Down Expand Up @@ -152,6 +152,7 @@ def __init__(
hidden_size: int,
intermediate_size: int,
hidden_act: str,
hidden_act_param: float,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
Expand All @@ -163,10 +164,13 @@ def __init__(
hidden_size,
bias=False,
quant_config=quant_config)
if hidden_act != "silu":
if hidden_act == "silu":
self.act_fn = SiluAndMul()
elif hidden_act == "fatrelu":
self.act_fn = FatreluAndMul(threshold=hidden_act_param)
else:
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
"Only silu and fatrelu are supported for now.")
Comment on lines -166 to +173
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I think it would be nice to use get_act_fn() and register this as an activation function

_ACTIVATION_REGISTRY = {
"gelu": nn.GELU(),
"gelu_fast": FastGELU(),
"gelu_new": NewGELU(),
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
"relu": nn.ReLU(),
"relu2": ReLUSquaredActivation(),
"quick_gelu": QuickGELU(),
}
- however I recognize that this function requires the threshold to be piped through so it might not be worth it


def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
Expand Down Expand Up @@ -304,6 +308,7 @@ def _init_ffn_block(self):
hidden_size=self.hidden_size,
intermediate_size=self.config.intermediate_size,
hidden_act=self.config.hidden_act,
hidden_act_param=getattr(self.config, "hidden_act_param", 0.),
quant_config=self.quant_config,
)
else:
Expand Down