Skip to content

Commit 650fea1

Browse files
fxmartyFelix MartyOlivierDehaene
authored
GPTQ support on ROCm (#1489)
Tested with ``` CUDA_VISIBLE_DEVICES=0 text-generation-launcher --model-id TheBloke/Llama-2-7B-Chat-GPTQ --quantize gptq EXLLAMA_VERSION=1 CUDA_VISIBLE_DEVICES=0 text-generation-launcher --model-id TheBloke/Llama-2-7B-Chat-GPTQ --quantize gptq CUDA_VISIBLE_DEVICES="0,1" text-generation-launcher --model-id TheBloke/Llama-2-7B-Chat-GPTQ --quantize gptq ``` all with good and identical results on MI210. --------- Co-authored-by: Felix Marty <[email protected]> Co-authored-by: OlivierDehaene <[email protected]> Co-authored-by: OlivierDehaene <[email protected]>
1 parent ebecc06 commit 650fea1

File tree

10 files changed

+80
-22
lines changed

10 files changed

+80
-22
lines changed

.gitignore

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,13 @@
22
target
33
router/tokenizer.json
44
*__pycache__*
5+
6+
# ROCm auto-generated files
7+
*.hip
8+
server/exllamav2_kernels/exllamav2_kernels/hip/
9+
server/exllama_kernels/exllama_kernels/hip/
10+
server/exllama_kernels/exllama_kernels/hip_func/
11+
*_hip.cuh
12+
server/exllama_kernels/exllama_kernels/hip_buffers.cuh
13+
server/exllama_kernels/exllama_kernels/exllama_ext_hip.cpp
14+

Dockerfile_amd

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ RUN chmod +x ~/mambaforge.sh && \
7575
mamba init && \
7676
rm ~/mambaforge.sh
7777

78-
# Install PyTorch nightly (2.2.0.dev2023) compiled against RoCm 5.7, as VLLM can not be compiled with RoCm 5.6.
79-
RUN pip install --pre torch==2.2.0.dev20231106 --index-url https://download.pytorch.org/whl/nightly/rocm5.7
78+
# Install PyTorch 2.2 RC compiled against RoCm 5.7, as VLLM can not be compiled with RoCm 5.6.
79+
RUN pip install torch --index-url https://download.pytorch.org/whl/test/rocm5.7/
8080

8181
FROM base AS kernel-builder
8282

@@ -104,6 +104,20 @@ WORKDIR /usr/src
104104
COPY server/custom_kernels/ .
105105
RUN PYTORCH_ROCM_ARCH=gfx90a python setup.py build
106106

107+
# Build exllama kernels
108+
FROM kernel-builder as exllama-kernels-builder
109+
WORKDIR /usr/src
110+
COPY server/exllama_kernels/ .
111+
112+
RUN PYTORCH_ROCM_ARCH="gfx90a" python setup.py build
113+
114+
# Build exllama v2 kernels
115+
FROM kernel-builder as exllamav2-kernels-builder
116+
WORKDIR /usr/src
117+
COPY server/exllamav2_kernels/ .
118+
119+
RUN PYTORCH_ROCM_ARCH="gfx90a" python setup.py build
120+
107121
FROM base as base-copy
108122

109123
# Text Generation Inference base env
@@ -120,6 +134,12 @@ COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86
120134
# Copy build artifacts from custom kernels builder
121135
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
122136

137+
# Copy build artifacts from exllama kernels builder
138+
COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
139+
140+
# Copy build artifacts from exllamav2 kernels builder
141+
COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
142+
123143
# Install flash-attention dependencies
124144
RUN pip install einops --no-cache-dir
125145

docs/source/supported_models.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ text-generation-launcher --model-id <PATH-TO-LOCAL-BLOOM>
4343

4444
TGI optimized models are supported on NVIDIA [A100](https://www.nvidia.com/en-us/data-center/a100/), [A10G](https://www.nvidia.com/en-us/data-center/products/a10-gpu/) and [T4](https://www.nvidia.com/en-us/data-center/tesla-t4/) GPUs with CUDA 12.2+. Note that you have to install [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) to use it. For other NVIDIA GPUs, continuous batching will still apply, but some operations like flash attention and paged attention will not be executed.
4545

46-
TGI also has support of ROCm-enabled AMD Instinct MI210 and MI250 GPUs, with paged attention and flash attention v2 support. The following features are currently not supported in the ROCm version of TGI, and the supported may be extended in the future:
47-
* Quantization (GPTQ, AWQ, etc.)
46+
TGI also has support of ROCm-enabled AMD Instinct MI210 and MI250 GPUs, with paged attention, GPTQ quantization, flash attention v2 support. The following features are currently not supported in the ROCm version of TGI, and the supported may be extended in the future:
47+
* Loading [AWQ](https://huggingface.co/docs/transformers/quantization#awq) checkpoints.
4848
* Flash [layer norm kernel](https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm)
4949
* Kernel for slinding window attention (Mistral)
5050

server/exllama_kernels/exllama_kernels/cuda_compat.cuh renamed to server/exllama_kernels/exllama_kernels/cu_compat.cuh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,12 @@ __device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
4343

4444
//
4545

46-
#if defined(__CUDA_ARCH__)
47-
#if __CUDA_ARCH__ < 700
46+
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
47+
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
4848

4949
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
5050

51-
#if __CUDA_ARCH__ < 600
51+
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
5252
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
5353
#endif
5454

server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cu

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22
#include "column_remap.cuh"
33
#include "../util.cuh"
44
#include "../matrix.cuh"
5-
#include "../cuda_compat.cuh"
5+
#include "../cu_compat.cuh"
66
#include "../cuda_buffers.cuh"
7+
#if defined(USE_ROCM)
8+
#include "../hip_compat.cuh"
9+
#endif
710

811
const int THREADS_X = 32; // Block size and thread count along columns in w and out
912
const int THREADS_Y = 1; // Block size and thread count along rows in x and out
@@ -128,7 +131,7 @@ __global__ void q4_matmul_kernel
128131

129132
if constexpr (use_half2)
130133
{
131-
half result = __hadd(acc.x, acc.y);
134+
half result = __hadd(__low2half(acc), __high2half(acc));
132135
atomicAdd(out_.item_ptr(x_row, w_column), result);
133136
}
134137
else

server/exllamav2_kernels/exllamav2_kernels/cuda/compat_gemm.cuh renamed to server/exllama_kernels/exllama_kernels/hip_compat.cuh

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,23 @@
1-
#ifndef _compat_gemm_cuh
2-
#define _compat_gemm_cuh
1+
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
32

4-
#if defined(USE_ROCM)
3+
#ifndef _hip_compat_cuh
4+
#define _hip_compat_cuh
55

6-
// For some reason this include is not present anywhere in exllama_v2 codebase, but it is required
7-
// for symbols as hipblasHalf.
8-
#include <hipblas/hipblas.h>
6+
// Workaround for a bug in hipamd, backported from upstream, this is fixed in ROCm 5.6.
7+
__device__ __forceinline__ __half __compat_hrcp(__half x) {
8+
return __half_raw{
9+
static_cast<_Float16>(__builtin_amdgcn_rcph(static_cast<__half_raw>(x).data))};
10+
}
11+
12+
__device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) {
13+
return _Float16_2{static_cast<_Float16>(__builtin_amdgcn_rcph(x.x)),
14+
static_cast<_Float16>(__builtin_amdgcn_rcph(x.y))};
15+
}
16+
17+
#define hrcp __compat_hrcp
18+
#define h2rcp __compat_h2rcp
919

20+
// Automatic conversion of hipblasHgemm doesn't convert half to hipblasHalf.
1021
__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle,
1122
hipblasOperation_t transA,
1223
hipblasOperation_t transB,
@@ -31,8 +42,10 @@ __host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t
3142
#define hipblasHgemm __compat_hipblasHgemm
3243

3344
// Previous version of PyTorch were converting to rocBLAS instead of hipBLAS.
45+
#define rocblas_handle hipblasHandle_t
3446
#define rocblas_operation_none HIPBLAS_OP_N
47+
#define rocblas_get_stream hipblasGetStream
48+
#define rocblas_set_stream hipblasSetStream
3549
#define rocblas_hgemm __compat_hipblasHgemm
36-
#endif
3750

38-
#endif
51+
#endif

server/exllama_kernels/exllama_kernels/util.cuh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88
#include <cstdint>
99
#include <cstdio>
1010

11+
#if defined(USE_ROCM)
12+
#define cudaUnspecified hipErrorUnknown
13+
#else
1114
#define cudaUnspecified cudaErrorApiFailureBase
15+
#endif
1216

1317
// React to failure on return code != cudaSuccess
1418

server/exllamav2_kernels/setup.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
11
from setuptools import setup
22
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
3+
import torch
4+
5+
extra_cuda_cflags = ["-lineinfo", "-O3"]
6+
7+
if torch.version.hip:
8+
extra_cuda_cflags += ["-DHIPBLAS_USE_HIP_HALF"]
9+
10+
extra_compile_args = {
11+
"nvcc": extra_cuda_cflags,
12+
}
313

414
setup(
515
name="exllamav2_kernels",
@@ -11,6 +21,7 @@
1121
"exllamav2_kernels/cuda/q_matrix.cu",
1222
"exllamav2_kernels/cuda/q_gemm.cu",
1323
],
24+
extra_compile_args=extra_compile_args,
1425
)
1526
],
1627
cmdclass={"build_ext": BuildExtension},

server/text_generation_server/utils/gptq/exllamav2.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
11
# Adapted from turboderp exllama: https://github.com/turboderp/exllamav2
22

3-
from logging import getLogger
4-
53
import torch
64
import torch.nn as nn
7-
import math
85

9-
logger = getLogger(__name__)
6+
from loguru import logger
107

118
try:
129
from exllamav2_kernels import make_q_matrix, gemm_half_q_half

server/text_generation_server/utils/layers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
major = 1
3434

3535
HAS_EXLLAMA = False
36-
CAN_EXLLAMA = major >= 8
36+
CAN_EXLLAMA = major >= 8 or IS_ROCM_SYSTEM
3737
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
3838
# if V2 and int(os.getenv("WORLD_SIZE", "1")) > 1:
3939
# V2 = False

0 commit comments

Comments
 (0)