-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Closed
Description
Description of your problem
Using NUTS sampler with ADVI initialization in models with Dirichlet distribution yields TypeError: Too many parameter passed to aesara function in PyMC 4.0.0b6. Note that this is not the same issue as in #4733. The error can be reproduced as:
import pymc as pm
import numpy as np
with pm.Model() as model:
var = pm.Dirichlet("var", np.ones((3,4)))
trace = pm.sample(cores=1, init="advi", n_init=2000)Click for full traceback
Auto-assigning NUTS sampler...
Initializing NUTS using advi...
65.10% [1302/2000 00:00<00:00 Average Loss = 1.3673]
Convergence achieved at 1900
Interrupted at 1,899 [94%]: Average Loss = 1.3242
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Input In [7], in <cell line: 4>()
4 with pm.Model() as model:
5 var = pm.Dirichlet("var", np.ones((3,4)))
----> 6 trace = pm.sample(cores=1, init="advi", n_init=2000)
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/sampling.py:506, in sample(draws, step, init, n_init, initvals, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, jitter_max_retries, return_inferencedata, idata_kwargs, mp_ctx, **kwargs)
504 # One final check that shapes and logps at the starting points are okay.
505 for ip in initial_points:
--> 506 model.check_start_vals(ip)
507 _check_start_shape(model, ip)
509 sample_args = {
510 "draws": draws,
511 "step": step,
(...)
522 "discard_tuned_samples": discard_tuned_samples,
523 }
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/model.py:1695, in Model.check_start_vals(self, start)
1689 valid_keys = ", ".join(self.named_vars.keys())
1690 raise KeyError(
1691 "Some start parameters do not appear in the model!\n"
1692 f"Valid keys are: {valid_keys}, but {extra_keys} was supplied"
1693 )
-> 1695 initial_eval = self.point_logps(point=elem)
1697 if not all(np.isfinite(v) for v in initial_eval.values()):
1698 raise SamplingError(
1699 "Initial evaluation of model at starting point failed!\n"
1700 f"Starting values:\n{elem}\n\n"
1701 f"Initial evaluation results:\n{initial_eval}"
1702 )
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/model.py:1736, in Model.point_logps(self, point, round_vals)
1730 factors = self.basic_RVs + self.potentials
1731 factor_logps_fn = [at.sum(factor) for factor in self.logpt(factors, sum=False)]
1732 return {
1733 factor.name: np.round(np.asarray(factor_logp), round_vals)
1734 for factor, factor_logp in zip(
1735 factors,
-> 1736 self.compile_fn(factor_logps_fn)(point),
1737 )
1738 }
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/model.py:1835, in PointFunc.__call__(self, state)
1834 def __call__(self, state):
-> 1835 return self.f(**state)
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/compile/function/types.py:836, in Function.__call__(self, *args, **kwargs)
833 c.provided = 0
835 if len(args) + len(kwargs) > len(self.input_storage):
--> 836 raise TypeError("Too many parameter passed to aesara function")
838 # Set positional arguments
839 i = 0
TypeError: Too many parameter passed to aesara functionVersions and main components
- PyMC/PyMC3 Version: 4.0.0b6
- Aesara/Theano Version: 2.5.1
- Python Version: 3.10.4 | packaged by conda-forge | (main, Mar 24 2022, 17:39:04) [GCC 10.3.0]
- Operating system: untu 20.04.4 LTS x86-64
- How did you install PyMC/PyMC3: pip. I followed the instructions from the installation guide, changing
jaxtojax[cpu].
ricardoV94