Skip to content

Commit 02446fa

Browse files
committed
added test
1 parent a4ac90b commit 02446fa

File tree

2 files changed

+46
-3
lines changed

2 files changed

+46
-3
lines changed

test/sparsity/test_fast_sparse_training.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,48 @@ def test_runtime_weight_sparsification(self):
6868
for name, mod in model_c.named_modules():
6969
assert not isinstance(mod, SemiSparseLinear)
7070

71+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "pytorch 2.4+ feature")
72+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
73+
def test_runtime_weight_sparsification_compile(self):
74+
# need this import inside to not break 2.2 tests
75+
from torch.sparse import SparseSemiStructuredTensorCUSPARSELT
76+
input = torch.rand((128, 128)).half().cuda()
77+
grad = torch.rand((128, 128)).half().cuda()
78+
model = TestModel().half().cuda()
79+
model_c = copy.deepcopy(model)
80+
81+
for name, mod in model.named_modules():
82+
if isinstance(mod, torch.nn.Linear):
83+
sparse = SparseSemiStructuredTensorCUSPARSELT.prune_dense_static_sort(mod.weight.detach()).to_dense()
84+
mod.weight = nn.Parameter(sparse)
85+
86+
model = torch.compile(model, fullgraph=True)
87+
dense_result = model(input)
88+
89+
# map from fqn to replacement linear module
90+
sparse_config = {
91+
"linear1": SemiSparseLinear,
92+
"linear2": SemiSparseLinear,
93+
}
94+
95+
swap_linear_with_semi_sparse_linear(model_c, sparse_config)
96+
model_c = torch.compile(model_c, fullgraph=True)
97+
sparse_result = model_c(input)
98+
99+
assert torch.allclose(dense_result, sparse_result, rtol=1e-1, atol=1e-1)
100+
101+
dense_result.backward(grad)
102+
sparse_result.backward(grad)
103+
104+
# check grad
105+
assert torch.allclose(model.linear1.weight.grad, model_c.linear1.weight.grad, rtol=1e-1, atol=1e-1)
106+
assert torch.allclose(model.linear2.weight.grad, model_c.linear2.weight.grad, rtol=1e-1, atol=1e-1)
107+
108+
# check that swap back works
109+
swap_semi_sparse_linear_with_linear(model_c)
110+
for name, mod in model_c.named_modules():
111+
assert not isinstance(mod, SemiSparseLinear)
112+
71113

72114
if __name__ == "__main__":
73115
unittest.main()

torchao/sparsity/training/README.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
# Accelerated Sparse Training
22

3-
This folder contains an implementation of accelerated sparse training, utilizing the runtime semi-structured (2:4) sparsification [kernels]() present in core.
3+
This folder contains an implementation of accelerated sparse training.
4+
<!--For more information about our API and how it works, please see our blog post. (Will add link when its public)-->
5+
6+
Special thanks to @danthe3rd for writing the runtime semi-structured (2:4) sparsification [kernels](https://github.com/pytorch/pytorch/pull/122350) in core.
47

58
### Quickstart
69

@@ -55,5 +58,3 @@ For VIT-L MLP shapes on a NVIDIA A100 we see the following results:
5558
5659
Times are in microseconds (us).
5760
```
58-
59-
For more information about our API and how it works, please see our blog post.

0 commit comments

Comments
 (0)