Skip to content

Commit 781d974

Browse files
michaelosthegericardoV94
authored andcommitted
Reintroduce sampling_jax.py for backward compatibility
This is a separate commit to make sure that git tracks the rename of the old `sampling_jax.py` to `sampling/jax.py` correctly.
1 parent 80fc108 commit 781d974

File tree

4 files changed

+24
-0
lines changed

4 files changed

+24
-0
lines changed

pymc/sampling/jax.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,14 @@
3737
warnings.warn("This module is experimental.")
3838

3939

40+
__all__ = (
41+
"get_jaxified_graph",
42+
"get_jaxified_logp",
43+
"sample_blackjax_nuts",
44+
"sample_numpyro_nuts",
45+
)
46+
47+
4048
@jax_funcify.register(Assert)
4149
@jax_funcify.register(CheckParameterValue)
4250
@jax_funcify.register(SpecifyShape)

pymc/sampling_jax.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# This file exists only for backward-compatibility with imports like
2+
# `import pymc.sampling_jax` or `from pymc import sampling_jax`.
3+
4+
# pylint: disable=wildcard-import
5+
# pylint: disable=unused-wildcard-import
6+
7+
from pymc.sampling.jax import *

pymc/tests/sampling/test_jax.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,14 @@
1616

1717
import pymc as pm
1818

19+
20+
def test_old_import_route():
21+
import pymc.sampling.jax as new_sj
22+
import pymc.sampling_jax as old_sj
23+
24+
assert set(new_sj.__all__) <= set(dir(old_sj))
25+
26+
1927
with pytest.warns(UserWarning, match="module is experimental"):
2028
from pymc.sampling.jax import (
2129
_get_batched_jittered_initial_points,

scripts/run_mypy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
pymc/ode/ode.py
5151
pymc/ode/utils.py
5252
pymc/plots/__init__.py
53+
pymc/sampling_jax.py
5354
pymc/sampling/__init__.py
5455
pymc/sampling/forward.py
5556
pymc/sampling/mcmc.py

0 commit comments

Comments
 (0)