diff --git a/test/sparsity/test_sparse_api.py b/test/sparsity/test_sparse_api.py index a54d902f1f..9a23a49520 100644 --- a/test/sparsity/test_sparse_api.py +++ b/test/sparsity/test_sparse_api.py @@ -4,27 +4,24 @@ import torch from torch import nn - -from torchao.sparsity import ( - apply_fake_sparsity, - sparsify_, - semi_sparse_weight, -) +from torch.testing._internal import common_utils from torchao.dtypes import MarlinSparseLayoutType, SemiSparseLayoutType from torchao.quantization.quant_api import ( + int4_weight_only, int8_dynamic_activation_int8_weight, quantize_, - int4_weight_only, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 -from torch.testing._internal.common_utils import TestCase + +from torchao.sparsity import apply_fake_sparsity, semi_sparse_weight, sparsify_ +from torchao.utils import TORCH_VERSION_AFTER_2_5, TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_4 logging.basicConfig( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO ) -class TestSemiStructuredSparse(TestCase): + +class TestSemiStructuredSparse(common_utils.TestCase): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "pytorch 2.3+ feature") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @@ -37,6 +34,7 @@ def test_sparse(self): ) .half() .cuda() + .eval() ) apply_fake_sparsity(model) @@ -45,13 +43,17 @@ def test_sparse(self): sparsify_(model, semi_sparse_weight()) sparse_result = model(input) - assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3) -class TestQuantSemiSparse(TestCase): - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "pytorch 2.3+ feature") +class TestQuantSemiSparse(common_utils.TestCase): + + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "pytorch 2.5+ feature") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - def test_quant_semi_sparse(self): + @common_utils.parametrize("compile", [True, False]) + def test_quant_semi_sparse(self, compile): + torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = False + input = torch.rand((128, 128)).half().cuda() model = ( nn.Sequential( @@ -60,19 +62,27 @@ def test_quant_semi_sparse(self): ) .half() .cuda() + .eval() ) apply_fake_sparsity(model) model_copy = copy.deepcopy(model) quantize_(model_copy, int8_dynamic_activation_int8_weight()) dense_result = model_copy(input) - quantize_(model, int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType())) + quantize_( + model, + int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType()), + ) + if compile: + model = torch.compile(model) sparse_result = model(input) - assert torch.allclose(dense_result, sparse_result, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(dense_result, sparse_result, rtol=1e-2, atol=1e-2) + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "pytorch 2.5+ feature") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - def test_sparse_marlin(self): + @common_utils.parametrize("compile", [True, False]) + def test_sparse_marlin(self, compile): input = torch.rand((256, 256)).half().cuda() model = ( nn.Sequential( @@ -81,6 +91,7 @@ def test_sparse_marlin(self): ) .half() .cuda() + .eval() ) apply_fake_sparsity(model) @@ -92,9 +103,101 @@ def test_sparse_marlin(self): # Sparse + quantized quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType())) + if compile: + model = torch.compile(model) sparse_result = model(input) - assert torch.allclose(dense_result, sparse_result, atol=3e-1), "Results are not close" + torch.testing.assert_close(dense_result, sparse_result, atol=3e-1, rtol=3e-1) + + +class TestBlockSparseWeight(common_utils.TestCase): + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "pytorch 2.4+ feature due to need for custom op support") + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @common_utils.parametrize("compile", [True, False]) + def test_sparse(self, compile): + input = torch.rand((1024, 1024)).half().cuda() + model = ( + nn.Sequential( + nn.Linear(1024, 2048), + nn.Linear(2048, 1024), + ) + .half() + .cuda() + .eval() + ) + + from torchao.sparsity.utils import create_block_sparse_tensor + + M, N = model[0].weight.shape + model[0].weight.data = create_block_sparse_tensor(M, N, 64, 0.5, torch.float16) + M, N = model[1].weight.shape + model[1].weight.data = create_block_sparse_tensor(M, N, 64, 0.5, torch.float16) + dense_result = model(input) + + from torchao.sparsity.prototype.superblock.blocksparse import ( + block_sparse_weight, + ) + + sparsify_(model, block_sparse_weight(blocksize=64)) + # if compile: + # model = torch.compile(model) + sparse_result = model(input) + + torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3) + + +class TestQuantBlockSparseWeight(common_utils.TestCase): + @unittest.skipIf(not TORCH_VERSION_AFTER_2_5, "pytorch 2.6+ feature") + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @common_utils.parametrize("compile", [True, False]) + def test_sparse(self, compile): + input = torch.rand((256, 128)).to(torch.bfloat16).cuda() + model = ( + nn.Sequential( + nn.Linear(128, 256), + nn.Linear(256, 128), + ) + .to(torch.bfloat16) + .cuda() + .eval() + ) + from torchao.sparsity.prototype.superblock.blocksparse import ( + blocksparse_int_addmm, + ) + from torchao.sparsity.utils import create_block_sparse_tensor + + M, N = model[0].weight.shape + model[0].weight.data = ( + create_block_sparse_tensor(M, N, 64, 0.5, torch.bfloat16) + * torch.rand(M, N, dtype=torch.bfloat16).cuda() + ) + M, N = model[1].weight.shape + model[1].weight.data = create_block_sparse_tensor(M, N, 64, 0.5, torch.bfloat16) + + model_copy = copy.deepcopy(model) + + quantize_(model_copy, int8_dynamic_activation_int8_weight()) + reference = model_copy(input) + + from torchao.dtypes.affine_quantized_tensor import BlockSparseLayoutType + + quantize_( + model, + int8_dynamic_activation_int8_weight( + layout_type=BlockSparseLayoutType(blocksize=64) + ), + ) + if compile: + model = torch.compile(model) + sparse_result = model(input) + + torch.testing.assert_close(reference, sparse_result, rtol=1e-1, atol=1e-1) + + +common_utils.instantiate_parametrized_tests(TestSemiStructuredSparse) +common_utils.instantiate_parametrized_tests(TestQuantSemiSparse) +common_utils.instantiate_parametrized_tests(TestBlockSparseWeight) +common_utils.instantiate_parametrized_tests(TestQuantBlockSparseWeight) if __name__ == "__main__": unittest.main() diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index e00576263f..43ee82ffaa 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -47,7 +47,6 @@ from torchao.float8.inference import Float8MMConfig aten = torch.ops.aten - ############################### # Base Layout Tensor Subclass # ############################### @@ -473,6 +472,11 @@ def pre_process(self, input: torch.Tensor) -> torch.Tensor: return temp +@dataclass(frozen=True) +class BlockSparseLayoutType(LayoutType): + blocksize: int = 64 + + @dataclass(frozen=True) class TensorCoreTiledLayoutType(LayoutType): inner_k_tiles: int = 8 @@ -669,6 +673,145 @@ def from_plain( int_data_compressed = torch._cslt_compress(int_data) return cls(int_data_compressed, scale, zero_point, layout_type) +@register_layout_cls(BlockSparseLayoutType) +class BlockSparseAQTLayout(PlainAQTLayout): + bsr_crow_indices: Optional[torch.Tensor] + bsr_col_indices: Optional[torch.Tensor] + bsr_values: Optional[torch.Tensor] + scale: Optional[torch.Tensor] + zero_point: Optional[torch.Tensor] + + __slots__ = ["bsr_crow_indices", "bsr_col_indices", "bsr_values", "scale", "zero_point"] + + @staticmethod + def __new__( # noqa: PYI034 + cls, + shape: torch.Size, + bsr_crow_indices: Optional[torch.Tensor], + bsr_col_indices: Optional[torch.Tensor], + bsr_values: Optional[torch.Tensor], + scale: Optional[torch.Tensor], + zero_point: Optional[torch.Tensor], + layout_type: LayoutType, + requires_grad: bool = False, + ): + if bsr_values is None: + raise ValueError("bsr values must be provided!") + else: + previous_tensor = bsr_values + + kwargs = { + "device": previous_tensor.device, + "dtype": previous_tensor.dtype, + "layout": previous_tensor.layout, + "requires_grad": requires_grad, + } + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( # noqa: PYI034 + self, + shape: torch.Size, + bsr_crow_indices: Optional[torch.Tensor], + bsr_col_indices: Optional[torch.Tensor], + bsr_values: Optional[torch.Tensor], + scale: Optional[torch.Tensor], + zero_point: Optional[torch.Tensor], + layout_type: LayoutType, + requires_grad: bool = False, + ): + self.bsr_crow_indices = bsr_crow_indices + self.bsr_col_indices = bsr_col_indices + self.bsr_values = bsr_values + self.scale = scale + self.zero_point = zero_point + self.layout_type = layout_type + + def __tensor_flatten__(self): + inner_tensors = list( + filter(lambda x: getattr(self, x) is not None, self.__slots__) + ) + tensor_meta = (self.shape, self.layout_type, self.requires_grad) + return inner_tensors, tensor_meta + + @classmethod + def __tensor_unflatten__( + cls, + inner_tensors, + tensor_meta: Tuple[torch.Size, bool], + outer_size, + outer_stride, + ) -> torch.Tensor: + shape, layout_type, requires_grad = tensor_meta + return cls( + shape=shape, + bsr_crow_indices=inner_tensors.get("bsr_crow_indices", None), + bsr_col_indices=inner_tensors.get("bsr_col_indices", None), + bsr_values=inner_tensors.get("bsr_values", None), + scale=inner_tensors.get("scale", None), + zero_point=inner_tensors.get("zero_point", None), + layout_type=layout_type, + requires_grad=requires_grad, + ) + + @classmethod + def from_plain(cls, int_data, scale, zero_point, layout_type): + bsr_tensor = int_data.to_sparse_bsr(layout_type.blocksize) + return cls( + shape=int_data.shape, + bsr_crow_indices=bsr_tensor.crow_indices(), + bsr_col_indices=bsr_tensor.col_indices(), + bsr_values=bsr_tensor.values(), + scale=scale, + zero_point=zero_point, + layout_type = layout_type, + requires_grad=False, + ) + + def get_plain(self): + int_data_expanded = torch.ops.blocksparse.bsr_to_dense(self.crow_indices(), self.col_indices(), self.values(), self.shape[0], self.shape[1]) + return int_data_expanded, self.scale, self.zero_point + + def _apply_fn_to_data(self, func): + return self.__class__( + shape = self.shape, + bsr_crow_indices=func(self.bsr_crow_indices), + bsr_col_indices=func(self.bsr_col_indices), + bsr_values=func(self.bsr_values), + scale=self.scale, + zero_point=self.zero_point, + layout_type=self.layout_type, + requires_grad=self.requires_grad, + ) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func is aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + if func is aten.clone.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + # Need the following for bsr specific functions + if func is aten.crow_indices.default: + return args[0].bsr_crow_indices.detach() + + if func is aten.col_indices.default: + return args[0].bsr_col_indices.detach() + + if func is aten.values.default: + return args[0].bsr_values.detach() + + if func is aten._nnz.default: + return args[0].bsr_values.shape[0] + + raise NotImplementedError( + f"BlockSparseAQTLayout dispatch: attempting to run {func}, this is not supported" + ) @register_layout_cls(MarlinSparseLayoutType) class MarlinSparseAQTLayout(AQTLayout): @@ -1221,6 +1364,43 @@ def _linear_int8_act_int8_weight_semi_structured_sparse_impl(input_tensor, weigh y += bias return y +def _linear_int8_act_int8_weight_block_sparse_check(input_tensor, weight_tensor, bias): + return ( + isinstance(input_tensor, AffineQuantizedTensor) and + _aqt_is_int8_reduced_range(input_tensor) and + isinstance(weight_tensor, AffineQuantizedTensor) and + weight_tensor.is_cuda and + input_tensor.dtype == weight_tensor.dtype and + isinstance(input_tensor.layout_type, PlainLayoutType) and + isinstance(weight_tensor.layout_type, BlockSparseLayoutType) + ) + + +def _linear_int8_act_int8_weight_block_sparse_impl(input_tensor, weight_tensor, bias): + x_vals_int8 = input_tensor.layout_tensor.int_data + x_scales = input_tensor.layout_tensor.scale + w_vals = weight_tensor.layout_tensor + w_scales = weight_tensor.layout_tensor.scale + tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) + tmp_t = tmp.t() + + y = torch.ops.blocksparse.int_addmm(w_vals.crow_indices(), + w_vals.col_indices(), + w_vals.values(), + tmp_t, + w_scales, + x_scales.reshape(-1)) + y_shape = (*x_vals_int8.shape[:-1], w_scales.shape[-1]) + y = y.reshape(*y_shape) + + # can downcast only at the very end + output_dtype = input_tensor.dtype + y = y.to(output_dtype) + if bias is not None: + y += bias + return y + + def _linear_bf16_act_uint4_weight_check(input_tensor, weight_tensor, bias): return ( # input is native bfloat16 tensor @@ -1473,6 +1653,7 @@ def _register_aqt_quantized_linear_dispatches(): for dispatch_condition, impl in [ (_linear_int8_act_int8_weight_check, _linear_int8_act_int8_weight_impl), (_linear_int8_act_int8_weight_semi_structured_sparse_check, _linear_int8_act_int8_weight_semi_structured_sparse_impl), + (_linear_int8_act_int8_weight_block_sparse_check, _linear_int8_act_int8_weight_block_sparse_impl), (_linear_fp8_act_fp8_weight_check, _linear_fp8_act_fp8_weight_impl), (_linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl), (_linear_fp_act_int8_weight_check, _linear_fp_act_int8_weight_impl), diff --git a/torchao/sparsity/prototype/superblock/README.md b/torchao/sparsity/prototype/superblock/README.md index 54a6964b17..6fea1a0e3a 100644 --- a/torchao/sparsity/prototype/superblock/README.md +++ b/torchao/sparsity/prototype/superblock/README.md @@ -36,76 +36,33 @@ At least one GPU: conda create -n superblock conda activate superblock ``` -* Install PyTorch. For best performance, we recommend `2.3.0.dev20240305+cu121` nightly +* Install PyTorch. For best performance, we recommend the pytorch nightlies ``` - pip install --pre torch==2.3.0.dev20240305+cu121 --index-url https://download.pytorch.org/whl/nightly/cu121 - pip install --pre torchvision==0.18.0 --no-deps + pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121 ``` + We ran our experiments with torch==2.6.0.dev20240924+cu121 -## Benchmarking -Baseline: -``` -python benchmark.py \ - --model vit_h_14 \ - --batch-size 256 \ -``` -Result: -``` -532.1160546875 ms -``` +# Results +### Benchmarking +For all our benchmarking results, you can run `benchmark.sh`. +These benchmarks were run on a NVIDIA-A100-80GB, with cuSPARSELt v0.5.2. -80% sparsity, block size 64 (random weights): -``` -python benchmark.py \ - --model vit_h_14 \ - --batch-size 256 \ - --sparsity-linear 0.8 \ - --sp-linear-tile-size 64 \ - --bsr 64 \ - --sparsity bsr -``` -Result: -``` -393.864453125 ms -``` -Semi-structured sparsity +### Evaluation + +To reproduce our accuracy results, you can run `evaluate.sh` +You will need to set the following environment variables first to run the script: + ``` -python benchmark.py \ - --model vit_h_14 \ - --batch-size 256 \ - --sparsity semi_structured +IMAGENET_PATH= +NGPUS=1 # put number of available GPUS here ``` - ## Training Please refer to [TRAINING.md](TRAINING.md) for training from scratch. We use [Torchvision](https://github.com/pytorch/vision/tree/main/references/classification) as our framework for training. Supermask can be applied during training. -To apply supermask, we have the following arguments at our disposal, - -* Apply Supermask to linear layers: - ``` - --sparsity-linear - --sp-linear-tile-size - ``` -* Apply Supermask to conv1x1 layers: - ``` - --sparsity-conv1x1 - --sp-conv1x1-tile-size - ``` -* Apply Supermask to all other convolutional layers: - ``` - --sparsity-conv - --sp-conv-tile-size - ``` -* Skip the first transformer layer and/or last linear layer (ViT only): - ``` - --skip-last-layer-sparsity - --skip-first-transformer-sparsity - ``` - For example, if you would like to train a `vit_b_16` from scratch using Supermask, you can use the respective torchvision command found in [TRAINING.md](TRAINING.md) and append the supermask arguments: ``` torchrun --nproc_per_node=8 train.py\ @@ -119,59 +76,6 @@ Through this command, we are training a `vit_b_16` with 90% sparsity to linear l Please run `python train.py --help` for a full list of available arguments. -## Evaluation - -To run an evaluation of a Supermask-trained model, you can use [evaluate.py](evaluate.py). Our current version has signficant speedup with float32 only and not float16, hence, to illustrate speedup, we don't pass `--amp` in the example commands below. - -``` -MODEL_PATH= -IMAGENET_PATH= -NGPUS=1 # put number of available GPUS here -``` - -* Offline sparsification with BSR: - ``` - python evaluate.py --model vit_b_16 --batch-size 256 --sparsity-linear 0.9 --sp-linear-tile-size 32 --weights-path ${MODEL_PATH} --data-path ${IMAGENET_PATH} --sparsity bsr --bsr 64 - ``` - This command applies 90% sparsity to linear layers using 32x32 tiles, loads the model weights from ${MODEL_PATH}, loads the ImageNet validation set located at the specified path, applies offline sparsification to the weights, and converts the sparse weights to BSR format with a block size of 32. It is recommended to set `--bsr` the same as tile size. - -* Online sparsification without BSR: - ``` - torchrun --nproc_per_node=${NGPUS} evaluate.py --model vit_b_16 --batch-size 256 --sparsity-linear 0.9 --sp-linear-tile-size 32 --weights-path ${MODEL_PATH} --data-path ${IMAGENET_PATH} - ``` - This is similar to the previous command, but it does not apply offline sparsification or BSR conversion. Instead, the sparsity is applied on-the-fly during evaluation. - -* Semi-structured sparsity - ``` - python evaluate.py --model vit_b_16 --batch-size 256 --data-path $IMAGENET_PATH --weights-path checkpoints/2x4_sparse_ft_1_epoch.pth --sparsity semi_structured --skip-last-layer-sparsity - ``` - -Please run `python evaluate.py --help` for a full list of available arguments. - -Results (1x A100): -* Baseline - ``` - Test: Total time: 0:02:11 - Test: Acc@1 78.392 Acc@5 93.592 - ``` - -* Sparsity= 0.9, Tile Size = 32, Online Sparsification, BSR = None - ``` - Test: Total time: 0:01:52 - Test: Acc@1 76.092 Acc@5 92.656 - ``` - -* Sparsity= 0.9, Tile Size = 32, Offline Sparsification, BSR = None - ``` - Test: Total time: 0:01:54 - Test: Acc@1 76.092 Acc@5 92.656 - ``` - -* Sparsity= 0.9, Tile Size = 32, Offline Sparsification, BSR = 32 - ``` - Test: Total time: 0:01:25 - Test: Acc@1 76.092 Acc@5 92.656 - ``` ## Pretrained Weights @@ -189,43 +93,5 @@ wget https://huggingface.co/facebook/superblock-vit-b-16/resolve/main/checkpoint # For sparsified checkpoints, wget https://huggingface.co/facebook/superblock-vit-b-16/resolve/main/checkpoints/sp${SPARSITY}-ts${BLOCK_SIZE}.pth -P checkpoints/ ``` - -### Benchmark: -``` -python benchmark.py --model vit_b_16 \ - --batch-size 256 \ - --sparsity-linear ${SPARSITY} \ - --sp-linear-tile-size ${BLOCK_SIZE} \ - --sparsity bsr\ - --bsr ${BLOCK_SIZE} \ - --weights-path ./checkpoints/sp${SPARSITY}-ts${BLOCK_SIZE}.pth \ - > /dev/null -``` -Result: -``` -530.342578125 ms -``` - -### Evaluate: -8 x A100 GPUs: -``` -torchrun --nproc_per_node=8 evaluate.py --model vit_b_16 --batch-size 256 --sparsity-linear ${SPARSITY} --sp-linear-tile-size ${BLOCK_SIZE} --bsr ${BLOCK_SIZE} --sparsity bsr --weights-path checkpoints/sp${SPARSITY}-ts${BLOCK_SIZE}.pth --data-path ${IMAGENET_PATH} -``` -Result: -``` -Test: Total time: 0:01:01 -Test: Acc@1 77.644 Acc@5 93.554 -``` - -1 x A100 GPUs: -``` -torchrun --nproc_per_node=1 evaluate.py --model vit_b_16 --batch-size 256 --sparsity-linear ${SPARSITY} --sp-linear-tile-size ${BLOCK_SIZE} --bsr ${BLOCK_SIZE} --sparsity bsr--weights-path checkpoints/sp${SPARSITY}-ts${BLOCK_SIZE}.pth --data-path ${IMAGENET_PATH} -``` -Result: -``` -Test: Total time: 0:01:51 -Test: Acc@1 77.644 Acc@5 93.554 -``` - ## License SuperBlock is released under the [MIT license](https://github.com/pytorch-labs/superblock?tab=MIT-1-ov-file#readme). diff --git a/torchao/sparsity/prototype/superblock/benchmark.py b/torchao/sparsity/prototype/superblock/benchmark.py index 18de0bc2d5..a0fb27022c 100644 --- a/torchao/sparsity/prototype/superblock/benchmark.py +++ b/torchao/sparsity/prototype/superblock/benchmark.py @@ -1,26 +1,25 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. -import os -import time -import sys -import warnings -import hashlib +import torch import torchvision -import presets -import torch -import torch.utils.data -import utils -from torch import nn -from torch.sparse._triton_ops_meta import optimize_bsr_dense_addmm -from torchao.sparsity.prototype.superblock.utils import accelerate_with_sparsity, simulate_sparsity +from torch.sparse._triton_ops_meta import ( + dump as store_tuned_kernel_params, + optimize_bsr_dense_addmm, +) +from torchao.sparsity.prototype.superblock.utils import ( + accelerate_with_sparsity, + get_args_parser, + simulate_sparsity, +) from torchao.utils import benchmark_model, profiler_runner torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = False +torch.backends.mha.set_fastpath_enabled(False) + @torch.inference_mode def main(args): - print(args) device = torch.device(args.device) # We disable the cudnn benchmarking because it can noticeably affect the accuracy @@ -29,90 +28,114 @@ def main(args): num_classes = 1000 dtype = getattr(torch, args.dtype) - print(f"Using dtype: {dtype}") # BSR kernel tuning if args.bsr and args.tune_kernel_params: - print("Tuning kernel params") + kwargs = dict( + dtype=torch.int8 if args.quantization else dtype, + sparsity=args.sparsity_linear, + verbose=True, + # per blocksparse_int_addmm: + alpha=1, + beta=0, + use_left_alpha=True, + use_right_alpha=True, + # force tuning because existing tuning parameters are + # computed for use_left/right_alpha=False, however, it + # turns out that re-tuning for use_left/right_alpha=False + # leads to the same set of tuning parametes: + # force=True + ) if args.model == "vit_b_16": - optimize_bsr_dense_addmm(3072, 768, 50432, args.bsr, args.bsr, dtype=dtype, sparsity=args.sparsity_linear, verbose=True) - optimize_bsr_dense_addmm(768, 3072, 50432, args.bsr, args.bsr, dtype=dtype, sparsity=args.sparsity_linear, verbose=True) + optimize_bsr_dense_addmm(3072, 768, 50432, args.bsr, args.bsr, **kwargs) + optimize_bsr_dense_addmm(768, 3072, 50432, args.bsr, args.bsr, **kwargs) elif args.model == "vit_h_14": - optimize_bsr_dense_addmm(5120, 1280, 65792, args.bsr, args.bsr, dtype=dtype, sparsity=args.sparsity_linear, verbose=True) - optimize_bsr_dense_addmm(1280, 5120, 65792, args.bsr, args.bsr, dtype=dtype, sparsity=args.sparsity_linear, verbose=True) + optimize_bsr_dense_addmm(5120, 1280, 65792, args.bsr, args.bsr, **kwargs) + optimize_bsr_dense_addmm(1280, 5120, 65792, args.bsr, args.bsr, **kwargs) else: - raise NotImplementedError("Tuning kernel params for this model is not supported yet.") - - print("Creating model") - model = torchvision.models.get_model(args.model, weights=args.weights, num_classes=num_classes) - - # Fake sparsity necessary for BSR - simulate_sparsity(model, args) + raise NotImplementedError( + "Tuning kernel params for this model is not supported yet." + ) + # Warning: the following call will overwrite the source code + # of torch.sparse._triton_ops_meta (hence it is commented out + # by default) but when used, it'll enables reusing the tuned + # parameters in subsequent runs of this script: + # store_tuned_kernel_params() + model = torchvision.models.get_model( + args.model, weights=args.weights, num_classes=num_classes + ).eval() + + # Fake sparsity necessary for BSR, since we find based on SuperBlock + sparsifier_or_none = simulate_sparsity(model, args) + if sparsifier_or_none is not None: + sparsifier_or_none.squash_mask() if args.weights_path: try: checkpoint = torch.load(args.weights_path, map_location="cpu") model.load_state_dict(checkpoint["model"]) - print(f"Loaded checkpoint successfully from: {args.weights_path}") except FileNotFoundError: raise FileNotFoundError(f"No checkpoint found at {args.weights_path}.") model.to(device).to(dtype) - # Fake sparsity necessary for BSR + # With quantization, we must use cuSPARSELt to fuse one of the scalar matmuls. + # Otherwise, we observe the CUTLASS kernels to be faster, so we use those instead. accelerate_with_sparsity(model, args) - # compile - model = torch.compile(model, mode='max-autotune', fullgraph=True) + # compile + model = torch.compile(model, mode="max-autotune", fullgraph=True) # define image - image = torch.randn(args.batch_size, 3, args.val_crop_size, args.val_crop_size, dtype=dtype, device=device) + image = torch.randn( + args.batch_size, + 3, + args.val_crop_size, + args.val_crop_size, + dtype=dtype, + device=device, + ) # warmup - benchmark_model(model, 10, args=(image,)) + benchmark_model(model, 10, args=(image,)) if args.profile: - return profiler_runner("test.json.gz", benchmark_model, model, 10, (image,)) + return profiler_runner("test.json.gz", benchmark_model, model, 10, (image,)) else: - return benchmark_model(model, 100, args=(image,)) - - - -def get_args_parser(add_help=True): - import argparse - - parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help) - parser.add_argument("--model", default="resnet18", type=str, help="model name") - parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)") - parser.add_argument( - "-b", "--batch-size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size" - ) - parser.add_argument( - "--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)" - ) - parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") - parser.add_argument("--weights-path", type=str, help="path of pretrained weights to load") - # NOTE: sparsity args - parser.add_argument("--sparsity", choices=["bsr", "semi_structured"], default=None, help='weight sparsification to apply') - parser.add_argument("--sparsity-linear", type=float, default=0.0) - parser.add_argument("--sp-linear-tile-size", type=int, default=1) - parser.add_argument("--sparsity-conv1x1", type=float, default=0.0) - parser.add_argument("--sp-conv1x1-tile-size", type=int, default=1) - parser.add_argument("--sparsity-conv", type=float, default=0.0) - parser.add_argument("--sp-conv-tile-size", type=int, default=1) - parser.add_argument("--skip-last-layer-sparsity", action="store_true", help="Skip applying sparsity to the last linear layer (for vit only)") - parser.add_argument("--skip-first-transformer-sparsity", action="store_true", help="Skip applying sparsity to the first transformer layer (for vit only)") - parser.add_argument('--bsr', type=int, nargs='?', const=256, default=None, help='Convert sparsified weights to BSR format with optional block size (default: 256)') - parser.add_argument("--dtype", choices=["float32", "bfloat16", "float16"], help="data type", default="bfloat16") - parser.add_argument("--float16", action="store_true", help="Use float16") - parser.add_argument("--tune-kernel-params", action="store_true", help="Tune kernel params") - parser.add_argument("--profile", action="store_true", help="Profile the run and dump Prefetto trace") - parser.add_argument("--quantization", action="store_true", help="Profile the run and dump Prefetto trace") - - return parser + return benchmark_model(model, 100, args=(image,)) if __name__ == "__main__": - args = get_args_parser().parse_args() + args = get_args_parser(benchmark=True).parse_args() result = main(args) - print(f"{result:.3f} ms", file=sys.stderr) - print(f"{1000/result:.3f} img/s") + header = [ + "model", + "batch_size", + "dtype", + "sparsity", + "bsr", + "sparsity_level", + "quantization", + "tune_kernel_params", + "latency", + "img/s", + ] + result_string = ",".join( + str(_) + for _ in [ + args.model, + args.batch_size, + args.dtype, + args.sparsity, + args.bsr, + args.sparsity_linear, + args.quantization, + args.tune_kernel_params, + result, + 1000 / result, + ] + ) + with open("benchmark_results.txt", "a") as f: + if args.header: + f.write(",".join(header) + "\n") + f.write(result_string + "\n") + print(result_string) diff --git a/torchao/sparsity/prototype/superblock/benchmark.sh b/torchao/sparsity/prototype/superblock/benchmark.sh new file mode 100644 index 0000000000..3fc2a9869b --- /dev/null +++ b/torchao/sparsity/prototype/superblock/benchmark.sh @@ -0,0 +1,39 @@ +MODEL=vit_h_14 +BATCH_SIZE=256 + +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --header +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --quantization + +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity semi_structured +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.80 --bsr 64 --sparsity bsr +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.84 --bsr 64 --sparsity bsr +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.90 --bsr 64 --sparsity bsr + +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity semi_structured --quantization +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.80 --bsr 64 --sparsity bsr --quantization +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.84 --bsr 64 --sparsity bsr --quantization +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.90 --bsr 64 --sparsity bsr --quantization + +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.80 --bsr 64 --sparsity bsr --quantization --tune-kernel-params +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.84 --bsr 64 --sparsity bsr --quantization --tune-kernel-params +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.90 --bsr 64 --sparsity bsr --quantization --tune-kernel-params + +MODEL=vit_b_16 +BATCH_SIZE=256 + +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --header +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --quantization + +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity semi_structured +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.80 --bsr 64 --sparsity bsr +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.84 --bsr 64 --sparsity bsr +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.90 --bsr 64 --sparsity bsr + +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity semi_structured --quantization +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.80 --bsr 64 --sparsity bsr --quantization +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.84 --bsr 64 --sparsity bsr --quantization +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.90 --bsr 64 --sparsity bsr --quantization + +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.80 --bsr 64 --sparsity bsr --quantization --tune-kernel-params +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.84 --bsr 64 --sparsity bsr --quantization --tune-kernel-params +python benchmark.py --model $MODEL --batch-size $BATCH_SIZE --sparsity-linear 0.90 --bsr 64 --sparsity bsr --quantization --tune-kernel-params diff --git a/torchao/sparsity/prototype/superblock/benchmark_results.txt b/torchao/sparsity/prototype/superblock/benchmark_results.txt new file mode 100644 index 0000000000..3e18d9faec --- /dev/null +++ b/torchao/sparsity/prototype/superblock/benchmark_results.txt @@ -0,0 +1,30 @@ +model,batch_size,dtype,sparsity,bsr,sparsity_level,quantization,tune_kernel_params,latency,img/s +vit_h_14,256,bfloat16,None,None,0.0,False,False,489.645859375,2.0422923646825746 +vit_h_14,256,bfloat16,None,None,0.0,True,False,454.5648828125,2.1999059712064963 +vit_h_14,256,bfloat16,semi_structured,None,0.0,False,False,458.638046875,2.180368608347371 +vit_h_14,256,bfloat16,bsr,64,0.8,False,False,361.5827734375,2.765618479257699 +vit_h_14,256,bfloat16,bsr,64,0.84,False,False,343.1771484375,2.9139469354327407 +vit_h_14,256,bfloat16,bsr,64,0.9,False,False,315.37119140625,3.170866671559215 +vit_h_14,256,bfloat16,semi_structured,None,0.0,True,False,438.1652734375,2.2822438486619143 +vit_h_14,256,bfloat16,bsr,64,0.8,True,False,439.5409765625,2.2751007376392045 +vit_h_14,256,bfloat16,bsr,64,0.84,True,False,416.799375,2.3992358433838823 +vit_h_14,256,bfloat16,bsr,64,0.9,True,False,381.9370703125,2.6182323679181034 +vit_h_14,256,bfloat16,bsr,64,0.8,True,True,439.1569921875,2.277090010610706 +vit_h_14,256,bfloat16,bsr,64,0.84,True,True,416.18,2.4028064779662643 +vit_h_14,256,bfloat16,bsr,64,0.9,True,True,384.2584765625,2.6024149394069362 + +model,batch_size,dtype,sparsity,bsr,sparsity_level,quantization,tune_kernel_params,latency,img/s +vit_b_16,256,bfloat16,None,None,0.0,False,False,61.407705078125,16.284601398599175 +vit_b_16,256,bfloat16,None,None,0.0,True,False,60.934091796875,16.41117427881784 +vit_b_16,256,bfloat16,semi_structured,None,0.0,False,False,59.9600732421875,16.677764817945665 +vit_b_16,256,bfloat16,bsr,64,0.8,False,False,47.6238916015625,20.997864020990484 +vit_b_16,256,bfloat16,bsr,64,0.84,False,False,45.7176416015625,21.873394273378768 +vit_b_16,256,bfloat16,bsr,64,0.9,False,False,42.708759765625,23.414400359264707 +vit_b_16,256,bfloat16,semi_structured,None,0.0,True,False,58.783828125,17.011481420937148 +vit_b_16,256,bfloat16,bsr,64,0.8,True,False,58.1029541015625,17.210828872005806 +vit_b_16,256,bfloat16,bsr,64,0.84,True,False,55.8751025390625,17.89705887878946 +vit_b_16,256,bfloat16,bsr,64,0.9,True,False,52.3257763671875,19.111039900921202 +vit_b_16,256,bfloat16,bsr,64,0.8,True,True,58.649736328125,17.050375033322325 +vit_b_16,256,bfloat16,bsr,64,0.84,True,True,56.46744140625,17.709320186930174 +vit_b_16,256,bfloat16,bsr,64,0.9,True,True,52.528623046875,19.037239927413086 +vit_b_16,256,bfloat16,bsr,64,0.8,True,False,57.6839794921875,17.335835856044508 diff --git a/torchao/sparsity/prototype/superblock/blocksparse.py b/torchao/sparsity/prototype/superblock/blocksparse.py index 06b3548c55..69c98f6afc 100644 --- a/torchao/sparsity/prototype/superblock/blocksparse.py +++ b/torchao/sparsity/prototype/superblock/blocksparse.py @@ -1,24 +1,114 @@ from functools import partial +from typing import Any, Callable, Dict, List, Optional, Tuple import torch -from typing import Optional, Tuple, List, Dict, Any, Callable +from torch.sparse._triton_ops import broadcast_batch_dims, bsr_dense_addmm, bsr_dense_mm from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.utils import TorchAOBaseTensor from torchao.quantization.quant_api import _get_linear_subclass_inserter +from torchao.utils import TorchAOBaseTensor aten = torch.ops.aten + +# quantization support +@torch.library.custom_op("blocksparse::bsr_to_dense", mutates_args=()) +def bsr_to_dense( + crow_indices: torch.Tensor, + col_indices: torch.Tensor, + values: torch.Tensor, + M: int, + K: int, +) -> torch.Tensor: + return torch.sparse_bsr_tensor( + crow_indices=crow_indices, col_indices=col_indices, values=values, size=(M, K) + ).to_dense() + + +@torch.library.register_fake("blocksparse::bsr_to_dense") +def bsr_to_dense_abstract( + crow_indices: torch.Tensor, + col_indices: torch.Tensor, + values: torch.Tensor, + M: int, + K: int, +) -> torch.Tensor: + return torch.empty((M, K), dtype=values.dtype, device=values.device) + + +@torch.library.custom_op("blocksparse::int_addmm", mutates_args=()) +def blocksparse_int_addmm( + crow_indices: torch.Tensor, + col_indices: torch.Tensor, + values: torch.Tensor, + A: torch.Tensor, + left_alpha: torch.Tensor, + right_alpha: torch.Tensor, +) -> torch.Tensor: + assert values.dtype == torch.int8 + M = left_alpha.shape[-1] + K = A.shape[-2] + N = A.shape[-1] + weight_bsr = torch.sparse_bsr_tensor(crow_indices, col_indices, values, size=(M, K)) + original_batch_dims_broadcasted = broadcast_batch_dims( + blocksparse_int_addmm, weight_bsr, A + ) + out = A.new_empty(original_batch_dims_broadcasted + (M, N), dtype=torch.bfloat16) + return bsr_dense_addmm( + out, + weight_bsr, + A, + alpha=1, + beta=0, + out=out, + left_alpha=left_alpha, + right_alpha=right_alpha, + ).t() + + +@torch.library.register_fake("blocksparse::int_addmm") +def blocksparse_int_addmm_abstract( + crow_indices: torch.Tensor, + col_indices: torch.Tensor, + values: torch.Tensor, + A: torch.Tensor, + left_alpha: torch.Tensor, + right_alpha: torch.Tensor, +) -> torch.Tensor: + N = A.shape[-1] + M = left_alpha.shape[-1] + # to have the same strides as the transposed result + return torch.empty((M, N), dtype=torch.bfloat16, device=A.device).t() + + # bsr wrapper custom op @torch.library.custom_op("blocksparse::linear", mutates_args=()) -def blocksparse_linear(A: torch.Tensor, crow_indices: torch.Tensor, col_indices: torch.Tensor, values: torch.Tensor, M: int, K: int, bias: torch.Tensor) -> torch.Tensor: +def blocksparse_linear( + A: torch.Tensor, + crow_indices: torch.Tensor, + col_indices: torch.Tensor, + values: torch.Tensor, + M: int, + K: int, + bias: torch.Tensor, +) -> torch.Tensor: weight_bsr = torch.sparse_bsr_tensor(crow_indices, col_indices, values, size=(M, K)) return torch.nn.functional.linear(A, weight_bsr, bias) + @torch.library.register_fake("blocksparse::linear") -def blocksparse_linear_abstract(A: torch.Tensor, crow_indices: torch.Tensor, col_indices: torch.Tensor, values: torch.Tensor, M: int, K:int , bias: torch.Tensor) -> torch.Tensor: - new_shape = A.shape[:-1] + (bias.shape[0],) +def blocksparse_linear_abstract( + A: torch.Tensor, + crow_indices: torch.Tensor, + col_indices: torch.Tensor, + values: torch.Tensor, + M: int, + K: int, + bias: torch.Tensor, +) -> torch.Tensor: + new_shape = A.shape[:-1] + (M,) return torch.empty(new_shape, dtype=A.dtype, device=A.device) + # Subclass definition class BlockSparseTensor(TorchAOBaseTensor): bsr_crow_indices: Optional[torch.Tensor] @@ -37,7 +127,9 @@ def __new__( # noqa: PYI034 requires_grad: bool = False, ): if bsr_values is None: - raise ValueError("bsr values must be provided!") + raise ValueError( + "No values passed to BlockSparseTensor: bsr_values must be provided!" + ) else: previous_tensor = bsr_values @@ -72,7 +164,7 @@ def __tensor_unflatten__( outer_size, outer_stride, ) -> torch.Tensor: - shape, requires_grad = tensor_meta + shape, requires_grad = tensor_meta return cls( shape=shape, bsr_crow_indices=inner_tensors.get("bsr_crow_indices", None), @@ -94,44 +186,54 @@ def from_dense(cls, dense_tensor, blocksize): def apply_fn_to_shard(self, func): return BlockSparseTensor( - shape = self.shape, + shape=self.shape, bsr_crow_indices=func(self.bsr_crow_indices), bsr_col_indices=func(self.bsr_col_indices), bsr_values=func(self.bsr_values), requires_grad=self.requires_grad, ) + # Subclass op dispatch registration implements = BlockSparseTensor.implements + @implements(aten.detach.default) def block_sparse_detach(func, types, args, kwargs): - return return_and_correct_aliasing(func, args, kwargs, args[0].apply_fn_to_shard(torch.detach)) + return return_and_correct_aliasing( + func, args, kwargs, args[0].apply_fn_to_shard(torch.detach) + ) + @implements(aten.values.default) def block_sparse_values(func, types, args, kwargs): return args[0].bsr_values.detach() + @implements(aten.crow_indices.default) def block_sparse_crow_indices(func, types, args, kwargs): return args[0].bsr_crow_indices.detach() + @implements(aten.col_indices.default) def block_sparse_col_indices(func, types, args, kwargs): return args[0].bsr_col_indices.detach() + @implements(aten._nnz.default) def block_sparse__nnz(func, types, args, kwargs): return args[0].bsr_values.shape[0] + @implements(torch.nn.functional.linear) def block_sparse_linear(func, types, args, kwargs): x, w, bias = args - return torch.ops.blocksparse.linear(x, - w.crow_indices(), - w.col_indices(), - w.values(), - w.shape[0], w.shape[1], bias) + return torch.ops.blocksparse.linear( + x, w.crow_indices(), w.col_indices(), w.values(), w.shape[0], w.shape[1], bias + ) + def block_sparse_weight(blocksize=64): - return _get_linear_subclass_inserter(partial(BlockSparseTensor.from_dense, blocksize=blocksize)) + return _get_linear_subclass_inserter( + partial(BlockSparseTensor.from_dense, blocksize=blocksize) + ) diff --git a/torchao/sparsity/prototype/superblock/evaluate.py b/torchao/sparsity/prototype/superblock/evaluate.py index 09f34ebb64..5db9fc9e38 100644 --- a/torchao/sparsity/prototype/superblock/evaluate.py +++ b/torchao/sparsity/prototype/superblock/evaluate.py @@ -1,29 +1,23 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. - import os -import sys -import warnings -import hashlib -from functools import partial - -import presets import torch -import torch.utils.data import torchvision -import utils -from torch import nn -from torchvision.transforms.functional import InterpolationMode -from torchao.sparsity import sparsify_, semi_sparse_weight -from torchao.sparsity.prototype.superblock.supermask import apply_supermask -from torchao.sparsity.prototype.superblock.utils import apply_sparsity, verify_sparsity, mlp_only_with_args, simulate_sparsity, accelerate_with_sparsity -from torchao.sparsity.prototype.superblock.train import evaluate, _get_cache_path, load_data -from torchao.sparsity.prototype.sparsifier.weight_norm_sparsifier import WeightNormSparsifier +from torchao.sparsity.prototype.superblock.train import evaluate, load_data +from torchao.sparsity.prototype.superblock.utils import ( + accelerate_with_sparsity, + apply_sparsity, + get_args_parser, + init_distributed_mode, + simulate_sparsity, +) torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = False +torch.backends.mha.set_fastpath_enabled(False) + def main(args): - utils.init_distributed_mode(args) + init_distributed_mode(args) print(args) device = torch.device(args.device) @@ -35,13 +29,20 @@ def main(args): val_dir = os.path.join(args.data_path, "val") dataset_test, test_sampler = load_data(None, val_dir, args) data_loader_test = torch.utils.data.DataLoader( - dataset_test, batch_size=args.batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True + dataset_test, + batch_size=args.batch_size, + sampler=test_sampler, + num_workers=args.workers, + pin_memory=True, + drop_last=True, ) num_classes = len(dataset_test.classes) # Create Model print("Creating model") - model = torchvision.models.get_model(args.model, weights=args.weights, num_classes=num_classes) + model = torchvision.models.get_model( + args.model, weights=args.weights, num_classes=num_classes + ) sparsifier_or_none = simulate_sparsity(model, args) @@ -58,62 +59,44 @@ def main(args): if sparsifier_or_none is not None: sparsifier_or_none.squash_mask() accelerate_with_sparsity(model, args) - - criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) - evaluate(model, criterion, data_loader_test, device=device, dtype=torch.bfloat16) - - -def get_args_parser(add_help=True): - import argparse - - parser = argparse.ArgumentParser(description="Superblock evaluation", add_help=add_help) - parser.add_argument("--data-path", default="/datasets01/imagenet_full_size/061417", type=str, help="dataset path") - parser.add_argument("--model", default="vit-", type=str, help="model name") - parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)") - parser.add_argument( - "-b", "--batch-size", default=256, type=int, help="images per gpu, the total batch size is $NGPU x batch_size" - ) - parser.add_argument( - "-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)" - ) - parser.add_argument( - "--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing" - ) - parser.add_argument("--print-freq", default=10, type=int, help="print frequency") - parser.add_argument( - "--cache-dataset", - dest="cache_dataset", - help="Cache the datasets for quicker initialization. It also serializes the transforms", - action="store_true", - ) - parser.add_argument( - "--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)" - ) - parser.add_argument( - "--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)" - ) - parser.add_argument( - "--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)" - ) - parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") - parser.add_argument("--weights-path", type=str, help="path of pretrained weights to load") + model = torch.compile(model, mode="max-autotune", fullgraph=True) - # NOTE: sparsity args - parser.add_argument("--sparsity", choices=["bsr", "semi_structured"], default=None, help='weight sparsification to apply') - parser.add_argument("--sparsity-linear", type=float, default=0.0) - parser.add_argument("--sp-linear-tile-size", type=int, default=1) - parser.add_argument("--sparsity-conv1x1", type=float, default=0.0) - parser.add_argument("--sp-conv1x1-tile-size", type=int, default=1) - parser.add_argument("--sparsity-conv", type=float, default=0.0) - parser.add_argument("--sp-conv-tile-size", type=int, default=1) - parser.add_argument("--skip-last-layer-sparsity", action="store_true", help="Skip applying sparsity to the last linear layer (for vit only)") - parser.add_argument("--skip-first-transformer-sparsity", action="store_true", help="Skip applying sparsity to the first transformer layer (for vit only)") - parser.add_argument('--bsr', type=int, nargs='?', default=64, help='Convert sparsified weights to BSR format with optional block size (default: 64)') - parser.add_argument('--meta', action='store_true', help='Use Meta internal imagenet structure') - - return parser + criterion = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) + return evaluate(model, criterion, data_loader_test, device=device, dtype=torch.bfloat16) if __name__ == "__main__": - args = get_args_parser().parse_args() - main(args) + args = get_args_parser(evaluate=True).parse_args() + accuracy, throughput, max_mem = main(args) + header = [ + "model", + "batch_size", + "dtype", + "sparsity", + "bsr", + "sparsity_level", + "quantization", + "top-1_acc", + "encoder img/s", + "max_mem (MB)", + ] + result_string = ",".join( + str(_) + for _ in [ + args.model, + args.batch_size, + "bfloat16", + args.sparsity, + args.bsr, + args.sparsity_linear, + args.quantization, + accuracy, + throughput, + max_mem + ] + ) + with open("evaluation_results.txt", "a") as f: + if args.header: + f.write(",".join(header) + "\n") + f.write(result_string + "\n") + print(result_string) diff --git a/torchao/sparsity/prototype/superblock/evaluate.sh b/torchao/sparsity/prototype/superblock/evaluate.sh new file mode 100644 index 0000000000..68be5175fd --- /dev/null +++ b/torchao/sparsity/prototype/superblock/evaluate.sh @@ -0,0 +1,23 @@ +MODEL=vit_b_16 +BATCH_SIZE=256 + +python evaluate.py --model $MODEL --batch-size $BATCH_SIZE --data-path $IMAGENET_PATH --weights ViT_B_16_Weights.IMAGENET1K_V1 --header +python evaluate.py --model $MODEL --batch-size $BATCH_SIZE --data-path $IMAGENET_PATH --weights ViT_B_16_Weights.IMAGENET1K_V1 --quantization +python evaluate.py --model $MODEL --batch-size $BATCH_SIZE --data-path $IMAGENET_PATH --weights ViT_B_16_Weights.IMAGENET1K_V1 --sparsity semi_structured +python evaluate.py --model $MODEL --batch-size $BATCH_SIZE --data-path $IMAGENET_PATH --weights ViT_B_16_Weights.IMAGENET1K_V1 --sparsity semi_structured --quantization +python evaluate.py --model $MODEL --batch-size $BATCH_SIZE --data-path $IMAGENET_PATH --sparsity bsr --sparsity-linear 0.80 --bsr 64 --weights-path checkpoints/$MODEL/sp0.80-ts64.pth +python evaluate.py --model $MODEL --batch-size $BATCH_SIZE --data-path $IMAGENET_PATH --sparsity bsr --sparsity-linear 0.80 --bsr 64 --weights-path checkpoints/$MODEL/sp0.80-ts64.pth --quantization +python evaluate.py --model $MODEL --batch-size $BATCH_SIZE --data-path $IMAGENET_PATH --sparsity bsr --sparsity-linear 0.84 --bsr 64 --weights-path checkpoints/$MODEL/sp0.84-ts64.pth +python evaluate.py --model $MODEL --batch-size $BATCH_SIZE --data-path $IMAGENET_PATH --sparsity bsr --sparsity-linear 0.84 --bsr 64 --weights-path checkpoints/$MODEL/sp0.84-ts64.pth --quantization +python evaluate.py --model $MODEL --batch-size $BATCH_SIZE --data-path $IMAGENET_PATH --sparsity bsr --sparsity-linear 0.90 --bsr 64 --weights-path checkpoints/$MODEL/sp0.90-ts64.pth +python evaluate.py --model $MODEL --batch-size $BATCH_SIZE --data-path $IMAGENET_PATH --sparsity bsr --sparsity-linear 0.90 --bsr 64 --weights-path checkpoints/$MODEL/sp0.90-ts64.pth --quantization + +MODEL=vit_h_14 +BATCH_SIZE=128 + +python evaluate.py --model $MODEL --batch-size $BATCH_SIZE --data-path $IMAGENET_PATH --weights ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1 --header +python evaluate.py --model $MODEL --batch-size $BATCH_SIZE --data-path $IMAGENET_PATH --weights ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1 --quantization +python evaluate.py --model $MODEL --batch-size $BATCH_SIZE --data-path $IMAGENET_PATH --weights ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1 --sparsity semi_structured +python evaluate.py --model $MODEL --batch-size $BATCH_SIZE --data-path $IMAGENET_PATH --weights ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1 --sparsity semi_structured --quantization +python evaluate.py --model $MODEL --batch-size $BATCH_SIZE --data-path $IMAGENET_PATH --sparsity bsr --sparsity-linear 0.90 --bsr 64 --weights-path checkpoints/$MODEL/sp0.90-ts64.pth +python evaluate.py --model $MODEL --batch-size $BATCH_SIZE --data-path $IMAGENET_PATH --sparsity bsr --sparsity-linear 0.90 --bsr 64 --weights-path checkpoints/$MODEL/sp0.90-ts64.pth --quantization diff --git a/torchao/sparsity/prototype/superblock/evaluation_results.txt b/torchao/sparsity/prototype/superblock/evaluation_results.txt new file mode 100644 index 0000000000..58dcade663 --- /dev/null +++ b/torchao/sparsity/prototype/superblock/evaluation_results.txt @@ -0,0 +1,19 @@ +model,batch_size,dtype,sparsity,bsr,sparsity_level,quantization,top-1_acc,encoder img/s,max_mem (MB) +vit_b_16,256,bfloat16,None,None,0.0,False,81.97716346153847,734.904399886552,247.97265625 +vit_b_16,256,bfloat16,None,None,0.0,True,81.89503205128206,230.83627917226997,196.841796875 +vit_b_16,256,bfloat16,semi_structured,None,0.0,False,77.05729166666667,1386.7278781133518,316.40234375 +vit_b_16,256,bfloat16,semi_structured,None,0.0,True,76.74078525641026,150.53603093207843,249.25390625 +vit_b_16,256,bfloat16,bsr,64,0.8,False,77.13541666666667,1469.2705176409308,179.55322265625 +vit_b_16,256,bfloat16,bsr,64,0.8,True,77.13341346153847,87.8480561274922,158.70361328125 +vit_b_16,256,bfloat16,bsr,64,0.84,False,76.14983974358974,1752.835540513905,174.01953125 +vit_b_16,256,bfloat16,bsr,64,0.84,True,76.0556891025641,1013.7495284783578,156.630859375 +vit_b_16,256,bfloat16,bsr,64,0.9,False,62.99879807692308,1702.289195236525,164.2822265625 +vit_b_16,256,bfloat16,bsr,64,0.9,True,62.946714743589745,987.5488468441617,152.5732421875 + +model,batch_size,dtype,sparsity,bsr,sparsity_level,quantization,top-1_acc,encoder img/s,max_mem (MB) +vit_h_14,128,bfloat16,None,None,0.0,False,89.29286858974359,81.02922135697278,1430.05615234375 +vit_h_14,128,bfloat16,None,None,0.0,True,89.3349358974359,56.076129157634355,1025.00927734375 +vit_h_14,128,bfloat16,semi_structured,None,0.0,False,82.03725961538461,75.83586253901329,1900.36279296875 +vit_h_14,128,bfloat16,semi_structured,None,0.0,True,82.06330128205128,36.36097831133589,1390.98779296875 +vit_h_14,128,bfloat16,bsr,64,0.9,False,78.21113782051282,350.91330496491446,599.6201171875 +vit_h_14,128,bfloat16,bsr,64,0.9,True,78.2051282051282,108.84048044884008,531.5810546875 diff --git a/torchao/sparsity/prototype/superblock/presets.py b/torchao/sparsity/prototype/superblock/presets.py deleted file mode 100644 index c5a242c549..0000000000 --- a/torchao/sparsity/prototype/superblock/presets.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. - -import torch -from torchvision.transforms import autoaugment, transforms -from torchvision.transforms.functional import InterpolationMode - - -class ClassificationPresetTrain: - def __init__( - self, - *, - crop_size, - mean=(0.485, 0.456, 0.406), - std=(0.229, 0.224, 0.225), - interpolation=InterpolationMode.BILINEAR, - hflip_prob=0.5, - auto_augment_policy=None, - ra_magnitude=9, - augmix_severity=3, - random_erase_prob=0.0, - ): - trans = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)] - if hflip_prob > 0: - trans.append(transforms.RandomHorizontalFlip(hflip_prob)) - if auto_augment_policy is not None: - if auto_augment_policy == "ra": - trans.append(autoaugment.RandAugment(interpolation=interpolation, magnitude=ra_magnitude)) - elif auto_augment_policy == "ta_wide": - trans.append(autoaugment.TrivialAugmentWide(interpolation=interpolation)) - elif auto_augment_policy == "augmix": - trans.append(autoaugment.AugMix(interpolation=interpolation, severity=augmix_severity)) - else: - aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy) - trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation)) - trans.extend( - [ - transforms.PILToTensor(), - transforms.ConvertImageDtype(torch.float), - transforms.Normalize(mean=mean, std=std), - ] - ) - if random_erase_prob > 0: - trans.append(transforms.RandomErasing(p=random_erase_prob)) - - self.transforms = transforms.Compose(trans) - - def __call__(self, img): - return self.transforms(img) - - -class ClassificationPresetEval: - def __init__( - self, - *, - crop_size, - resize_size=256, - mean=(0.485, 0.456, 0.406), - std=(0.229, 0.224, 0.225), - interpolation=InterpolationMode.BILINEAR, - ): - - self.transforms = transforms.Compose( - [ - transforms.Resize(resize_size, interpolation=interpolation), - transforms.CenterCrop(crop_size), - transforms.PILToTensor(), - transforms.ConvertImageDtype(torch.float), - transforms.Normalize(mean=mean, std=std), - ] - ) - - def __call__(self, img): - return self.transforms(img) diff --git a/torchao/sparsity/prototype/superblock/sampler.py b/torchao/sparsity/prototype/superblock/sampler.py deleted file mode 100644 index bf36a17954..0000000000 --- a/torchao/sparsity/prototype/superblock/sampler.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. - -import math - -import torch -import torch.distributed as dist - - -class RASampler(torch.utils.data.Sampler): - """Sampler that restricts data loading to a subset of the dataset for distributed, - with repeated augmentation. - It ensures that different each augmented version of a sample will be visible to a - different process (GPU). - Heavily based on 'torch.utils.data.DistributedSampler'. - - This is borrowed from the DeiT Repo: - https://github.com/facebookresearch/deit/blob/main/samplers.py - """ - - def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0, repetitions=3): - if num_replicas is None: - if not dist.is_available(): - raise RuntimeError("Requires distributed package to be available!") - num_replicas = dist.get_world_size() - if rank is None: - if not dist.is_available(): - raise RuntimeError("Requires distributed package to be available!") - rank = dist.get_rank() - self.dataset = dataset - self.num_replicas = num_replicas - self.rank = rank - self.epoch = 0 - self.num_samples = int(math.ceil(len(self.dataset) * float(repetitions) / self.num_replicas)) - self.total_size = self.num_samples * self.num_replicas - self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) - self.shuffle = shuffle - self.seed = seed - self.repetitions = repetitions - - def __iter__(self): - if self.shuffle: - # Deterministically shuffle based on epoch - g = torch.Generator() - g.manual_seed(self.seed + self.epoch) - indices = torch.randperm(len(self.dataset), generator=g).tolist() - else: - indices = list(range(len(self.dataset))) - - # Add extra samples to make it evenly divisible - indices = [ele for ele in indices for i in range(self.repetitions)] - indices += indices[: (self.total_size - len(indices))] - assert len(indices) == self.total_size - - # Subsample - indices = indices[self.rank : self.total_size : self.num_replicas] - assert len(indices) == self.num_samples - - return iter(indices[: self.num_selected_samples]) - - def __len__(self): - return self.num_selected_samples - - def set_epoch(self, epoch): - self.epoch = epoch diff --git a/torchao/sparsity/prototype/superblock/train.py b/torchao/sparsity/prototype/superblock/train.py index 7fd1ce4d20..acfed09bc6 100644 --- a/torchao/sparsity/prototype/superblock/train.py +++ b/torchao/sparsity/prototype/superblock/train.py @@ -1,26 +1,35 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. import datetime -import os import glob +import os import sys import time import warnings -import presets import torch import torch.utils.data import torchvision -import transforms import utils -from sampler import RASampler from torch import nn from torch.utils.data.dataloader import default_collate -from torchvision.transforms.functional import InterpolationMode from torchao.sparsity.prototype.superblock.utils import simulate_sparsity - -def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None): +from torchvision.transforms.functional import InterpolationMode +from utils import RASampler + + +def train_one_epoch( + model, + criterion, + optimizer, + data_loader, + device, + epoch, + args, + model_ema=None, + scaler=None, +): model.train() metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}")) @@ -29,7 +38,9 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg header = f"Epoch: [{epoch}]" accumulation_counter = 0 # Counter for tracking accumulated gradients - for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)): + for i, (image, target) in enumerate( + metric_logger.log_every(data_loader, args.print_freq, header) + ): start_time = time.time() image, target = image.to(device), target.to(device) @@ -65,23 +76,48 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) batch_size = image.shape[0] - metric_logger.update(loss=loss.item() * args.accumulation_steps, lr=optimizer.param_groups[0]["lr"]) # Scale back up for logging + metric_logger.update( + loss=loss.item() * args.accumulation_steps, + lr=optimizer.param_groups[0]["lr"], + ) # Scale back up for logging metric_logger.meters["acc1"].update(acc1.item(), n=batch_size) metric_logger.meters["acc5"].update(acc5.item(), n=batch_size) metric_logger.meters["img/s"].update(batch_size / (time.time() - start_time)) - -def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix="", dtype=torch.float32): + + +def evaluate( + model, + criterion, + data_loader, + device, + print_freq=100, + log_suffix="", + dtype=torch.float32, +): model.eval() metric_logger = utils.MetricLogger(delimiter=" ") header = f"Test: {log_suffix}" - + encoder_time = 0 num_processed_samples = 0 with torch.inference_mode(): for image, target in metric_logger.log_every(data_loader, print_freq, header): image = image.to(device, non_blocking=True).to(dtype) target = target.to(device, non_blocking=True).to(dtype) + # intialize encoder measurements + torch.cuda.reset_max_memory_allocated() + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + + # run encoder output = model(image) - # loss = criterion(output, target) + + # measure time in encoder + end_event.record() + torch.cuda.synchronize() + encoder_time += start_event.elapsed_time(end_event) + max_mem = torch.cuda.max_memory_allocated() / (1024**2) acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) # FIXME need to take into account that the datasets @@ -90,6 +126,7 @@ def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=" # metric_logger.update(loss=loss.item()) metric_logger.meters["acc1"].update(acc1.item(), n=batch_size) metric_logger.meters["acc5"].update(acc5.item(), n=batch_size) + metric_logger.meters["batch_time"].update(encoder_time, n=batch_size) num_processed_samples += batch_size # gather the stats from all processes @@ -97,7 +134,6 @@ def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=" if ( hasattr(data_loader.dataset, "__len__") and len(data_loader.dataset) != num_processed_samples - and torch.distributed.get_rank() == 0 ): # See FIXME above warnings.warn( @@ -109,21 +145,31 @@ def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=" metric_logger.synchronize_between_processes() - print(f"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}") - return metric_logger.acc1.global_avg + print( + f"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}" + ) + total_time = encoder_time / 1000.0 + return metric_logger.acc1.global_avg, num_processed_samples.item() / total_time, max_mem + def _get_cache_path(filepath): import hashlib h = hashlib.sha1(filepath.encode()).hexdigest() - cache_path = os.path.join("~", ".torch", "vision", "datasets", "imagefolder", h[:10] + ".pt") + cache_path = os.path.join( + "~", ".torch", "vision", "datasets", "imagefolder", h[:10] + ".pt" + ) cache_path = os.path.expanduser(cache_path) return cache_path + def load_data(traindir, valdir, args): # Data loading code print("Loading data") - val_resize_size, val_crop_size, = ( + ( + val_resize_size, + val_crop_size, + ) = ( args.val_resize_size, args.val_crop_size, ) @@ -142,7 +188,7 @@ def load_data(traindir, valdir, args): random_erase_prob = getattr(args, "random_erase", 0.0) ra_magnitude = args.ra_magnitude augmix_severity = args.augmix_severity - preprocessing = presets.ClassificationPresetTrain( + preprocessing = utils.ClassificationPresetTrain( crop_size=train_crop_size, interpolation=interpolation, auto_augment_policy=auto_augment_policy, @@ -150,9 +196,7 @@ def load_data(traindir, valdir, args): ra_magnitude=ra_magnitude, augmix_severity=augmix_severity, ) - dataset = torchvision.datasets.ImageFolder( - traindir, - preprocessing) + dataset = torchvision.datasets.ImageFolder(traindir, preprocessing) # ) if args.meta else torchvision.datasets.ImageNet( # traindir, # split="train", @@ -166,7 +210,9 @@ def load_data(traindir, valdir, args): print(f"Number of training images: {len(dataset)}") if args.distributed: if hasattr(args, "ra_sampler") and args.ra_sampler: - train_sampler = RASampler(dataset, shuffle=True, repetitions=args.ra_reps) + train_sampler = RASampler( + dataset, shuffle=True, repetitions=args.ra_reps + ) else: train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) else: @@ -177,30 +223,38 @@ def load_data(traindir, valdir, args): if args.cache_dataset and os.path.exists(cache_path): # Attention, as the transforms are also cached! print(f"Loading dataset_test from {cache_path}") - dataset_test, _ = torch.load(cache_path) + dataset_test, test_sampler = torch.load(cache_path) else: if args.weights: weights = torchvision.models.get_weight(args.weights) preprocessing = weights.transforms() else: - preprocessing = presets.ClassificationPresetEval( - crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation + preprocessing = utils.ClassificationPresetEval( + crop_size=val_crop_size, + resize_size=val_resize_size, + interpolation=interpolation, + ) + dataset_test = ( + torchvision.datasets.ImageFolder( + valdir, + preprocessing, + ) + if args.meta + else torchvision.datasets.ImageNet( + valdir, split="val", transform=preprocessing ) - dataset_test = torchvision.datasets.ImageFolder( - valdir, - preprocessing, - ) if args.meta else torchvision.datasets.ImageNet( - valdir, - split='val', - transform=preprocessing ) if args.cache_dataset: print(f"Saving dataset_test to {cache_path}") utils.mkdir(os.path.dirname(cache_path)) utils.save_on_master((dataset_test, valdir), cache_path) - + print(f"Number of validation images: {len(dataset_test)}") - test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False) if args.distributed else torch.utils.data.SequentialSampler(dataset_test) + test_sampler = ( + torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False) + if args.distributed + else torch.utils.data.SequentialSampler(dataset_test) + ) # for evaluation if traindir is None: @@ -208,6 +262,7 @@ def load_data(traindir, valdir, args): return dataset, dataset_test, train_sampler, test_sampler + def main(args): if args.output_dir: utils.mkdir(args.output_dir) @@ -225,15 +280,21 @@ def main(args): train_dir = os.path.join(args.data_path, "train_blurred") val_dir = os.path.join(args.data_path, "val") - dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args) + dataset, dataset_test, train_sampler, test_sampler = load_data( + train_dir, val_dir, args + ) collate_fn = None num_classes = len(dataset.classes) mixup_transforms = [] if args.mixup_alpha > 0.0: - mixup_transforms.append(transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha)) + mixup_transforms.append( + utils.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha) + ) if args.cutmix_alpha > 0.0: - mixup_transforms.append(transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha)) + mixup_transforms.append( + utils.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha) + ) if mixup_transforms: mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms) @@ -249,11 +310,17 @@ def collate_fn(batch): collate_fn=collate_fn, ) data_loader_test = torch.utils.data.DataLoader( - dataset_test, batch_size=args.batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True + dataset_test, + batch_size=args.batch_size, + sampler=test_sampler, + num_workers=args.workers, + pin_memory=True, ) print("Creating model") - model = torchvision.models.get_model(args.model, weights=args.weights, num_classes=num_classes) + model = torchvision.models.get_model( + args.model, weights=args.weights, num_classes=num_classes + ) if args.weights_path is not None: sd = torch.load(args.weights_path, map_location="cpu") @@ -262,7 +329,7 @@ def collate_fn(batch): model.to(device) if args.distributed and args.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) - + sparsifier = simulate_sparsity(model, args) criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) @@ -270,13 +337,19 @@ def collate_fn(batch): if args.bias_weight_decay is not None: custom_keys_weight_decay.append(("bias", args.bias_weight_decay)) if args.transformer_embedding_decay is not None: - for key in ["class_token", "position_embedding", "relative_position_bias_table"]: + for key in [ + "class_token", + "position_embedding", + "relative_position_bias_table", + ]: custom_keys_weight_decay.append((key, args.transformer_embedding_decay)) parameters = utils.set_weight_decay( model, args.weight_decay, norm_weight_decay=args.norm_weight_decay, - custom_keys_weight_decay=custom_keys_weight_decay if len(custom_keys_weight_decay) > 0 else None, + custom_keys_weight_decay=( + custom_keys_weight_decay if len(custom_keys_weight_decay) > 0 else None + ), ) opt_name = args.opt.lower() @@ -290,24 +363,37 @@ def collate_fn(batch): ) elif opt_name == "rmsprop": optimizer = torch.optim.RMSprop( - parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, eps=0.0316, alpha=0.9 + parameters, + lr=args.lr, + momentum=args.momentum, + weight_decay=args.weight_decay, + eps=0.0316, + alpha=0.9, ) elif opt_name == "adamw": - optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay) + optimizer = torch.optim.AdamW( + parameters, lr=args.lr, weight_decay=args.weight_decay + ) else: - raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD, RMSprop and AdamW are supported.") + raise RuntimeError( + f"Invalid optimizer {args.opt}. Only SGD, RMSprop and AdamW are supported." + ) scaler = torch.cuda.amp.GradScaler() if args.amp else None args.lr_scheduler = args.lr_scheduler.lower() if args.lr_scheduler == "steplr": - main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) + main_lr_scheduler = torch.optim.lr_scheduler.StepLR( + optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma + ) elif args.lr_scheduler == "cosineannealinglr": main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=args.epochs - args.lr_warmup_epochs, eta_min=args.lr_min ) elif args.lr_scheduler == "exponentiallr": - main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma) + main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR( + optimizer, gamma=args.lr_gamma + ) else: raise RuntimeError( f"Invalid lr scheduler '{args.lr_scheduler}'. Only StepLR, CosineAnnealingLR and ExponentialLR " @@ -317,25 +403,33 @@ def collate_fn(batch): if args.lr_warmup_epochs > 0: if args.lr_warmup_method == "linear": warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR( - optimizer, start_factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs + optimizer, + start_factor=args.lr_warmup_decay, + total_iters=args.lr_warmup_epochs, ) elif args.lr_warmup_method == "constant": warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR( - optimizer, factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs + optimizer, + factor=args.lr_warmup_decay, + total_iters=args.lr_warmup_epochs, ) else: raise RuntimeError( f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported." ) lr_scheduler = torch.optim.lr_scheduler.SequentialLR( - optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[args.lr_warmup_epochs] + optimizer, + schedulers=[warmup_lr_scheduler, main_lr_scheduler], + milestones=[args.lr_warmup_epochs], ) else: lr_scheduler = main_lr_scheduler model_without_ddp = model if args.distributed: - model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[args.gpu], find_unused_parameters=True + ) model_without_ddp = model.module model_ema = None @@ -349,16 +443,20 @@ def collate_fn(batch): adjust = args.world_size * args.batch_size * args.model_ema_steps / args.epochs alpha = 1.0 - args.model_ema_decay alpha = min(1.0, alpha * adjust) - model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=1.0 - alpha) + model_ema = utils.ExponentialMovingAverage( + model_without_ddp, device=device, decay=1.0 - alpha + ) - #TODO: need to test resume functionality + # TODO: need to test resume functionality if args.resume: checkpoint_pattern = os.path.join(args.output_dir, "model_*.pth") checkpoint_files = glob.glob(checkpoint_pattern) - epochs = [int(f.split('_')[-1].split('.')[0]) for f in checkpoint_files] + epochs = [int(f.split("_")[-1].split(".")[0]) for f in checkpoint_files] if epochs: latest_epoch = max(epochs) - latest_checkpoint = os.path.join(args.output_dir, f"model_{latest_epoch}.pth") + latest_checkpoint = os.path.join( + args.output_dir, f"model_{latest_epoch}.pth" + ) try: checkpoint = torch.load(latest_checkpoint, map_location="cpu") model_without_ddp.load_state_dict(checkpoint["model"]) @@ -371,7 +469,9 @@ def collate_fn(batch): scaler.load_state_dict(checkpoint["scaler"]) print(f"Resumed training from epoch {args.start_epoch}.") except FileNotFoundError: - print(f"No checkpoint found at {latest_checkpoint}. Starting training from scratch.") + print( + f"No checkpoint found at {latest_checkpoint}. Starting training from scratch." + ) args.start_epoch = 0 else: print("No checkpoint found. Starting training from scratch.") @@ -380,7 +480,9 @@ def collate_fn(batch): args.start_epoch = 0 print("Zero-shot evaluation") if model_ema: - evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA") + evaluate( + model_ema, criterion, data_loader_test, device=device, log_suffix="EMA" + ) else: evaluate(model, criterion, data_loader_test, device=device) @@ -389,11 +491,23 @@ def collate_fn(batch): for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) - train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema, scaler) + train_one_epoch( + model, + criterion, + optimizer, + data_loader, + device, + epoch, + args, + model_ema, + scaler, + ) lr_scheduler.step() evaluate(model, criterion, data_loader_test, device=device) if model_ema: - evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA") + evaluate( + model_ema, criterion, data_loader_test, device=device, log_suffix="EMA" + ) if args.output_dir: checkpoint = { "model": model_without_ddp.state_dict(), @@ -408,152 +522,18 @@ def collate_fn(batch): checkpoint["model_ema"] = model_ema.state_dict() if scaler: checkpoint["scaler"] = scaler.state_dict() - utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth")) - utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth")) + utils.save_on_master( + checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth") + ) + utils.save_on_master( + checkpoint, os.path.join(args.output_dir, "checkpoint.pth") + ) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print(f"Training time {total_time_str}") -def get_args_parser(add_help=True): - import argparse - - parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help) - parser.add_argument("--data-path", type=str, help="dataset path") - parser.add_argument("--model", default="resnet18", type=str, help="model name") - parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)") - parser.add_argument( - "-b", "--batch-size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size" - ) - parser.add_argument("--accumulation-steps", default=1, type=int, help="Number of steps to accumulate gradients over") - parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run") - parser.add_argument( - "-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)" - ) - parser.add_argument("--opt", default="sgd", type=str, help="optimizer") - parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate") - parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum") - parser.add_argument( - "--wd", - "--weight-decay", - default=1e-4, - type=float, - metavar="W", - help="weight decay (default: 1e-4)", - dest="weight_decay", - ) - parser.add_argument( - "--norm-weight-decay", - default=None, - type=float, - help="weight decay for Normalization layers (default: None, same value as --wd)", - ) - parser.add_argument( - "--bias-weight-decay", - default=None, - type=float, - help="weight decay for bias parameters of all layers (default: None, same value as --wd)", - ) - parser.add_argument( - "--transformer-embedding-decay", - default=None, - type=float, - help="weight decay for embedding parameters for vision transformer models (default: None, same value as --wd)", - ) - parser.add_argument( - "--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing" - ) - parser.add_argument("--mixup-alpha", default=0.0, type=float, help="mixup alpha (default: 0.0)") - parser.add_argument("--cutmix-alpha", default=0.0, type=float, help="cutmix alpha (default: 0.0)") - parser.add_argument("--lr-scheduler", default="steplr", type=str, help="the lr scheduler (default: steplr)") - parser.add_argument("--lr-warmup-epochs", default=0, type=int, help="the number of epochs to warmup (default: 0)") - parser.add_argument( - "--lr-warmup-method", default="constant", type=str, help="the warmup method (default: constant)" - ) - parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr") - parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs") - parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma") - parser.add_argument("--lr-min", default=0.0, type=float, help="minimum lr of lr schedule (default: 0.0)") - parser.add_argument("--print-freq", default=10, type=int, help="print frequency") - parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs") - parser.add_argument('--resume', action='store_true', help='Resumes training from latest available checkpoint ("model_.pth")') - parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch") - parser.add_argument( - "--cache-dataset", - dest="cache_dataset", - help="Cache the datasets for quicker initialization. It also serializes the transforms", - action="store_true", - ) - parser.add_argument( - "--sync-bn", - dest="sync_bn", - help="Use sync batch norm", - action="store_true", - ) - parser.add_argument("--auto-augment", default=None, type=str, help="auto augment policy (default: None)") - parser.add_argument("--ra-magnitude", default=9, type=int, help="magnitude of auto augment policy") - parser.add_argument("--augmix-severity", default=3, type=int, help="severity of augmix policy") - parser.add_argument("--random-erase", default=0.0, type=float, help="random erasing probability (default: 0.0)") - - # Mixed precision training parameters - parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training") - - # distributed training parameters - parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") - parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training") - parser.add_argument( - "--model-ema", action="store_true", help="enable tracking Exponential Moving Average of model parameters" - ) - parser.add_argument( - "--model-ema-steps", - type=int, - default=32, - help="the number of iterations that controls how often to update the EMA model (default: 32)", - ) - parser.add_argument( - "--model-ema-decay", - type=float, - default=0.99998, - help="decay factor for Exponential Moving Average of model parameters (default: 0.99998)", - ) - parser.add_argument( - "--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only." - ) - parser.add_argument( - "--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)" - ) - parser.add_argument( - "--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)" - ) - parser.add_argument( - "--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)" - ) - parser.add_argument( - "--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)" - ) - parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)") - parser.add_argument( - "--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)" - ) - parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") - parser.add_argument("--weights-path", type=str) - - # NOTE: sparsity args - parser.add_argument("--sparsity", choices=["bsr", "semi_structured"], default=None, help='weight sparsification to apply') - parser.add_argument("--sparsity-linear", type=float, default=0.0) - parser.add_argument("--sp-linear-tile-size", type=int, default=1) - parser.add_argument("--sparsity-conv1x1", type=float, default=0.0) - parser.add_argument("--sp-conv1x1-tile-size", type=int, default=1) - parser.add_argument("--sparsity-conv", type=float, default=0.0) - parser.add_argument("--sp-conv-tile-size", type=int, default=1) - parser.add_argument("--skip-last-layer-sparsity", action="store_true", help="Skip applying sparsity to the last linear layer (for vit only)") - parser.add_argument("--skip-first-transformer-sparsity", action="store_true", help="Skip applying sparsity to the first transformer layer (for vit only)") - parser.add_argument('--bsr', type=int, nargs='?', const=256, default=None, help='Convert sparsified weights to BSR format with optional block size (default: 256)') - parser.add_argument('--meta', action='store_true', help='Use Meta internal imagenet structure') - return parser - - if __name__ == "__main__": - args = get_args_parser().parse_args() + args = utils.get_args_parser(train=True).parse_args() main(args) diff --git a/torchao/sparsity/prototype/superblock/transforms.py b/torchao/sparsity/prototype/superblock/transforms.py deleted file mode 100644 index 2375e3fc41..0000000000 --- a/torchao/sparsity/prototype/superblock/transforms.py +++ /dev/null @@ -1,185 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. - -import math -from typing import Tuple - -import torch -from torch import Tensor -from torchvision.transforms import functional as F - - -class RandomMixup(torch.nn.Module): - """Randomly apply Mixup to the provided batch and targets. - The class implements the data augmentations as described in the paper - `"mixup: Beyond Empirical Risk Minimization" `_. - - Args: - num_classes (int): number of classes used for one-hot encoding. - p (float): probability of the batch being transformed. Default value is 0.5. - alpha (float): hyperparameter of the Beta distribution used for mixup. - Default value is 1.0. - inplace (bool): boolean to make this transform inplace. Default set to False. - """ - - def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None: - super().__init__() - - if num_classes < 1: - raise ValueError( - f"Please provide a valid positive value for the num_classes. Got num_classes={num_classes}" - ) - - if alpha <= 0: - raise ValueError("Alpha param can't be zero.") - - self.num_classes = num_classes - self.p = p - self.alpha = alpha - self.inplace = inplace - - def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: - """ - Args: - batch (Tensor): Float tensor of size (B, C, H, W) - target (Tensor): Integer tensor of size (B, ) - - Returns: - Tensor: Randomly transformed batch. - """ - if batch.ndim != 4: - raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}") - if target.ndim != 1: - raise ValueError(f"Target ndim should be 1. Got {target.ndim}") - if not batch.is_floating_point(): - raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.") - if target.dtype != torch.int64: - raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}") - - if not self.inplace: - batch = batch.clone() - target = target.clone() - - if target.ndim == 1: - target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype) - - if torch.rand(1).item() >= self.p: - return batch, target - - # It's faster to roll the batch by one instead of shuffling it to create image pairs - batch_rolled = batch.roll(1, 0) - target_rolled = target.roll(1, 0) - - # Implemented as on mixup paper, page 3. - lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0]) - batch_rolled.mul_(1.0 - lambda_param) - batch.mul_(lambda_param).add_(batch_rolled) - - target_rolled.mul_(1.0 - lambda_param) - target.mul_(lambda_param).add_(target_rolled) - - return batch, target - - def __repr__(self) -> str: - s = ( - f"{self.__class__.__name__}(" - f"num_classes={self.num_classes}" - f", p={self.p}" - f", alpha={self.alpha}" - f", inplace={self.inplace}" - f")" - ) - return s - - -class RandomCutmix(torch.nn.Module): - """Randomly apply Cutmix to the provided batch and targets. - The class implements the data augmentations as described in the paper - `"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features" - `_. - - Args: - num_classes (int): number of classes used for one-hot encoding. - p (float): probability of the batch being transformed. Default value is 0.5. - alpha (float): hyperparameter of the Beta distribution used for cutmix. - Default value is 1.0. - inplace (bool): boolean to make this transform inplace. Default set to False. - """ - - def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None: - super().__init__() - if num_classes < 1: - raise ValueError("Please provide a valid positive value for the num_classes.") - if alpha <= 0: - raise ValueError("Alpha param can't be zero.") - - self.num_classes = num_classes - self.p = p - self.alpha = alpha - self.inplace = inplace - - def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: - """ - Args: - batch (Tensor): Float tensor of size (B, C, H, W) - target (Tensor): Integer tensor of size (B, ) - - Returns: - Tensor: Randomly transformed batch. - """ - if batch.ndim != 4: - raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}") - if target.ndim != 1: - raise ValueError(f"Target ndim should be 1. Got {target.ndim}") - if not batch.is_floating_point(): - raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.") - if target.dtype != torch.int64: - raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}") - - if not self.inplace: - batch = batch.clone() - target = target.clone() - - if target.ndim == 1: - target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype) - - if torch.rand(1).item() >= self.p: - return batch, target - - # It's faster to roll the batch by one instead of shuffling it to create image pairs - batch_rolled = batch.roll(1, 0) - target_rolled = target.roll(1, 0) - - # Implemented as on cutmix paper, page 12 (with minor corrections on typos). - lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0]) - _, H, W = F.get_dimensions(batch) - - r_x = torch.randint(W, (1,)) - r_y = torch.randint(H, (1,)) - - r = 0.5 * math.sqrt(1.0 - lambda_param) - r_w_half = int(r * W) - r_h_half = int(r * H) - - x1 = int(torch.clamp(r_x - r_w_half, min=0)) - y1 = int(torch.clamp(r_y - r_h_half, min=0)) - x2 = int(torch.clamp(r_x + r_w_half, max=W)) - y2 = int(torch.clamp(r_y + r_h_half, max=H)) - - batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2] - lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H)) - - target_rolled.mul_(1.0 - lambda_param) - target.mul_(lambda_param).add_(target_rolled) - - return batch, target - - def __repr__(self) -> str: - s = ( - f"{self.__class__.__name__}(" - f"num_classes={self.num_classes}" - f", p={self.p}" - f", alpha={self.alpha}" - f", inplace={self.inplace}" - f")" - ) - return s diff --git a/torchao/sparsity/prototype/superblock/utils.py b/torchao/sparsity/prototype/superblock/utils.py index e779613f5c..cf865fd369 100644 --- a/torchao/sparsity/prototype/superblock/utils.py +++ b/torchao/sparsity/prototype/superblock/utils.py @@ -1,91 +1,190 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import argparse import copy import datetime import errno import hashlib +import math import os import time from collections import defaultdict, deque, OrderedDict from typing import List, Optional, Tuple import torch -import torch.distributed as dist -from torchao.quantization import quantize_, int8_dynamic_activation_int8_semi_sparse_weight -from torchao.sparsity import sparsify_, semi_sparse_weight -from torchao.sparsity.prototype.superblock.supermask import SupermaskLinear, apply_supermask +from torchao.quantization import int8_dynamic_activation_int8_weight, quantize_ +from torchao.sparsity import semi_sparse_weight, sparsify_ +from torchao.sparsity.prototype.sparsifier.weight_norm_sparsifier import ( + WeightNormSparsifier, +) from torchao.sparsity.prototype.superblock.blocksparse import block_sparse_weight -from torchao.sparsity.prototype.sparsifier.weight_norm_sparsifier import WeightNormSparsifier +from torchao.sparsity.prototype.superblock.supermask import ( + apply_supermask, + SupermaskLinear, +) +from torchvision.transforms import autoaugment, functional as F, transforms +from torchvision.transforms.functional import InterpolationMode + +def get_args_parser(train=False, evaluate=False, benchmark=False): + assert sum([train, evaluate, benchmark]) == 1, "One and only one of training, evaluation, or benchmark can be true" + + # Shared common args + parser = argparse.ArgumentParser(description="SuperBlock Imagenet Training/Evaluation/Benchmarking Script", add_help=True) + parser.add_argument("--data-path", type=str, help="IMAGENET dataset path") + parser.add_argument("--model", default="vit_b_16", choices=["vit_b_16", "vit_h_14"], type=str, help="ViT base model") + parser.add_argument("--device", default="cuda", type=str, help="device (Default: cuda)") + parser.add_argument("-b", "--batch-size", default=32, type=int, help="per device batch size") + parser.add_argument("--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)") + parser.add_argument("--sparsity", choices=["bsr", "semi_structured"], default=None, help='weight sparsification to apply') + parser.add_argument('--bsr', type=int, nargs='?', const=256, default=None, help='Convert sparsified weights to BSR format with optional block size (default: 256)') + parser.add_argument("--sparsity-linear", type=float, default=0.0) + parser.add_argument("--sparsity-conv1x1", type=float, default=0.0) + parser.add_argument("--sparsity-conv", type=float, default=0.0) + parser.add_argument("--skip-last-layer-sparsity", action="store_true", help="Skip applying sparsity to the last linear layer (for vit only)") + parser.add_argument("--skip-first-transformer-sparsity", action="store_true", help="Skip applying sparsity to the first transformer layer (for vit only)") + parser.add_argument("--quantization", action="store_true", help="Run with int8 dynamic quantization") + parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") + parser.add_argument("--weights-path", type=str, help="optional checkpoint to load weights after intialization") + parser.add_argument("--header", action="store_true", help="Print header for first run") + + # Eval a subset of training args + # lots of training args + if train or evaluate: + parser.add_argument("-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers") + parser.add_argument("--accumulation-steps", default=1, type=int, help="Number of steps to accumulate gradients over") + parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run") + parser.add_argument("--opt", default="sgd", type=str, help="optimizer") + parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate") + parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum") + parser.add_argument("--wd", "--weight-decay", default=1e-4, type=float, metavar="W", help="weight decay", dest="weight_decay") + parser.add_argument("--norm-weight-decay", default=None, type=float, help="weight decay for Normalization layers (default: None, same value as --wd)") + parser.add_argument("--bias-weight-decay", default=None, type=float, help="weight decay for bias parameters of all layers (default: None, same value as --wd)") + parser.add_argument("--transformer-embedding-decay", default=None, type=float, help="weight decay for embedding parameters for vision transformer models (default: None, same value as --wd)") + parser.add_argument("--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing") + parser.add_argument("--mixup-alpha", default=0.0, type=float, help="mixup alpha (default: 0.0)") + parser.add_argument("--cutmix-alpha", default=0.0, type=float, help="cutmix alpha (default: 0.0)") + parser.add_argument("--lr-scheduler", default="steplr", type=str, help="the lr scheduler (default: steplr)") + parser.add_argument("--lr-warmup-epochs", default=0, type=int, help="the number of epochs to warmup (default: 0)") + parser.add_argument("--lr-warmup-method", default="constant", type=str, help="the warmup method (default: constant)") + parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr") + parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs") + parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma") + parser.add_argument("--lr-min", default=0.0, type=float, help="minimum lr of lr schedule (default: 0.0)") + parser.add_argument("--print-freq", default=10, type=int, help="print frequency") + parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs") + parser.add_argument('--resume', action='store_true', help='Resumes training from latest available checkpoint ("model_.pth")') + parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch") + parser.add_argument("--cache-dataset", dest="cache_dataset", help="Cache the datasets for quicker initialization. It also serializes the transforms", action="store_true") + parser.add_argument("--sync-bn", dest="sync_bn", help="Use sync batch norm", action="store_true") + parser.add_argument("--auto-augment", default=None, type=str, help="auto augment policy (default: None)") + parser.add_argument("--ra-magnitude", default=9, type=int, help="magnitude of auto augment policy") + parser.add_argument("--augmix-severity", default=3, type=int, help="severity of augmix policy") + parser.add_argument("--random-erase", default=0.0, type=float, help="random erasing probability (default: 0.0)") + # Mixed precision training parameters + parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training") + # distributed training parameters + parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") + parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training") + parser.add_argument("--model-ema", action="store_true", help="enable tracking Exponential Moving Average of model parameters") + parser.add_argument("--model-ema-steps", type=int, default=32, help="the number of iterations that controls how often to update the EMA model (default: 32)") + parser.add_argument("--model-ema-decay", type=float, default=0.99998, help="decay factor for Exponential Moving Average of model parameters (default: 0.99998)") + parser.add_argument("--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only.") + parser.add_argument("--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)") + parser.add_argument("--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)") + parser.add_argument("--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)") + parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)") + parser.add_argument("--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)") + parser.add_argument('--meta', action='store_true', help='Use Meta internal imagenet structure') + + if benchmark: + parser.add_argument("--dtype", choices=["float32", "bfloat16", "float16"], help="Data type", default="bfloat16") + parser.add_argument("--tune-kernel-params", action="store_true", help="Tune kernel params for BSR") + parser.add_argument("--profile", action="store_true", help="Dump Prefetto trace") + + return parser + -### Custom sparsification utils -def apply_sparsity(model): - for name, module in model.named_modules(): - if isinstance(module, SupermaskLinear) and "mlp" in name: - module.sparsify_offline() - -def verify_sparsity(model): - for name, module in model.named_modules(): - if isinstance(module, torch.nn.Linear): - total_weights = module.weight.numel() - sparse_weights = (module.weight == 0).sum().item() - sparsity_percentage = (sparse_weights / total_weights) * 100 - print(f"Sparsity verified in layer {name}: {sparsity_percentage:.2f}%") # filter functions def mlp_0_only(mod, name): - return isinstance(mod, torch.nn.Linear) and 'mlp.0' in name + return isinstance(mod, torch.nn.Linear) and "mlp.0" in name + def mlp_3_only(mod, name): - return isinstance(mod, torch.nn.Linear) and 'mlp.3' in name + return isinstance(mod, torch.nn.Linear) and "mlp.3" in name + def mlp_only(mod, name): - return isinstance(mod, torch.nn.Linear) and 'mlp' in name - + return isinstance(mod, torch.nn.Linear) and "mlp" in name + + def superblock_only(mod, name): return isinstance(mod, SupermaskLinear) and "mlp" in name -def mlp_only_with_args(mod, name, skip_last_layer_sparsity=False, skip_first_transformer_sparsity=False): + +def mlp_only_with_args( + mod, name, skip_last_layer_sparsity=False, skip_first_transformer_sparsity=False +): if skip_last_layer_sparsity and "heads.head" in name: return False if skip_first_transformer_sparsity and "encoder.layers.encoder_layer_0" in name: return False - if isinstance(mod, torch.nn.Linear) and "mlp" in name: + if isinstance(mod, torch.nn.Linear) and "mlp" in name: return True return False -### other + +### Custom sparsification utils +def apply_sparsity(model): + for name, module in model.named_modules(): + if isinstance(module, SupermaskLinear) and "mlp" in name: + module.sparsify_offline() + def accelerate_with_sparsity(model, args): if args.sparsity == "bsr": apply_sparsity(model) - verify_sparsity(model) - assert args.bsr is not None, "BSR requires a block size" - sparsify_(model, block_sparse_weight(blocksize=args.bsr), superblock_only) - + if args.quantization: + from torchao.dtypes.affine_quantized_tensor import BlockSparseLayoutType + + quantize_( + model, + int8_dynamic_activation_int8_weight( + layout_type=BlockSparseLayoutType(blocksize=args.bsr) + ), + superblock_only, + ) + else: + assert args.bsr is not None, "BSR requires a block size" + sparsify_(model, block_sparse_weight(blocksize=args.bsr), superblock_only) elif args.sparsity == "semi_structured": if args.quantization: - quantize_(model, - int8_dynamic_activation_int8_semi_sparse_weight(), - mlp_0_only) - sparsify_(model, - semi_sparse_weight(), - mlp_3_only) + from torchao.dtypes.affine_quantized_tensor import SemiSparseLayoutType + + quantize_( + model, + int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType()), + mlp_0_only, + ) + sparsify_(model, semi_sparse_weight(), mlp_3_only) else: - sparsify_(model, - semi_sparse_weight(), - mlp_only) + sparsify_(model, semi_sparse_weight(), mlp_only) + else: + if args.quantization: + quantize_(model, int8_dynamic_activation_int8_weight(), mlp_only) + def simulate_sparsity(model, args): if args.sparsity == "bsr": apply_supermask( model, linear_sparsity=args.sparsity_linear, - linear_sp_tilesize=args.sp_linear_tile_size, + linear_sp_tilesize=args.bsr, conv1x1_sparsity=args.sparsity_conv1x1, - conv1x1_sp_tilesize=args.sp_conv1x1_tile_size, + conv1x1_sp_tilesize=args.bsr, conv_sparsity=args.sparsity_conv, - conv_sp_tilesize=args.sp_conv_tile_size, + conv_sp_tilesize=args.bsr, skip_last_layer_sparsity=args.skip_last_layer_sparsity, skip_first_transformer_sparsity=args.skip_first_transformer_sparsity, device=args.device, @@ -94,24 +193,27 @@ def simulate_sparsity(model, args): elif args.sparsity == "semi_structured": sparse_config = [] for name, mod in model.named_modules(): - if mlp_only_with_args(mod, name, - skip_first_transformer_sparsity=args.skip_first_transformer_sparsity, - skip_last_layer_sparsity=args.skip_last_layer_sparsity): + if mlp_only_with_args( + mod, + name, + skip_first_transformer_sparsity=args.skip_first_transformer_sparsity, + skip_last_layer_sparsity=args.skip_last_layer_sparsity, + ): sparse_config.append({"tensor_fqn": f"{name}.weight"}) sparsifier = WeightNormSparsifier( sparsity_level=1.0, sparse_block_shape=(1, 4), zeros_per_block=2 ) sparsifier.prepare(model, sparse_config) - for line in sparse_config: - print(line) sparsifier.step() return sparsifier - else: - print("No sparsity applied!") -### Existing torchvision utils +# ------------------------------------------------------------ +# The following code contains torchvision reference code, +# largely copied from: https://github.com/pytorch/vision/tree/main/references/classification +# Please open issues in the original repository if you have questions. + class SmoothedValue: """Track a series of values and provide access to smoothed values over a @@ -164,7 +266,11 @@ def value(self): def __str__(self): return self.fmt.format( - median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value, ) @@ -185,7 +291,9 @@ def __getattr__(self, attr): return self.meters[attr] if attr in self.__dict__: return self.__dict__[attr] - raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'") + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{attr}'" + ) def __str__(self): loss_str = [] @@ -223,7 +331,14 @@ def log_every(self, iterable, print_freq, header=None): ) else: log_msg = self.delimiter.join( - [header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"] + [ + header, + "[{0" + space_fmt + "}/{1}]", + "eta: {eta}", + "{meters}", + "time: {time}", + "data: {data}", + ] ) MB = 1024.0 * 1024.0 for obj in iterable: @@ -248,7 +363,12 @@ def log_every(self, iterable, print_freq, header=None): else: print( log_msg.format( - i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time) + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), ) ) i += 1 @@ -316,9 +436,9 @@ def print(*args, **kwargs): def is_dist_avail_and_initialized(): - if not dist.is_available(): + if not torch.distributed.is_available(): return False - if not dist.is_initialized(): + if not torch.distributed.is_initialized(): return False return True @@ -326,13 +446,13 @@ def is_dist_avail_and_initialized(): def get_world_size(): if not is_dist_avail_and_initialized(): return 1 - return dist.get_world_size() + return torch.distributed.get_world_size() def get_rank(): if not is_dist_avail_and_initialized(): return 0 - return dist.get_rank() + return torch.distributed.get_rank() def is_main_process(): @@ -363,9 +483,12 @@ def init_distributed_mode(args): torch.cuda.set_device(args.gpu) args.dist_backend = "nccl" - print(f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True) + print(f"| distributed init (rank {args.rank})", flush=True) torch.distributed.init_process_group( - backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank + backend=args.dist_backend, + init_method=args.dist_url, + world_size=args.world_size, + rank=args.rank, ) torch.distributed.barrier() setup_for_distributed(args.rank == 0) @@ -390,7 +513,9 @@ def average_checkpoints(inputs): with open(fpath, "rb") as f: state = torch.load( f, - map_location=(lambda s, _: torch.serialization.default_restore_location(s, "cpu")), + map_location=( + lambda s, _: torch.serialization.default_restore_location(s, "cpu") + ), ) # Copies over the settings from the first checkpoint if new_state is None: @@ -475,7 +600,9 @@ def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=T # and remove unnecessary weights (such as auxiliaries, etc) if checkpoint_key == "model_ema": del checkpoint[checkpoint_key]["n_averaged"] - torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(checkpoint[checkpoint_key], "module.") + torch.nn.modules.utils.consume_prefix_in_state_dict_if_present( + checkpoint[checkpoint_key], "module." + ) model.load_state_dict(checkpoint[checkpoint_key], strict=strict) tmp_path = os.path.join(output_dir, str(model.__hash__())) @@ -500,8 +627,8 @@ def reduce_across_processes(val): return torch.tensor(val) t = torch.tensor(val, device="cuda") - dist.barrier() - dist.all_reduce(t) + torch.distributed.barrier() + torch.distributed.all_reduce(t) return t @@ -543,7 +670,9 @@ def _add_params(module, prefix=""): continue is_custom_key = False for key in custom_keys: - target_name = f"{prefix}.{name}" if prefix != "" and "." in key else name + target_name = ( + f"{prefix}.{name}" if prefix != "" and "." in key else name + ) if key == target_name: params[key].append(p) is_custom_key = True @@ -563,5 +692,365 @@ def _add_params(module, prefix=""): param_groups = [] for key in params: if len(params[key]) > 0: - param_groups.append({"params": params[key], "weight_decay": params_weight_decay[key]}) + param_groups.append( + {"params": params[key], "weight_decay": params_weight_decay[key]} + ) return param_groups + + +# Presets for ImageNet training/eval taken from: https://github.com/pytorch/vision/blob/main/references/classification/presets.py + + +class ClassificationPresetTrain: + def __init__( + self, + *, + crop_size, + mean=(0.485, 0.456, 0.406), + std=(0.229, 0.224, 0.225), + interpolation=InterpolationMode.BILINEAR, + hflip_prob=0.5, + auto_augment_policy=None, + ra_magnitude=9, + augmix_severity=3, + random_erase_prob=0.0, + ): + trans = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)] + if hflip_prob > 0: + trans.append(transforms.RandomHorizontalFlip(hflip_prob)) + if auto_augment_policy is not None: + if auto_augment_policy == "ra": + trans.append( + autoaugment.RandAugment( + interpolation=interpolation, magnitude=ra_magnitude + ) + ) + elif auto_augment_policy == "ta_wide": + trans.append( + autoaugment.TrivialAugmentWide(interpolation=interpolation) + ) + elif auto_augment_policy == "augmix": + trans.append( + autoaugment.AugMix( + interpolation=interpolation, severity=augmix_severity + ) + ) + else: + aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy) + trans.append( + autoaugment.AutoAugment( + policy=aa_policy, interpolation=interpolation + ) + ) + trans.extend( + [ + transforms.PILToTensor(), + transforms.ConvertImageDtype(torch.float), + transforms.Normalize(mean=mean, std=std), + ] + ) + if random_erase_prob > 0: + trans.append(transforms.RandomErasing(p=random_erase_prob)) + + self.transforms = transforms.Compose(trans) + + def __call__(self, img): + return self.transforms(img) + + +class ClassificationPresetEval: + def __init__( + self, + *, + crop_size, + resize_size=256, + mean=(0.485, 0.456, 0.406), + std=(0.229, 0.224, 0.225), + interpolation=InterpolationMode.BILINEAR, + ): + + self.transforms = transforms.Compose( + [ + transforms.Resize(resize_size, interpolation=interpolation), + transforms.CenterCrop(crop_size), + transforms.PILToTensor(), + transforms.ConvertImageDtype(torch.float), + transforms.Normalize(mean=mean, std=std), + ] + ) + + def __call__(self, img): + return self.transforms(img) + + +# transforms taken from: https://github.com/pytorch/vision/blob/main/references/classification/transforms.py + + +class RandomMixup(torch.nn.Module): + """Randomly apply Mixup to the provided batch and targets. + The class implements the data augmentations as described in the paper + `"mixup: Beyond Empirical Risk Minimization" `_. + + Args: + num_classes (int): number of classes used for one-hot encoding. + p (float): probability of the batch being transformed. Default value is 0.5. + alpha (float): hyperparameter of the Beta distribution used for mixup. + Default value is 1.0. + inplace (bool): boolean to make this transform inplace. Default set to False. + """ + + def __init__( + self, + num_classes: int, + p: float = 0.5, + alpha: float = 1.0, + inplace: bool = False, + ) -> None: + super().__init__() + + if num_classes < 1: + raise ValueError( + f"Please provide a valid positive value for the num_classes. Got num_classes={num_classes}" + ) + + if alpha <= 0: + raise ValueError("Alpha param can't be zero.") + + self.num_classes = num_classes + self.p = p + self.alpha = alpha + self.inplace = inplace + + def forward( + self, batch: torch.Tensor, target: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + batch (Tensor): Float tensor of size (B, C, H, W) + target (Tensor): Integer tensor of size (B, ) + + Returns: + Tensor: Randomly transformed batch. + """ + if batch.ndim != 4: + raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}") + if target.ndim != 1: + raise ValueError(f"Target ndim should be 1. Got {target.ndim}") + if not batch.is_floating_point(): + raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.") + if target.dtype != torch.int64: + raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}") + + if not self.inplace: + batch = batch.clone() + target = target.clone() + + if target.ndim == 1: + target = torch.nn.functional.one_hot( + target, num_classes=self.num_classes + ).to(dtype=batch.dtype) + + if torch.rand(1).item() >= self.p: + return batch, target + + # It's faster to roll the batch by one instead of shuffling it to create image pairs + batch_rolled = batch.roll(1, 0) + target_rolled = target.roll(1, 0) + + # Implemented as on mixup paper, page 3. + lambda_param = float( + torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0] + ) + batch_rolled.mul_(1.0 - lambda_param) + batch.mul_(lambda_param).add_(batch_rolled) + + target_rolled.mul_(1.0 - lambda_param) + target.mul_(lambda_param).add_(target_rolled) + + return batch, target + + def __repr__(self) -> str: + s = ( + f"{self.__class__.__name__}(" + f"num_classes={self.num_classes}" + f", p={self.p}" + f", alpha={self.alpha}" + f", inplace={self.inplace}" + f")" + ) + return s + + +class RandomCutmix(torch.nn.Module): + """Randomly apply Cutmix to the provided batch and targets. + The class implements the data augmentations as described in the paper + `"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features" + `_. + + Args: + num_classes (int): number of classes used for one-hot encoding. + p (float): probability of the batch being transformed. Default value is 0.5. + alpha (float): hyperparameter of the Beta distribution used for cutmix. + Default value is 1.0. + inplace (bool): boolean to make this transform inplace. Default set to False. + """ + + def __init__( + self, + num_classes: int, + p: float = 0.5, + alpha: float = 1.0, + inplace: bool = False, + ) -> None: + super().__init__() + if num_classes < 1: + raise ValueError( + "Please provide a valid positive value for the num_classes." + ) + if alpha <= 0: + raise ValueError("Alpha param can't be zero.") + + self.num_classes = num_classes + self.p = p + self.alpha = alpha + self.inplace = inplace + + def forward( + self, batch: torch.Tensor, target: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + batch (Tensor): Float tensor of size (B, C, H, W) + target (Tensor): Integer tensor of size (B, ) + + Returns: + Tensor: Randomly transformed batch. + """ + if batch.ndim != 4: + raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}") + if target.ndim != 1: + raise ValueError(f"Target ndim should be 1. Got {target.ndim}") + if not batch.is_floating_point(): + raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.") + if target.dtype != torch.int64: + raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}") + + if not self.inplace: + batch = batch.clone() + target = target.clone() + + if target.ndim == 1: + target = torch.nn.functional.one_hot( + target, num_classes=self.num_classes + ).to(dtype=batch.dtype) + + if torch.rand(1).item() >= self.p: + return batch, target + + # It's faster to roll the batch by one instead of shuffling it to create image pairs + batch_rolled = batch.roll(1, 0) + target_rolled = target.roll(1, 0) + + # Implemented as on cutmix paper, page 12 (with minor corrections on typos). + lambda_param = float( + torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0] + ) + _, H, W = F.get_dimensions(batch) + + r_x = torch.randint(W, (1,)) + r_y = torch.randint(H, (1,)) + + r = 0.5 * math.sqrt(1.0 - lambda_param) + r_w_half = int(r * W) + r_h_half = int(r * H) + + x1 = int(torch.clamp(r_x - r_w_half, min=0)) + y1 = int(torch.clamp(r_y - r_h_half, min=0)) + x2 = int(torch.clamp(r_x + r_w_half, max=W)) + y2 = int(torch.clamp(r_y + r_h_half, max=H)) + + batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2] + lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H)) + + target_rolled.mul_(1.0 - lambda_param) + target.mul_(lambda_param).add_(target_rolled) + + return batch, target + + def __repr__(self) -> str: + s = ( + f"{self.__class__.__name__}(" + f"num_classes={self.num_classes}" + f", p={self.p}" + f", alpha={self.alpha}" + f", inplace={self.inplace}" + f")" + ) + return s + + +# RA Sampler implementaion taken from: https://github.com/pytorch/vision/blob/main/references/classification/sampler.py + + +class RASampler(torch.utils.data.Sampler): + """Sampler that restricts data loading to a subset of the dataset for distributed, + with repeated augmentation. + It ensures that different each augmented version of a sample will be visible to a + different process (GPU). + Heavily based on 'torch.utils.data.DistributedSampler'. + + This is borrowed from the DeiT Repo: + https://github.com/facebookresearch/deit/blob/main/samplers.py + """ + + def __init__( + self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0, repetitions=3 + ): + if num_replicas is None: + if not torch.distributed.is_available(): + raise RuntimeError("Requires distributed package to be available!") + num_replicas = torch.distributed.get_world_size() + if rank is None: + if not torch.distributed.is_available(): + raise RuntimeError("Requires distributed package to be available!") + rank = torch.distributed.get_rank() + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.num_samples = int( + math.ceil(len(self.dataset) * float(repetitions) / self.num_replicas) + ) + self.total_size = self.num_samples * self.num_replicas + self.num_selected_samples = int( + math.floor(len(self.dataset) // 256 * 256 / self.num_replicas) + ) + self.shuffle = shuffle + self.seed = seed + self.repetitions = repetitions + + def __iter__(self): + if self.shuffle: + # Deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + indices = torch.randperm(len(self.dataset), generator=g).tolist() + else: + indices = list(range(len(self.dataset))) + + # Add extra samples to make it evenly divisible + indices = [ele for ele in indices for i in range(self.repetitions)] + indices += indices[: (self.total_size - len(indices))] + assert len(indices) == self.total_size + + # Subsample + indices = indices[self.rank : self.total_size : self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices[: self.num_selected_samples]) + + def __len__(self): + return self.num_selected_samples + + def set_epoch(self, epoch): + self.epoch = epoch diff --git a/torchao/sparsity/sparse_api.py b/torchao/sparsity/sparse_api.py index e1d1a99627..ae343add9e 100644 --- a/torchao/sparsity/sparse_api.py +++ b/torchao/sparsity/sparse_api.py @@ -4,17 +4,17 @@ from torch.ao.pruning import WeightNormSparsifier from torch.sparse import to_sparse_semi_structured from torchao.quantization.quant_api import ( + _get_linear_subclass_inserter, _is_linear, _replace_with_custom_fn_if_matches_filter, - _get_linear_subclass_inserter, int8_dynamic_activation_int8_semi_sparse_weight, ) + # Sparsity helper functions def apply_fake_sparsity(model, **kwargs): """ This function simulates 2:4 sparsity on all linear layers in a model. - It uses the torch.ao.pruning flow. """ filter_fn = kwargs.pop("filter_fn", _is_linear) # torch.ao.pruning flow @@ -30,15 +30,19 @@ def apply_fake_sparsity(model, **kwargs): sparsifier.step() sparsifier.squash_mask() + def semi_sparse_weight(): """ Convert the weight of linear moduels to semi-structured (2:4) sparsity """ return _get_linear_subclass_inserter(to_sparse_semi_structured) -def sparsify_(model: torch.nn.Module, - apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor], - filter_fn: Optional[Callable[[torch.nn.Module, str], bool]]=None) -> torch.nn.Module: + +def sparsify_( + model: torch.nn.Module, + apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor], + filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None, +) -> torch.nn.Module: """Convert the weight of linear modules in the model with `apply_tensor_subclass` This function is essentially the same as quantize, put for sparsity subclasses. @@ -73,6 +77,6 @@ def filter_fn(module: nn.Module, fqn: str) -> bool: """ _replace_with_custom_fn_if_matches_filter( model, - apply_tensor_subclass, + apply_tensor_subclass, _is_linear if filter_fn is None else filter_fn, )