-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Closed
Description
Description of your problem
Please provide a minimal, self-contained, and reproducible example.
import pymc as pm
import pymc.sampling_jax
import numpy as np
import pandas as pd
from aesara import shared, tensor as at
from patsy import dmatrix
rng = np.random.default_rng(0)
size = 2_000
x1 = rng.normal(size=size)
x2 = rng.normal(size=size)
data = pd.DataFrame(
{
"x1": x1,
"x2": x2,
"y": rng.normal(loc=x1+x2, size=size)
}
)
features = ["x1", "x2"]
DEGREES = 3
N_KNOT = 7
df = N_KNOT + DEGREES + 1
mat_str = ""
mat_str_end = " - 1"
mat_str_middle = " + "
np_features = data[features].values
for feature in features:
mat_str += f"bs({feature}, df={df}, degree={DEGREES}){mat_str_middle}"
mat_str = mat_str[:-2] + mat_str_end
basis = dmatrix(mat_str, {feature: np_features[:, i] for i, feature in enumerate(features)})
dmat_data = np.asarray(basis).reshape(np_features.shape[0], np_features.shape[1], -1)
dmat = shared(dmat_data)
with pm.Model() as model:
mutable_data = pm.MutableData("data", np_features)
HALFNORMAL_SCALE = 1. / np.sqrt(1. - 2. / np.pi)
mu = pm.Normal('mu_grw', 0., 1., shape=dmat.shape[1])
delta = pm.Normal('delta_grw', 0., 0.1/2.5, shape=(dmat.shape[1], dmat.shape[2]))
sigma = pm.HalfNormal('sigma_grw', 0.1 * HALFNORMAL_SCALE, shape=dmat.shape[1])
grw = pm.Deterministic('grw', mu[:, None] + sigma[:, None] * delta.cumsum(axis=1))
f = at.tensordot(dmat, grw)
y = pm.MutableData("y", data["y"])
eps = pm.HalfNormal("eps", sigma=1)
normal = pm.Normal("normal", mu=f, sigma=eps, observed=y)
results = pm.sample()
#results = pm.sampling_jax.sample_blackjax_nuts(chain_method="vectorized")
Please provide the full traceback.
Complete error traceback
With pm.sample():
Works as expected (with Ricardo's fix)
For numpyro/blackjax:
Traceback (most recent call last):
results = pm.sampling_jax.sample_blackjax_nuts(chain_method="vectorized")
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/pymc/sampling_jax.py", line 335, in sample_blackjax_nuts
states, _ = map_fn(get_posterior_samples)(keys, init_params)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/api.py", line 1485, in vmap_f
out_flat = batching.batch(
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/linear_util.py", line 168, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/api.py", line 473, in cache_miss
out_flat = xla.xla_call(
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/core.py", line 1765, in bind
return call_bind(self, fun, *args, **params)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/core.py", line 1781, in call_bind
outs = top_trace.process_call(primitive, fun_, tracers, params)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/interpreters/batching.py", line 226, in process_call
vals_out = call_primitive.bind(f_, *vals, **params)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/core.py", line 1765, in bind
return call_bind(self, fun, *args, **params)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/core.py", line 1781, in call_bind
outs = top_trace.process_call(primitive, fun_, tracers, params)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/core.py", line 678, in process_call
return primitive.impl(f, *tracers, **params)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/dispatch.py", line 182, in _xla_call_impl
compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/linear_util.py", line 285, in memoized_fun
ans = call(fun, *args)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/dispatch.py", line 230, in _xla_callable_uncached
return lower_xla_callable(fun, device, backend, name, donated_invars, False,
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/profiler.py", line 206, in wrapper
return func(*args, **kwargs)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/dispatch.py", line 272, in lower_xla_callable
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/profiler.py", line 206, in wrapper
return func(*args, **kwargs)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1893, in trace_to_jaxpr_final
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1865, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/linear_util.py", line 168, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/pymc/sampling_jax.py", line 199, in _blackjax_inference_loop
last_state, kernel, _ = adapt.run(seed, init_position)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/blackjax/kernels.py", line 567, in run
init_state = algorithm.init(position, logprob_fn, logprob_grad_fn)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/blackjax/mcmc/hmc.py", line 78, in init
potential_energy, potential_energy_grad = jax.value_and_grad(potential_fn)(
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/api.py", line 995, in value_and_grad_f
ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/api.py", line 2457, in _vjp
out_primal, out_vjp = ad.vjp(
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/interpreters/ad.py", line 130, in vjp
out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/interpreters/ad.py", line 119, in linearize
jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/profiler.py", line 206, in wrapper
return func(*args, **kwargs)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 616, in trace_to_jaxpr_nounits
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/linear_util.py", line 168, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/blackjax/mcmc/hmc.py", line 70, in potential_fn
return -logprob_fn(x)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/pymc/sampling_jax.py", line 109, in logp_fn_wrap
return logp_fn(*x)[0]
File "/tmp/tmpwzpirvop", line 32, in jax_funcified_fgraph
auto_129685 = reshape(auto_131175, auto_130149)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/aesara/link/jax/dispatch.py", line 731, in reshape
return jnp.reshape(x, shape)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 727, in reshape
return a.reshape(newshape, order=order) # forward to method for ndarrays
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 745, in _reshape
newshape = _compute_newshape(a, args[0] if len(args) == 1 else args)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 738, in _compute_newshape
newshape = core.canonicalize_shape(newshape if iterable else [newshape])
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/core.py", line 1651, in canonicalize_shape
raise _invalid_shape_error(shape, context)
jax._src.traceback_util.UnfilteredStackTrace: TypeError: Shapes must be 1D sequences of concrete values of integer type, got [22].
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
results = pm.sampling_jax.sample_blackjax_nuts(chain_method="vectorized")
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/pymc/sampling_jax.py", line 335, in sample_blackjax_nuts
states, _ = map_fn(get_posterior_samples)(keys, init_params)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/pymc/sampling_jax.py", line 199, in _blackjax_inference_loop
last_state, kernel, _ = adapt.run(seed, init_position)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/blackjax/kernels.py", line 567, in run
init_state = algorithm.init(position, logprob_fn, logprob_grad_fn)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/blackjax/mcmc/hmc.py", line 78, in init
potential_energy, potential_energy_grad = jax.value_and_grad(potential_fn)(
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/blackjax/mcmc/hmc.py", line 70, in potential_fn
return -logprob_fn(x)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/pymc/sampling_jax.py", line 109, in logp_fn_wrap
return logp_fn(*x)[0]
File "/tmp/tmpwzpirvop", line 32, in jax_funcified_fgraph
auto_129685 = reshape(auto_131175, auto_130149)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/aesara/link/jax/dispatch.py", line 731, in reshape
return jnp.reshape(x, shape)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 727, in reshape
return a.reshape(newshape, order=order) # forward to method for ndarrays
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 745, in _reshape
newshape = _compute_newshape(a, args[0] if len(args) == 1 else args)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 738, in _compute_newshape
newshape = core.canonicalize_shape(newshape if iterable else [newshape])
TypeError: Shapes must be 1D sequences of concrete values of integer type, got [22].
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
Please provide any additional information below.
Versions and main components
- PyMC/PyMC3 Version: 4.0.1
- Aesara/Theano Version: 2.7.3
- Python Version: 3.9
- Operating system: Linux
- How did you install PyMC/PyMC3: (conda/pip) pip