From c2f80d09082a29a0d3ec9cdf25c8755300a0ab67 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Tue, 18 Feb 2025 13:13:44 -0800 Subject: [PATCH 01/11] moved supermask out of prototype --- test/sparsity/test_supermask.py | 66 ++++++++++++++++ torchao/sparsity/supermask.py | 132 ++++++++++++++++++++++++++++++++ 2 files changed, 198 insertions(+) create mode 100644 test/sparsity/test_supermask.py create mode 100644 torchao/sparsity/supermask.py diff --git a/test/sparsity/test_supermask.py b/test/sparsity/test_supermask.py new file mode 100644 index 0000000000..e246c2e169 --- /dev/null +++ b/test/sparsity/test_supermask.py @@ -0,0 +1,66 @@ +import copy +import logging +import unittest +import math + +import torch +from torch import nn +from torch.testing._internal import common_utils + +from torchao.dtypes import MarlinSparseLayout, SemiSparseLayout +from torchao.quantization.quant_api import ( + int4_weight_only, + int8_dynamic_activation_int8_weight, + quantize_, +) +from torchao.sparsity import apply_fake_sparsity, semi_sparse_weight, sparsify_ +from torchao.sparsity.utils import create_block_sparse_tensor +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_3, + TORCH_VERSION_AT_LEAST_2_4, + TORCH_VERSION_AT_LEAST_2_5, + TORCH_VERSION_AT_LEAST_2_6, +) + +logging.basicConfig( + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO +) + +class TestSupermask(common_utils.TestCase): + + @common_utils.parametrize("sparsity_level", [0.25, 0.5]) + @common_utils.parametrize("blocksize", [2, 4, 8]) + def test_supermask(self, sparsity_level, blocksize): + input = torch.randn((1, 16)).half().cuda() + model = ( + nn.Sequential( + nn.Linear(16, 16, bias=False), + ) + .half() + .cuda() + .eval() + ) + + from torchao.sparsity import SupermaskLinear + + M, N = model[0].weight.shape + sparsify_(model, lambda x: SupermaskLinear.from_linear(x, sparsity_level=sparsity_level, blocksize=blocksize)) + sparsify_(model, SupermaskLinear.to_linear) + weight_bsr = model[0].weight.to_sparse_bsr(blocksize=blocksize) + + # Test correct sparsity level + nnz = weight_bsr._nnz() + expected = round((M // blocksize) * (N // blocksize) * (1 - sparsity_level)) + assert nnz == expected, f"Expected {expected} nonzeros, got {nnz}" + + def test_from_linear(self): + from torchao.sparsity import SupermaskLinear + linear = nn.Linear(128, 128) + supermask_linear = SupermaskLinear.from_linear(linear, sparsity_level=0.5, blocksize=4) + assert supermask_linear.weight.shape == linear.weight.shape + + +common_utils.instantiate_parametrized_tests(TestSupermask) + +if __name__ == "__main__": + unittest.main() diff --git a/torchao/sparsity/supermask.py b/torchao/sparsity/supermask.py new file mode 100644 index 0000000000..0f2fec55f3 --- /dev/null +++ b/torchao/sparsity/supermask.py @@ -0,0 +1,132 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +import torch.nn as nn +import math +import torch +from torch.autograd import Variable +import torch.nn.functional as F +import numpy as np + +from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter + +# original supermask +scores_min=None +scores_max=9e9 + +def percentile(t, q): + """Return the value that is larger than q% of t""" + k = 1 + round(.01 * float(q) * (t.numel() - 1)) + return t.view(-1).kthvalue(k).values + + +class GetSubnet(torch.autograd.Function): + """Supermask STE function""" + @staticmethod + def forward(ctx, scores, zeros, ones, sparsity): + clamped_scores = scores.clamp(min=scores_min,max=scores_max) + k_val = percentile(clamped_scores, sparsity*100) + return torch.where(clamped_scores < k_val, zeros.to(scores.device), ones.to(scores.device)) + + @staticmethod + def backward(ctx, g): + return g, None, None, None + + +class ApplyMask(torch.autograd.Function): + """Supermask STE function""" + @staticmethod + def forward(ctx, weight, scores): + return weight * scores + @staticmethod + def backward(ctx, grad_output): + grad_weight = grad_scores = None + if ctx.needs_input_grad[0]: + grad_weight = grad_output + if ctx.needs_input_grad[1]: + grad_scores = grad_output + return grad_weight, grad_scores + + +class SupermaskLinear(nn.Linear): + """Supermask class for Linear layer""" + def __init__(self, sparsity_level, blocksize, fixed_mask, fixed_weight, *args, **kwargs): + super(SupermaskLinear, self).__init__(*args, **kwargs) + # calculate the maximum sparsity given blocksize for the layer + max_sparsity_level = 1 - (1 / math.prod([math.ceil(k / blocksize) for k in self.weight.size()])) + self.sparsity_level = sparsity_level + if self.sparsity_level > max_sparsity_level: + print( + f"reducing sparsity from {self.sparsity} to {max_sparsity}", + f"(maximum sparsity for layer with shape {self.weight.size()} and tile size {blocksize})" + ) + self.sparsity_level = max_sparsity_level + self.blocksize = blocksize + self.sparsify_weights = False + self.scores = nn.Parameter( + torch.empty( + [max(1, int(math.ceil(wn / blocksize))) for wn in self.weight.size()] + ), + requires_grad=not fixed_mask, + ) + nn.init.kaiming_uniform_(self.scores, a=math.sqrt(5)) + + # NOTE: the previous implementation of Supermask supported quantizing the weights, this has been removed. + + self.weight.requires_grad = not fixed_weight + + def get_mask(self): + subnet = GetSubnet.apply(self.scores, + torch.zeros_like(self.scores), + torch.ones_like(self.scores), + self.sparsity_level) + + if self.blocksize != 1: + for i, k in enumerate(self.weight.shape): + subnet = subnet.repeat_interleave(self.blocksize, dim=i) + subnet = torch.narrow(subnet, i, 0, k) + + return subnet + + + def forward(self, x): + subnet = self.get_mask() + w = ApplyMask.apply(self.weight, subnet) + return F.linear(x, w, self.bias) + + @classmethod + def from_linear(cls, linear, sparsity_level=0.0, blocksize=1, ): + """ + Main entrypoint for creating a SupermaskLinear from a Linear layer. + """ + assert isinstance(linear, torch.nn.Linear) + + supermask_linear = SupermaskLinear( + sparsity_level, blocksize, False, False, + linear.in_features, + linear.out_features, + bias=linear.bias is not None, + ).to(device=linear.weight.device, dtype=linear.weight.dtype) + supermask_linear.weight.data.copy_(linear.weight.data) + if linear.bias is not None: + supermask_linear.bias.data.copy_(linear.bias.data) + return supermask_linear + + @classmethod + def to_linear(cls, supermask_linear): + """ + Convert a SupermaskLinear to a Linear layer. + Replaces the old sparsify_offline() function. + """ + self = supermask_linear + + linear = torch.nn.Linear( + self.in_features, + self.out_features, + bias=self.bias is not None, + ).to(device=self.weight.device, dtype=self.weight.dtype) + + mask = self.get_mask() + linear.weight.data.copy_(self.weight * mask) + if self.bias is not None: + linear.bias.data.copy_(self.bias.data) + return linear From 4e1c9f3090705388a1f4f1f3e08ab43b71c5dc03 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Tue, 18 Feb 2025 13:14:30 -0800 Subject: [PATCH 02/11] updated api --- .../sparsity/superblock/supermask.py | 365 ------------------ torchao/sparsity/__init__.py | 2 + 2 files changed, 2 insertions(+), 365 deletions(-) delete mode 100644 torchao/prototype/sparsity/superblock/supermask.py diff --git a/torchao/prototype/sparsity/superblock/supermask.py b/torchao/prototype/sparsity/superblock/supermask.py deleted file mode 100644 index abd23c566e..0000000000 --- a/torchao/prototype/sparsity/superblock/supermask.py +++ /dev/null @@ -1,365 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. - -import math - -import torch -import torch.nn as nn -import torch.nn.functional as F - -# original supermask -scores_min = None -scores_max = 9e9 -uniform_init_01 = False - -# adjusted supermask, initialize scores with uniform distribution in [0,1], clamp scores in each step in [0,1] -# scores_min=0. -# scores_max=1. -# uniform_init_01 = True - - -def percentile(t, q): - """Return the value that is larger than q% of t""" - k = 1 + round(0.01 * float(q) * (t.numel() - 1)) - return t.view(-1).kthvalue(k).values - - -class GetSubnet(torch.autograd.Function): - """Supermask STE function""" - - @staticmethod - def forward(ctx, scores, zeros, ones, sparsity): - clamped_scores = scores.clamp(min=scores_min, max=scores_max) - k_val = percentile(clamped_scores, sparsity * 100) - return torch.where( - clamped_scores < k_val, zeros.to(scores.device), ones.to(scores.device) - ) - - @staticmethod - def backward(ctx, g): - return g, None, None, None - - -class SupermaskLinear(nn.Linear): - """Supermask class for Linear layer""" - - def __init__( - self, - sparsity, - fixed_mask, - fixed_weight, - bitwidth, - transform, - fixed_transform, - *args, - **kwargs, - ): - tile_size = kwargs.pop("tile_size", 1) - super(SupermaskLinear, self).__init__(*args, **kwargs) - # initialize the scores - max_sparsity = 1 - ( - 1 / math.prod([math.ceil(k / tile_size) for k in self.weight.size()]) - ) - self.sparsity = sparsity - if self.sparsity > max_sparsity: - print( - f"reducing sparsity from {self.sparsity} to {max_sparsity}", - f"(maximum sparsity for layer with shape {self.weight.size()} and tile size {tile_size})", - ) - self.sparsity = max_sparsity - self.tile_size = tile_size - self.sparsify_weights = False - self.scores = nn.Parameter( - torch.empty( - [max(1, int(math.ceil(wn / tile_size))) for wn in self.weight.size()] - ), - requires_grad=not fixed_mask, - ) - nn.init.uniform_(self.scores) if uniform_init_01 else nn.init.kaiming_uniform_( - self.scores, a=math.sqrt(5) - ) - - # the shift and the scale are transformation parameters - # the actually used weights = self.weight*self.scale+self.shift - # the transformation is activated only for quantized weights - self.shift = nn.Parameter(torch.Tensor(1).fill_(0.0), requires_grad=False) - self.scale = nn.Parameter(torch.Tensor(1).fill_(1.0), requires_grad=False) - - with torch.no_grad(): - # if bitwidth is None, then use floating point values in self.weight - # if bitwidth is not None, then quantize self.weight into k-bit (k=bitwidth) - # quantized values are -2^(k-1), -2^(k-1)+1, ..., 0, 1, ..., 2^(k-1)-1 - # these quantized values are uniformly distributed - if bitwidth is not None: - weights_max = torch.max(self.weight).item() - weights_min = torch.min(self.weight).item() - least_step = (weights_max - weights_min) / pow(2, bitwidth) - left_bound = weights_min - 1e-6 - right_bound = weights_min + least_step + 1e-6 - # self.shift=nn.Parameter(torch.Tensor(1).fill_( (weights_min+(pow(2,bitwidth-1)+0.5)*least_step) if transform[0] is None else transform[0] ), requires_grad=not fixed_transform[0]) - # self.scale=nn.Parameter(torch.Tensor(1).fill_( least_step if transform[1] is None else transform[1] ), requires_grad=not fixed_transform[1]) - # for example, if using binary weights (k=1) with -a, +a, set transform = [a,2a]; if using binary weights (k=1) with a, 0, set transform = [0,-a]; - self.shift = nn.Parameter( - torch.Tensor(1).fill_( - 0.0 if transform[0] is None else transform[0] - ), - requires_grad=not fixed_transform[0], - ) - self.scale = nn.Parameter( - torch.Tensor(1).fill_( - 1.0 if transform[1] is None else transform[1] - ), - requires_grad=not fixed_transform[1], - ) - for i in range(-int(pow(2, bitwidth - 1)), int(pow(2, bitwidth - 1))): - self.weight[ - torch.logical_and( - self.weight > left_bound, self.weight <= right_bound - ) - ] = i - left_bound = right_bound - right_bound += least_step - - self.weight.requires_grad = not fixed_weight - - def get_mask(self): - subnet = GetSubnet.apply( - self.scores, - torch.zeros_like(self.scores), - torch.ones_like(self.scores), - self.sparsity, - ) - - if self.tile_size != 1: - for i, k in enumerate(self.weight.shape): - subnet = subnet.repeat_interleave(self.tile_size, dim=i) - subnet = torch.narrow(subnet, i, 0, k) - - return subnet - - def sparsify_offline(self): - subnet = self.get_mask() - self.weight.data = (self.weight * self.scale + self.shift) * subnet - self.sparsify_weights = True - - def forward(self, x): - if not self.sparsify_weights: - subnet = self.get_mask() - w = (self.weight * self.scale + self.shift) * subnet - else: - w = self.weight - return F.linear(x, w, self.bias) - - -class SupermaskConv2d(nn.Conv2d): - """Supermask class for Conv2d layer""" - - def __init__( - self, - sparsity, - fixed_mask, - fixed_weight, - bitwidth, - transform, - fixed_transform, - *args, - **kwargs, - ): - tile_size = kwargs.pop("tile_size", 1) - super(SupermaskConv2d, self).__init__(*args, **kwargs) - # initialize the scores - max_sparsity = 1 - ( - 1 / math.prod([math.ceil(k / tile_size) for k in self.weight.size()]) - ) - self.sparsity = sparsity - if self.sparsity > max_sparsity: - print( - f"reducing sparsity from {self.sparsity} to {max_sparsity}", - f"(maximum sparsity for layer with shape {self.weight.size()} and tile size {tile_size})", - ) - self.sparsity = max_sparsity - self.tile_size = tile_size - self.scores = nn.Parameter( - torch.empty( - [max(1, int(math.ceil(wn / tile_size))) for wn in self.weight.size()] - ), - requires_grad=not fixed_mask, - ) - nn.init.uniform_(self.scores) if uniform_init_01 else nn.init.kaiming_uniform_( - self.scores, a=math.sqrt(5) - ) - - # the shift and the scale are transformation parameters - # the actually used weights = self.weight*self.scale+self.shift - # the transformation is activated only for quantized weights - self.shift = nn.Parameter(torch.Tensor(1).fill_(0.0), requires_grad=False) - self.scale = nn.Parameter(torch.Tensor(1).fill_(1.0), requires_grad=False) - - with torch.no_grad(): - # if bitwidth is None, then use floating point values in self.weight - # if bitwidth is not None, then quantize self.weight into k-bit (k=bitwidth) - # quantized values are -2^(k-1), -2^(k-1)+1, ..., 0, 1, ..., 2^(k-1)-1 - # these quantized values are uniformly distributed - if bitwidth is not None: - weights_max = torch.max(self.weight).item() - weights_min = torch.min(self.weight).item() - least_step = (weights_max - weights_min) / pow(2, bitwidth) - left_bound = weights_min - 1e-6 - right_bound = weights_min + least_step + 1e-6 - # self.shift=nn.Parameter(torch.Tensor(1).fill_( (weights_min+(pow(2,bitwidth-1)+0.5)*least_step) if transform[0] is None else transform[0] ), requires_grad=not fixed_transform[0]) - # self.scale=nn.Parameter(torch.Tensor(1).fill_( least_step if transform[1] is None else transform[1]), requires_grad=not fixed_transform[1]) - # for example, if using binary weights (k=1) with -a, +a, set transform = [a,2a]; if using binary weights (k=1) with a, 0, set transform = [0,-a]; - self.shift = nn.Parameter( - torch.Tensor(1).fill_( - 0.0 if transform[0] is None else transform[0] - ), - requires_grad=not fixed_transform[0], - ) - self.scale = nn.Parameter( - torch.Tensor(1).fill_( - 1.0 if transform[1] is None else transform[1] - ), - requires_grad=not fixed_transform[1], - ) - for i in range(-int(pow(2, bitwidth - 1)), int(pow(2, bitwidth - 1))): - self.weight[ - torch.logical_and( - self.weight > left_bound, self.weight <= right_bound - ) - ] = i - left_bound = right_bound - right_bound += least_step - - self.weight.requires_grad = not fixed_weight - - def forward(self, x): - subnet = GetSubnet.apply( - self.scores, - torch.zeros_like(self.scores), - torch.ones_like(self.scores), - self.sparsity, - ) - - if self.tile_size != 1: - for i, k in enumerate(self.weight.shape): - # if k == 1: continue - subnet = subnet.repeat_interleave(self.tile_size, dim=i) - subnet = torch.narrow(subnet, i, 0, k) - - w = (self.weight * self.scale + self.shift) * subnet - return F.conv2d( - x, w, self.bias, self.stride, self.padding, self.dilation, self.groups - ) - - -def apply_supermask( - model, - linear_sparsity=0.0, - linear_sp_tilesize=1, - conv1x1_sparsity=0.0, - conv1x1_sp_tilesize=1, - conv_sparsity=0.0, - conv_sp_tilesize=1, - skip_last_layer_sparsity=False, - skip_first_transformer_sparsity=False, - device="cuda", - verbose=False, -): - sparsified_modules = {} - - for n, m in model.named_modules(): - # check conditions for skipping sparsity - if skip_last_layer_sparsity and n == "heads.head": - continue - if skip_first_transformer_sparsity and "encoder.layers.encoder_layer_0" in n: - continue - - # convert 1x1 convolutions - if ( - conv1x1_sparsity != 0.0 - and isinstance(m, torch.nn.Conv2d) - and m.kernel_size == (1, 1) - ): - new_m = SupermaskConv2d( - conv1x1_sparsity, - False, - False, - None, - None, - None, - m.in_channels, - m.out_channels, - m.kernel_size, - stride=m.stride, - padding=m.padding, - dilation=m.dilation, - groups=m.groups, - bias=m.bias is not None, - padding_mode=m.padding_mode, - device=device, - tile_size=conv1x1_sp_tilesize, - ) - new_m.weight.data.copy_(m.weight.data) - if m.bias is not None: - new_m.bias.data.copy_(m.bias.data) - sparsified_modules[n] = new_m - continue - - # convert all other convolutions (not tested!) - if conv_sparsity != 0.0 and isinstance(m, torch.nn.Conv2d): - new_m = SupermaskConv2d( - conv_sparsity, - False, - False, - None, - None, - None, - m.in_channels, - m.out_channels, - m.kernel_size, - stride=m.stride, - padding=m.padding, - dilation=m.dilation, - groups=m.groups, - bias=m.bias is not None, - padding_mode=m.padding_mode, - device=device, - tile_size=conv_sp_tilesize, - ) - new_m.weight.data.copy_(m.weight.data) - if m.bias is not None: - new_m.bias.data.copy_(m.bias.data) - sparsified_modules[n] = new_m - continue - - if linear_sparsity != 0.0 and isinstance(m, torch.nn.Linear): - new_m = SupermaskLinear( - linear_sparsity, - False, - False, - None, - None, - None, - m.in_features, - m.out_features, - bias=m.bias is not None, - device=device, - tile_size=linear_sp_tilesize, - ) - new_m.weight.data.copy_(m.weight.data) - if m.bias is not None: - new_m.bias.data.copy_(m.bias.data) - sparsified_modules[n] = new_m - continue - - # add modules to model - for k, v in sparsified_modules.items(): - sm_name, ch_name = k.rsplit(".", 1) - sm = model.get_submodule(sm_name) - sm.add_module(ch_name, v) - - if verbose: - print( - f'sparsified module "{k}" with sparsity={v.sparsity}, tile size={v.tile_size}' - ) - - return model diff --git a/torchao/sparsity/__init__.py b/torchao/sparsity/__init__.py index 77ccd2c00b..3c36bc42ca 100644 --- a/torchao/sparsity/__init__.py +++ b/torchao/sparsity/__init__.py @@ -15,9 +15,11 @@ ) from .utils import PerChannelNormObserver # noqa: F403 from .wanda import WandaSparsifier # noqa: F403 +from .supermask import SupermaskLinear __all__ = [ "WandaSparsifier", + "SupermaskLinear", "PerChannelNormObserver", "apply_fake_sparsity", "sparsify_", From e5175afb0856581ab6a23815297b0b025330022b Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Tue, 18 Feb 2025 13:18:17 -0800 Subject: [PATCH 03/11] updated prototype folder inside sparsity --- torchao/sparsity/prototype/superblock/supermask.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/torchao/sparsity/prototype/superblock/supermask.py b/torchao/sparsity/prototype/superblock/supermask.py index f502d1f2ad..97d0b36c79 100644 --- a/torchao/sparsity/prototype/superblock/supermask.py +++ b/torchao/sparsity/prototype/superblock/supermask.py @@ -1,11 +1,7 @@ -from torchao.prototype.sparsity.superblock.supermask import ( - GetSubnet, - SupermaskConv2d, +from torchao.sparsity.supermask import ( SupermaskLinear, ) __all__ = [ - "GetSubnet", - "SupermaskConv2d", "SupermaskLinear", ] From d9803e6efadf957623aa5501e943518c227cdbd9 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Tue, 18 Feb 2025 13:24:30 -0800 Subject: [PATCH 04/11] ruff formatting --- torchao/sparsity/supermask.py | 56 ++++++++++++++++++++++++----------- 1 file changed, 38 insertions(+), 18 deletions(-) diff --git a/torchao/sparsity/supermask.py b/torchao/sparsity/supermask.py index 0f2fec55f3..e2204b7cae 100644 --- a/torchao/sparsity/supermask.py +++ b/torchao/sparsity/supermask.py @@ -10,22 +10,26 @@ from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter # original supermask -scores_min=None -scores_max=9e9 +scores_min = None +scores_max = 9e9 + def percentile(t, q): """Return the value that is larger than q% of t""" - k = 1 + round(.01 * float(q) * (t.numel() - 1)) + k = 1 + round(0.01 * float(q) * (t.numel() - 1)) return t.view(-1).kthvalue(k).values class GetSubnet(torch.autograd.Function): """Supermask STE function""" + @staticmethod def forward(ctx, scores, zeros, ones, sparsity): - clamped_scores = scores.clamp(min=scores_min,max=scores_max) - k_val = percentile(clamped_scores, sparsity*100) - return torch.where(clamped_scores < k_val, zeros.to(scores.device), ones.to(scores.device)) + clamped_scores = scores.clamp(min=scores_min, max=scores_max) + k_val = percentile(clamped_scores, sparsity * 100) + return torch.where( + clamped_scores < k_val, zeros.to(scores.device), ones.to(scores.device) + ) @staticmethod def backward(ctx, g): @@ -34,9 +38,11 @@ def backward(ctx, g): class ApplyMask(torch.autograd.Function): """Supermask STE function""" + @staticmethod def forward(ctx, weight, scores): return weight * scores + @staticmethod def backward(ctx, grad_output): grad_weight = grad_scores = None @@ -49,15 +55,20 @@ def backward(ctx, grad_output): class SupermaskLinear(nn.Linear): """Supermask class for Linear layer""" - def __init__(self, sparsity_level, blocksize, fixed_mask, fixed_weight, *args, **kwargs): + + def __init__( + self, sparsity_level, blocksize, fixed_mask, fixed_weight, *args, **kwargs + ): super(SupermaskLinear, self).__init__(*args, **kwargs) # calculate the maximum sparsity given blocksize for the layer - max_sparsity_level = 1 - (1 / math.prod([math.ceil(k / blocksize) for k in self.weight.size()])) + max_sparsity_level = 1 - ( + 1 / math.prod([math.ceil(k / blocksize) for k in self.weight.size()]) + ) self.sparsity_level = sparsity_level if self.sparsity_level > max_sparsity_level: print( f"reducing sparsity from {self.sparsity} to {max_sparsity}", - f"(maximum sparsity for layer with shape {self.weight.size()} and tile size {blocksize})" + f"(maximum sparsity for layer with shape {self.weight.size()} and tile size {blocksize})", ) self.sparsity_level = max_sparsity_level self.blocksize = blocksize @@ -70,15 +81,17 @@ def __init__(self, sparsity_level, blocksize, fixed_mask, fixed_weight, *args, * ) nn.init.kaiming_uniform_(self.scores, a=math.sqrt(5)) - # NOTE: the previous implementation of Supermask supported quantizing the weights, this has been removed. + # NOTE: the previous implementation of Supermask supported quantizing the weights, this has been removed. self.weight.requires_grad = not fixed_weight def get_mask(self): - subnet = GetSubnet.apply(self.scores, - torch.zeros_like(self.scores), - torch.ones_like(self.scores), - self.sparsity_level) + subnet = GetSubnet.apply( + self.scores, + torch.zeros_like(self.scores), + torch.ones_like(self.scores), + self.sparsity_level, + ) if self.blocksize != 1: for i, k in enumerate(self.weight.shape): @@ -86,7 +99,6 @@ def get_mask(self): subnet = torch.narrow(subnet, i, 0, k) return subnet - def forward(self, x): subnet = self.get_mask() @@ -94,14 +106,22 @@ def forward(self, x): return F.linear(x, w, self.bias) @classmethod - def from_linear(cls, linear, sparsity_level=0.0, blocksize=1, ): + def from_linear( + cls, + linear, + sparsity_level=0.0, + blocksize=1, + ): """ Main entrypoint for creating a SupermaskLinear from a Linear layer. """ assert isinstance(linear, torch.nn.Linear) supermask_linear = SupermaskLinear( - sparsity_level, blocksize, False, False, + sparsity_level, + blocksize, + False, + False, linear.in_features, linear.out_features, bias=linear.bias is not None, @@ -109,7 +129,7 @@ def from_linear(cls, linear, sparsity_level=0.0, blocksize=1, ): supermask_linear.weight.data.copy_(linear.weight.data) if linear.bias is not None: supermask_linear.bias.data.copy_(linear.bias.data) - return supermask_linear + return supermask_linear @classmethod def to_linear(cls, supermask_linear): From 16411c6d7c9191b2a39c0d59342c0f51709aca7b Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Tue, 18 Feb 2025 13:28:31 -0800 Subject: [PATCH 05/11] more ruff formatting --- torchao/sparsity/supermask.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/torchao/sparsity/supermask.py b/torchao/sparsity/supermask.py index e2204b7cae..a04b824428 100644 --- a/torchao/sparsity/supermask.py +++ b/torchao/sparsity/supermask.py @@ -1,17 +1,13 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. -import torch.nn as nn import math + import torch -from torch.autograd import Variable +import torch.nn as nn import torch.nn.functional as F -import numpy as np - -from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter -# original supermask -scores_min = None -scores_max = 9e9 +SCORES_MIN = None +SCORES_MAX = 9e9 def percentile(t, q): @@ -25,7 +21,7 @@ class GetSubnet(torch.autograd.Function): @staticmethod def forward(ctx, scores, zeros, ones, sparsity): - clamped_scores = scores.clamp(min=scores_min, max=scores_max) + clamped_scores = scores.clamp(min=SCORES_MIN, max=SCORES_MAX) k_val = percentile(clamped_scores, sparsity * 100) return torch.where( clamped_scores < k_val, zeros.to(scores.device), ones.to(scores.device) @@ -67,7 +63,7 @@ def __init__( self.sparsity_level = sparsity_level if self.sparsity_level > max_sparsity_level: print( - f"reducing sparsity from {self.sparsity} to {max_sparsity}", + f"reducing sparsity from {self.sparsity} to {max_sparsity_level}", f"(maximum sparsity for layer with shape {self.weight.size()} and tile size {blocksize})", ) self.sparsity_level = max_sparsity_level From 0b9e0be7aae59530f74556d5e18dcd6988915f07 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Tue, 18 Feb 2025 13:30:26 -0800 Subject: [PATCH 06/11] ruff test --- test/sparsity/test_supermask.py | 19 +------------------ 1 file changed, 1 insertion(+), 18 deletions(-) diff --git a/test/sparsity/test_supermask.py b/test/sparsity/test_supermask.py index e246c2e169..001356af19 100644 --- a/test/sparsity/test_supermask.py +++ b/test/sparsity/test_supermask.py @@ -1,26 +1,10 @@ -import copy import logging import unittest -import math -import torch from torch import nn from torch.testing._internal import common_utils -from torchao.dtypes import MarlinSparseLayout, SemiSparseLayout -from torchao.quantization.quant_api import ( - int4_weight_only, - int8_dynamic_activation_int8_weight, - quantize_, -) -from torchao.sparsity import apply_fake_sparsity, semi_sparse_weight, sparsify_ -from torchao.sparsity.utils import create_block_sparse_tensor -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_3, - TORCH_VERSION_AT_LEAST_2_4, - TORCH_VERSION_AT_LEAST_2_5, - TORCH_VERSION_AT_LEAST_2_6, -) +from torchao.sparsity import sparsify_ logging.basicConfig( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO @@ -31,7 +15,6 @@ class TestSupermask(common_utils.TestCase): @common_utils.parametrize("sparsity_level", [0.25, 0.5]) @common_utils.parametrize("blocksize", [2, 4, 8]) def test_supermask(self, sparsity_level, blocksize): - input = torch.randn((1, 16)).half().cuda() model = ( nn.Sequential( nn.Linear(16, 16, bias=False), From 7a1b9eaa7aadde09154fa12325c3a5fcdb138097 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Tue, 18 Feb 2025 13:31:12 -0800 Subject: [PATCH 07/11] fix __init__ formatting --- torchao/sparsity/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/sparsity/__init__.py b/torchao/sparsity/__init__.py index 3c36bc42ca..c13bb4209c 100644 --- a/torchao/sparsity/__init__.py +++ b/torchao/sparsity/__init__.py @@ -13,9 +13,9 @@ semi_sparse_weight, sparsify_, ) +from .supermask import SupermaskLinear from .utils import PerChannelNormObserver # noqa: F403 from .wanda import WandaSparsifier # noqa: F403 -from .supermask import SupermaskLinear __all__ = [ "WandaSparsifier", From ae007d379d9ea13380b46219e8b535df857319c4 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Tue, 18 Feb 2025 13:35:37 -0800 Subject: [PATCH 08/11] ruff reformat --- test/sparsity/test_supermask.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/test/sparsity/test_supermask.py b/test/sparsity/test_supermask.py index 001356af19..21326e1d12 100644 --- a/test/sparsity/test_supermask.py +++ b/test/sparsity/test_supermask.py @@ -10,8 +10,8 @@ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO ) -class TestSupermask(common_utils.TestCase): +class TestSupermask(common_utils.TestCase): @common_utils.parametrize("sparsity_level", [0.25, 0.5]) @common_utils.parametrize("blocksize", [2, 4, 8]) def test_supermask(self, sparsity_level, blocksize): @@ -27,19 +27,27 @@ def test_supermask(self, sparsity_level, blocksize): from torchao.sparsity import SupermaskLinear M, N = model[0].weight.shape - sparsify_(model, lambda x: SupermaskLinear.from_linear(x, sparsity_level=sparsity_level, blocksize=blocksize)) + sparsify_( + model, + lambda x: SupermaskLinear.from_linear( + x, sparsity_level=sparsity_level, blocksize=blocksize + ), + ) sparsify_(model, SupermaskLinear.to_linear) weight_bsr = model[0].weight.to_sparse_bsr(blocksize=blocksize) # Test correct sparsity level - nnz = weight_bsr._nnz() + nnz = weight_bsr._nnz() expected = round((M // blocksize) * (N // blocksize) * (1 - sparsity_level)) assert nnz == expected, f"Expected {expected} nonzeros, got {nnz}" def test_from_linear(self): from torchao.sparsity import SupermaskLinear + linear = nn.Linear(128, 128) - supermask_linear = SupermaskLinear.from_linear(linear, sparsity_level=0.5, blocksize=4) + supermask_linear = SupermaskLinear.from_linear( + linear, sparsity_level=0.5, blocksize=4 + ) assert supermask_linear.weight.shape == linear.weight.shape From 190a0f0c8efb9ed5fcfe1efce9f29c0d354bfb54 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Tue, 18 Feb 2025 13:51:32 -0800 Subject: [PATCH 09/11] skip CPU tests --- test/sparsity/test_supermask.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/sparsity/test_supermask.py b/test/sparsity/test_supermask.py index 21326e1d12..bbfcbbd747 100644 --- a/test/sparsity/test_supermask.py +++ b/test/sparsity/test_supermask.py @@ -12,6 +12,8 @@ class TestSupermask(common_utils.TestCase): + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") @common_utils.parametrize("sparsity_level", [0.25, 0.5]) @common_utils.parametrize("blocksize", [2, 4, 8]) def test_supermask(self, sparsity_level, blocksize): @@ -41,6 +43,7 @@ def test_supermask(self, sparsity_level, blocksize): expected = round((M // blocksize) * (N // blocksize) * (1 - sparsity_level)) assert nnz == expected, f"Expected {expected} nonzeros, got {nnz}" + @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") def test_from_linear(self): from torchao.sparsity import SupermaskLinear From 86cbb5de331d60012d3c3eb1da3c0b4fde0b6051 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Tue, 18 Feb 2025 13:57:34 -0800 Subject: [PATCH 10/11] ruff formatting --- test/sparsity/test_supermask.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/sparsity/test_supermask.py b/test/sparsity/test_supermask.py index bbfcbbd747..92058a1d29 100644 --- a/test/sparsity/test_supermask.py +++ b/test/sparsity/test_supermask.py @@ -1,6 +1,8 @@ import logging import unittest +import pytest +import torch from torch import nn from torch.testing._internal import common_utils From 343f3fe3ce33499cfb059a4994f2042bab33916e Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Tue, 18 Feb 2025 13:59:31 -0800 Subject: [PATCH 11/11] update --- test/sparsity/test_supermask.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/sparsity/test_supermask.py b/test/sparsity/test_supermask.py index 92058a1d29..fa86850a07 100644 --- a/test/sparsity/test_supermask.py +++ b/test/sparsity/test_supermask.py @@ -14,7 +14,6 @@ class TestSupermask(common_utils.TestCase): - @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") @common_utils.parametrize("sparsity_level", [0.25, 0.5]) @common_utils.parametrize("blocksize", [2, 4, 8])