Skip to content

Reshape operation in logp graph not supported in JAX backend #5927

@markgoodhead

Description

@markgoodhead

Description of your problem

Please provide a minimal, self-contained, and reproducible example.

import pymc as pm
import pymc.sampling_jax
import numpy as np
import pandas as pd
from aesara import shared, tensor as at
from patsy import dmatrix

rng = np.random.default_rng(0)
size = 2_000
x1 = rng.normal(size=size)
x2 = rng.normal(size=size)
data = pd.DataFrame(
    {
        "x1": x1,
        "x2": x2,
        "y": rng.normal(loc=x1+x2, size=size)
    }
)
features = ["x1", "x2"]
DEGREES = 3
N_KNOT = 7
df = N_KNOT + DEGREES + 1
mat_str = ""
mat_str_end = " - 1"
mat_str_middle = " + "
np_features = data[features].values
for feature in features:
    mat_str += f"bs({feature}, df={df}, degree={DEGREES}){mat_str_middle}"
mat_str = mat_str[:-2] + mat_str_end
basis = dmatrix(mat_str, {feature: np_features[:, i] for i, feature in enumerate(features)})
dmat_data = np.asarray(basis).reshape(np_features.shape[0], np_features.shape[1], -1)
dmat = shared(dmat_data)
with pm.Model() as model:
    mutable_data = pm.MutableData("data", np_features)
    HALFNORMAL_SCALE = 1. / np.sqrt(1. - 2. / np.pi)
    mu = pm.Normal('mu_grw', 0., 1., shape=dmat.shape[1])
    delta = pm.Normal('delta_grw', 0., 0.1/2.5, shape=(dmat.shape[1], dmat.shape[2]))
    sigma = pm.HalfNormal('sigma_grw', 0.1 * HALFNORMAL_SCALE, shape=dmat.shape[1])
    grw = pm.Deterministic('grw', mu[:, None] + sigma[:, None] * delta.cumsum(axis=1))
    f = at.tensordot(dmat, grw)
    y = pm.MutableData("y", data["y"])
    eps = pm.HalfNormal("eps", sigma=1)
    normal = pm.Normal("normal", mu=f, sigma=eps, observed=y)
    results = pm.sample()
    #results = pm.sampling_jax.sample_blackjax_nuts(chain_method="vectorized")

Please provide the full traceback.

Complete error traceback
With pm.sample():

Works as expected (with Ricardo's fix)

For numpyro/blackjax:

Traceback (most recent call last):
    results = pm.sampling_jax.sample_blackjax_nuts(chain_method="vectorized")
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/pymc/sampling_jax.py", line 335, in sample_blackjax_nuts
    states, _ = map_fn(get_posterior_samples)(keys, init_params)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/api.py", line 1485, in vmap_f
    out_flat = batching.batch(
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/linear_util.py", line 168, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/api.py", line 473, in cache_miss
    out_flat = xla.xla_call(
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/core.py", line 1765, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/core.py", line 1781, in call_bind
    outs = top_trace.process_call(primitive, fun_, tracers, params)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/interpreters/batching.py", line 226, in process_call
    vals_out = call_primitive.bind(f_, *vals, **params)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/core.py", line 1765, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/core.py", line 1781, in call_bind
    outs = top_trace.process_call(primitive, fun_, tracers, params)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/core.py", line 678, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/dispatch.py", line 182, in _xla_call_impl
    compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/linear_util.py", line 285, in memoized_fun
    ans = call(fun, *args)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/dispatch.py", line 230, in _xla_callable_uncached
    return lower_xla_callable(fun, device, backend, name, donated_invars, False,
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/profiler.py", line 206, in wrapper
    return func(*args, **kwargs)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/dispatch.py", line 272, in lower_xla_callable
    jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/profiler.py", line 206, in wrapper
    return func(*args, **kwargs)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1893, in trace_to_jaxpr_final
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1865, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/linear_util.py", line 168, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/pymc/sampling_jax.py", line 199, in _blackjax_inference_loop
    last_state, kernel, _ = adapt.run(seed, init_position)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/blackjax/kernels.py", line 567, in run
    init_state = algorithm.init(position, logprob_fn, logprob_grad_fn)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/blackjax/mcmc/hmc.py", line 78, in init
    potential_energy, potential_energy_grad = jax.value_and_grad(potential_fn)(
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/api.py", line 995, in value_and_grad_f
    ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/api.py", line 2457, in _vjp
    out_primal, out_vjp = ad.vjp(
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/interpreters/ad.py", line 130, in vjp
    out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/interpreters/ad.py", line 119, in linearize
    jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/profiler.py", line 206, in wrapper
    return func(*args, **kwargs)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 616, in trace_to_jaxpr_nounits
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/linear_util.py", line 168, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/blackjax/mcmc/hmc.py", line 70, in potential_fn
    return -logprob_fn(x)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/pymc/sampling_jax.py", line 109, in logp_fn_wrap
    return logp_fn(*x)[0]
  File "/tmp/tmpwzpirvop", line 32, in jax_funcified_fgraph
    auto_129685 = reshape(auto_131175, auto_130149)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/aesara/link/jax/dispatch.py", line 731, in reshape
    return jnp.reshape(x, shape)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 727, in reshape
    return a.reshape(newshape, order=order)  # forward to method for ndarrays
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 745, in _reshape
    newshape = _compute_newshape(a, args[0] if len(args) == 1 else args)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 738, in _compute_newshape
    newshape = core.canonicalize_shape(newshape if iterable else [newshape])
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/core.py", line 1651, in canonicalize_shape
    raise _invalid_shape_error(shape, context)
jax._src.traceback_util.UnfilteredStackTrace: TypeError: Shapes must be 1D sequences of concrete values of integer type, got [22].
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
    results = pm.sampling_jax.sample_blackjax_nuts(chain_method="vectorized")
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/pymc/sampling_jax.py", line 335, in sample_blackjax_nuts
    states, _ = map_fn(get_posterior_samples)(keys, init_params)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/pymc/sampling_jax.py", line 199, in _blackjax_inference_loop
    last_state, kernel, _ = adapt.run(seed, init_position)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/blackjax/kernels.py", line 567, in run
    init_state = algorithm.init(position, logprob_fn, logprob_grad_fn)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/blackjax/mcmc/hmc.py", line 78, in init
    potential_energy, potential_energy_grad = jax.value_and_grad(potential_fn)(
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/blackjax/mcmc/hmc.py", line 70, in potential_fn
    return -logprob_fn(x)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/pymc/sampling_jax.py", line 109, in logp_fn_wrap
    return logp_fn(*x)[0]
  File "/tmp/tmpwzpirvop", line 32, in jax_funcified_fgraph
    auto_129685 = reshape(auto_131175, auto_130149)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/aesara/link/jax/dispatch.py", line 731, in reshape
    return jnp.reshape(x, shape)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 727, in reshape
    return a.reshape(newshape, order=order)  # forward to method for ndarrays
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 745, in _reshape
    newshape = _compute_newshape(a, args[0] if len(args) == 1 else args)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 738, in _compute_newshape
    newshape = core.canonicalize_shape(newshape if iterable else [newshape])
TypeError: Shapes must be 1D sequences of concrete values of integer type, got [22].
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.

Please provide any additional information below.

Versions and main components

  • PyMC/PyMC3 Version: 4.0.1
  • Aesara/Theano Version: 2.7.3
  • Python Version: 3.9
  • Operating system: Linux
  • How did you install PyMC/PyMC3: (conda/pip) pip

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions