Skip to content

Add floating point options for autoquant and add accuracy measurement #1355

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions examples/sam2_amg_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ def main(checkpoint_path,
baseline=False,
fast=False,
furious=False,
use_autoquant=False,
unittest=False,
benchmark=False,
profile=None,
Expand Down Expand Up @@ -366,13 +367,13 @@ def main(checkpoint_path,
from torchao._models.sam2.build_sam import build_sam2
from torchao._models.sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
from torchao._models.sam2.utils.amg import rle_to_mask

device = "cuda"
sam2_checkpoint, model_cfg = model_type_to_paths(checkpoint_path, model_type)

logging.info(f"Loading model {sam2_checkpoint} with config {model_cfg}")
sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False)

logging.info(f"Using {points_per_batch} points_per_batch")
mask_generator = SAM2AutomaticMaskGenerator(sam2, points_per_batch=points_per_batch, output_mode="uncompressed_rle")

Expand Down Expand Up @@ -409,6 +410,18 @@ def main(checkpoint_path,
# NOTE: Not baseline feature
mask_generator.predictor.model.sam_mask_decoder._src_dtype = torch.float16

# since autoquant is replicating what furious mode is doing, don't use these two together
elif use_autoquant:
from torchao import autoquant
from torchao.quantization import DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST
mask_generator.predictor.model = autoquant(mask_generator.predictor.model, qtensor_class_list=DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, min_sqnr=40)

mask_generator.predictor.model.image_encoder = mask_generator.predictor.model.image_encoder.to(torch.float16, min_sqnr=40)
# NOTE: Not baseline feature
mask_generator.predictor._transforms_device = mask_generator.predictor.device
torch.set_float32_matmul_precision('high')


with open('dog.jpg', 'rb') as f:
image_tensor = file_bytes_to_image_tensor(bytearray(f.read()))

Expand Down Expand Up @@ -487,7 +500,7 @@ async def upload_rle(image: UploadFile = File(...)):
await request_queue.put((image_tensor, response_future))
masks = await response_future
return masks_to_rle_dict(masks)

@app.post("/upload")
async def upload_image(image: UploadFile = File(...)):
image_tensor = file_bytes_to_image_tensor(bytearray(await image.read()))
Expand All @@ -505,7 +518,7 @@ async def upload_image(image: UploadFile = File(...)):
plt.savefig(buf, format='png')
buf.seek(0)
return StreamingResponse(buf, media_type="image/png")


# uvicorn.run(app, host=host, port=port, log_level="info")
uvicorn.run(app, host=host, port=port)
Expand Down
17 changes: 17 additions & 0 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1514,6 +1514,23 @@ def forward(self, x):
assert not isinstance(model.lin1.weight.weight, AutoQuantizableLinearWeight)
model(x_in)

@parameterized.expand(list(itertools.product(["cuda"], COMMON_DTYPES)))
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_autoquant_min_sqnr(self, device, dtype):
m, k, n = 128, 128, 128
example_input = torch.randn(m, k, device=device, dtype=dtype)
model = torch.nn.Sequential(
torch.nn.ReLU(),
torch.nn.Linear(k,n),
torch.nn.ReLU(),
).to(device).to(dtype)
out = model(example_input)
torchao.autoquant(model, min_sqnr=60)
out2 = model(example_input)
sqnr = SQNR(out, out2)
# without setting min_sqnr to 60, we get around 45-50 final sqnr
# setting min_sqnr for individual linear to be 60 allows us to achieve >= 50 final sqnr
self.assertTrue(sqnr >= 50, f"sqnr: {sqnr}")



Expand Down
2 changes: 2 additions & 0 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,8 @@ def main(
model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST, example_input=inputs)
elif "autoquant-float8" == quantization:
model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST, example_input=inputs)
if "autoquant-fp" == quantization:
model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, example_input=inputs)
else:
model = autoquant(model, manual=True, example_input=inputs)

Expand Down
2 changes: 2 additions & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from .autoquant import (
DEFAULT_AUTOQUANT_CLASS_LIST,
DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST,
DEFAULT_INT4_AUTOQUANT_CLASS_LIST,
OTHER_AUTOQUANT_CLASS_LIST,
autoquant,
Expand Down Expand Up @@ -89,6 +90,7 @@
"autoquant",
"DEFAULT_AUTOQUANT_CLASS_LIST",
"DEFAULT_INT4_AUTOQUANT_CLASS_LIST",
"DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST",
"OTHER_AUTOQUANT_CLASS_LIST",
# top level API - manual
"quantize_",
Expand Down
Loading