You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Support different reg_reduction in Captum STG (#1090)
Summary:
Pull Request resolved: #1090
Add a new `str` argument `reg_reduction` in Captum STG classes, which specifies how the returned regularization should be reduced. Following Pytorch Loss's design, support 3 modes: `sum`, `mean`, and `none`. The default is `sum`.
(There may be needs for other modes in future, like `weighted_sum`. With customized `mask`, each gate may handle different number of elements. The application may want to use as few elements as possible instead of as few gates. For now, such use cases can use `none` option and reduce themselves)
Although we previously used `mean`, we decided to change to `sum` as default for 3 reasons:
1. The original paper "LEARNING SPARSE NEURAL NETWORKS THROUGH L0 REGULARIZATION" used `sum` both in its writing and its [implementation](https://github.com/AMLab-Amsterdam/L0_regularization/blob/master/l0_layers.py#L70) {F822978249}
2. L^1 and L^2 regularization also `sum` over each parameter without averaging over total number of parameters within a model. See [Pytorch's implementation](https://github.com/pytorch/pytorch/blob/df569367ef444dc9831ef0dde3bc611bcabcfbf9/torch/optim/adagrad.py#L268)
3. When there are multiple STG of imbalanced lengths, the results are comparable in `sum` but not `mean`. If the model has 2 STG, where one has 100 gates and the other has one single gate, the regularization of each gate in the 1st STG will be divided by 100 in `mean`, which makes the 1st STG 100 times weaker than the 2nd STG. This is usually unexpected for users.
Using `mean` or `sum` will not impact the performance when there is only one BSN layer, coz people can tune `reg_weight` to counter the difference. The authors of "Feature selection using Stochastic Gates" mixed using `sum` and `mean` in [their implementation](https://github.com/runopti/stg/blob/master/python/stg/models.py#L164-L195)
For backward compatibility, explicitly specified `reg_reduction = "mean"` for all existing usages in Pyper and MVAI.
Reviewed By: cyrjano, edqwerty10
Differential Revision: D41991741
fbshipit-source-id: 698db938fc373747db0df1b1145c6e9943476142
0 commit comments