-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Closed
Description
Description of your problem
pymc.sampling_jax.sample_blackjax_nuts
crashes with chains=1
Please provide a minimal, self-contained, and reproducible example.
# %%
import numpy as np
import pandas as pd
import pymc as pm
import pymc.sampling_jax
import pymc.distributions.transforms
import aesara.tensor as at
import aesara
np.set_printoptions(2)
# %%
data = pd.DataFrame(np.random.randn(100, 23))
# %%
K = 15
n_f = len(data.columns)
# %%
with pm.Model() as model:
alpha = pm.Gamma("alpha", 0.7, 1.0)
w = pm.StickBreakingWeights("w", alpha=alpha, K=K-1)
cov = np.identity(n_f)
mu = pm.MvNormal(
"mu",
mu=data.mean(axis=0),
cov=cov,
shape=(K, n_f),
)
components = [
pm.MvNormal.dist(mu=mu[i, :], cov=cov, shape=(n_f,))
for i in range(K)
]
obs = pm.Mixture(
"obs",
w=w,
comp_dists=components,
observed=data,
)
# %%
with model:
trace = pymc.sampling_jax.sample_blackjax_nuts(target_accept=0.9, chains=1)
Please provide the full traceback.
Compiling...
Traceback (most recent call last):
File "/usr/lib64/python3.9/code.py", line 90, in runcode
exec(code, self.locals)
File "<input>", line 2, in <module>
File "/home/guillaume/.cache/pypoetry/virtualenvs/tabular-T_K5n0PG-py3.9/lib/python3.9/site-packages/pymc/sampling_jax.py", line 301, in sample_blackjax_nuts
init_params = [np.stack(init_params)]
File "<__array_function__ internals>", line 180, in stack
File "/home/guillaume/.cache/pypoetry/virtualenvs/tabular-T_K5n0PG-py3.9/lib64/python3.9/site-packages/numpy/core/shape_base.py", line 426, in stack
raise ValueError('all input arrays must have the same shape')
ValueError: all input arrays must have the same shape
Versions and main components
- PyMC/PyMC3 Version: PyMC4.1.1
- Aesara/Theano Version: 2.7.4
- Python Version: 3.9
- Operating system: Linux ( Fedora 36 )
- How did you install PyMC/PyMC3: (conda/pip) Pip ( poetry )
Metadata
Metadata
Assignees
Labels
No labels