@@ -278,31 +278,29 @@ async def serve_inner(
278278
279279 if quantize == "gptq" and deployment_framework == "tgis_native" :
280280 from text_generation_server .utils .layers import HAS_GPTQ_CUDA , EXLLAMA_VERSION
281- if HAS_GPTQ_CUDA :
282- if EXLLAMA_VERSION is not None :
283- try :
284- # When using GPTQ, Exllama kernels need some global kernels
285- # For which we have the final shapes only after the model has loaded
286- # This will allocate those buffers.
287-
288- if EXLLAMA_VERSION == "1" :
289- from text_generation_server .utils .gptq .exllama import (
290- create_exllama_buffers , set_device ,
291- )
292- set_device (device )
293- create_exllama_buffers (max_sequence_length )
294- else :
295- assert EXLLAMA_VERSION == "2"
296- from text_generation_server .utils .gptq .exllamav2 import (
297- set_device , Ex4bitLinearV2 ,
298- )
299- set_device (device )
300- for _ , submodule in model .model .named_modules ():
301- if isinstance (submodule , Ex4bitLinearV2 ):
302- submodule .post_init () # make q matrix and set scratch space
303-
304- except ImportError :
305- print ("WARN: Error setting up GPTQ exllama buffers" )
281+ if HAS_GPTQ_CUDA and EXLLAMA_VERSION is not None :
282+ try :
283+ # When using GPTQ, Exllama kernels need some global kernels
284+ # For which we have the final shapes only after the model has loaded
285+ # This will allocate those buffers.
286+ if EXLLAMA_VERSION == "1" :
287+ from text_generation_server .utils .gptq .exllama import (
288+ create_exllama_buffers , set_device ,
289+ )
290+ set_device (device )
291+ create_exllama_buffers (max_sequence_length )
292+ elif EXLLAMA_VERSION == "2" :
293+ from text_generation_server .utils .gptq .exllamav2 import (
294+ set_device , Ex4bitLinearV2 ,
295+ )
296+ set_device (device )
297+ for _ , submodule in model .model .named_modules ():
298+ if isinstance (submodule , Ex4bitLinearV2 ):
299+ submodule .post_init () # make q matrix and set scratch space
300+ else :
301+ raise ValueError (f"Unsupported { EXLLAMA_VERSION = } " )
302+ except ImportError :
303+ print ("WARN: Error setting up GPTQ exllama buffers" )
306304
307305 if local_rank == 0 and device .type == "cuda" :
308306 # Log GPU memory stats at startup
0 commit comments