11#include < c10/cuda/CUDAStream.h>
22#include < libtorchaudio/rnnt/gpu/gpu_transducer.h>
33#include < torch/csrc/inductor/aoti_torch/c/shim.h>
4- #include < torch/csrc/inductor/aoti_runtime/utils.h>
54#include < torch/csrc/stable/library.h>
65#include < torch/types.h>
7- #include < torch/csrc/inductor/aoti_torch/utils.h>
86
97namespace torchaudio {
108namespace rnnt {
@@ -115,7 +113,7 @@ std::tuple<Tensor, Tensor> compute(
115113 options.fusedLogSmax_ = fused_log_softmax;
116114
117115 int32_t logits_device_index;
118- aoti_torch_get_device_index (logits.get (), &logits_device_index);
116+ TORCH_ERROR_CODE_CHECK ( aoti_torch_get_device_index (logits.get (), &logits_device_index) );
119117
120118 TORCH_CHECK_EQ (logits_device, aoti_torch_device_type_cuda ());
121119
@@ -144,22 +142,16 @@ std::tuple<Tensor, Tensor> compute(
144142 TORCH_ERROR_CODE_CHECK (
145143 aoti_torch_empty_strided (1 , int_sizes, strides, aoti_torch_dtype_int32 (), logits_device, logits_device_index, &int_workspace));
146144
147- // torch::Tensor int_workspace = torch::empty(
148- // IntWorkspace::ComputeSizeFromOptions(options),
149- // torch::TensorOptions()
150- // .device(torch::aot_inductor::tensor_handle_to_tensor_pointer(logits.get())->device())
151- // .dtype(torch::ScalarType::Int));
152-
153145 AtenTensorHandle float_workspace;
154146 int64_t float_sizes[1 ] = {DtypeWorkspace<float >::ComputeSizeFromOptions (options)};
155147 TORCH_ERROR_CODE_CHECK (
156148 aoti_torch_empty_strided (1 , float_sizes, strides, aoti_torch_dtype_float32 (), logits_device, logits_device_index, &float_workspace));
157149
158150 int64_t float_numel;
159151 aoti_torch_get_numel (float_workspace, &float_numel);
160- // void *int_workspace_ptr;
161- // TORCH_ERROR_CODE_CHECK(
162- // aoti_torch_get_data_ptr(int_workspace, &int_workspace_ptr));
152+ void *int_workspace_ptr;
153+ TORCH_ERROR_CODE_CHECK (
154+ aoti_torch_get_data_ptr (int_workspace, &int_workspace_ptr));
163155 void *float_workspace_ptr;
164156 TORCH_ERROR_CODE_CHECK (
165157 aoti_torch_get_data_ptr (float_workspace, &float_workspace_ptr));
@@ -176,43 +168,43 @@ std::tuple<Tensor, Tensor> compute(
176168 /* int_size=*/ int_numel
177169 );
178170 at::cuda::stream_synchronize (options.stream_ );
179- // void *logit_ptr;
180- // aoti_torch_get_data_ptr(logits.get(), &logit_ptr);
181-
182- // void *target_ptr;
183- // aoti_torch_get_data_ptr(targets.get(), &target_ptr);
184-
185- // void *logit_len_ptr;
186- // aoti_torch_get_data_ptr(logit_lengths.get(), &logit_len_ptr);
187-
188- // void *target_len_ptr;
189- // aoti_torch_get_data_ptr(target_lengths.get(), &target_len_ptr);
190-
191- // void *costs_ptr;
192- // aoti_torch_get_data_ptr(costs, &costs_ptr);
193-
194- // void *grads_ptr;
195- // aoti_torch_get_data_ptr(gradients, &grads_ptr);
196-
197- // if (logits_dtype == aoti_torch_dtype_float32()) {
198- // Compute</*DTYPE=*/float, /*CAST_DTYPE=*/float>(
199- // /*workspace=*/workspace,
200- // /*logits=*/(float*)logit_ptr,
201- // /*targets=*/(int*)target_ptr,
202- // /*logit_lengths=*/(int*)logit_len_ptr,
203- // /*target_lengths=*/(int*)target_len_ptr,
204- // /*costs=*/(float*)costs_ptr,
205- // /*gradients=*/(float*)grads_ptr);
206- // } else {
207- // Compute</*DTYPE=*/c10::Half, /*CAST_DTYPE=*/float>(
208- // /*workspace=*/workspace,
209- // /*logits=*/(c10::Half*)logit_ptr,
210- // /*targets=*/(int*)target_ptr,
211- // /*logit_lengths=*/(int*)logit_len_ptr,
212- // /*target_lengths=*/(int*)target_len_ptr,
213- // /*costs=*/(c10::Half*)costs_ptr,
214- // /*gradients=*/(c10::Half*)grads_ptr);
215- // }
171+ void *logit_ptr;
172+ aoti_torch_get_data_ptr (logits.get (), &logit_ptr);
173+
174+ void *target_ptr;
175+ aoti_torch_get_data_ptr (targets.get (), &target_ptr);
176+
177+ void *logit_len_ptr;
178+ aoti_torch_get_data_ptr (logit_lengths.get (), &logit_len_ptr);
179+
180+ void *target_len_ptr;
181+ aoti_torch_get_data_ptr (target_lengths.get (), &target_len_ptr);
182+
183+ void *costs_ptr;
184+ aoti_torch_get_data_ptr (costs, &costs_ptr);
185+
186+ void *grads_ptr;
187+ aoti_torch_get_data_ptr (gradients, &grads_ptr);
188+
189+ if (logits_dtype == aoti_torch_dtype_float32 ()) {
190+ Compute</* DTYPE=*/ float , /* CAST_DTYPE=*/ float >(
191+ /* workspace=*/ workspace,
192+ /* logits=*/ (float *)logit_ptr,
193+ /* targets=*/ (int *)target_ptr,
194+ /* logit_lengths=*/ (int *)logit_len_ptr,
195+ /* target_lengths=*/ (int *)target_len_ptr,
196+ /* costs=*/ (float *)costs_ptr,
197+ /* gradients=*/ (float *)grads_ptr);
198+ } else {
199+ Compute</* DTYPE=*/ c10::Half, /* CAST_DTYPE=*/ float >(
200+ /* workspace=*/ workspace,
201+ /* logits=*/ (c10::Half*)logit_ptr,
202+ /* targets=*/ (int *)target_ptr,
203+ /* logit_lengths=*/ (int *)logit_len_ptr,
204+ /* target_lengths=*/ (int *)target_len_ptr,
205+ /* costs=*/ (c10::Half*)costs_ptr,
206+ /* gradients=*/ (c10::Half*)grads_ptr);
207+ }
216208
217209 return std::make_tuple (Tensor (costs), Tensor (gradients));
218210}
0 commit comments