Add start_sigma
to ADVI 2
#6132
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Recreation of #6096 after failing test (fixed in latest commit).
Introduces the
start_sigma
argument which allows to set a starting value for the sigmas of mean field approximation inADVI
. I am using theordering
property to create the mapping between the variables and the flat 1d array required byADVI
.See also:
https://discourse.pymc.io/t/quality-of-life-improvements-to-advi/10254
Checklist
Major / Breaking Changes
Bugfixes / New features
start_sigma
being given tofit
.ForASVGD
forward the key word argumentsstart
andrandom_seed
to the constructor of the default approximationFullRank
instead of sending them tosuper().__init__
which can't process them_iterate_with_loss
to run withn=0
by giving an appropriate logger message. This allows to check initialization values, which is useful for testing, but could also be useful for debugging a user modelDocs / Maintenance
start
argument was outdated, it claimed typePoint
, but it gets passed through to a function that requiresStartDict
. Since users are probably unfamiliar with this custom type (Dict[Union[Variable, str], Union[np.ndarray, Variable, str]]
), I decided to also mention the most relevant and most commonly used subtypedict[str, np.ndarray]
.