Skip to content

Error in NUTS + ADVI init with deterministics #5732

@PedroSebe

Description

@PedroSebe

Description of your problem

Using NUTS sampler with ADVI initialization in models with Dirichlet distribution yields TypeError: Too many parameter passed to aesara function in PyMC 4.0.0b6. Note that this is not the same issue as in #4733. The error can be reproduced as:

import pymc as pm
import numpy as np

with pm.Model() as model:
    var = pm.Dirichlet("var", np.ones((3,4)))
    trace = pm.sample(cores=1, init="advi", n_init=2000)
Click for full traceback
Auto-assigning NUTS sampler...
Initializing NUTS using advi...
 65.10% [1302/2000 00:00<00:00 Average Loss = 1.3673]
Convergence achieved at 1900
Interrupted at 1,899 [94%]: Average Loss = 1.3242
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Input In [7], in <cell line: 4>()
      4 with pm.Model() as model:
      5     var = pm.Dirichlet("var", np.ones((3,4)))
----> 6     trace = pm.sample(cores=1, init="advi", n_init=2000)

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/sampling.py:506, in sample(draws, step, init, n_init, initvals, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, jitter_max_retries, return_inferencedata, idata_kwargs, mp_ctx, **kwargs)
    504 # One final check that shapes and logps at the starting points are okay.
    505 for ip in initial_points:
--> 506     model.check_start_vals(ip)
    507     _check_start_shape(model, ip)
    509 sample_args = {
    510     "draws": draws,
    511     "step": step,
   (...)
    522     "discard_tuned_samples": discard_tuned_samples,
    523 }

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/model.py:1695, in Model.check_start_vals(self, start)
   1689     valid_keys = ", ".join(self.named_vars.keys())
   1690     raise KeyError(
   1691         "Some start parameters do not appear in the model!\n"
   1692         f"Valid keys are: {valid_keys}, but {extra_keys} was supplied"
   1693     )
-> 1695 initial_eval = self.point_logps(point=elem)
   1697 if not all(np.isfinite(v) for v in initial_eval.values()):
   1698     raise SamplingError(
   1699         "Initial evaluation of model at starting point failed!\n"
   1700         f"Starting values:\n{elem}\n\n"
   1701         f"Initial evaluation results:\n{initial_eval}"
   1702     )

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/model.py:1736, in Model.point_logps(self, point, round_vals)
   1730 factors = self.basic_RVs + self.potentials
   1731 factor_logps_fn = [at.sum(factor) for factor in self.logpt(factors, sum=False)]
   1732 return {
   1733     factor.name: np.round(np.asarray(factor_logp), round_vals)
   1734     for factor, factor_logp in zip(
   1735         factors,
-> 1736         self.compile_fn(factor_logps_fn)(point),
   1737     )
   1738 }

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/model.py:1835, in PointFunc.__call__(self, state)
   1834 def __call__(self, state):
-> 1835     return self.f(**state)

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/compile/function/types.py:836, in Function.__call__(self, *args, **kwargs)
    833     c.provided = 0
    835 if len(args) + len(kwargs) > len(self.input_storage):
--> 836     raise TypeError("Too many parameter passed to aesara function")
    838 # Set positional arguments
    839 i = 0

TypeError: Too many parameter passed to aesara function

Versions and main components

  • PyMC/PyMC3 Version: 4.0.0b6
  • Aesara/Theano Version: 2.5.1
  • Python Version: 3.10.4 | packaged by conda-forge | (main, Mar 24 2022, 17:39:04) [GCC 10.3.0]
  • Operating system: untu 20.04.4 LTS x86-64
  • How did you install PyMC/PyMC3: pip. I followed the instructions from the installation guide, changing jax to jax[cpu].

Metadata

Metadata

Assignees

Labels

VIVariational Inferencebug

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions