-
Notifications
You must be signed in to change notification settings - Fork 317
Description
Recently we landed #939 to support tensor parallelism for int8 weight only quantization, another example: #785
now we can support tensor parallelism for other types of quantization as well.
- float8 weight only @jainapurva - Add Float8 support for AQT tensor parallel #1003
- float8 dynamic activation @jainapurva - Add AQT tensor parallel for float8_dynamic_quant #1078
- uintx weight only @melvinebenezer
- int4 weight only quant - @jerryzh168 Add tensor parallelism support for int4_weight_only quantization #1120
- int8 dynamic act + int8 weight - @jainapurva Add int8 dynamic activation + int8 weight only test to TensorParallel #1657
- fpx -
Steps
1. Create test
Since we don't have many tests today, we can optimize for readability for now, so we can copy paste the test cases to a https://github.com/pytorch/ao/blob/main/test/dtypes/test_affine_quantized_tensor_parallel.py instead of inheriting from these test cases
For new tests you can follow
ao/test/dtypes/test_affine_quantized_tensor_parallel.py
Lines 133 to 153 in c87cc9b
class TestFloat8dqTensorAffineQuantizedTensorParallel(TestFloat8dqAffineQuantizedTensorParallel): | |
QUANT_METHOD_FN = staticmethod(float8_dynamic_activation_float8_weight) | |
QUANT_METHOD_KWARGS = {"granularity": PerTensor()} | |
COMMON_DTYPES = [torch.bfloat16, torch.float16, torch.float32] | |
@common_utils.parametrize("dtype", COMMON_DTYPES) | |
@with_comms | |
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") | |
def test_tp(self, dtype): | |
return self._test_tp(dtype) | |
class TestFloat8dqRowAffineQuantizedTensorParallel(TestFloat8dqAffineQuantizedTensorParallel): | |
QUANT_METHOD_FN = staticmethod(float8_dynamic_activation_float8_weight) | |
QUANT_METHOD_KWARGS = {"granularity": PerRow()} | |
COMMON_DTYPES = [torch.bfloat16] | |
@common_utils.parametrize("dtype", COMMON_DTYPES) | |
@with_comms | |
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") | |
def test_tp(self, dtype): | |
return self._test_tp(dtype) |
2. Run the test
python test/dtypes/test_affine_quantized_tensor_parallel.py
3. Add support for missing ops until test passes
We'd expect people to add some slicing ops etc. to the corresponding TensorImpl tensor subclass