Skip to content

Commit 5b8a1fd

Browse files
authored
[Model][Bugfix] Add FATReLU activation and support for openbmb/MiniCPM-S-1B-sft (#9396)
1 parent fb60ae9 commit 5b8a1fd

File tree

3 files changed

+37
-5
lines changed

3 files changed

+37
-5
lines changed

docs/source/models/supported_models.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ Text Generation
159159
-
160160
* - :code:`MiniCPMForCausalLM`
161161
- MiniCPM
162-
- :code:`openbmb/MiniCPM-2B-sft-bf16`, :code:`openbmb/MiniCPM-2B-dpo-bf16`, etc.
162+
- :code:`openbmb/MiniCPM-2B-sft-bf16`, :code:`openbmb/MiniCPM-2B-dpo-bf16`, :code:`openbmb/MiniCPM-S-1B-sft`, etc.
163163
- ✅︎
164164
- ✅︎
165165
* - :code:`MiniCPM3ForCausalLM`

vllm/model_executor/layers/activation.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,33 @@
1313
from vllm.model_executor.utils import set_weight_attrs
1414

1515

16+
class FatreluAndMul(CustomOp):
17+
"""An activation function for FATReLU.
18+
19+
The function computes x -> FATReLU(x[:d]) * x[d:] where
20+
d = x.shape[-1] // 2.
21+
This is used in openbmb/MiniCPM-S-1B-sft.
22+
23+
Shapes:
24+
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
25+
return: (num_tokens, d) or (batch_size, seq_len, d)
26+
"""
27+
28+
def __init__(self, threshold: float = 0.):
29+
super().__init__()
30+
self.threshold = threshold
31+
32+
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
33+
d = x.shape[-1] // 2
34+
x1 = x[..., :d]
35+
x2 = x[..., d:]
36+
x1 = F.threshold(x1, self.threshold, 0.0)
37+
return x1 * x2
38+
39+
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
40+
return self.forward_native(x)
41+
42+
1643
class SiluAndMul(CustomOp):
1744
"""An activation function for SwiGLU.
1845

vllm/model_executor/models/minicpm.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
3434
get_tensor_model_parallel_world_size,
3535
tensor_model_parallel_all_reduce)
36-
from vllm.model_executor.layers.activation import SiluAndMul
36+
from vllm.model_executor.layers.activation import FatreluAndMul, SiluAndMul
3737
from vllm.model_executor.layers.fused_moe import fused_moe
3838
from vllm.model_executor.layers.layernorm import RMSNorm
3939
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
@@ -152,6 +152,7 @@ def __init__(
152152
hidden_size: int,
153153
intermediate_size: int,
154154
hidden_act: str,
155+
hidden_act_param: float,
155156
quant_config: Optional[QuantizationConfig] = None,
156157
) -> None:
157158
super().__init__()
@@ -163,10 +164,13 @@ def __init__(
163164
hidden_size,
164165
bias=False,
165166
quant_config=quant_config)
166-
if hidden_act != "silu":
167+
if hidden_act == "silu":
168+
self.act_fn = SiluAndMul()
169+
elif hidden_act == "fatrelu":
170+
self.act_fn = FatreluAndMul(threshold=hidden_act_param)
171+
else:
167172
raise ValueError(f"Unsupported activation: {hidden_act}. "
168-
"Only silu is supported for now.")
169-
self.act_fn = SiluAndMul()
173+
"Only silu and fatrelu are supported for now.")
170174

171175
def forward(self, x):
172176
gate_up, _ = self.gate_up_proj(x)
@@ -304,6 +308,7 @@ def _init_ffn_block(self):
304308
hidden_size=self.hidden_size,
305309
intermediate_size=self.config.intermediate_size,
306310
hidden_act=self.config.hidden_act,
311+
hidden_act_param=getattr(self.config, "hidden_act_param", 0.),
307312
quant_config=self.quant_config,
308313
)
309314
else:

0 commit comments

Comments
 (0)