Skip to content

Conversation

jhrcook
Copy link
Contributor

@jhrcook jhrcook commented Jul 31, 2022

What is this PR about?

There are currently defaults for a few keyword arguments for the Numpyro NUTS initialization method that prevent the PyMC user for passing certain keyword arguments. This PR refactors the code to update the user-provided keyword arguments dictionary with the defaults if they weren't included.

Addresses issue #6020

Checklist

Major / Breaking Changes

Bugfixes / New features

  • Enable the user to pass any keyword arguments to the Numpyro NUTS initialization.

Docs / Maintenance

  • Updated docstring of sample_numpyro_nuts to include information for the initvals argument
  • fixed some incorrect typehints of arguments in the sample_numpyro_nuts definition
  • added a return typehint to sample_numpyro_nuts

@codecov
Copy link

codecov bot commented Jul 31, 2022

Codecov Report

Merging #6021 (66dd22a) into main (18bbcbb) will increase coverage by 0.00%.
The diff coverage is 100.00%.

Impacted file tree graph

@@           Coverage Diff           @@
##             main    #6021   +/-   ##
=======================================
  Coverage   89.17%   89.17%           
=======================================
  Files          72       72           
  Lines       12905    12911    +6     
=======================================
+ Hits        11508    11514    +6     
  Misses       1397     1397           
Impacted Files Coverage Δ
pymc/sampling_jax.py 97.05% <100.00%> (+0.08%) ⬆️

@cluhmann cluhmann requested a review from twiecki August 1, 2022 16:47
@twiecki
Copy link
Member

twiecki commented Aug 1, 2022

@jhrcook Can you add a test?

@jhrcook
Copy link
Contributor Author

jhrcook commented Aug 2, 2022

@twiecki I've added tests for the keyword-argument updating function and another set of tests that run a short MCMC with some different NUTS keyword arguments, but I have limited ability to actually determine if numpyro.NUTS actually used the arguments. The only one I could think of was to set the step_size and confirm that it was in the result trace.sample_stats. I could try monkey-patching into the NUTS initializer method, but it could be tricky because the sample_numpyro_nuts imports NUTS internally. Let me know if you'd like me to try, though.

{"adapt_step_size": False, "adapt_mass_matrix": True, "dense_mass": True},
{"adapt_step_size": False, "step_size": 0.13},
],
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's try to monkey patch? If it's not possible I think we can survive without explicit tests. This one seems too inneficient.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll give it shot. I agree that the current method I used is inefficient and isn't really testing the desired behavior.

@jhrcook
Copy link
Contributor Author

jhrcook commented Aug 3, 2022

I used a monkeypatch to intercept the NUTS object in the initialization of the Numpyro MCMC object when called in pymc.sampling_jax.sample_numpyro_jax(). I just made a MockMCMC that checks some of the properties of its NUTS sampler and asserts they are what I passed to sample_numpyro_jax(..., nuts_kwargs={...}).

I confirmed that the monkeypatch was intercepting the MCMC.__init__() by purposefully making the tests fail with incorrect assertions.

I have limited experience with using monkeypatches in tests, so let me know if you think there would be a better construction/format.

@ricardoV94
Copy link
Member

When using monkey patching there's a helper method, something like assert_called_once_with(...) where you can directly test your monkey patched object was called once with specific arguments and kwargs.

@jhrcook
Copy link
Contributor Author

jhrcook commented Aug 3, 2022

Thank you @ricardoV94 for the help. I've replaced the monkey patching system with a rather simple (too simple?) mocking system. I used mocking to allow me to peer into when the numpyro.infer.MCMC object is made and then inspect the NUTS sampler it was passed. This puts all of the assertions in a single function and removed the need to make a new mock class.

@ricardoV94
Copy link
Member

Thank you @ricardoV94 for the help. I've replaced the monkey patching system with a rather simple (too simple?) mocking system. I used mocking to allow me to peer into when the numpyro.infer.MCMC object is made and then inspect the NUTS sampler it was passed. This puts all of the assertions in a single function and removed the need to make a new mock class.

Looks perfect. The simpler the better!

@twiecki twiecki merged commit a8279d7 into pymc-devs:main Aug 5, 2022
@twiecki
Copy link
Member

twiecki commented Aug 5, 2022

Thanks @jhrcook!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants