55import subprocess
66from setuptools import setup , find_packages
77import torch
8- from torch .utils .cpp_extension import BuildExtension , CUDAExtension , include_paths , CppExtension
8+ from torch .utils .cpp_extension import (
9+ BuildExtension ,
10+ CUDAExtension ,
11+ include_paths ,
12+ CppExtension ,
13+ )
914import os
1015import sys
1116
12- is_windows = sys .platform == ' win32'
17+ is_windows = sys .platform == " win32"
1318
1419try :
1520 version = (
1621 subprocess .check_output (["git" , "describe" , "--abbrev=0" , "--tags" ])
1722 .strip ()
1823 .decode ("utf-8" )
1924 )
20- except :
25+ except Exception :
2126 print ("Failed to retrieve the current version, defaulting to 0" )
2227 version = "0"
23- # If CPU_ONLY is defined
24- force_cpu_only = os .environ .get ("CPU_ONLY" , None ) is not None
25- use_cuda = torch .cuda ._is_compiled () if not force_cpu_only else False
28+
29+ # If WITH_CUDA is defined
30+ if os .environ .get ("WITH_CUDA" , "0" ) == "1" :
31+ use_cuda = True
32+ else :
33+ use_cuda = torch .cuda ._is_compiled ()
34+
35+
2636def set_torch_cuda_arch_list ():
27- """ Set the CUDA arch list according to the architectures the current torch installation was compiled for.
37+ """Set the CUDA arch list according to the architectures the current torch installation was compiled for.
2838 This function is a no-op if the environment variable TORCH_CUDA_ARCH_LIST is already set or if torch was not compiled with CUDA support.
2939 """
3040 if not os .environ .get ("TORCH_CUDA_ARCH_LIST" ):
@@ -35,20 +45,24 @@ def set_torch_cuda_arch_list():
3545 formatted_versions += "+PTX"
3646 os .environ ["TORCH_CUDA_ARCH_LIST" ] = formatted_versions
3747
48+
3849set_torch_cuda_arch_list ()
3950
40- extension_root = os .path .join ("torchmdnet" , "extensions" )
41- neighbor_sources = ["neighbors_cpu.cpp" ]
51+ extension_root = os .path .join ("torchmdnet" , "extensions" )
52+ neighbor_sources = ["neighbors_cpu.cpp" ]
4253if use_cuda :
4354 neighbor_sources .append ("neighbors_cuda.cu" )
44- neighbor_sources = [os .path .join (extension_root , "neighbors" , source ) for source in neighbor_sources ]
55+ neighbor_sources = [
56+ os .path .join (extension_root , "neighbors" , source ) for source in neighbor_sources
57+ ]
4558
4659ExtensionType = CppExtension if not use_cuda else CUDAExtension
4760extensions = ExtensionType (
48- name = 'torchmdnet.extensions.torchmdnet_extensions' ,
49- sources = [os .path .join (extension_root , "torchmdnet_extensions.cpp" )] + neighbor_sources ,
61+ name = "torchmdnet.extensions.torchmdnet_extensions" ,
62+ sources = [os .path .join (extension_root , "torchmdnet_extensions.cpp" )]
63+ + neighbor_sources ,
5064 include_dirs = include_paths (),
51- define_macros = [(' WITH_CUDA' , 1 )] if use_cuda else [],
65+ define_macros = [(" WITH_CUDA" , 1 )] if use_cuda else [],
5266)
5367
5468if __name__ == "__main__" :
@@ -58,8 +72,19 @@ def set_torch_cuda_arch_list():
5872 packages = find_packages (),
5973 ext_modules = [extensions ],
6074 cmdclass = {
61- 'build_ext' : BuildExtension .with_options (no_python_abi_suffix = True , use_ninja = False )},
75+ "build_ext" : BuildExtension .with_options (
76+ no_python_abi_suffix = True , use_ninja = False
77+ )
78+ },
6279 include_package_data = True ,
63- entry_points = {"console_scripts" : ["torchmd-train = torchmdnet.scripts.train:main" ]},
64- package_data = {"torchmdnet" : ["extensions/torchmdnet_extensions.so" ] if not is_windows else ["extensions/torchmdnet_extensions.dll" ]},
80+ entry_points = {
81+ "console_scripts" : ["torchmd-train = torchmdnet.scripts.train:main" ]
82+ },
83+ package_data = {
84+ "torchmdnet" : (
85+ ["extensions/torchmdnet_extensions.so" ]
86+ if not is_windows
87+ else ["extensions/torchmdnet_extensions.dll" ]
88+ )
89+ },
6590 )
0 commit comments