@@ -198,7 +198,8 @@ causal_conv1d_update(const at::Tensor &x,
198198 const at::Tensor &conv_state,
199199 const at::Tensor &weight,
200200 const c10::optional<at::Tensor> &bias_,
201- bool silu_activation) {
201+ bool silu_activation,
202+ const c10::optional<at::Tensor> &conv_state_indices_) {
202203 auto input_type = x.scalar_type ();
203204 auto weight_type = weight.scalar_type ();
204205 TORCH_CHECK (input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
@@ -216,7 +217,6 @@ causal_conv1d_update(const at::Tensor &x,
216217 const int width = weight.size (-1 );
217218
218219 CHECK_SHAPE (x, batch_size, dim);
219- CHECK_SHAPE (conv_state, batch_size, dim, width);
220220 CHECK_SHAPE (weight, dim, width);
221221
222222 TORCH_CHECK (width >= 2 && width <= 4 , " causal_conv1d only supports width between 2 and 4" );
@@ -241,6 +241,22 @@ causal_conv1d_update(const at::Tensor &x,
241241 params.conv_state_c_stride = conv_state.stride (1 );
242242 params.conv_state_l_stride = conv_state.stride (2 );
243243
244+ if (conv_state_indices_.has_value ()) {
245+ auto conv_state_indices = conv_state_indices_.value ();
246+ TORCH_CHECK (conv_state_indices.scalar_type () == torch::kInt32 )
247+ TORCH_CHECK (conv_state_indices.is_cuda ());
248+ TORCH_CHECK (conv_state_indices.stride (0 ) == 1 )
249+ CHECK_SHAPE (conv_state_indices, batch_size);
250+
251+ int conv_state_entries = conv_state.size (0 );
252+ CHECK_SHAPE (conv_state, conv_state_entries, dim, width);
253+
254+ params.conv_state_indices_ptr = conv_state_indices.data_ptr <int32_t >();
255+ } else {
256+ CHECK_SHAPE (conv_state, batch_size, dim, width);
257+ params.conv_state_indices_ptr = nullptr ;
258+ }
259+
244260 // Otherwise the kernel will be launched from cuda:0 device
245261 // Cast to char to avoid compiler warning about narrowing
246262 at::cuda::CUDAGuard device_guard{(char )x.get_device ()};
@@ -646,8 +662,16 @@ void causal_conv1d_update_kernel(ConvParamsBase params) {
646662 const int channel_id = blockIdx .y * kNThreads + tidx;
647663 input_t *x = reinterpret_cast <input_t *>(params.x_ptr ) + batch_id * params.x_batch_stride
648664 + channel_id * params.x_c_stride ;
649- input_t *conv_state = reinterpret_cast <input_t *>(params.conv_state_ptr ) + batch_id * params.conv_state_batch_stride
665+
666+ // If params.conv_state_batch_indices is set, then the conv state is gathered from the conv state tensor
667+ // along the batch axis. Otherwise, the conv state coordinate is the same as the batch id.
668+ const int conv_state_batch_coord = params.conv_state_indices_ptr == nullptr
669+ ? batch_id
670+ : params.conv_state_indices_ptr [batch_id];
671+ input_t *conv_state = reinterpret_cast <input_t *>(params.conv_state_ptr )
672+ + conv_state_batch_coord * params.conv_state_batch_stride
650673 + channel_id * params.conv_state_c_stride ;
674+
651675 weight_t *weight = reinterpret_cast <weight_t *>(params.weight_ptr ) + channel_id * params.weight_c_stride ;
652676 input_t *out = reinterpret_cast <input_t *>(params.out_ptr ) + batch_id * params.out_batch_stride
653677 + channel_id * params.out_c_stride ;
0 commit comments