Skip to content

Commit ada8c0d

Browse files
aobo-yfacebook-github-bot
authored andcommitted
move STG to captum module (#1064)
Summary: Pull Request resolved: #1064 move STGBase, BinaryConcreteSTG and GaussianSTG to path `captum/module` Reviewed By: vivekmig Differential Revision: D41000738 fbshipit-source-id: 0892004ba7c9f6f817c80a497fc27b30a0c59e3a
1 parent a7610be commit ada8c0d

File tree

6 files changed

+1292
-0
lines changed

6 files changed

+1292
-0
lines changed

captum/module/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from captum.module.binary_concrete_stochastic_gates import ( # noqa
2+
BinaryConcreteStochasticGates,
3+
)
4+
from captum.module.gaussian_stochastic_gates import GaussianStochasticGates # noqa
5+
from captum.module.stochastic_gates_base import StochasticGatesBase # noqa
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
#!/usr/bin/env python3
2+
import math
3+
from typing import Optional
4+
5+
import torch
6+
from captum.module.stochastic_gates_base import StochasticGatesBase
7+
from torch import nn, Tensor
8+
9+
10+
def _torch_empty(batch_size: int, n_gates: int, device: torch.device) -> Tensor:
11+
return torch.empty(batch_size, n_gates, device=device)
12+
13+
14+
# torch.fx is introduced in 1.8.0
15+
if hasattr(torch, "fx"):
16+
torch.fx.wrap(_torch_empty)
17+
18+
19+
def _logit(inp):
20+
# torch.logit is introduced in 1.7.0
21+
if hasattr(torch, "logit"):
22+
return torch.logit(inp)
23+
else:
24+
return torch.log(inp) - torch.log(1 - inp)
25+
26+
27+
class BinaryConcreteStochasticGates(StochasticGatesBase):
28+
"""
29+
Stochastic Gates with binary concrete distribution.
30+
31+
Stochastic Gates is a practical solution to add L0 norm regularization for neural
32+
networks. L0 regularization, which explicitly penalizes any present (non-zero)
33+
parameters, can help network pruning and feature selection, but directly optimizing
34+
L0 is a non-differentiable combinatorial problem. To surrogate L0, Stochastic Gate
35+
uses certain continuous probability distributions (e.g., Concrete, Gaussian) with
36+
hard-sigmoid rectification as a continuous smoothed Bernoulli distribution
37+
determining the weight of a parameter, i.e., gate. Then L0 is equal to the gates's
38+
non-zero probability represented by the parameters of the continuous probability
39+
distribution. The gate value can also be reparameterized to the distribution
40+
parameters with a noise. So the expected L0 can be optimized through learning
41+
the distribution parameters via stochastic gradients.
42+
43+
BinaryConcreteStochasticGates adopts a "stretched" binary concrete distribution as
44+
the smoothed Bernoulli distribution of gate. The binary concrete distribution does
45+
not include its lower and upper boundaries, 0 and 1, which are required by a
46+
Bernoulli distribution, so it needs to be linearly stretched beyond both boundaries.
47+
Then use hard-sigmoid rectification to "fold" the parts smaller than 0 or larger
48+
than 1 back to 0 and 1.
49+
50+
More details can be found in the
51+
`original paper <https://arxiv.org/abs/1712.01312>`.
52+
"""
53+
54+
def __init__(
55+
self,
56+
n_gates: int,
57+
mask: Optional[Tensor] = None,
58+
reg_weight: float = 1.0,
59+
temperature: float = 2.0 / 3,
60+
lower_bound: float = -0.1,
61+
upper_bound: float = 1.1,
62+
eps: float = 1e-8,
63+
):
64+
"""
65+
Args:
66+
n_gates (int): number of gates.
67+
68+
mask (Optional[Tensor]): If provided, this allows grouping multiple
69+
input tensor elements to share the same stochastic gate.
70+
This tensor should be broadcastable to match the input shape
71+
and contain integers in the range 0 to n_gates - 1.
72+
Indices grouped to the same stochastic gate should have the same value.
73+
If not provided, each element in the input tensor
74+
(on dimensions other than dim 0 - batch dim) is gated separately.
75+
Default: None
76+
77+
reg_weight (Optional[float]): rescaling weight for L0 regularization term.
78+
Default: 1.0
79+
80+
temperature (float): temperature of the concrete distribution, controls
81+
the degree of approximation, as 0 means the original Bernoulli
82+
without relaxation. The value should be between 0 and 1.
83+
Default: 2/3
84+
85+
lower_bound (float): the lower bound to "stretch" the binary concrete
86+
distribution
87+
Default: -0.1
88+
89+
upper_bound (float): the upper bound to "stretch" the binary concrete
90+
distribution
91+
Default: 1.1
92+
93+
eps (float): term to improve numerical stability in binary concerete
94+
sampling
95+
Default: 1e-8
96+
"""
97+
super().__init__(n_gates, mask=mask, reg_weight=reg_weight)
98+
99+
# avoid changing the tensor's variable name
100+
# when the module is used after compilation,
101+
# users may directly access this tensor by name
102+
log_alpha_param = torch.empty(n_gates)
103+
nn.init.normal_(log_alpha_param, mean=0.0, std=0.01)
104+
self.log_alpha_param = nn.Parameter(log_alpha_param)
105+
106+
assert (
107+
0 < temperature < 1
108+
), f"the temperature should be bwteen 0 and 1, received {temperature}"
109+
self.temperature = temperature
110+
111+
assert (
112+
lower_bound < 0
113+
), f"the stretch lower bound should smaller than 0, received {lower_bound}"
114+
self.lower_bound = lower_bound
115+
assert (
116+
upper_bound > 1
117+
), f"the stretch upper bound should larger than 1, received {upper_bound}"
118+
self.upper_bound = upper_bound
119+
120+
self.eps = eps
121+
122+
# pre-calculate the fixed term used in active prob
123+
self.active_prob_offset = temperature * math.log(-lower_bound / upper_bound)
124+
125+
def forward(self, *args, **kwargs):
126+
"""
127+
Args:
128+
input_tensor (Tensor): Tensor to be gated with stochastic gates
129+
130+
131+
Outputs:
132+
gated_input (Tensor): Tensor of the same shape weighted by the sampled
133+
gate values
134+
135+
l0_reg (Tensor): L0 regularization term to be optimized together with
136+
model loss,
137+
e.g. loss(model_out, target) + l0_reg
138+
"""
139+
return super().forward(*args, **kwargs)
140+
141+
def _sample_gate_values(self, batch_size: int) -> Tensor:
142+
"""
143+
Sample gate values for each example in the batch from the binary concrete
144+
distributions
145+
146+
Args:
147+
batch_size (int): input batch size
148+
149+
Returns:
150+
gate_values (Tensor): gate value tensor of shape(batch_size, n_gates)
151+
"""
152+
if self.training:
153+
u = _torch_empty(
154+
batch_size, self.n_gates, device=self.log_alpha_param.device
155+
)
156+
u.uniform_(self.eps, 1 - self.eps)
157+
s = torch.sigmoid((_logit(u) + self.log_alpha_param) / self.temperature)
158+
159+
else:
160+
s = torch.sigmoid(self.log_alpha_param)
161+
s = s.expand(batch_size, self.n_gates)
162+
163+
s_bar = s * (self.upper_bound - self.lower_bound) + self.lower_bound
164+
165+
return s_bar
166+
167+
def _get_gate_values(self) -> Tensor:
168+
"""
169+
Get the gate values derived from learned log_alpha_param after model is trained
170+
171+
Returns:
172+
gate_values (Tensor): value of each gate after model is trained
173+
"""
174+
gate_values = (
175+
torch.sigmoid(self.log_alpha_param) * (self.upper_bound - self.lower_bound)
176+
+ self.lower_bound
177+
)
178+
return torch.clamp(gate_values, min=0, max=1)
179+
180+
def _get_gate_active_probs(self) -> Tensor:
181+
"""
182+
Get the active probability of each gate, i.e, gate value > 0, in the binary
183+
concrete distributions
184+
185+
Returns:
186+
probs (Tensor): probabilities tensor of the gates are active
187+
in shape(n_gates)
188+
"""
189+
return torch.sigmoid(self.log_alpha_param - self.active_prob_offset)
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
#!/usr/bin/env python3
2+
import math
3+
from typing import Optional
4+
5+
import torch
6+
from captum.module.stochastic_gates_base import StochasticGatesBase
7+
from torch import nn, Tensor
8+
9+
10+
class GaussianStochasticGates(StochasticGatesBase):
11+
"""
12+
Stochastic Gates with Gaussian distribution.
13+
14+
Stochastic Gates is a practical solution to add L0 norm regularization for neural
15+
networks. L0 regularization, which explicitly penalizes any present (non-zero)
16+
parameters, can help network pruning and feature selection, but directly optimizing
17+
L0 is a non-differentiable combinatorial problem. To surrogate L0, Stochastic Gate
18+
uses certain continuous probability distributions (e.g., Concrete, Gaussian) with
19+
hard-sigmoid rectification as a continuous smoothed Bernoulli distribution
20+
determining the weight of a parameter, i.e., gate. Then L0 is equal to the gates's
21+
non-zero probability represented by the parameters of the continuous probability
22+
distribution. The gate value can also be reparameterized to the distribution
23+
parameters with a noise. So the expected L0 can be optimized through learning
24+
the distribution parameters via stochastic gradients.
25+
26+
GaussianStochasticGates adopts a gaussian distribution as the smoothed Bernoulli
27+
distribution of gate. While the smoothed Bernoulli distribution should be
28+
within 0 and 1, gaussian does not have boundaries. So hard-sigmoid rectification
29+
is used to "fold" the parts smaller than 0 or larger than 1 back to 0 and 1.
30+
31+
More details can be found in the
32+
`original paper <https://arxiv.org/abs/1810.04247>`.
33+
"""
34+
35+
def __init__(
36+
self,
37+
n_gates: int,
38+
mask: Optional[Tensor] = None,
39+
reg_weight: Optional[float] = 1.0,
40+
std: Optional[float] = 0.5,
41+
):
42+
"""
43+
Args:
44+
n_gates (int): number of gates.
45+
46+
mask (Optional[Tensor]): If provided, this allows grouping multiple
47+
input tensor elements to share the same stochastic gate.
48+
This tensor should be broadcastable to match the input shape
49+
and contain integers in the range 0 to n_gates - 1.
50+
Indices grouped to the same stochastic gate should have the same value.
51+
If not provided, each element in the input tensor
52+
(on dimensions other than dim 0 - batch dim) is gated separately.
53+
Default: None
54+
55+
reg_weight (Optional[float]): rescaling weight for L0 regularization term.
56+
Default: 1.0
57+
58+
std (Optional[float]): standard deviation that will be fixed throughout.
59+
Default: 0.5 (by paper reference)
60+
61+
"""
62+
super().__init__(n_gates, mask=mask, reg_weight=reg_weight)
63+
64+
mu = torch.empty(n_gates)
65+
nn.init.normal_(mu, mean=0.5, std=0.01)
66+
self.mu = nn.Parameter(mu)
67+
68+
assert 0 < std, f"the standard deviation should be positive, received {std}"
69+
self.std = std
70+
71+
def forward(self, *args, **kwargs):
72+
"""
73+
Args:
74+
input_tensor (Tensor): Tensor to be gated with stochastic gates
75+
76+
Outputs:
77+
gated_input (Tensor): Tensor of the same shape weighted by the sampled
78+
gate values
79+
80+
l0_reg (Tensor): L0 regularization term to be optimized together with
81+
model loss,
82+
e.g. loss(model_out, target) + l0_reg
83+
"""
84+
return super().forward(*args, **kwargs)
85+
86+
def _sample_gate_values(self, batch_size: int) -> Tensor:
87+
"""
88+
Sample gate values for each example in the batch from the Gaussian distribution
89+
90+
Args:
91+
batch_size (int): input batch size
92+
93+
Returns:
94+
gate_values (Tensor): gate value tensor of shape(batch_size, n_gates)
95+
"""
96+
97+
if self.training:
98+
n = torch.empty(batch_size, self.n_gates, device=self.mu.device)
99+
n.normal_(mean=0, std=self.std)
100+
return self.mu + n
101+
102+
return self.mu.expand(batch_size, self.n_gates)
103+
104+
def _get_gate_values(self) -> Tensor:
105+
"""
106+
Get the gate values derived from learned mu after model is trained
107+
108+
Returns:
109+
gate_values (Tensor): value of each gate after model is trained
110+
"""
111+
return torch.clamp(self.mu, min=0, max=1)
112+
113+
def _get_gate_active_probs(self) -> Tensor:
114+
"""
115+
Get the active probability of each gate, i.e, gate value > 0, in the
116+
Gaussian distribution
117+
118+
Returns:
119+
probs (Tensor): probabilities tensor of the gates are active
120+
in shape(n_gates)
121+
"""
122+
x = self.mu / self.std
123+
return 0.5 * (1 + torch.erf(x / math.sqrt(2)))

0 commit comments

Comments
 (0)