1414
1515static inline __device__ int8_t float_to_int8_rn (float x) {
1616#ifdef USE_ROCM
17- static const float i8_min =
17+ static constexpr auto i8_min =
1818 static_cast <float >(std::numeric_limits<int8_t >::min ());
19- static const float i8_max =
19+ static constexpr auto i8_max =
2020 static_cast <float >(std::numeric_limits<int8_t >::max ());
21- // round
21+
22+ // To match the rounding mode of CUDA, we use nearbyint.
23+ // It uses the current rounding mode, which is always FE_TONEAREST on HIP.
24+ // If that changes in the future, we may need to set the rounding mode
25+ // explicitly, either at runtime or compile time.
2226 float dst = std::nearbyint (x);
27+
2328 // saturate
2429 dst = std::clamp (dst, i8_min, i8_max);
2530 return static_cast <int8_t >(dst);
@@ -31,6 +36,59 @@ static inline __device__ int8_t float_to_int8_rn(float x) {
3136#endif
3237}
3338
39+ static inline __device__ int32_t float_to_int32_rn (float x) {
40+ #ifdef USE_ROCM
41+ // int32_max is not exactly representable as float.
42+ // Therefore, we need to be careful and manually return int32_max on overflow.
43+ // For symmetry, we also do the same for int32_min, even though it is exactly
44+ // representable as float and the conversion should be exact.
45+ static constexpr auto i32_min = std::numeric_limits<int32_t >::min ();
46+ static constexpr auto i32_min_f = static_cast <float >(i32_min);
47+ static constexpr auto i32_max = std::numeric_limits<int32_t >::max ();
48+ static constexpr auto i32_max_f = static_cast <float >(i32_max);
49+
50+ // To match the rounding mode of CUDA, we use nearbyint.
51+ // It uses the current rounding mode, which is always FE_TONEAREST on HIP.
52+ // If that changes in the future, we may need to set the rounding mode
53+ // explicitly, either at runtime or compile time.
54+ float dst = std::nearbyint (x);
55+
56+ // saturate on the higher end.
57+ if (dst >= i32_max_f) {
58+ return i32_max;
59+ }
60+ // saturate on the lower end.
61+ if (dst <= i32_min_f) {
62+ return i32_min;
63+ }
64+
65+ return static_cast <int32_t >(dst);
66+ #else
67+ // CUDA path
68+ uint32_t dst;
69+ asm volatile (" cvt.rni.sat.s32.f32 %0, %1;" : " =r" (dst) : " f" (x));
70+ return reinterpret_cast <const int32_t &>(dst);
71+ #endif
72+ }
73+
74+ static inline __device__ int8_t int32_to_int8 (int32_t x) {
75+ #ifdef USE_ROCM
76+ static constexpr auto i8_min =
77+ static_cast <int32_t >(std::numeric_limits<int8_t >::min ());
78+ static constexpr auto i8_max =
79+ static_cast <int32_t >(std::numeric_limits<int8_t >::max ());
80+
81+ // saturate
82+ int32_t dst = std::clamp (x, i8_min, i8_max);
83+ return static_cast <int8_t >(dst);
84+ #else
85+ // CUDA path
86+ uint32_t dst;
87+ asm volatile (" cvt.sat.s8.s32 %0, %1;" : " =r" (dst) : " r" (x));
88+ return reinterpret_cast <const int8_t &>(dst);
89+ #endif
90+ }
91+
3492namespace vllm {
3593
3694template <typename scalar_t , typename scale_type>
@@ -47,6 +105,23 @@ __global__ void static_scaled_int8_quant_kernel(
47105 }
48106}
49107
108+ template <typename scalar_t , typename scale_type, typename azp_type>
109+ __global__ void static_scaled_int8_azp_quant_kernel (
110+ scalar_t const * __restrict__ input, int8_t * __restrict__ out,
111+ scale_type const * scale_ptr, azp_type const * azp_ptr,
112+ const int hidden_size) {
113+ int const tid = threadIdx .x ;
114+ int const token_idx = blockIdx .x ;
115+ scale_type const scale = *scale_ptr;
116+ azp_type const azp = *azp_ptr;
117+
118+ for (int i = tid; i < hidden_size; i += blockDim .x ) {
119+ auto const val = static_cast <float >(input[token_idx * hidden_size + i]);
120+ auto const quant_val = int32_to_int8 (float_to_int32_rn (val / scale) + azp);
121+ out[token_idx * hidden_size + i] = quant_val;
122+ }
123+ }
124+
50125template <typename scalar_t , typename scale_type>
51126__global__ void dynamic_scaled_int8_quant_kernel (
52127 scalar_t const * __restrict__ input, int8_t * __restrict__ out,
@@ -80,14 +155,68 @@ __global__ void dynamic_scaled_int8_quant_kernel(
80155 }
81156}
82157
158+ template <typename scalar_t , typename scale_type, typename azp_type>
159+ __global__ void dynamic_scaled_int8_azp_quant_kernel (
160+ scalar_t const * __restrict__ input, int8_t * __restrict__ out,
161+ scale_type* scale, azp_type* azp, const int hidden_size) {
162+ int const token_idx = blockIdx .x ;
163+
164+ // Scan for the min and max value for this token
165+ float max_val = std::numeric_limits<float >::min ();
166+ float min_val = std::numeric_limits<float >::max ();
167+ for (int i = threadIdx .x ; i < hidden_size; i += blockDim .x ) {
168+ auto val = static_cast <float >(input[token_idx * hidden_size + i]);
169+ max_val = std::max (max_val, val);
170+ min_val = std::min (min_val, val);
171+ }
172+
173+ // Reduce the max and min values across the block
174+ using BlockReduce = cub::BlockReduce<float , 1024 >;
175+ __shared__ typename BlockReduce::TempStorage reduceStorage;
176+ max_val = BlockReduce (reduceStorage).Reduce (max_val, cub::Max{}, blockDim .x );
177+ __syncthreads (); // Make sure min doesn't mess with max shared memory
178+ min_val = BlockReduce (reduceStorage).Reduce (min_val, cub::Min{}, blockDim .x );
179+
180+ __shared__ scale_type scale_sh;
181+ __shared__ azp_type azp_sh;
182+
183+ // Compute the scale and zero point and store them, only on the first thread
184+ if (threadIdx .x == 0 ) {
185+ float const scale_val = (max_val - min_val) / 255 .0f ;
186+ // Use rounding to even (same as torch.round)
187+ auto const azp_float = std::nearbyint (-128 .0f - min_val / scale_val);
188+ auto const azp_val = static_cast <azp_type>(azp_float);
189+
190+ // Store the scale and azp into shared and global
191+ scale[token_idx] = scale_sh = scale_val;
192+ azp[token_idx] = azp_sh = azp_val;
193+ }
194+
195+ // Wait for the scale and azp to be computed
196+ __syncthreads ();
197+
198+ float const scale_val = scale_sh;
199+ azp_type const azp_val = azp_sh;
200+
201+ // Quantize the values
202+ for (int i = threadIdx .x ; i < hidden_size; i += blockDim .x ) {
203+ auto const val = static_cast <float >(input[token_idx * hidden_size + i]);
204+ auto const quant_val =
205+ int32_to_int8 (float_to_int32_rn (val / scale_val) + azp_val);
206+ out[token_idx * hidden_size + i] = quant_val;
207+ }
208+ }
209+
83210} // namespace vllm
84211
85212void static_scaled_int8_quant (torch::Tensor& out, // [..., hidden_size]
86213 torch::Tensor const & input, // [..., hidden_size]
87- torch::Tensor const & scale) {
214+ torch::Tensor const & scale,
215+ c10::optional<torch::Tensor> const & azp) {
88216 TORCH_CHECK (input.is_contiguous ());
89217 TORCH_CHECK (out.is_contiguous ());
90218 TORCH_CHECK (scale.numel () == 1 );
219+ TORCH_CHECK (!azp || azp->numel () == 1 );
91220
92221 int const hidden_size = input.size (-1 );
93222 int const num_tokens = input.numel () / hidden_size;
@@ -96,19 +225,29 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
96225 const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
97226 VLLM_DISPATCH_FLOATING_TYPES (
98227 input.scalar_type (), " static_scaled_int8_quant_kernel" , [&] {
99- vllm::static_scaled_int8_quant_kernel<scalar_t , float >
100- <<<grid, block, 0 , stream>>> (input.data_ptr <scalar_t >(),
101- out.data_ptr <int8_t >(),
102- scale.data_ptr <float >(), hidden_size);
228+ if (!azp) {
229+ vllm::static_scaled_int8_quant_kernel<scalar_t , float >
230+ <<<grid, block, 0 , stream>>> (
231+ input.data_ptr <scalar_t >(), out.data_ptr <int8_t >(),
232+ scale.data_ptr <float >(), hidden_size);
233+ } else {
234+ vllm::static_scaled_int8_azp_quant_kernel<scalar_t , float , int32_t >
235+ <<<grid, block, 0 , stream>>> (
236+ input.data_ptr <scalar_t >(), out.data_ptr <int8_t >(),
237+ scale.data_ptr <float >(), azp->data_ptr <int32_t >(),
238+ hidden_size);
239+ }
103240 });
104241}
105242
106243void dynamic_scaled_int8_quant (
107244 torch::Tensor& out, // [..., hidden_size]
108245 torch::Tensor const & input, // [..., hidden_size]
109- torch::Tensor& scales) {
246+ torch::Tensor& scales, c10::optional<torch::Tensor> const & azp ) {
110247 TORCH_CHECK (input.is_contiguous ());
111248 TORCH_CHECK (out.is_contiguous ());
249+ TORCH_CHECK (scales.is_contiguous ());
250+ TORCH_CHECK (!azp || azp->is_contiguous ());
112251
113252 int const hidden_size = input.size (-1 );
114253 int const num_tokens = input.numel () / hidden_size;
@@ -117,9 +256,17 @@ void dynamic_scaled_int8_quant(
117256 const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
118257 VLLM_DISPATCH_FLOATING_TYPES (
119258 input.scalar_type (), " dynamic_scaled_int8_quant_kernel" , [&] {
120- vllm::dynamic_scaled_int8_quant_kernel<scalar_t , float >
121- <<<grid, block, 0 , stream>>> (input.data_ptr <scalar_t >(),
122- out.data_ptr <int8_t >(),
123- scales.data_ptr <float >(), hidden_size);
259+ if (!azp) {
260+ vllm::dynamic_scaled_int8_quant_kernel<scalar_t , float >
261+ <<<grid, block, 0 , stream>>> (
262+ input.data_ptr <scalar_t >(), out.data_ptr <int8_t >(),
263+ scales.data_ptr <float >(), hidden_size);
264+ } else {
265+ vllm::dynamic_scaled_int8_azp_quant_kernel<scalar_t , float , int32_t >
266+ <<<grid, block, 0 , stream>>> (
267+ input.data_ptr <scalar_t >(), out.data_ptr <int8_t >(),
268+ scales.data_ptr <float >(), azp->data_ptr <int32_t >(),
269+ hidden_size);
270+ }
124271 });
125272}
0 commit comments