12
12
IS_WINDOWS ,
13
13
is_suppressed_dll_file ,
14
14
)
15
- from cuda .pathfinder ._utils .find_sub_dirs import find_sub_dirs_all_sitepackages
15
+ from cuda .pathfinder ._utils .find_sub_dirs import find_sub_dirs , find_sub_dirs_all_sitepackages
16
16
17
17
18
18
def _no_such_file_in_sub_dirs (
@@ -28,18 +28,21 @@ def _no_such_file_in_sub_dirs(
28
28
def _find_so_using_nvidia_lib_dirs (
29
29
libname : str , so_basename : str , error_messages : list [str ], attachments : list [str ]
30
30
) -> Optional [str ]:
31
- nvidia_sub_dirs = ("nvidia" , "*" , "nvvm" , "lib64" ) if libname == "nvvm" else ("nvidia" , "*" , "lib" )
32
31
file_wild = so_basename + "*"
33
- for lib_dir in find_sub_dirs_all_sitepackages (nvidia_sub_dirs ):
34
- # First look for an exact match
35
- so_name = os .path .join (lib_dir , so_basename )
36
- if os .path .isfile (so_name ):
37
- return so_name
38
- # Look for a versioned library
39
- # Using sort here mainly to make the result deterministic.
40
- for so_name in sorted (glob .glob (os .path .join (lib_dir , file_wild ))):
32
+ nvidia_sub_dirs_list : list [tuple [str , ...]] = [("nvidia" , "*" , "lib" )] # works also for CTK 13 nvvm
33
+ if libname == "nvvm" :
34
+ nvidia_sub_dirs_list .append (("nvidia" , "*" , "nvvm" , "lib64" )) # CTK 12
35
+ for nvidia_sub_dirs in nvidia_sub_dirs_list :
36
+ for lib_dir in find_sub_dirs_all_sitepackages (nvidia_sub_dirs ):
37
+ # First look for an exact match
38
+ so_name = os .path .join (lib_dir , so_basename )
41
39
if os .path .isfile (so_name ):
42
40
return so_name
41
+ # Look for a versioned library
42
+ # Using sort here mainly to make the result deterministic.
43
+ for so_name in sorted (glob .glob (os .path .join (lib_dir , file_wild ))):
44
+ if os .path .isfile (so_name ):
45
+ return so_name
43
46
_no_such_file_in_sub_dirs (nvidia_sub_dirs , file_wild , error_messages , attachments )
44
47
return None
45
48
@@ -56,11 +59,17 @@ def _find_dll_under_dir(dirpath: str, file_wild: str) -> Optional[str]:
56
59
def _find_dll_using_nvidia_bin_dirs (
57
60
libname : str , lib_searched_for : str , error_messages : list [str ], attachments : list [str ]
58
61
) -> Optional [str ]:
59
- nvidia_sub_dirs = ("nvidia" , "*" , "nvvm" , "bin" ) if libname == "nvvm" else ("nvidia" , "*" , "bin" )
60
- for bin_dir in find_sub_dirs_all_sitepackages (nvidia_sub_dirs ):
61
- dll_name = _find_dll_under_dir (bin_dir , lib_searched_for )
62
- if dll_name is not None :
63
- return dll_name
62
+ nvidia_sub_dirs_list : list [tuple [str , ...]] = [
63
+ ("nvidia" , "*" , "bin" ), # CTK 12
64
+ ("nvidia" , "*" , "bin" , "*" ), # CTK 13, e.g. site-packages\nvidia\cu13\bin\x86_64\
65
+ ]
66
+ if libname == "nvvm" :
67
+ nvidia_sub_dirs_list .append (("nvidia" , "*" , "nvvm" , "bin" )) # Only for CTK 12
68
+ for nvidia_sub_dirs in nvidia_sub_dirs_list :
69
+ for bin_dir in find_sub_dirs_all_sitepackages (nvidia_sub_dirs ):
70
+ dll_name = _find_dll_under_dir (bin_dir , lib_searched_for )
71
+ if dll_name is not None :
72
+ return dll_name
64
73
_no_such_file_in_sub_dirs (nvidia_sub_dirs , lib_searched_for , error_messages , attachments )
65
74
return None
66
75
@@ -76,21 +85,29 @@ def _find_lib_dir_using_cuda_home(libname: str) -> Optional[str]:
76
85
cuda_home = _get_cuda_home ()
77
86
if cuda_home is None :
78
87
return None
79
- subdirs : tuple [str , ...]
88
+ subdirs_list : tuple [tuple [ str , ...] , ...]
80
89
if IS_WINDOWS :
81
- subdirs = (os .path .join ("nvvm" , "bin" ),) if libname == "nvvm" else ("bin" ,)
90
+ if libname == "nvvm" : # noqa: SIM108
91
+ subdirs_list = (
92
+ ("nvvm" , "bin" , "*" ), # CTK 13
93
+ ("nvvm" , "bin" ), # CTK 12
94
+ )
95
+ else :
96
+ subdirs_list = (
97
+ ("bin" , "x64" ), # CTK 13
98
+ ("bin" ,), # CTK 12
99
+ )
82
100
else :
83
- subdirs = (
84
- ( os . path . join ("nvvm" , "lib64" ),)
85
- if libname == "nvvm"
86
- else (
87
- "lib64" , # CTK
88
- "lib" , # Conda
101
+ if libname == "nvvm" : # noqa: SIM108
102
+ subdirs_list = ( ("nvvm" , "lib64" ),)
103
+ else :
104
+ subdirs_list = (
105
+ ( "lib64" ,) , # CTK
106
+ ( "lib" ,) , # Conda
89
107
)
90
- )
91
- for subdir in subdirs :
92
- dirname = os .path .join (cuda_home , subdir )
93
- if os .path .isdir (dirname ):
108
+ for sub_dirs in subdirs_list :
109
+ dirname : str # work around bug in mypy
110
+ for dirname in find_sub_dirs ((cuda_home ,), sub_dirs ):
94
111
return dirname
95
112
return None
96
113
0 commit comments