Skip to content

Commit 3a5517a

Browse files
committed
fix test_ops
1 parent 6a4430d commit 3a5517a

File tree

2 files changed

+14
-14
lines changed

2 files changed

+14
-14
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ The best example we have combining the composability of lower bit dtype with com
128128

129129
We've added support for authoring and releasing [custom ops](./torchao/csrc/) that do not graph break with `torch.compile()` so if you love writing kernels but hate packaging them so they work all operating systems and cuda versions, we'd love to accept contributions for your custom ops. We have a few examples you can follow
130130

131-
1. [fp6](torchao/prototype/quant_llm/) for 2x faster inference over fp16 with an easy to use API `quantize_(model, fp6_llm_weight_only())`
131+
1. [fp6](torchao/dtypes/floatx) for 2x faster inference over fp16 with an easy to use API `quantize_(model, fpx_weight_only(3, 2))`
132132
2. [2:4 Sparse Marlin GEMM](https://github.com/pytorch/ao/pull/733) 2x speedups for FP16xINT4 kernels even at batch sizes up to 256
133133
3. [int4 tinygemm unpacker](https://github.com/pytorch/ao/pull/415) which makes it easier to switch quantized backends for inference
134134

test/test_ops.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
)
1212
from torch.testing._internal.optests import opcheck
1313
from torchao.utils import is_fbcode, TORCH_VERSION_AT_LEAST_2_5, compute_max_diff
14-
from torchao.dtypes.fpx import from_scaled_tc_fpx
14+
from torchao.dtypes.floatx import from_scaled_tc_floatx
1515
from torchao.sparsity.marlin import marlin_24_workspace, pack_to_marlin_24, inject_24
1616
import pytest
1717

@@ -33,13 +33,13 @@
3333

3434

3535
class TestOps(TestCase):
36-
def _create_fpx_inputs(self, ebits: int, mbits: int, BS: int, OC: int, IC: int, device):
36+
def _create_floatx_inputs(self, ebits: int, mbits: int, BS: int, OC: int, IC: int, device):
3737
# Randomly initialize each byte
3838
nbits = 1 + ebits + mbits
39-
fpx_weight = torch.randint(256, (OC, IC // 8 * nbits), dtype=torch.uint8)
39+
floatx_weight = torch.randint(256, (OC, IC // 8 * nbits), dtype=torch.uint8)
4040
scale = torch.rand(OC).half() + 0.5
4141
fp16_act = torch.rand(BS, IC).half() + 0.5
42-
return fpx_weight.to(device), scale.to(device), fp16_act.to(device)
42+
return floatx_weight.to(device), scale.to(device), fp16_act.to(device)
4343

4444
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
4545
@parametrize("ebits,mbits", [(3, 2), (2, 2)])
@@ -48,28 +48,28 @@ def test_quant_llm_linear(self, ebits, mbits):
4848
OC = 256
4949
IC = 256
5050
splitK = 1
51-
fpx_weight, scale, fp16_act = self._create_fpx_inputs(ebits, mbits, BS, OC, IC, "cuda")
51+
floatx_weight, scale, fp16_act = self._create_floatx_inputs(ebits, mbits, BS, OC, IC, "cuda")
5252

5353
# smoke test
54-
torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, fpx_weight, scale, splitK)
54+
torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, floatx_weight, scale, splitK)
5555

5656
# comprehensive testing
5757
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"]
58-
opcheck(torch.ops.torchao.quant_llm_linear, (ebits, mbits, fp16_act, fpx_weight, scale, splitK), test_utils=test_utils)
58+
opcheck(torch.ops.torchao.quant_llm_linear, (ebits, mbits, fp16_act, floatx_weight, scale, splitK), test_utils=test_utils)
5959

6060
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
6161
@parametrize("BS,OC,IC,splitK", [(1, 2048, 4096, 5), (2, 8192, 8192, 6)])
6262
@parametrize("ebits,mbits", [(3, 2), (2, 2)])
6363
def test_quant_llm_linear_correctness(self, ebits, mbits, BS, OC, IC, splitK):
6464
# adapted from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/tests/python/kernel_test_fpx.py
65-
fpx_weight, scale, fp16_act = self._create_fpx_inputs(ebits, mbits, BS, OC, IC, "cuda")
65+
floatx_weight, scale, fp16_act = self._create_floatx_inputs(ebits, mbits, BS, OC, IC, "cuda")
6666

67-
results_fpx = torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, fpx_weight, scale, splitK)
67+
results_floatx = torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, floatx_weight, scale, splitK)
6868

69-
fp16_weight = from_scaled_tc_fpx(fpx_weight, ebits, mbits, scale).half()
69+
fp16_weight = from_scaled_tc_floatx(floatx_weight, ebits, mbits, scale).half()
7070
results_fp16 = fp16_act @ fp16_weight.T
7171

72-
error = (results_fpx - results_fp16).abs().mean()
72+
error = (results_floatx - results_fp16).abs().mean()
7373
gt = results_fp16.abs().mean()
7474
relative_error = error / gt
7575
assert relative_error < 1e-3
@@ -319,7 +319,7 @@ def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size
319319
MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
320320

321321
MARLIN_TEST_PARAMS = list(itertools.product(
322-
MARLIN_24_BATCH_SIZE, MARLIN_24_K_CHUNKS, MARLIN_24_N_CHUNKS,
322+
MARLIN_24_BATCH_SIZE, MARLIN_24_K_CHUNKS, MARLIN_24_N_CHUNKS,
323323
MARLIN_24_SUPPORTED_NUM_BITS, MARLIN_24_SUPPORTED_GROUP_SIZES, MNK_FACTORS
324324
))
325325

@@ -405,7 +405,7 @@ def test_marlin_24(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_facto
405405
workspace_24 = marlin_24_workspace(size_n)
406406

407407
fn_inputs = (
408-
input_2d, marlin_24_q_w_comp, meta, marlin_24_scale, workspace_24,
408+
input_2d, marlin_24_q_w_comp, meta, marlin_24_scale, workspace_24,
409409
num_bits, a_input_in, marlin_24_scale.shape[1], a_input_out,
410410
)
411411
output = torchao.ops.marlin_24_gemm(*fn_inputs)

0 commit comments

Comments
 (0)