Skip to content

Fixes observer attachment to model based on config for wanda sparsifier #1265

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions test/sparsity/test_wanda.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,39 @@ def test_two_layer_mlp_unstructured(self):

sparsifier.squash_mask()

def test_two_layer_mlp_unstructured_custom_config(self):
model = nn.Sequential(
nn.Linear(128, 200), nn.ReLU(), nn.Linear(200, 10)
) # C_in by C_out
X1 = torch.randn(100, 128) # B1 by C_in
X2 = torch.randn(50, 128) # B2 by C_in

# Define custom config to sparsify only the first Linear layer for testing
config = [{"tensor_fqn": "0.weight"}]

sparsifier = WandaSparsifier(sparsity_level=0.5)
sparsifier.prepare(model, config=config)

model(X1)
model(X2)
sparsifier.step()

cnt = 0
for m in model.modules():
if isinstance(m, nn.Linear):
cnt += 1
sparsity_level = (m.weight == 0).float().mean()
if cnt == 1: # First Linear layer should have 50% sparsity
assert (
sparsity_level == 0.5
), f"sparsity for linear layer {cnt} should be 0.5"
else: # Other layers should not be sparsified
assert (
sparsity_level != 0.5
), f"sparsity for linear layer {cnt} should not be 0.5"

sparsifier.squash_mask()


if __name__ == "__main__":
unittest.main()
26 changes: 22 additions & 4 deletions torchao/sparsity/wanda.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch
from torch import nn
from torch.ao.pruning import BaseSparsifier
from torch.ao.pruning import BaseSparsifier, get_arg_info_from_tensor_fqn
from torch.ao.quantization import QConfig, default_placeholder_observer
from torch.ao.quantization.quantize import _remove_qconfig

Expand Down Expand Up @@ -47,9 +47,27 @@ def __init__(
def prepare(self, model: nn.Module, config: List[Dict]) -> None:
# activation: use PerChannelNormObserver
# use no-op placeholder weight observer
model.qconfig = QConfig(
activation=PerChannelNormObserver, weight=default_placeholder_observer
) # type: ignore[assignment]
if config is None:
# If no config is provided, apply the qconfig to the entire model
model.qconfig = QConfig(
activation=PerChannelNormObserver, weight=default_placeholder_observer
) # type: ignore[assignment]
else:
for module_config in config:
tensor_fqn = module_config.get("tensor_fqn", None)
if tensor_fqn is None:
raise ValueError("Each config must contain a 'tensor_fqn'.")

# Extract module information from tensor_fqn
info_from_tensor_fqn = get_arg_info_from_tensor_fqn(model, tensor_fqn)
module = info_from_tensor_fqn["module"]

# Apply the qconfig directly to the module if it exists
if module is not None:
module.qconfig = QConfig(
activation=PerChannelNormObserver,
weight=default_placeholder_observer,
) # type: ignore[assignment]
torch.ao.quantization.prepare(model, inplace=True)

# call superclass prepare
Expand Down
Loading