Skip to content

Commit ded2bb6

Browse files
tlrmchlsmthLeiWang1999
authored andcommitted
[Kernel] Change interface to Mamba causal_conv1d_update for continuous batching (vllm-project#8012)
Signed-off-by: LeiWang1999 <[email protected]>
1 parent 8521cf1 commit ded2bb6

File tree

7 files changed

+114
-16
lines changed

7 files changed

+114
-16
lines changed

csrc/mamba/causal_conv1d/causal_conv1d.cu

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,8 @@ causal_conv1d_update(const at::Tensor &x,
198198
const at::Tensor &conv_state,
199199
const at::Tensor &weight,
200200
const c10::optional<at::Tensor> &bias_,
201-
bool silu_activation) {
201+
bool silu_activation,
202+
const c10::optional<at::Tensor> &conv_state_indices_) {
202203
auto input_type = x.scalar_type();
203204
auto weight_type = weight.scalar_type();
204205
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
@@ -216,7 +217,6 @@ causal_conv1d_update(const at::Tensor &x,
216217
const int width = weight.size(-1);
217218

218219
CHECK_SHAPE(x, batch_size, dim);
219-
CHECK_SHAPE(conv_state, batch_size, dim, width);
220220
CHECK_SHAPE(weight, dim, width);
221221

222222
TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
@@ -241,6 +241,22 @@ causal_conv1d_update(const at::Tensor &x,
241241
params.conv_state_c_stride = conv_state.stride(1);
242242
params.conv_state_l_stride = conv_state.stride(2);
243243

244+
if (conv_state_indices_.has_value()) {
245+
auto conv_state_indices = conv_state_indices_.value();
246+
TORCH_CHECK(conv_state_indices.scalar_type() == torch::kInt32)
247+
TORCH_CHECK(conv_state_indices.is_cuda());
248+
TORCH_CHECK(conv_state_indices.stride(0) == 1)
249+
CHECK_SHAPE(conv_state_indices, batch_size);
250+
251+
int conv_state_entries = conv_state.size(0);
252+
CHECK_SHAPE(conv_state, conv_state_entries, dim, width);
253+
254+
params.conv_state_indices_ptr = conv_state_indices.data_ptr<int32_t>();
255+
} else {
256+
CHECK_SHAPE(conv_state, batch_size, dim, width);
257+
params.conv_state_indices_ptr = nullptr;
258+
}
259+
244260
// Otherwise the kernel will be launched from cuda:0 device
245261
// Cast to char to avoid compiler warning about narrowing
246262
at::cuda::CUDAGuard device_guard{(char)x.get_device()};
@@ -646,8 +662,16 @@ void causal_conv1d_update_kernel(ConvParamsBase params) {
646662
const int channel_id = blockIdx.y * kNThreads + tidx;
647663
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
648664
+ channel_id * params.x_c_stride;
649-
input_t *conv_state = reinterpret_cast<input_t *>(params.conv_state_ptr) + batch_id * params.conv_state_batch_stride
665+
666+
// If params.conv_state_batch_indices is set, then the conv state is gathered from the conv state tensor
667+
// along the batch axis. Otherwise, the conv state coordinate is the same as the batch id.
668+
const int conv_state_batch_coord = params.conv_state_indices_ptr == nullptr
669+
? batch_id
670+
: params.conv_state_indices_ptr[batch_id];
671+
input_t *conv_state = reinterpret_cast<input_t *>(params.conv_state_ptr)
672+
+ conv_state_batch_coord * params.conv_state_batch_stride
650673
+ channel_id * params.conv_state_c_stride;
674+
651675
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
652676
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
653677
+ channel_id * params.out_c_stride;

csrc/mamba/causal_conv1d/causal_conv1d.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ struct ConvParamsBase {
3636

3737
void *__restrict__ conv_state_ptr;
3838

39+
// For the continuous batching case. Makes it so that the mamba state for
40+
// the current batch doesn't need to be a contiguous tensor.
41+
int32_t *__restrict__ conv_state_indices_ptr;
42+
3943
void *__restrict__ seq_idx_ptr;
4044

4145
// No __restrict__ since initial_states could be the same as final_states.

csrc/ops.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -222,11 +222,10 @@ std::vector<torch::Tensor> selective_scan_fwd(
222222
const c10::optional<torch::Tensor>& index_,
223223
const c10::optional<torch::Tensor>& x);
224224

225-
at::Tensor causal_conv1d_update(const at::Tensor& x,
226-
const at::Tensor& conv_state,
227-
const at::Tensor& weight,
228-
const c10::optional<at::Tensor>& bias_,
229-
bool silu_activation);
225+
at::Tensor causal_conv1d_update(
226+
const at::Tensor& x, const at::Tensor& conv_state, const at::Tensor& weight,
227+
const c10::optional<at::Tensor>& bias, bool silu_activation,
228+
const c10::optional<at::Tensor>& conv_state_indices);
230229

231230
at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
232231
const c10::optional<at::Tensor>& bias_,

csrc/torch_bindings.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -279,8 +279,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
279279
"causal_conv1d_update(Tensor! x,"
280280
"Tensor! conv_state,"
281281
"Tensor! weight,"
282-
"Tensor? bias_,"
283-
"bool silu_activation) -> Tensor");
282+
"Tensor? bias,"
283+
"bool silu_activation,"
284+
"Tensor? conv_state_indices) -> Tensor");
284285
ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update);
285286

286287
ops.def(

tests/kernels/test_causal_conv1d.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,3 +203,61 @@ def test_causal_conv1d_update(batch, dim, width, has_bias, silu_activation,
203203

204204
assert torch.equal(conv_state, conv_state_ref)
205205
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
206+
207+
208+
@pytest.mark.parametrize("itype",
209+
[torch.float32, torch.float16, torch.bfloat16])
210+
@pytest.mark.parametrize("silu_activation", [False, True])
211+
@pytest.mark.parametrize("has_bias", [False, True])
212+
@pytest.mark.parametrize("seqlen", [1, 4, 5])
213+
@pytest.mark.parametrize("width", [2, 3, 4])
214+
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
215+
def test_causal_conv1d_update_with_batch_gather(dim, width, seqlen, has_bias,
216+
silu_activation, itype):
217+
device = "cuda"
218+
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
219+
if itype == torch.bfloat16:
220+
rtol, atol = 1e-2, 5e-2
221+
222+
# set seed
223+
torch.random.manual_seed(0)
224+
batch = 64
225+
226+
x = torch.randn(batch, dim, device=device, dtype=itype)
227+
228+
total_entries = 10 * batch
229+
conv_state = torch.randn(total_entries,
230+
dim,
231+
width,
232+
device=device,
233+
dtype=itype)
234+
conv_state_indices = torch.randperm(total_entries)[:batch].to(
235+
dtype=torch.int32, device=device)
236+
237+
weight = torch.randn(dim,
238+
width,
239+
device=device,
240+
dtype=itype,
241+
requires_grad=True)
242+
if has_bias:
243+
bias = torch.randn(dim, device=device, dtype=itype, requires_grad=True)
244+
else:
245+
bias = None
246+
conv_state_ref = conv_state[conv_state_indices, :].detach().clone()
247+
activation = None if not silu_activation else "silu"
248+
out = causal_conv1d_update(x,
249+
conv_state,
250+
weight,
251+
bias,
252+
activation=activation,
253+
conv_state_indices=conv_state_indices)
254+
out_ref = causal_conv1d_update_ref(x,
255+
conv_state_ref,
256+
weight,
257+
bias,
258+
activation=activation)
259+
260+
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
261+
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
262+
assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref)
263+
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)

vllm/_custom_ops.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -768,11 +768,17 @@ def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor,
768768
silu_activation)
769769

770770

771-
def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor,
772-
weight: torch.Tensor, bias_: Optional[torch.Tensor],
773-
silu_activation: bool) -> torch.Tensor:
771+
def causal_conv1d_update(
772+
x: torch.Tensor,
773+
conv_state: torch.Tensor,
774+
weight: torch.Tensor,
775+
bias_: Optional[torch.Tensor],
776+
silu_activation: bool,
777+
conv_state_indices: Optional[torch.Tensor],
778+
) -> torch.Tensor:
774779
return torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_,
775-
silu_activation)
780+
silu_activation,
781+
conv_state_indices)
776782

777783

778784
def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,

vllm/model_executor/layers/mamba/ops/causal_conv1d.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright (c) 2024, Tri Dao.
2+
# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py
23

34
from typing import Optional
45

@@ -70,17 +71,22 @@ def causal_conv1d_update(x: torch.Tensor,
7071
conv_state: torch.Tensor,
7172
weight: torch.Tensor,
7273
bias: Optional[torch.Tensor] = None,
73-
activation: Optional[str] = None):
74+
activation: Optional[str] = None,
75+
conv_state_indices: Optional[torch.Tensor] = None):
7476
"""
7577
x: (batch, dim)
7678
conv_state: (batch, dim, width)
7779
weight: (dim, width)
7880
bias: (dim,)
81+
conv_state_indices: (batch,), dtype int32
82+
If not None, the conv_state is a larger tensor along the batch dim,
83+
and we are selecting the batch coords specified by conv_state_indices.
84+
Useful for a continuous batching scenario.
7985
8086
out: (batch, dim)
8187
"""
8288
if activation not in [None, "silu", "swish"]:
8389
raise NotImplementedError("activation must be None, silu, or swish")
8490
activation_bool = activation in ["silu", "swish"]
8591
return ops.causal_conv1d_update(x, conv_state, weight, bias,
86-
activation_bool)
92+
activation_bool, conv_state_indices)

0 commit comments

Comments
 (0)