Skip to content

Conversation

cyang49
Copy link
Contributor

@cyang49 cyang49 commented Mar 8, 2024

Motivation

This PR enables the use of Marlin kernel for GPTQ checkpoints. Marlin is shown to outperform Exllamav2 on Nvidia GPUs, especially for larger batch sizes.

Modifications

The code changes are mostly similar to exllamav2, except that it uses the Marlin kernel code and binding from the AutoGPTQ package instead of sourcing a separate marlin package. I adapted the QuantLinear implementation from AutoGPTQ with changes to remove codes that we don't need. Note that, my changes also enable marlin support for checkpoints that uses activation reordering (desc_act=True).

Marlin can be turned on by setting environment variable GPTQ_CUDA_TYPE=marlin.

Note that Marlin kernel only works on Nvidia GPUs with compute capability >= 8.0.

Result

[Llama-70B-4bit-128g]
Single A100x80GB, 1k context, output 512 tokens, batch size=16,

Marlin
Prefill : 12.2s, Inference time:38.57s
Exllamav2
Prefill : 9.68s, Inference time:79.7s
  • Investigations are needed as Marlin prefill appears slower.

The code needs to be more thoroughly tested both for the performance and correctness in the following scenarios:

  • Should not break fp16 logic
  • Should work for desc_act=False GPTQ checkpoints correctly with optimal performance
  • Should work for desc_act=True GPTQ checkpoints correctly with optimal performance, with slightly worse performance than the previous scenario
  • Should not break TP uses, although TP performance still needs further optimizations
  • Memory management needs extensive reviews

Related Issues

Copy link
Contributor

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @cyang49 this is awesome! I think we can merge it leaving exllama as the default while we perform the tests that you enumerated.

As mentioned before, we see some differences in outputs when using the transformers auto-gptq path with Mixtral when updating to latest auto-gptq, so we also want to check that.

Dockerfile Outdated
# numpy is required to run auto-gptq's setup.py
RUN pip install numpy
RUN DISABLE_QIGEN=1 pip wheel git+https://github.com/AutoGPTQ/AutoGPTQ@${AUTO_GPTQ_REF} --no-cache-dir --no-deps --verbose
RUN BUILD_CUDA_EXT=1 COMPILE_MARLIN=1 DISABLE_QIGEN=1 pip wheel git+https://github.com/AutoGPTQ/AutoGPTQ@${AUTO_GPTQ_REF} --no-cache-dir --no-deps --verbose
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cyang49 do you know if we would get the Marlin kernel if we pip install from the auto-gptq wheel (e.g. latest version 0.7.1)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure. I did a quick test in a local python env and got an error while trying to import marlin

>>> import autogptq_marlin_cuda
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
ImportError: /net/storage149/autofs/css22/ccyang/miniconda3-netsres/envs/vllm/lib/python3.11/site-packages/autogptq_marlin_cuda.cpython-311-x86_64-linux-gnu.so: undefined symbol: _ZN2at4_ops19empty_memory_format4callEN3c108ArrayRefINS2_6SymIntEEESt8optionalINS2_10ScalarTypeEES6_INS2_6LayoutEES6_INS2_6DeviceEES6_IbES6_INS2_12MemoryFormatEE

I can rebuild the image locally and see if it works in the TGIS container

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to AutoGPTQ readme, the package depends on torch 2.2.1+cu121.
I see that TGIS uses 2.2.0. Will that be a problem..?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can update torch too. May need to also update flash attention version but that's fine

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried rebuilding the image with prebuilt auto-gptq 0.7.1 but it doesn't work

Shard 0:
Shard 0:   File "/opt/tgis/lib/python3.11/site-packages/text_generation_server/utils/layers.py", line 192, in get_linear
Shard 0:     linear = (QuantLinear if not use_gptq_cuda else GPTQ_CUDA_LINEAR)(
Shard 0:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Shard 0:
Shard 0:   File "/opt/tgis/lib/python3.11/site-packages/text_generation_server/utils/gptq/marlin.py", line 145, in __init__
Shard 0:     self.B = autogptq_marlin_cuda.gptq_repack(qweight)
Shard 0:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Shard 0:
Shard 0: AttributeError: 'function' object has no attribute 'gptq_repack'

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the Dockerfile to pull v0.7.1 for the build-from-source

@cyang49
Copy link
Contributor Author

cyang49 commented Mar 18, 2024

BTW @njhill do we want to remove exllamav1? Is it used at all?

@njhill
Copy link
Contributor

njhill commented Mar 19, 2024

BTW @njhill do we want to remove exllamav1? Is it used at all?

@cyang49 we actually already remove the explicitly bundled kernel itself in #59 since it's already included in auto-gptq.

Or do you mean exllama.py? If so, we could, but also I guess it's not doing any harm.

@njhill njhill mentioned this pull request Mar 19, 2024
@njhill
Copy link
Contributor

njhill commented Mar 19, 2024

@cyang49 FYI @joerunde is working on updating auto-gptq, pytorch, cuda in #62, I guess we can merge this one once that's in...

@cyang49 cyang49 closed this Mar 22, 2024
njhill added a commit that referenced this pull request Mar 25, 2024
Resubmitting Marlin PR due to accidental removal

#### Motivation

This PR enables the use of Marlin kernel for GPTQ checkpoints. Marlin is
shown to outperform Exllamav2 on Nvidia GPUs, especially for larger
batch sizes.

#### Modifications

The code changes are mostly similar to exllamav2, except that it uses
the Marlin kernel code and binding from the AutoGPTQ package instead of
sourcing a separate marlin package. I adapted the QuantLinear
implementation from AutoGPTQ with changes to remove codes that we don't
need. Note that, my changes also enable marlin support for checkpoints
that uses activation reordering (`desc_act=True`).

Marlin can be turned on by setting environment variable
`GPTQ_CUDA_TYPE=marlin`.

Note that Marlin kernel only works on Nvidia GPUs with compute
capability >= 8.0.

#### Result
```
[Llama-70B-4bit-128g]
Single A100x80GB, 1k context, output 512 tokens, batch size=16,

Marlin
Prefill : 12.2s, Inference time:38.57s
Exllamav2
Prefill : 9.68s, Inference time:79.7s
```
- Investigations are needed as Marlin prefill appears slower.

The code needs to be more thoroughly tested both for the performance and
correctness in the following scenarios:
- Should not break fp16 logic
- Should work for `desc_act=False` GPTQ checkpoints correctly with
optimal performance
- Should work for `desc_act=True` GPTQ checkpoints correctly with
optimal performance, with slightly worse performance than the previous
scenario
- Should not break TP uses, although TP performance still needs further
optimizations
- Memory management needs extensive reviews

#### Related Issues

#51

---------

Signed-off-by: Chih-Chieh-Yang <[email protected]>
Signed-off-by: cyang49 <[email protected]>
Co-authored-by: Nick Hill <[email protected]>
Xaenalt pushed a commit to Xaenalt/text-generation-inference that referenced this pull request Jul 26, 2024
[pull] main from IBM:main
Xaenalt pushed a commit to Xaenalt/text-generation-inference that referenced this pull request Jul 30, 2024
Upgrade UBI version to 9.4-1181
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants