Skip to content

Commit 3ed1a46

Browse files
committed
Fix the impl for to for int4 weight only use case
Summary: Note that we can do the following right now: * initialize and quantize the model with int4_weight_only quant in cpu * move the model to cuda we'll enable this in a separate PR Test Plan: CI Reviewers: Subscribers: Tasks: Tags:
1 parent 6dd82d8 commit 3ed1a46

File tree

2 files changed

+19
-2
lines changed

2 files changed

+19
-2
lines changed

test/quantization/test_quant_api.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -624,7 +624,7 @@ def test_quantized_tensor_subclass_save_load(self):
624624

625625
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
626626
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
627-
def test_quantized_model_to_device(self):
627+
def test_int8wo_quantized_model_to_device(self):
628628
m = ToyLinearModel().eval().to(torch.bfloat16)
629629
m_copy = copy.deepcopy(m)
630630
example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cpu")
@@ -637,6 +637,23 @@ def test_quantized_model_to_device(self):
637637
cuda_res = m(*example_inputs_cuda)
638638
self.assertEqual(cuda_res.cpu(), ref)
639639

640+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
641+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
642+
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "Test currently doesn't work for 2.5+")
643+
def test_int4wo_quantized_model_to_device(self):
644+
# TODO: change initial model to "cpu"
645+
m = ToyLinearModel().eval().to(torch.bfloat16).to("cuda")
646+
m_copy = copy.deepcopy(m)
647+
example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cuda")
648+
649+
quantize_(m, int4_weight_only())
650+
ref = m(*example_inputs)
651+
652+
example_inputs_cuda = (example_inputs[0].to("cuda"),)
653+
m.to(device="cuda")
654+
cuda_res = m(*example_inputs_cuda)
655+
self.assertEqual(cuda_res.cpu(), ref)
656+
640657
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
641658
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
642659
def test_quantized_tensor_subclass_save_load_map_location(self):

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,7 @@ def from_plain(
544544
def to(self, *args, **kwargs):
545545
kwargs = self._get_to_kwargs(*args, **kwargs)
546546
device = kwargs["device"]
547-
if device != "cuda" or (isinstance(device, torch.device) and device.type != "cuda"):
547+
if device != "cuda" and (isinstance(device, torch.device) and device.type != "cuda"):
548548
raise ValueError(f"TensorCoreTiledAQTLayout is only available for cuda device, can't convert to {device}")
549549
return self.__class__(
550550
self.packed_weight.to(device),

0 commit comments

Comments
 (0)