Skip to content

Commit dc1a233

Browse files
committed
Merge remote-tracking branch 'origin/main' into fp8_linear_quantize
2 parents a73180d + ed76e9c commit dc1a233

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+1472
-280
lines changed

.github/workflows/regression_test.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ jobs:
7070
torch-spec: 'torch==2.5.1 --index-url https://download.pytorch.org/whl/cu121'
7171
gpu-arch-type: "cuda"
7272
gpu-arch-version: "12.1"
73+
7374
- name: CPU 2.3
7475
runs-on: linux.4xlarge
7576
torch-spec: 'torch==2.3.0 --index-url https://download.pytorch.org/whl/cpu'

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ If you find the torchao library useful, please cite it in your work as below.
201201
@software{torchao,
202202
title = {torchao: PyTorch native quantization and sparsity for training and inference},
203203
author = {torchao maintainers and contributors},
204-
url = {https//github.com/pytorch/torchao},
204+
url = {https://github.com/pytorch/torchao},
205205
license = {BSD-3-Clause},
206206
month = oct,
207207
year = {2024}

benchmarks/benchmark_low_bit_adam.py

Lines changed: 63 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# - lpmm (4-bit optim): pip install yacs git+https://github.com/thu-ml/low-bit-optimizers.git
55
# - DeepSpeed (ZeRO-Offload):
66
# sudo apt install libopenmpi-dev
7-
# LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu pip install mpi4p
7+
# LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu pip install mpi4py
88
# DS_BUILD_CPU_ADAM=1 pip install deepspeed --no-cache-dir
99
#
1010
# To fine-tune a pre-trained ViT-Base on resisc45 dataset with BF16 AMP, using default AdamW optimizer from PyTorch core
@@ -31,11 +31,15 @@
3131
import torch.nn.functional as F
3232
import wandb
3333
from torch.utils.data import DataLoader
34+
from torchao.utils import get_available_devices
3435
from torchvision.transforms import v2
3536
from tqdm import tqdm
3637

3738
from torchao.prototype import low_bit_optim
3839

40+
_DEVICE = get_available_devices()[-1]
41+
assert _DEVICE in ["cuda", "xpu"], "Benchmark currently only supports CUDA & XPU(BF16)"
42+
3943
OPTIM_MAP = dict(
4044
AdamW=partial(torch.optim.AdamW, fused=True),
4145
AdamW8bitBnb=bnb.optim.AdamW8bit,
@@ -49,7 +53,9 @@
4953

5054
OPTIM_MAP.update(
5155
AdamW4bitLpmm=partial(lpmm.optim.AdamW, fused=True),
52-
AdamW4bitRank1Lpmm=partial(lpmm.optim.AdamW, qconfig=argparse.Namespace(scale_type="rank1")),
56+
AdamW4bitRank1Lpmm=partial(
57+
lpmm.optim.AdamW, qconfig=argparse.Namespace(scale_type="rank1")
58+
),
5359
)
5460

5561
except ImportError:
@@ -67,8 +73,12 @@ def get_lr(self, step: int) -> float:
6773
if step < self.warmup_steps:
6874
return self.lr * step / self.warmup_steps
6975
if step < self.total_steps:
70-
progress = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
71-
return self.final_lr + 0.5 * (self.lr - self.final_lr) * (1 + math.cos(progress * math.pi))
76+
progress = (step - self.warmup_steps) / (
77+
self.total_steps - self.warmup_steps
78+
)
79+
return self.final_lr + 0.5 * (self.lr - self.final_lr) * (
80+
1 + math.cos(progress * math.pi)
81+
)
7282
return self.final_lr
7383

7484

@@ -92,7 +102,9 @@ def get_parser():
92102
parser.add_argument("--weight_decay", type=float, default=0)
93103
parser.add_argument("--optim_kwargs", type=json.loads, default=dict())
94104
parser.add_argument("--cosine_lr_scheduler", action="store_true")
95-
parser.add_argument("--optim_cpu_offload", choices=["ao", "ao_offload_grads", "deepspeed"])
105+
parser.add_argument(
106+
"--optim_cpu_offload", choices=["ao", "ao_offload_grads", "deepspeed"]
107+
)
96108

97109
parser.add_argument("--project")
98110
parser.add_argument("--run_name", default="debug")
@@ -110,11 +122,15 @@ def get_dloader(args, training: bool):
110122
transforms.extend([v2.Resize(256), v2.CenterCrop(224)])
111123

112124
transforms.append(v2.ToDtype(torch.float32, scale=True))
113-
transforms.append(v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
125+
transforms.append(
126+
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
127+
)
114128
transforms = v2.Compose(transforms)
115129

116130
# use dataset from HF so download is fast
117-
ds = datasets.load_dataset("timm/resisc45", split="train" if training else "validation")
131+
ds = datasets.load_dataset(
132+
"timm/resisc45", split="train" if training else "validation"
133+
)
118134
ds = ds.select_columns(["image", "label"])
119135
ds.set_transform(lambda x: dict(image=transforms(x["image"]), label=x["label"]))
120136

@@ -128,9 +144,9 @@ def get_dloader(args, training: bool):
128144
)
129145

130146

131-
def get_amp_ctx(amp):
147+
def get_amp_ctx(amp, device):
132148
dtype = dict(bf16=torch.bfloat16, fp16=torch.float16, none=None)[amp]
133-
return torch.autocast("cuda", dtype=dtype, enabled=amp != "none")
149+
return torch.autocast(device, dtype=dtype, enabled=amp != "none")
134150

135151

136152
@torch.no_grad()
@@ -148,8 +164,8 @@ def evaluate_model(model, args):
148164
if args.channels_last:
149165
batch["image"] = batch["image"].to(memory_format=torch.channels_last)
150166

151-
with get_amp_ctx(args.amp):
152-
all_preds.append(model(batch["image"].cuda()).argmax(1).cpu())
167+
with get_amp_ctx(args.amp, _DEVICE):
168+
all_preds.append(model(batch["image"].to(_DEVICE)).argmax(1).cpu())
153169

154170
all_labels = torch.cat(all_labels, dim=0)
155171
all_preds = torch.cat(all_preds, dim=0)
@@ -164,8 +180,12 @@ def evaluate_model(model, args):
164180
if args.full_bf16:
165181
assert args.amp == "none", "When --full_bf16 is set, --amp must be none"
166182
if args.optim_cpu_offload == "deepspeed":
167-
assert args.amp == "none", "When using DeepSpeed ZeRO-Offload, --amp must be none"
168-
assert args.optim == "AdamW", "When using DeepSpeed ZeRO-Offload, --optim must be AdamW"
183+
assert (
184+
args.amp == "none"
185+
), "When using DeepSpeed ZeRO-Offload, --amp must be none"
186+
assert (
187+
args.optim == "AdamW"
188+
), "When using DeepSpeed ZeRO-Offload, --optim must be AdamW"
169189
if args.profile:
170190
args.n_epochs = 1
171191
if args.seed is not None:
@@ -185,14 +205,16 @@ def evaluate_model(model, args):
185205
dloader = get_dloader(args, True)
186206
print(f"Train dataset: {len(dloader.dataset):,} images")
187207

188-
model = timm.create_model(args.model, pretrained=True, num_classes=45, **args.model_kwargs)
208+
model = timm.create_model(
209+
args.model, pretrained=True, num_classes=45, **args.model_kwargs
210+
)
189211
if args.checkpoint_activations:
190212
model.set_grad_checkpointing()
191213
if args.full_bf16:
192214
model.bfloat16()
193215
if args.channels_last:
194216
model.to(memory_format=torch.channels_last)
195-
model.cuda() # move model to CUDA after optionally convert it to BF16
217+
model.to(_DEVICE) # move model to DEVICE after optionally convert it to BF16
196218
if args.compile:
197219
model.compile(fullgraph=True)
198220
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
@@ -227,9 +249,15 @@ def evaluate_model(model, args):
227249
optim_cls = OPTIM_MAP[args.optim]
228250

229251
if args.optim_cpu_offload == "ao":
230-
optim_cls = partial(low_bit_optim.CPUOffloadOptimizer, optimizer_class=optim_cls)
252+
optim_cls = partial(
253+
low_bit_optim.CPUOffloadOptimizer, optimizer_class=optim_cls
254+
)
231255
elif args.optim_cpu_offload == "ao_offload_grads":
232-
optim_cls = partial(low_bit_optim.CPUOffloadOptimizer, optimizer_class=optim_cls, offload_gradients=True)
256+
optim_cls = partial(
257+
low_bit_optim.CPUOffloadOptimizer,
258+
optimizer_class=optim_cls,
259+
offload_gradients=True,
260+
)
233261

234262
optim = optim_cls(
235263
model.parameters(),
@@ -239,24 +267,30 @@ def evaluate_model(model, args):
239267
)
240268

241269
lr_schedule = CosineSchedule(args.lr, len(dloader) * args.n_epochs)
242-
grad_scaler = torch.amp.GradScaler("cuda", enabled=args.amp == "fp16")
270+
grad_scaler = torch.amp.GradScaler(_DEVICE, enabled=args.amp == "fp16")
243271
log_interval = 10
244272
t0 = time.perf_counter()
245273

246274
step = 0
247275
for epoch_idx in range(args.n_epochs):
248276
model.train()
249-
pbar = tqdm(dloader, dynamic_ncols=True, desc=f"Epoch {epoch_idx + 1}/{args.n_epochs}")
277+
pbar = tqdm(
278+
dloader, dynamic_ncols=True, desc=f"Epoch {epoch_idx + 1}/{args.n_epochs}"
279+
)
250280

251281
with torch.profiler.profile() if args.profile else nullcontext() as prof:
252282
for batch in pbar:
253283
if args.full_bf16:
254284
batch["image"] = batch["image"].bfloat16()
255285
if args.channels_last:
256-
batch["image"] = batch["image"].to(memory_format=torch.channels_last)
286+
batch["image"] = batch["image"].to(
287+
memory_format=torch.channels_last
288+
)
257289

258-
with get_amp_ctx(args.amp):
259-
loss = F.cross_entropy(model(batch["image"].cuda()), batch["label"].cuda())
290+
with get_amp_ctx(args.amp, _DEVICE):
291+
loss = F.cross_entropy(
292+
model(batch["image"].to(_DEVICE)), batch["label"].to(_DEVICE)
293+
)
260294

261295
if args.optim_cpu_offload == "deepspeed":
262296
model.backward(loss)
@@ -275,7 +309,9 @@ def evaluate_model(model, args):
275309
log_dict = dict(loss=loss.item(), lr=optim.param_groups[0]["lr"])
276310
if step > 0:
277311
t1 = time.perf_counter()
278-
log_dict["imgs_per_second"] = args.batch_size * log_interval / (t1 - t0)
312+
log_dict["imgs_per_second"] = (
313+
args.batch_size * log_interval / (t1 - t0)
314+
)
279315
t0 = t1
280316
logger.log(log_dict, step=step)
281317

@@ -296,9 +332,11 @@ def evaluate_model(model, args):
296332

297333
else:
298334
val_acc = evaluate_model(model, args)
299-
print(f"Epoch {epoch_idx + 1}/{args.n_epochs}: val_acc={val_acc.item() * 100:.2f}")
335+
print(
336+
f"Epoch {epoch_idx + 1}/{args.n_epochs}: val_acc={val_acc.item() * 100:.2f}"
337+
)
300338
logger.log(dict(val_acc=val_acc), step=step)
301339

302-
peak_mem = torch.cuda.max_memory_allocated() / 1e9
340+
peak_mem = getattr(torch, _DEVICE).max_memory_allocated() / 1e9
303341
print(f"Max memory used: {peak_mem:.02f} GB")
304342
logger.log(dict(max_memory_allocated=peak_mem))

examples/sam2_amg_server/cli.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55
from server import show_anns
66
from server import model_type_to_paths
77
from server import MODEL_TYPES_TO_MODEL
8+
from server import set_fast
9+
from server import set_aot_fast
10+
from server import load_aot_fast
11+
from server import set_furious
812
from torchao._models.sam2.build_sam import build_sam2
913
from torchao._models.sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
1014
from torchao._models.sam2.utils.amg import rle_to_mask
@@ -19,19 +23,28 @@ def main_docstring():
1923
output_path (str): Path to output image
2024
"""
2125

22-
def main(checkpoint_path, model_type, input_path, output_path, points_per_batch=1024, output_format='png', verbose=False):
26+
27+
def main_headless(checkpoint_path, model_type, input_bytes, points_per_batch=1024, output_format='png', verbose=False, fast=False, furious=False, load_fast=""):
2328
device = "cuda"
2429
sam2_checkpoint, model_cfg = model_type_to_paths(checkpoint_path, model_type)
2530
if verbose:
2631
print(f"Loading model {sam2_checkpoint} with config {model_cfg}")
2732
sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False)
2833
mask_generator = SAM2AutomaticMaskGenerator(sam2, points_per_batch=points_per_batch, output_mode="uncompressed_rle")
29-
image_tensor = file_bytes_to_image_tensor(bytearray(open(input_path, 'rb').read()))
34+
if furious:
35+
set_furious(mask_generator)
36+
if load_fast:
37+
load_aot_fast(mask_generator, load_fast)
38+
if fast:
39+
set_fast(mask_generator, load_fast)
40+
41+
image_tensor = file_bytes_to_image_tensor(input_bytes)
3042
if verbose:
3143
print(f"Loaded image of size {tuple(image_tensor.shape)} and generating mask.")
3244
masks = mask_generator.generate(image_tensor)
33-
34-
# Save an example
45+
46+
if verbose:
47+
print("Generating mask annotations for input image.")
3548
plt.figure(figsize=(image_tensor.shape[1]/100., image_tensor.shape[0]/100.), dpi=100)
3649
plt.imshow(image_tensor)
3750
show_anns(masks, rle_to_mask)
@@ -40,8 +53,21 @@ def main(checkpoint_path, model_type, input_path, output_path, points_per_batch=
4053
buf = BytesIO()
4154
plt.savefig(buf, format=output_format)
4255
buf.seek(0)
56+
return buf.getvalue()
57+
58+
def main(checkpoint_path, model_type, input_path, output_path, points_per_batch=1024, output_format='png', verbose=False, fast=False, furious=False, load_fast=""):
59+
input_bytes = bytearray(open(input_path, 'rb').read())
60+
output_bytes = main_headless(checkpoint_path,
61+
model_type,
62+
input_bytes,
63+
points_per_batch=points_per_batch,
64+
output_format=output_format,
65+
verbose=verbose,
66+
fast=fast,
67+
furious=furious,
68+
load_fast=load_fast)
4369
with open(output_path, "wb") as file:
44-
file.write(buf.getvalue())
70+
file.write(output_bytes)
4571

4672
main.__doc__ = main_docstring()
4773
if __name__ == "__main__":

0 commit comments

Comments
 (0)