Skip to content

Commit dcb87d3

Browse files
aobo-yfacebook-github-bot
authored andcommitted
Support different reg_reduction in Captum STG (#1090)
Summary: Pull Request resolved: #1090 Add a new `str` argument `reg_reduction` in Captum STG classes, which specifies how the returned regularization should be reduced. Following Pytorch Loss's design, support 3 modes: `sum`, `mean`, and `none`. The default is `sum`. (There may be needs for other modes in future, like `weighted_sum`. With customized `mask`, each gate may handle different number of elements. The application may want to use as few elements as possible instead of as few gates. For now, such use cases can use `none` option and reduce themselves) Although we previously used `mean`, we decided to change to `sum` as default for 3 reasons: 1. The original paper "LEARNING SPARSE NEURAL NETWORKS THROUGH L0 REGULARIZATION" used `sum` both in its writing and its [implementation](https://github.com/AMLab-Amsterdam/L0_regularization/blob/master/l0_layers.py#L70) {F822978249} 2. L^1 and L^2 regularization also `sum` over each parameter without averaging over total number of parameters within a model. See [Pytorch's implementation](https://github.com/pytorch/pytorch/blob/df569367ef444dc9831ef0dde3bc611bcabcfbf9/torch/optim/adagrad.py#L268) 3. When there are multiple STG of imbalanced lengths, the results are comparable in `sum` but not `mean`. If the model has 2 STG, where one has 100 gates and the other has one single gate, the regularization of each gate in the 1st STG will be divided by 100 in `mean`, which makes the 1st STG 100 times weaker than the 2nd STG. This is usually unexpected for users. Using `mean` or `sum` will not impact the performance when there is only one BSN layer, coz people can tune `reg_weight` to counter the difference. The authors of "Feature selection using Stochastic Gates" mixed using `sum` and `mean` in [their implementation](https://github.com/runopti/stg/blob/master/python/stg/models.py#L164-L195) For backward compatibility, explicitly specified `reg_reduction = "mean"` for all existing usages in Pyper and MVAI. Reviewed By: cyrjano, edqwerty10 Differential Revision: D41991741 fbshipit-source-id: 698db938fc373747db0df1b1145c6e9943476142
1 parent e7b58af commit dcb87d3

File tree

5 files changed

+107
-12
lines changed

5 files changed

+107
-12
lines changed

captum/module/binary_concrete_stochastic_gates.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def __init__(
6060
lower_bound: float = -0.1,
6161
upper_bound: float = 1.1,
6262
eps: float = 1e-8,
63+
reg_reduction: str = "sum",
6364
):
6465
"""
6566
Args:
@@ -93,8 +94,18 @@ def __init__(
9394
eps (float): term to improve numerical stability in binary concerete
9495
sampling
9596
Default: 1e-8
97+
98+
reg_reduction (str, optional): the reduction to apply to
99+
the regularization: 'none'|'mean'|'sum'. 'none': no reduction will be
100+
applied and it will be the same as the return of get_active_probs,
101+
'mean': the sum of the gates non-zero probabilities will be divided by
102+
the number of gates, 'sum': the gates non-zero probabilities will
103+
be summed.
104+
Default: 'sum'
96105
"""
97-
super().__init__(n_gates, mask=mask, reg_weight=reg_weight)
106+
super().__init__(
107+
n_gates, mask=mask, reg_weight=reg_weight, reg_reduction=reg_reduction
108+
)
98109

99110
# avoid changing the tensor's variable name
100111
# when the module is used after compilation,

captum/module/gaussian_stochastic_gates.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def __init__(
3838
mask: Optional[Tensor] = None,
3939
reg_weight: Optional[float] = 1.0,
4040
std: Optional[float] = 0.5,
41+
reg_reduction: str = "sum",
4142
):
4243
"""
4344
Args:
@@ -58,8 +59,17 @@ def __init__(
5859
std (Optional[float]): standard deviation that will be fixed throughout.
5960
Default: 0.5 (by paper reference)
6061
62+
reg_reduction (str, optional): the reduction to apply to
63+
the regularization: 'none'|'mean'|'sum'. 'none': no reduction will be
64+
applied and it will be the same as the return of get_active_probs,
65+
'mean': the sum of the gates non-zero probabilities will be divided by
66+
the number of gates, 'sum': the gates non-zero probabilities will
67+
be summed.
68+
Default: 'sum'
6169
"""
62-
super().__init__(n_gates, mask=mask, reg_weight=reg_weight)
70+
super().__init__(
71+
n_gates, mask=mask, reg_weight=reg_weight, reg_reduction=reg_reduction
72+
)
6373

6474
mu = torch.empty(n_gates)
6575
nn.init.normal_(mu, mean=0.5, std=0.01)

captum/module/stochastic_gates_base.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,11 @@ class StochasticGatesBase(Module, ABC):
2929
"""
3030

3131
def __init__(
32-
self, n_gates: int, mask: Optional[Tensor] = None, reg_weight: float = 1.0
32+
self,
33+
n_gates: int,
34+
mask: Optional[Tensor] = None,
35+
reg_weight: float = 1.0,
36+
reg_reduction: str = "sum",
3337
):
3438
"""
3539
Args:
@@ -46,6 +50,14 @@ def __init__(
4650
4751
reg_weight (Optional[float]): rescaling weight for L0 regularization term.
4852
Default: 1.0
53+
54+
reg_reduction (str, optional): the reduction to apply to
55+
the regularization: 'none'|'mean'|'sum'. 'none': no reduction will be
56+
applied and it will be the same as the return of get_active_probs,
57+
'mean': the sum of the gates non-zero probabilities will be divided by
58+
the number of gates, 'sum': the gates non-zero probabilities will
59+
be summed.
60+
Default: 'sum'
4961
"""
5062
super().__init__()
5163

@@ -57,6 +69,12 @@ def __init__(
5769
" should correspond to a gate"
5870
)
5971

72+
valid_reg_reduction = ["none", "mean", "sum"]
73+
assert (
74+
reg_reduction in valid_reg_reduction
75+
), f"reg_reduction must be one of [none, mean, sum], received: {reg_reduction}"
76+
self.reg_reduction = reg_reduction
77+
6078
self.n_gates = n_gates
6179
self.register_buffer(
6280
"mask", mask.detach().clone() if mask is not None else None
@@ -106,7 +124,14 @@ def forward(self, input_tensor: Tensor) -> Tuple[Tensor, Tensor]:
106124
gated_input = input_tensor * gate_values
107125

108126
prob_density = self._get_gate_active_probs()
109-
l0_reg = self.reg_weight * prob_density.mean()
127+
if self.reg_reduction == "sum":
128+
l0_reg = prob_density.sum()
129+
elif self.reg_reduction == "mean":
130+
l0_reg = prob_density.mean()
131+
else:
132+
l0_reg = prob_density
133+
134+
l0_reg *= self.reg_weight
110135

111136
return gated_input, l0_reg
112137

tests/module/test_binary_concrete_stochastic_gates.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def test_bcstg_1d_input(self) -> None:
3232
).to(self.testing_device)
3333

3434
gated_input, reg = bcstg(input_tensor)
35-
expected_reg = 0.8316
35+
expected_reg = 2.4947
3636

3737
if self.testing_device == "cpu":
3838
expected_gated_input = [[0.0000, 0.0212, 0.1892], [0.1839, 0.3753, 0.4937]]
@@ -42,6 +42,30 @@ def test_bcstg_1d_input(self) -> None:
4242
assertTensorAlmostEqual(self, gated_input, expected_gated_input, mode="max")
4343
assertTensorAlmostEqual(self, reg, expected_reg)
4444

45+
def test_bcstg_1d_input_with_reg_reduction(self) -> None:
46+
47+
dim = 3
48+
mean_bcstg = BinaryConcreteStochasticGates(dim, reg_reduction="mean").to(
49+
self.testing_device
50+
)
51+
none_bcstg = BinaryConcreteStochasticGates(dim, reg_reduction="none").to(
52+
self.testing_device
53+
)
54+
input_tensor = torch.tensor(
55+
[
56+
[0.0, 0.1, 0.2],
57+
[0.3, 0.4, 0.5],
58+
]
59+
).to(self.testing_device)
60+
61+
mean_gated_input, mean_reg = mean_bcstg(input_tensor)
62+
none_gated_input, none_reg = none_bcstg(input_tensor)
63+
expected_mean_reg = 0.8316
64+
expected_none_reg = torch.tensor([0.8321, 0.8310, 0.8325])
65+
66+
assertTensorAlmostEqual(self, mean_reg, expected_mean_reg)
67+
assertTensorAlmostEqual(self, none_reg, expected_none_reg)
68+
4569
def test_bcstg_1d_input_with_n_gates_error(self) -> None:
4670

4771
dim = 3
@@ -85,7 +109,7 @@ def test_bcstg_1d_input_with_mask(self) -> None:
85109
).to(self.testing_device)
86110

87111
gated_input, reg = bcstg(input_tensor)
88-
expected_reg = 0.8321
112+
expected_reg = 1.6643
89113

90114
if self.testing_device == "cpu":
91115
expected_gated_input = [[0.0000, 0.0000, 0.1679], [0.0000, 0.0000, 0.2223]]
@@ -118,7 +142,7 @@ def test_bcstg_2d_input(self) -> None:
118142

119143
gated_input, reg = bcstg(input_tensor)
120144

121-
expected_reg = 0.8317
145+
expected_reg = 4.9903
122146
if self.testing_device == "cpu":
123147
expected_gated_input = [
124148
[[0.0000, 0.0990], [0.0261, 0.2431], [0.0551, 0.3863]],
@@ -179,7 +203,7 @@ def test_bcstg_2d_input_with_mask(self) -> None:
179203
).to(self.testing_device)
180204

181205
gated_input, reg = bcstg(input_tensor)
182-
expected_reg = 0.8316
206+
expected_reg = 2.4947
183207

184208
if self.testing_device == "cpu":
185209
expected_gated_input = [

tests/module/test_gaussian_stochastic_gates.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def test_gstg_1d_input(self) -> None:
2525

2626
dim = 3
2727
gstg = GaussianStochasticGates(dim).to(self.testing_device)
28+
2829
input_tensor = torch.tensor(
2930
[
3031
[0.0, 0.1, 0.2],
@@ -33,7 +34,7 @@ def test_gstg_1d_input(self) -> None:
3334
).to(self.testing_device)
3435

3536
gated_input, reg = gstg(input_tensor)
36-
expected_reg = 0.8404
37+
expected_reg = 2.5213
3738

3839
if self.testing_device == "cpu":
3940
expected_gated_input = [[0.0000, 0.0198, 0.1483], [0.1848, 0.3402, 0.1782]]
@@ -43,6 +44,30 @@ def test_gstg_1d_input(self) -> None:
4344
assertTensorAlmostEqual(self, gated_input, expected_gated_input, mode="max")
4445
assertTensorAlmostEqual(self, reg, expected_reg)
4546

47+
def test_gstg_1d_input_with_reg_reduction(self) -> None:
48+
dim = 3
49+
mean_gstg = GaussianStochasticGates(dim, reg_reduction="mean").to(
50+
self.testing_device
51+
)
52+
none_gstg = GaussianStochasticGates(dim, reg_reduction="none").to(
53+
self.testing_device
54+
)
55+
56+
input_tensor = torch.tensor(
57+
[
58+
[0.0, 0.1, 0.2],
59+
[0.3, 0.4, 0.5],
60+
]
61+
).to(self.testing_device)
62+
63+
_, mean_reg = mean_gstg(input_tensor)
64+
_, none_reg = none_gstg(input_tensor)
65+
expected_mean_reg = 0.8404
66+
expected_none_reg = torch.tensor([0.8424, 0.8384, 0.8438])
67+
68+
assertTensorAlmostEqual(self, mean_reg, expected_mean_reg)
69+
assertTensorAlmostEqual(self, none_reg, expected_none_reg)
70+
4671
def test_gstg_1d_input_with_n_gates_error(self) -> None:
4772

4873
dim = 3
@@ -65,7 +90,7 @@ def test_gstg_1d_input_with_mask(self) -> None:
6590
).to(self.testing_device)
6691

6792
gated_input, reg = gstg(input_tensor)
68-
expected_reg = 0.8424
93+
expected_reg = 1.6849
6994

7095
if self.testing_device == "cpu":
7196
expected_gated_input = [[0.0000, 0.0000, 0.1225], [0.0583, 0.0777, 0.3779]]
@@ -111,7 +136,7 @@ def test_gstg_2d_input(self) -> None:
111136
).to(self.testing_device)
112137

113138
gated_input, reg = gstg(input_tensor)
114-
expected_reg = 0.8410
139+
expected_reg = 5.0458
115140

116141
if self.testing_device == "cpu":
117142
expected_gated_input = [
@@ -173,7 +198,7 @@ def test_gstg_2d_input_with_mask(self) -> None:
173198
).to(self.testing_device)
174199

175200
gated_input, reg = gstg(input_tensor)
176-
expected_reg = 0.8404
201+
expected_reg = 2.5213
177202

178203
if self.testing_device == "cpu":
179204
expected_gated_input = [

0 commit comments

Comments
 (0)