Skip to content

Conversation

bherwerth
Copy link
Contributor

The PR addresses #5954, which reported that calling sample_blackjax_nuts failed with chains=1 when prior parameters had different shapes. A fix is provided by:

  • A change in sampling_jax.py so that init_params has the same structure for chains=1 as for chains>1
  • Modification of the test case in test_sampling_jax.py so that it covers parameters with different shapes.
    ...

Checklist

Major / Breaking Changes

  • ...

Bugfixes / New features

  • Fix bug in sampling_jax.sample_blackjax_nuts failing when called for one chain and prior parameters of different shape

Docs / Maintenance

  • ...

@twiecki
Copy link
Member

twiecki commented Jul 12, 2022

Can you add a test for the chains=1 case?

@codecov
Copy link

codecov bot commented Jul 12, 2022

Codecov Report

Merging #5969 (a3e8f4f) into main (82d7b0e) will decrease coverage by 0.00%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #5969      +/-   ##
==========================================
- Coverage   89.38%   89.37%   -0.01%     
==========================================
  Files          73       73              
  Lines       13249    13248       -1     
==========================================
- Hits        11843    11841       -2     
- Misses       1406     1407       +1     
Impacted Files Coverage Δ
pymc/sampling_jax.py 96.93% <100.00%> (-0.02%) ⬇️
pymc/step_methods/hmc/base_hmc.py 89.76% <0.00%> (-0.79%) ⬇️

@bherwerth
Copy link
Contributor Author

Hi @twiecki ,

The test for chains=1 was already there before my changes:

@pytest.mark.parametrize(
"chains",
[
pytest.param(1),
pytest.param(
2, marks=pytest.mark.skipif(len(jax.devices()) < 2, reason="not enough devices")
),
],
)
def test_transform_samples(sampler, postprocessing_backend, chains):

However, it did not fail because the prior parameters of the test case are both of the same shape. I changed the sigma parameter in the test case to be of shape (2,), cf. the diff on test_sampling_jax.py. With this change of the test case, the test fails before my change to sampling_jax.py and it works afterwards (when I run the test locally).

Please let me know in case you had in mind something else.

On the code coverage report, I don't understand the coverage change on base_hmc.py because I did not edit that file:
image

@twiecki twiecki merged commit 32ffaf5 into pymc-devs:main Jul 13, 2022
@twiecki
Copy link
Member

twiecki commented Jul 13, 2022

Thanks @bherwerth!

@bherwerth bherwerth deleted the fix-sampling-jax-5954 branch July 15, 2022 19:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants