Skip to content

BUG: NotConstantValueError when using coordinates observed data with missing values with pymc==5.14.0 #7304

@jhandsel

Description

@jhandsel

Describe the issue:

In pymc 5.14.0, I get a NotConstantValueError when sampling with observed data that contains missing values. This only occurs when passing coords to the model.

The same code runs without an error with pymc==5.12.0.

Reproduceable code example:

import numpy as np
import pymc as pm
size = 30
f1 = np.random.normal(scale=1.5, size=size)
f2 = np.random.normal(scale=0.5, size=size)
features = np.stack((f1, f2), axis=1)
features[0, 0] = np.nan
n_features = features.shape[1]

coords = {"location": range(size),
          "feature": ("f1", "f2")}

with pm.Model(coords=coords) as model:
    sd_dist = pm.Exponential.dist(1.0, shape=n_features)
    chol, _, _ = pm.LKJCholeskyCov('chol_cov', n=n_features, eta=2,
                                   sd_dist=sd_dist, compute_corr=True)
    vals = pm.MvNormal('vals', mu=np.zeros(n_features), chol=chol,
                       observed=features, dims=("location", "feature"))
    trace = pm.sample()

Error message:

/Users/jennifer/env/lib/python3.11/site-packages/pymc/model/core.py:1349: ImputationWarning: Data in vals contains missing values and will be automatically imputed from the sampling distribution.
  warnings.warn(impute_message, ImputationWarning)
---------------------------------------------------------------------------
NotConstantValueError                     Traceback (most recent call last)
Cell In[2], line 19
     15 chol, _, _ = pm.LKJCholeskyCov('chol_cov', n=n_features, eta=2,
     16                                sd_dist=sd_dist, compute_corr=True)
     17 vals = pm.MvNormal('vals', mu=np.zeros(n_features), chol=chol,
     18                    observed=features, dims=("location", "feature"))
---> 19 trace = pm.sample()

File ~/env/lib/python3.11/site-packages/pymc/sampling/mcmc.py:684, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, model, **kwargs)
    681         auto_nuts_init = False
    683 initial_points = None
--> 684 step = assign_step_methods(model, step, methods=pm.STEP_METHODS, step_kwargs=kwargs)
    686 if nuts_sampler != "pymc":
    687     if not isinstance(step, NUTS):

File ~/env/lib/python3.11/site-packages/pymc/sampling/mcmc.py:212, in assign_step_methods(model, step, methods, step_kwargs)
    210 methods_list: list[type[BlockedStep]] = list(methods or pm.STEP_METHODS)
    211 selected_steps: dict[type[BlockedStep], list] = {}
--> 212 model_logp = model.logp()
    214 for var in model.value_vars:
    215     if var not in assigned_vars:
    216         # determine if a gradient can be computed

File ~/env/lib/python3.11/site-packages/pymc/model/core.py:725, in Model.logp(self, vars, jacobian, sum)
    723 rv_logps: list[TensorVariable] = []
    724 if rvs:
--> 725     rv_logps = transformed_conditional_logp(
    726         rvs=rvs,
    727         rvs_to_values=self.rvs_to_values,
    728         rvs_to_transforms=self.rvs_to_transforms,
    729         jacobian=jacobian,
    730     )
    731     assert isinstance(rv_logps, list)
    733 # Replace random variables by their value variables in potential terms

File ~/env/lib/python3.11/site-packages/pymc/logprob/basic.py:611, in transformed_conditional_logp(rvs, rvs_to_values, rvs_to_transforms, jacobian, **kwargs)
    608     transform_rewrite = TransformValuesRewrite(values_to_transforms)  # type: ignore
    610 kwargs.setdefault("warn_rvs", False)
--> 611 temp_logp_terms = conditional_logp(
    612     rvs_to_values,
    613     extra_rewrites=transform_rewrite,
    614     use_jacobian=jacobian,
    615     **kwargs,
    616 )
    618 # The function returns the logp for every single value term we provided to it.
    619 # This includes the extra values we plugged in above, so we filter those we
    620 # actually wanted in the same order they were given in.
    621 logp_terms = {}

File ~/env/lib/python3.11/site-packages/pymc/logprob/basic.py:541, in conditional_logp(rv_values, warn_rvs, ir_rewriter, extra_rewrites, **kwargs)
    538 q_values = remapped_vars[: len(q_values)]
    539 q_rv_inputs = remapped_vars[len(q_values) :]
--> 541 q_logprob_vars = _logprob(
    542     node.op,
    543     q_values,
    544     *q_rv_inputs,
    545     **kwargs,
    546 )
    548 if not isinstance(q_logprob_vars, list | tuple):
    549     q_logprob_vars = [q_logprob_vars]

File /opt/local/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/functools.py:909, in singledispatch.<locals>.wrapper(*args, **kw)
    905 if not args:
    906     raise TypeError(f'{funcname} requires at least '
    907                     '1 positional argument')
--> 909 return dispatch(args[0].__class__)(*args, **kw)

File ~/env/lib/python3.11/site-packages/pymc/distributions/distribution.py:1633, in partial_observed_rv_logprob(op, values, dist, mask, **kwargs)
   1631 [obs_value, unobs_value] = values
   1632 antimask = ~mask
-> 1633 joined_value = pt.empty(constant_fold([dist.shape])[0])
   1634 joined_value = pt.set_subtensor(joined_value[mask], unobs_value)
   1635 joined_value = pt.set_subtensor(joined_value[antimask], obs_value)

File ~/env/lib/python3.11/site-packages/pymc/pytensorf.py:1037, in constant_fold(xs, raise_not_constant)
   1034 folded_xs = rewrite_graph(fg).outputs
   1036 if raise_not_constant and not all(isinstance(folded_x, Constant) for folded_x in folded_xs):
-> 1037     raise NotConstantValueError
   1039 return tuple(
   1040     folded_x.data if isinstance(folded_x, Constant) else folded_x for folded_x in folded_xs
   1041 )

NotConstantValueError:

PyMC version information:

pymc==5.14.0
pytensor==2.20.0

MacOS Sonoma 14.4.1
Installation: pip

Context for the issue:

Using PyMC for data analysis for a customer. Covariates for certain locations have missing values, but still need to provide predictions for these locations to the customer. Passing in coordinates is important to keep track of different locations in the trace for visualization.

Currently working around by using pymc==5.12.0.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions