diff --git a/nemo/collections/asr/parts/submodules/cuda_graph_rnnt_greedy_decoding.py b/nemo/collections/asr/parts/submodules/cuda_graph_rnnt_greedy_decoding.py index aa49435ded16..90c2dd545ec5 100644 --- a/nemo/collections/asr/parts/submodules/cuda_graph_rnnt_greedy_decoding.py +++ b/nemo/collections/asr/parts/submodules/cuda_graph_rnnt_greedy_decoding.py @@ -27,6 +27,7 @@ from nemo.collections.asr.parts.utils import rnnt_utils from nemo.core.utils.cuda_python_utils import ( check_cuda_python_cuda_graphs_conditional_nodes_supported, + checked_graph, cu_call, run_nvrtc, with_conditional_node, @@ -174,7 +175,7 @@ def _reinitialize(self, max_time, batch_size, encoder_output, encoder_output_len with ( torch.cuda.stream(stream_for_graph), torch.inference_mode(), - torch.cuda.graph(self.graph, stream=stream_for_graph, capture_error_mode="thread_local"), + checked_graph(self.graph, stream=stream_for_graph, capture_error_mode="thread_local"), ): # This is failing... self.f = torch.zeros( diff --git a/nemo/collections/asr/parts/submodules/rnnt_loop_labels_computer.py b/nemo/collections/asr/parts/submodules/rnnt_loop_labels_computer.py index c0783c301c44..7dfb4a020c83 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_loop_labels_computer.py +++ b/nemo/collections/asr/parts/submodules/rnnt_loop_labels_computer.py @@ -25,6 +25,7 @@ from nemo.collections.common.parts.optional_cuda_graphs import WithOptionalCudaGraphs from nemo.core.utils.cuda_python_utils import ( check_cuda_python_cuda_graphs_conditional_nodes_supported, + checked_graph, cu_call, run_nvrtc, with_conditional_node, @@ -630,7 +631,7 @@ def _partial_graphs_compile(self): with ( torch.cuda.stream(stream_for_graph), torch.inference_mode(), - torch.cuda.graph( + checked_graph( self.separate_graphs.before_outer_loop, stream=stream_for_graph, capture_error_mode="thread_local" ), ): @@ -639,7 +640,7 @@ def _partial_graphs_compile(self): with ( torch.cuda.stream(stream_for_graph), torch.inference_mode(), - torch.cuda.graph( + checked_graph( self.separate_graphs.before_inner_loop, stream=stream_for_graph, capture_error_mode="thread_local" ), ): @@ -649,7 +650,7 @@ def _partial_graphs_compile(self): with ( torch.cuda.stream(stream_for_graph), torch.inference_mode(), - torch.cuda.graph( + checked_graph( self.separate_graphs.inner_loop_code, stream=stream_for_graph, capture_error_mode="thread_local" ), ): @@ -658,7 +659,7 @@ def _partial_graphs_compile(self): with ( torch.cuda.stream(stream_for_graph), torch.inference_mode(), - torch.cuda.graph( + checked_graph( self.separate_graphs.after_inner_loop, stream=stream_for_graph, capture_error_mode="thread_local" ), ): @@ -672,7 +673,7 @@ def _full_graph_compile(self): with ( torch.cuda.stream(stream_for_graph), torch.inference_mode(), - torch.cuda.graph(self.full_graph, stream=stream_for_graph, capture_error_mode="thread_local"), + checked_graph(self.full_graph, stream=stream_for_graph, capture_error_mode="thread_local"), ): self._before_outer_loop() diff --git a/nemo/collections/asr/parts/submodules/tdt_loop_labels_computer.py b/nemo/collections/asr/parts/submodules/tdt_loop_labels_computer.py index 4132c453d570..0255b6f2d831 100644 --- a/nemo/collections/asr/parts/submodules/tdt_loop_labels_computer.py +++ b/nemo/collections/asr/parts/submodules/tdt_loop_labels_computer.py @@ -26,6 +26,7 @@ from nemo.collections.common.parts.optional_cuda_graphs import WithOptionalCudaGraphs from nemo.core.utils.cuda_python_utils import ( check_cuda_python_cuda_graphs_conditional_nodes_supported, + checked_graph, cu_call, run_nvrtc, with_conditional_node, @@ -691,7 +692,7 @@ def _partial_graphs_compile(self): with ( torch.cuda.stream(stream_for_graph), torch.inference_mode(), - torch.cuda.graph( + checked_graph( self.separate_graphs.before_outer_loop, stream=stream_for_graph, capture_error_mode="thread_local" ), ): @@ -700,7 +701,7 @@ def _partial_graphs_compile(self): with ( torch.cuda.stream(stream_for_graph), torch.inference_mode(), - torch.cuda.graph( + checked_graph( self.separate_graphs.before_inner_loop, stream=stream_for_graph, capture_error_mode="thread_local" ), ): @@ -710,7 +711,7 @@ def _partial_graphs_compile(self): with ( torch.cuda.stream(stream_for_graph), torch.inference_mode(), - torch.cuda.graph( + checked_graph( self.separate_graphs.inner_loop_code, stream=stream_for_graph, capture_error_mode="thread_local" ), ): @@ -719,7 +720,7 @@ def _partial_graphs_compile(self): with ( torch.cuda.stream(stream_for_graph), torch.inference_mode(), - torch.cuda.graph( + checked_graph( self.separate_graphs.after_inner_loop, stream=stream_for_graph, capture_error_mode="thread_local" ), ): @@ -734,7 +735,7 @@ def _full_graph_compile(self): with ( torch.cuda.stream(stream_for_graph), torch.inference_mode(), - torch.cuda.graph(self.full_graph, stream=stream_for_graph, capture_error_mode="thread_local"), + checked_graph(self.full_graph, stream=stream_for_graph, capture_error_mode="thread_local"), ): self._before_outer_loop() diff --git a/nemo/core/utils/cuda_python_utils.py b/nemo/core/utils/cuda_python_utils.py index eb8897df0797..253b69b19436 100644 --- a/nemo/core/utils/cuda_python_utils.py +++ b/nemo/core/utils/cuda_python_utils.py @@ -95,6 +95,28 @@ def cu_call(f_call_out): return tuple(others) +def cuda_python_conditional_node_cooperative_kernels_supported(): + """ + Returns true if cuda-python is installed and CUDA driver 12.6 or newer is + installed. Before this CUDA driver version, cooperative nodes could not run + within cuda graph conditional nodes. + """ + try: + check_cuda_python_cuda_graphs_conditional_nodes_supported() + except: + return False + else: + from cuda import cuda + + error, driver_version = cuda.cuDriverGetVersion() + if error != cuda.CUresult.CUDA_SUCCESS: + raise ImportError(f"cuDriverGetVersion() returned {cuda.cuGetErrorString(error)}") + driver_version_major = driver_version // 1000 + driver_version_minor = (driver_version % 1000) // 10 + driver_version = (driver_version_major, driver_version_minor) + return driver_version >= (12, 6) + + @contextlib.contextmanager def with_conditional_node(while_loop_kernel, while_loop_args, while_loop_conditional_handle, device): """ @@ -219,3 +241,19 @@ def run_nvrtc(kernel_string: str, kernel_name: bytes, program_name: bytes): assert_drv(err) return kernel + + +@contextlib.contextmanager +def checked_graph(*args, **kwargs): + """ + Wrapper around torch.cuda.graph that checks for common errors that are too vague for an end user to diagnose based on the error message. + """ + try: + with torch.cuda.graph(*args, **kwargs): + yield + except RuntimeError as err: + if "CUDA error: invalid argument" in str(err): + raise RuntimeError( + "CUDA Graph capture failed. It is likely that you are calling a cooperative kernel in your RNN-T or TDT prediction network. Cooperative kernels are not allowed inside the bodies of CUDA Graph conditional nodes until CUDA 12.6. Please update to CUDA 12.6. File an issue if that still does not work." + ) from err + raise diff --git a/tests/collections/asr/decoding/test_cuda_graph_rnnt_greedy_decoding.py b/tests/collections/asr/decoding/test_cuda_graph_rnnt_greedy_decoding.py index 31fe822573ce..696cc8f2aa41 100644 --- a/tests/collections/asr/decoding/test_cuda_graph_rnnt_greedy_decoding.py +++ b/tests/collections/asr/decoding/test_cuda_graph_rnnt_greedy_decoding.py @@ -20,7 +20,10 @@ from omegaconf import open_dict from nemo.collections.asr.models import ASRModel -from nemo.core.utils.cuda_python_utils import skip_cuda_python_test_if_cuda_graphs_conditional_nodes_not_supported +from nemo.core.utils.cuda_python_utils import ( + cuda_python_conditional_node_cooperative_kernels_supported, + skip_cuda_python_test_if_cuda_graphs_conditional_nodes_not_supported, +) @pytest.fixture(scope="module") @@ -53,9 +56,10 @@ def stt_en_fastconformer_transducer_large(): 8, True, marks=pytest.mark.xfail( - reason="""Cannot instantiate the -body cuda graph of a conditional node with a persistent kernel (in this case, -a persistent LSTM), which is triggered in cudnn by using a batch size of 8.""" + not cuda_python_conditional_node_cooperative_kernels_supported(), + reason="""Cannot instantiate the +body cuda graph of a conditional node with a persistent kernel (in this case, +a persistent LSTM), which is triggered in cudnn by using a batch size of 8.""", ), ), ],