File tree Expand file tree Collapse file tree 4 files changed +24
-0
lines changed Expand file tree Collapse file tree 4 files changed +24
-0
lines changed Original file line number Diff line number Diff line change 37
37
warnings .warn ("This module is experimental." )
38
38
39
39
40
+ __all__ = (
41
+ "get_jaxified_graph" ,
42
+ "get_jaxified_logp" ,
43
+ "sample_blackjax_nuts" ,
44
+ "sample_numpyro_nuts" ,
45
+ )
46
+
47
+
40
48
@jax_funcify .register (Assert )
41
49
@jax_funcify .register (CheckParameterValue )
42
50
@jax_funcify .register (SpecifyShape )
Original file line number Diff line number Diff line change
1
+ # This file exists only for backward-compatibility with imports like
2
+ # `import pymc.sampling_jax` or `from pymc import sampling_jax`.
3
+
4
+ # pylint: disable=wildcard-import
5
+ # pylint: disable=unused-wildcard-import
6
+
7
+ from pymc .sampling .jax import *
Original file line number Diff line number Diff line change 16
16
17
17
import pymc as pm
18
18
19
+
20
+ def test_old_import_route ():
21
+ import pymc .sampling .jax as new_sj
22
+ import pymc .sampling_jax as old_sj
23
+
24
+ assert set (new_sj .__all__ ) <= set (dir (old_sj ))
25
+
26
+
19
27
with pytest .warns (UserWarning , match = "module is experimental" ):
20
28
from pymc .sampling .jax import (
21
29
_get_batched_jittered_initial_points ,
Original file line number Diff line number Diff line change 50
50
pymc/ode/ode.py
51
51
pymc/ode/utils.py
52
52
pymc/plots/__init__.py
53
+ pymc/sampling_jax.py
53
54
pymc/sampling/__init__.py
54
55
pymc/sampling/forward.py
55
56
pymc/sampling/mcmc.py
You can’t perform that action at this time.
0 commit comments