Skip to content

pymc.sampling_jax.sample_blackjax_nuts crashes with chains=1 #5954

@aussetg

Description

@aussetg

Description of your problem

pymc.sampling_jax.sample_blackjax_nuts crashes with chains=1

Please provide a minimal, self-contained, and reproducible example.

# %%
import numpy as np
import pandas as pd
import pymc as pm
import pymc.sampling_jax
import pymc.distributions.transforms
import aesara.tensor as at
import aesara
np.set_printoptions(2)

# %%
data = pd.DataFrame(np.random.randn(100, 23))

# %%
K = 15
n_f = len(data.columns)

# %%
with pm.Model() as model:
    alpha = pm.Gamma("alpha", 0.7, 1.0)
    w = pm.StickBreakingWeights("w", alpha=alpha, K=K-1)

    cov = np.identity(n_f)

    mu = pm.MvNormal(
        "mu",
        mu=data.mean(axis=0),
        cov=cov,
        shape=(K, n_f),
    )

    components = [
        pm.MvNormal.dist(mu=mu[i, :], cov=cov, shape=(n_f,))
        for i in range(K)
    ]

    obs = pm.Mixture(
        "obs",
        w=w,
        comp_dists=components,
        observed=data,
    )

# %%
with model:
    trace = pymc.sampling_jax.sample_blackjax_nuts(target_accept=0.9, chains=1)

Please provide the full traceback.

Compiling...
Traceback (most recent call last):
  File "/usr/lib64/python3.9/code.py", line 90, in runcode
    exec(code, self.locals)
  File "<input>", line 2, in <module>
  File "/home/guillaume/.cache/pypoetry/virtualenvs/tabular-T_K5n0PG-py3.9/lib/python3.9/site-packages/pymc/sampling_jax.py", line 301, in sample_blackjax_nuts
    init_params = [np.stack(init_params)]
  File "<__array_function__ internals>", line 180, in stack
  File "/home/guillaume/.cache/pypoetry/virtualenvs/tabular-T_K5n0PG-py3.9/lib64/python3.9/site-packages/numpy/core/shape_base.py", line 426, in stack
    raise ValueError('all input arrays must have the same shape')
ValueError: all input arrays must have the same shape

Versions and main components

  • PyMC/PyMC3 Version: PyMC4.1.1
  • Aesara/Theano Version: 2.7.4
  • Python Version: 3.9
  • Operating system: Linux ( Fedora 36 )
  • How did you install PyMC/PyMC3: (conda/pip) Pip ( poetry )

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions