Skip to content

Make build from source easier on unsupported hardware #632

@gau-nernst

Description

@gau-nernst

Compiling CUDA extensions on unsupported hardware (i.e. compute capability is too low) will error out. Reported by @danielpatrickhug on CUDA-MODE Discord.

Error logs
  ⚡ main ~/ao python setup.py develop
  running develop
  /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/setuptools/command/develop.py:42: EasyInstallDeprecationWarning: easy_install command is deprecated.
  !!
  
          ********************************************************************************
          Please avoid running ``setup.py`` and ``easy_install``.
          Instead, use pypa/build, pypa/installer or other
          standards-based tools.
  
          See https://github.com/pypa/setuptools/issues/917 for details.
          ********************************************************************************
  
  !!
    easy_install.initialize_options(self)
  /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/setuptools/_distutils/cmd.py:66: SetuptoolsDeprecationWarning: setup.py install is deprecated.
  !!
  
          ********************************************************************************
          Please avoid running ``setup.py`` directly.
          Instead, use pypa/build, pypa/installer or other
          standards-based tools.
  
          See https://blog.ganssle.io/articles/2021/10/setup-py-deprecated.html for details.
          ********************************************************************************
  
  !!
    self.initialize_options()
  running egg_info
  writing torchao.egg-info/PKG-INFO
  writing dependency_links to torchao.egg-info/dependency_links.txt
  writing requirements to torchao.egg-info/requires.txt
  writing top-level names to torchao.egg-info/top_level.txt
  reading manifest file 'torchao.egg-info/SOURCES.txt'
  adding license file 'LICENSE'
  writing manifest file 'torchao.egg-info/SOURCES.txt'
  running build_ext
  /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/cpp_extension.py:424: UserWarning: There are no g++ version bounds defined for CUDA version 12.1
    warnings.warn(f'There are no {compiler_name} version bounds defined for CUDA version {cuda_str_version}')
  building 'torchao._C' extension
  creating /teamspace/studios/this_studio/ao/build/temp.linux-x86_64-cpython-312
  creating /teamspace/studios/this_studio/ao/build/temp.linux-x86_64-cpython-312/torchao
  creating /teamspace/studios/this_studio/ao/build/temp.linux-x86_64-cpython-312/torchao/csrc
  creating /teamspace/studios/this_studio/ao/build/temp.linux-x86_64-cpython-312/torchao/csrc/cuda
  creating /teamspace/studios/this_studio/ao/build/temp.linux-x86_64-cpython-312/torchao/csrc/cuda/fp6_llm
  creating /teamspace/studios/this_studio/ao/build/temp.linux-x86_64-cpython-312/torchao/csrc/cuda/tensor_core_tiled_layout
  Emitting ninja build file /teamspace/studios/this_studio/ao/build/temp.linux-x86_64-cpython-312/build.ninja...
  Compiling objects...
  Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
  [1/5] c++ -MMD -MF /teamspace/studios/this_studio/ao/build/temp.linux-x86_64-cpython-312/torchao/csrc/tensor_core_tiled_layout.o.d -pthread -B /home/zeus/miniconda3/envs/cloudspace/compiler_compat -fno-strict-overflow -DNDEBUG -O2 -Wall -fPIC -O2 -isystem /home/zeus/miniconda3/envs/cloudspace/include -fPIC -O2 -isystem /home/zeus/miniconda3/envs/cloudspace/include -fPIC -I/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/include -I/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/include/TH -I/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/include/THC -I/usr/local/cuda/include -I/home/zeus/miniconda3/envs/cloudspace/include/python3.12 -c -c /teamspace/studios/this_studio/ao/torchao/csrc/tensor_core_tiled_layout.cpp -o /teamspace/studios/this_studio/ao/build/temp.linux-x86_64-cpython-312/torchao/csrc/tensor_core_tiled_layout.o -O3 -fdiagnostics-color=always -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=_C -D_GLIBCXX_USE_CXX11_ABI=0 -std=c++17
  [2/5] c++ -MMD -MF /teamspace/studios/this_studio/ao/build/temp.linux-x86_64-cpython-312/torchao/csrc/fp6_llm.o.d -pthread -B /home/zeus/miniconda3/envs/cloudspace/compiler_compat -fno-strict-overflow -DNDEBUG -O2 -Wall -fPIC -O2 -isystem /home/zeus/miniconda3/envs/cloudspace/include -fPIC -O2 -isystem /home/zeus/miniconda3/envs/cloudspace/include -fPIC -I/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/include -I/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/include/TH -I/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/include/THC -I/usr/local/cuda/include -I/home/zeus/miniconda3/envs/cloudspace/include/python3.12 -c -c /teamspace/studios/this_studio/ao/torchao/csrc/fp6_llm.cpp -o /teamspace/studios/this_studio/ao/build/temp.linux-x86_64-cpython-312/torchao/csrc/fp6_llm.o -O3 -fdiagnostics-color=always -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=_C -D_GLIBCXX_USE_CXX11_ABI=0 -std=c++17
  [3/5] c++ -MMD -MF /teamspace/studios/this_studio/ao/build/temp.linux-x86_64-cpython-312/torchao/csrc/init.o.d -pthread -B /home/zeus/miniconda3/envs/cloudspace/compiler_compat -fno-strict-overflow -DNDEBUG -O2 -Wall -fPIC -O2 -isystem /home/zeus/miniconda3/envs/cloudspace/include -fPIC -O2 -isystem /home/zeus/miniconda3/envs/cloudspace/include -fPIC -I/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/include -I/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/include/TH -I/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/include/THC -I/usr/local/cuda/include -I/home/zeus/miniconda3/envs/cloudspace/include/python3.12 -c -c /teamspace/studios/this_studio/ao/torchao/csrc/init.cpp -o /teamspace/studios/this_studio/ao/build/temp.linux-x86_64-cpython-312/torchao/csrc/init.o -O3 -fdiagnostics-color=always -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=_C -D_GLIBCXX_USE_CXX11_ABI=0 -std=c++17
  [4/5] /usr/local/cuda/bin/nvcc --generate-dependencies-with-compile --dependency-output /teamspace/studios/this_studio/ao/build/temp.linux-x86_64-cpython-312/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.o.d -I/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/include -I/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/include/TH -I/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/include/THC -I/usr/local/cuda/include -I/home/zeus/miniconda3/envs/cloudspace/include/python3.12 -c -c /teamspace/studios/this_studio/ao/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu -o /teamspace/studios/this_studio/ao/build/temp.linux-x86_64-cpython-312/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -O3 -t=0 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=_C -D_GLIBCXX_USE_CXX11_ABI=0 -gencode=arch=compute_50,code=sm_50 -gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_86,code=sm_86 -gencode=arch=compute_90,code=sm_90 -std=c++17
  FAILED: /teamspace/studios/this_studio/ao/build/temp.linux-x86_64-cpython-312/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.o 
  /usr/local/cuda/bin/nvcc --generate-dependencies-with-compile --dependency-output /teamspace/studios/this_studio/ao/build/temp.linux-x86_64-cpython-312/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.o.d -I/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/include -I/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/include/TH -I/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/include/THC -I/usr/local/cuda/include -I/home/zeus/miniconda3/envs/cloudspace/include/python3.12 -c -c /teamspace/studios/this_studio/ao/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu -o /teamspace/studios/this_studio/ao/build/temp.linux-x86_64-cpython-312/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -O3 -t=0 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=_C -D_GLIBCXX_USE_CXX11_ABI=0 -gencode=arch=compute_50,code=sm_50 -gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_86,code=sm_86 -gencode=arch=compute_90,code=sm_90 -std=c++17
  /teamspace/studios/this_studio/ao/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu(127): error: identifier "__bfloat162bfloat162" is undefined
          __nv_bfloat162 scale2 = __bfloat162bfloat162(pSZ[0]);
                                  ^
  
  /teamspace/studios/this_studio/ao/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu(132): error: no suitable user-defined conversion from "__nv_bfloat162" to "__half2" exists
            reinterpret_cast<__nv_bfloat162*>(&pOut[ks[i]])[0] = __hfma2(v_bf16x2x4.vals[i], scale2, zero2);
                                                                         ^
  
  /teamspace/studios/this_studio/ao/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu(132): error: no suitable user-defined conversion from "__nv_bfloat162" to "__half2" exists
            reinterpret_cast<__nv_bfloat162*>(&pOut[ks[i]])[0] = __hfma2(v_bf16x2x4.vals[i], scale2, zero2);
                                                                                             ^
  
  /teamspace/studios/this_studio/ao/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu(132): error: no suitable user-defined conversion from "__nv_bfloat162" to "__half2" exists
            reinterpret_cast<__nv_bfloat162*>(&pOut[ks[i]])[0] = __hfma2(v_bf16x2x4.vals[i], scale2, zero2);
                                                                                                     ^
  
  4 errors detected in the compilation of "/teamspace/studios/this_studio/ao/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu".
  /teamspace/studios/this_studio/ao/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu(127): error: identifier "__bfloat162bfloat162" is undefined
          __nv_bfloat162 scale2 = __bfloat162bfloat162(pSZ[0]);
                                  ^
  
  /teamspace/studios/this_studio/ao/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu(132): error: identifier "__hfma2" is undefined
            reinterpret_cast<__nv_bfloat162*>(&pOut[ks[i]])[0] = __hfma2(v_bf16x2x4.vals[i], scale2, zero2);
                                                                 ^
  
  2 errors detected in the compilation of "/teamspace/studios/this_studio/ao/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu".
  /teamspace/studios/this_studio/ao/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu(127): error: identifier "__bfloat162bfloat162" is undefined
          __nv_bfloat162 scale2 = __bfloat162bfloat162(pSZ[0]);
                                  ^
  
  /teamspace/studios/this_studio/ao/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu(132): error: no suitable user-defined conversion from "__nv_bfloat162" to "__half2" exists
            reinterpret_cast<__nv_bfloat162*>(&pOut[ks[i]])[0] = __hfma2(v_bf16x2x4.vals[i], scale2, zero2);
                                                                         ^
  
  /teamspace/studios/this_studio/ao/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu(132): error: no suitable user-defined conversion from "__nv_bfloat162" to "__half2" exists
            reinterpret_cast<__nv_bfloat162*>(&pOut[ks[i]])[0] = __hfma2(v_bf16x2x4.vals[i], scale2, zero2);
                                                                                             ^
  
  /teamspace/studios/this_studio/ao/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu(132): error: no suitable user-defined conversion from "__nv_bfloat162" to "__half2" exists
            reinterpret_cast<__nv_bfloat162*>(&pOut[ks[i]])[0] = __hfma2(v_bf16x2x4.vals[i], scale2, zero2);
                                                                                                     ^
  
  4 errors detected in the compilation of "/teamspace/studios/this_studio/ao/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu".
  /teamspace/studios/this_studio/ao/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu(127): error: identifier "__bfloat162bfloat162" is undefined
          __nv_bfloat162 scale2 = __bfloat162bfloat162(pSZ[0]);
                                  ^
  
  /teamspace/studios/this_studio/ao/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu(132): error: no suitable user-defined conversion from "__nv_bfloat162" to "__half2" exists
            reinterpret_cast<__nv_bfloat162*>(&pOut[ks[i]])[0] = __hfma2(v_bf16x2x4.vals[i], scale2, zero2);
                                                                         ^
  
  /teamspace/studios/this_studio/ao/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu(132): error: no suitable user-defined conversion from "__nv_bfloat162" to "__half2" exists
            reinterpret_cast<__nv_bfloat162*>(&pOut[ks[i]])[0] = __hfma2(v_bf16x2x4.vals[i], scale2, zero2);
                                                                                             ^
  
  /teamspace/studios/this_studio/ao/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu(132): error: no suitable user-defined conversion from "__nv_bfloat162" to "__half2" exists
            reinterpret_cast<__nv_bfloat162*>(&pOut[ks[i]])[0] = __hfma2(v_bf16x2x4.vals[i], scale2, zero2);
                                                                                                     ^
  
  4 errors detected in the compilation of "/teamspace/studios/this_studio/ao/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu".
  [5/5] /usr/local/cuda/bin/nvcc --generate-dependencies-with-compile --dependency-output /teamspace/studios/this_studio/ao/build/temp.linux-x86_64-cpython-312/torchao/csrc/cuda/fp6_llm/fp6_linear.o.d -I/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/include -I/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/include/TH -I/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/include/THC -I/usr/local/cuda/include -I/home/zeus/miniconda3/envs/cloudspace/include/python3.12 -c -c /teamspace/studios/this_studio/ao/torchao/csrc/cuda/fp6_llm/fp6_linear.cu -o /teamspace/studios/this_studio/ao/build/temp.linux-x86_64-cpython-312/torchao/csrc/cuda/fp6_llm/fp6_linear.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -O3 -t=0 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=_C -D_GLIBCXX_USE_CXX11_ABI=0 -gencode=arch=compute_50,code=sm_50 -gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_86,code=sm_86 -gencode=arch=compute_90,code=sm_90 -std=c++17
  FAILED: /teamspace/studios/this_studio/ao/build/temp.linux-x86_64-cpython-312/torchao/csrc/cuda/fp6_llm/fp6_linear.o 
  /usr/local/cuda/bin/nvcc --generate-dependencies-with-compile --dependency-output /teamspace/studios/this_studio/ao/build/temp.linux-x86_64-cpython-312/torchao/csrc/cuda/fp6_llm/fp6_linear.o.d -I/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/include -I/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/include/TH -I/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/include/THC -I/usr/local/cuda/include -I/home/zeus/miniconda3/envs/cloudspace/include/python3.12 -c -c /teamspace/studios/this_studio/ao/torchao/csrc/cuda/fp6_llm/fp6_linear.cu -o /teamspace/studios/this_studio/ao/build/temp.linux-x86_64-cpython-312/torchao/csrc/cuda/fp6_llm/fp6_linear.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -O3 -t=0 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=_C -D_GLIBCXX_USE_CXX11_ABI=0 -gencode=arch=compute_50,code=sm_50 -gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_86,code=sm_86 -gencode=arch=compute_90,code=sm_90 -std=c++17
  /teamspace/studios/this_studio/ao/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh(56): error: identifier "__hmul" is undefined
        output_half_ptr[0] = __hmul( __hmul(*FP16_1,__float2half(1.0f*BIAS)), Scale);
                                     ^
            detected during:
              instantiation of "void Dequant_32FP6_4Way<EXPONENT,MANTISSA>(uint32_t (*)[4], uint32_t *, uint32_t *, uint32_t *, uint32_t *) [with EXPONENT=3, MANTISSA=2]" at line 60 of /teamspace/studios/this_studio/ao/torchao/csrc/cuda/fp6_llm/utils_core.cuh
              instantiation of "void initialize_mma_slice<TilingConfig,EXPONENT,MANTISSA>(uint32_t (*)[4], uint32_t (*)[4], uint32_t *, uint32_t *, uint32_t *, half (*)[72], uint32_t *) [with TilingConfig=TilingConfig<4, 1, 1>, EXPONENT=3, MANTISSA=2]" at line 150 of /teamspace/studios/this_studio/ao/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh
              instantiation of "void QUANT_GEMM_Kernel<TilingConfig,OutputDataType,EXPONENT,MANTISSA>(const uint4 *, const half *, const half *, OutputDataType *, size_t, size_t, size_t, int) [with TilingConfig=TilingConfig<4, 1, 1>, OutputDataType=half, EXPONENT=3, MANTISSA=2]" at line 41 of /teamspace/studios/this_studio/ao/torchao/csrc/cuda/fp6_llm/fp6_linear.cu
              instantiation of "void Kernel_Ex<TilingConfig,OutputDataType,EXPONENT,MANTISSA>(cudaStream_t, const uint4 *, const half *, const half *, OutputDataType *, size_t, size_t, size_t, int) [with TilingConfig=TilingConfig<4, 1, 1>, OutputDataType=half, EXPONENT=3, MANTISSA=2]" at line 83 of /teamspace/studios/this_studio/ao/torchao/csrc/cuda/fp6_llm/fp6_linear.cu
              instantiation of "cudaError_t fpx_linear_kernel<EXPONENT,MANTISSA>(cudaStream_t, const uint4 *, const half *, const half *, half *, size_t, size_t, size_t, float *, int) [with EXPONENT=3, MANTISSA=2]" at line 176 of /teamspace/studios/this_studio/ao/torchao/csrc/cuda/fp6_llm/fp6_linear.cu
  
  /teamspace/studios/this_studio/ao/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh(56): error: identifier "__hmul" is undefined
        output_half_ptr[0] = __hmul( __hmul(*FP16_1,__float2half(1.0f*BIAS)), Scale);
                             ^
            detected during:
              instantiation of "void Dequant_32FP6_4Way<EXPONENT,MANTISSA>(uint32_t (*)[4], uint32_t *, uint32_t *, uint32_t *, uint32_t *) [with EXPONENT=3, MANTISSA=2]" at line 60 of /teamspace/studios/this_studio/ao/torchao/csrc/cuda/fp6_llm/utils_core.cuh
              instantiation of "void initialize_mma_slice<TilingConfig,EXPONENT,MANTISSA>(uint32_t (*)[4], uint32_t (*)[4], uint32_t *, uint32_t *, uint32_t *, half (*)[72], uint32_t *) [with TilingConfig=TilingConfig<4, 1, 1>, EXPONENT=3, MANTISSA=2]" at line 150 of /teamspace/studios/this_studio/ao/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh
              instantiation of "void QUANT_GEMM_Kernel<TilingConfig,OutputDataType,EXPONENT,MANTISSA>(const uint4 *, const half *, const half *, OutputDataType *, size_t, size_t, size_t, int) [with TilingConfig=TilingConfig<4, 1, 1>, OutputDataType=half, EXPONENT=3, MANTISSA=2]" at line 41 of /teamspace/studios/this_studio/ao/torchao/csrc/cuda/fp6_llm/fp6_linear.cu
              instantiation of "void Kernel_Ex<TilingConfig,OutputDataType,EXPONENT,MANTISSA>(cudaStream_t, const uint4 *, const half *, const half *, OutputDataType *, size_t, size_t, size_t, int) [with TilingConfig=TilingConfig<4, 1, 1>, OutputDataType=half, EXPONENT=3, MANTISSA=2]" at line 83 of /teamspace/studios/this_studio/ao/torchao/csrc/cuda/fp6_llm/fp6_linear.cu
              instantiation of "cudaError_t fpx_linear_kernel<EXPONENT,MANTISSA>(cudaStream_t, const uint4 *, const half *, const half *, half *, size_t, size_t, size_t, float *, int) [with EXPONENT=3, MANTISSA=2]" at line 176 of /teamspace/studios/this_studio/ao/torchao/csrc/cuda/fp6_llm/fp6_linear.cu
  
  2 errors detected in the compilation of "/teamspace/studios/this_studio/ao/torchao/csrc/cuda/fp6_llm/fp6_linear.cu".
  ninja: build stopped: subcommand failed.
  Traceback (most recent call last):
    File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/cpp_extension.py", line 2105, in _run_ninja_build
      subprocess.run(
    File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/subprocess.py", line 571, in run
      raise CalledProcessError(retcode, process.args,
  subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.
  
  The above exception was the direct cause of the following exception:
  
  Traceback (most recent call last):
    File "/teamspace/studios/this_studio/ao/setup.py", line 126, in <module>
      setup(
    File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/setuptools/__init__.py", line 108, in setup
      return distutils.core.setup(**attrs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/setuptools/_distutils/core.py", line 184, in setup
      return run_commands(dist)
             ^^^^^^^^^^^^^^^^^^
    File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/setuptools/_distutils/core.py", line 200, in run_commands
      dist.run_commands()
    File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/setuptools/_distutils/dist.py", line 970, in run_commands
      self.run_command(cmd)
    File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/setuptools/dist.py", line 945, in run_command
      super().run_command(command)
    File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/setuptools/_distutils/dist.py", line 989, in run_command
      cmd_obj.run()
    File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/setuptools/command/develop.py", line 36, in run
      self.install_for_development()
    File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/setuptools/command/develop.py", line 113, in install_for_development
      self.run_command('build_ext')
    File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/setuptools/_distutils/cmd.py", line 316, in run_command
      self.distribution.run_command(command)
    File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/setuptools/dist.py", line 945, in run_command
      super().run_command(command)
    File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/setuptools/_distutils/dist.py", line 989, in run_command
      cmd_obj.run()
    File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/setuptools/command/build_ext.py", line 93, in run
      _build_ext.run(self)
    File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/setuptools/_distutils/command/build_ext.py", line 359, in run
      self.build_extensions()
    File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/cpp_extension.py", line 866, in build_extensions
      build_ext.build_extensions(self)
    File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/setuptools/_distutils/command/build_ext.py", line 479, in build_extensions
      self._build_extensions_serial()
    File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/setuptools/_distutils/command/build_ext.py", line 505, in _build_extensions_serial
      self.build_extension(ext)
    File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/setuptools/command/build_ext.py", line 254, in build_extension
      _build_ext.build_extension(self, ext)
    File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/setuptools/_distutils/command/build_ext.py", line 560, in build_extension
      objects = self.compiler.compile(
                ^^^^^^^^^^^^^^^^^^^^^^
    File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/cpp_extension.py", line 679, in unix_wrap_ninja_compile
      _write_ninja_file_and_compile_objects(
    File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/cpp_extension.py", line 1785, in _write_ninja_file_and_compile_objects
      _run_ninja_build(
    File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/cpp_extension.py", line 2121, in _run_ninja_build
      raise RuntimeError(message) from e
  RuntimeError: Error compiling objects for extension

Both tensor_core_tiled_layout.cu and fp6_llm/quant_llm require at least Ampere (CC 8.0+). We can add a compile guard for the kernels.

For testing, we can use T4 (Turing, CC 7.5) from Google Colab or Lightning studio.

Related: #288

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions