@@ -55,6 +55,7 @@ void set_conv_params_fwd(ConvParamsBase ¶ms,
5555 const at::Tensor out,
5656 const c10::optional<at::Tensor>& bias,
5757 bool silu_activation,
58+ int64_t pad_slot_id,
5859 const c10::optional<at::Tensor>& query_start_loc = std::nullopt ,
5960 const c10::optional<at::Tensor>& cache_indices = std::nullopt ,
6061 const c10::optional<at::Tensor>& has_initial_state = std::nullopt ) {
@@ -66,6 +67,7 @@ void set_conv_params_fwd(ConvParamsBase ¶ms,
6667 params.dim = dim;
6768 params.seqlen = seqlen;
6869 params.width = width;
70+ params.pad_slot_id = pad_slot_id;
6971
7072 params.silu_activation = silu_activation;
7173
@@ -90,14 +92,16 @@ void set_conv_params_fwd(ConvParamsBase ¶ms,
9092}
9193
9294
93- at::Tensor
94- causal_conv1d_fwd (const at::Tensor &x, const at::Tensor &weight,
95+ void causal_conv1d_fwd (const at::Tensor &x, const at::Tensor &weight,
9596 const c10::optional<at::Tensor> &bias_,
9697 const c10::optional<at::Tensor> &conv_states,
9798 const c10::optional<at::Tensor> &query_start_loc,
9899 const c10::optional<at::Tensor> &cache_indices,
99100 const c10::optional<at::Tensor> &has_initial_state,
100- bool silu_activation) {
101+ bool silu_activation,
102+ // used to identify padding entries if cache_indices provided
103+ // in case of padding, the kernel will return early
104+ int64_t pad_slot_id) {
101105 auto input_type = x.scalar_type ();
102106 auto weight_type = weight.scalar_type ();
103107 TORCH_CHECK (input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
@@ -153,12 +157,13 @@ causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
153157 CHECK_SHAPE (cache_indices_, batch_size);
154158 }
155159
156- at::Tensor out = torch::empty_like (x) ;
160+ at::Tensor out = x ;
157161
158162 ConvParamsBase params;
159163 set_conv_params_fwd (params, batch_size, dim, seqlen, width, x, weight, out,
160164 bias_,
161165 silu_activation,
166+ pad_slot_id,
162167 query_start_loc,
163168 cache_indices,
164169 has_initial_state
@@ -183,18 +188,19 @@ causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
183188 DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16 (x.scalar_type (), " causal_conv1d_fwd" , [&] {
184189 causal_conv1d_fwd_cuda<input_t , weight_t >(params, stream);
185190 });
186- return out;
187191}
188192
189193
190- at::Tensor
191- causal_conv1d_update (const at::Tensor &x,
194+ void causal_conv1d_update (const at::Tensor &x,
192195 const at::Tensor &conv_state,
193196 const at::Tensor &weight,
194197 const c10::optional<at::Tensor> &bias_,
195198 bool silu_activation,
196199 const c10::optional<at::Tensor> &cache_seqlens_,
197- const c10::optional<at::Tensor> &conv_state_indices_) {
200+ const c10::optional<at::Tensor> &conv_state_indices_,
201+ // used to identify padding entries if cache_indices provided
202+ // in case of padding, the kernel will return early
203+ int64_t pad_slot_id) {
198204 auto input_type = x.scalar_type ();
199205 auto weight_type = weight.scalar_type ();
200206 TORCH_CHECK (input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
@@ -227,12 +233,13 @@ causal_conv1d_update(const at::Tensor &x,
227233 CHECK_SHAPE (bias, dim);
228234 }
229235
230- at::Tensor out = torch::empty_like (x) ;
236+ at::Tensor out = x ;
231237
232238 ConvParamsBase params;
233239 set_conv_params_fwd (params, batch_size, dim, seqlen, width, x, weight, out,
234240 bias_,
235- silu_activation);
241+ silu_activation,
242+ pad_slot_id);
236243 params.conv_state_ptr = conv_state.data_ptr ();
237244 params.conv_state_len = conv_state_len;
238245 // All stride are in elements, not bytes.
@@ -274,7 +281,6 @@ causal_conv1d_update(const at::Tensor &x,
274281 DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16 (x.scalar_type (), " causal_conv1d_update" , [&] {
275282 causal_conv1d_update_cuda<input_t , weight_t >(params, stream);
276283 });
277- return out;
278284}
279285
280286template <int kNThreads_ , int kWidth_ , bool kIsVecLoad_ , typename input_t_, typename weight_t_>
@@ -340,7 +346,10 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) {
340346 int * cache_indices = params.cache_indices_ptr == nullptr ? nullptr
341347 : reinterpret_cast <int *>(params.cache_indices_ptr );
342348 int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id];
343-
349+ // cache_index == params.pad_slot_id is defined as padding, so we exit early
350+ if (cache_index == params.pad_slot_id ){
351+ return ;
352+ }
344353 input_t *conv_states = params.conv_states_ptr == nullptr ? nullptr
345354 : reinterpret_cast <input_t *>(params.conv_states_ptr ) + cache_index * params.conv_states_batch_stride + channel_id * params.conv_states_c_stride ;
346355
@@ -528,6 +537,10 @@ void causal_conv1d_update_kernel(ConvParamsBase params) {
528537 const int conv_state_batch_coord = params.conv_state_indices_ptr == nullptr
529538 ? batch_id
530539 : params.conv_state_indices_ptr [batch_id];
540+ // conv_state_batch_coord == params.pad_slot_id is defined as padding so we exit early
541+ if (conv_state_batch_coord == params.pad_slot_id ){
542+ return ;
543+ }
531544 input_t *conv_state = reinterpret_cast <input_t *>(params.conv_state_ptr )
532545 + conv_state_batch_coord * params.conv_state_batch_stride
533546 + channel_id * params.conv_state_c_stride ;
0 commit comments