-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Pass user-provided NUTS kwargs to Numpyro #6021
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov Report
@@ Coverage Diff @@
## main #6021 +/- ##
=======================================
Coverage 89.17% 89.17%
=======================================
Files 72 72
Lines 12905 12911 +6
=======================================
+ Hits 11508 11514 +6
Misses 1397 1397
|
@jhrcook Can you add a test? |
@twiecki I've added tests for the keyword-argument updating function and another set of tests that run a short MCMC with some different NUTS keyword arguments, but I have limited ability to actually determine if |
pymc/tests/test_sampling_jax.py
Outdated
{"adapt_step_size": False, "adapt_mass_matrix": True, "dense_mass": True}, | ||
{"adapt_step_size": False, "step_size": 0.13}, | ||
], | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's try to monkey patch? If it's not possible I think we can survive without explicit tests. This one seems too inneficient.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll give it shot. I agree that the current method I used is inefficient and isn't really testing the desired behavior.
I used a monkeypatch to intercept the I confirmed that the monkeypatch was intercepting the I have limited experience with using monkeypatches in tests, so let me know if you think there would be a better construction/format. |
When using monkey patching there's a helper method, something like |
Thank you @ricardoV94 for the help. I've replaced the monkey patching system with a rather simple (too simple?) mocking system. I used mocking to allow me to peer into when the |
Looks perfect. The simpler the better! |
Thanks @jhrcook! |
What is this PR about?
There are currently defaults for a few keyword arguments for the Numpyro NUTS initialization method that prevent the PyMC user for passing certain keyword arguments. This PR refactors the code to update the user-provided keyword arguments dictionary with the defaults if they weren't included.
Addresses issue #6020
Checklist
Major / Breaking Changes
Bugfixes / New features
Docs / Maintenance
sample_numpyro_nuts
to include information for theinitvals
argumentsample_numpyro_nuts
definitionsample_numpyro_nuts