Skip to content

Broken example in docstring of pm.set_data  #6004

@ltoniazzi

Description

@ltoniazzi

Broken example in docstring of pm.set_data

The example given in the docstring of pm.set_data appears broken. When I run

>>> import pymc as pm
>>> with pm.Model() as model:
...     x = pm.MutableData('x', [1., 2., 3.])
...     y = pm.MutableData('y', [1., 2., 3.])
...     beta = pm.Normal('beta', 0, 1)
...     obs = pm.Normal('obs', x * beta, 1, observed=y)
...     idata = pm.sample(1000, tune=1000)

>>> with model:
...     pm.set_data({'x': [5., 6., 9.]})
...     y_test = pm.sample_posterior_predictive(idata)
>>> y_test['obs'].mean(axis=0)
array([4.6088569 , 5.54128318, 8.32953844])

I get this

Traceback
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [beta]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 3 seconds.
Traceback (most recent call last):
  File "/home/usr/pymc/check_docstring.py", line 14, in <module>
    y_test['obs'].mean(axis=0)
  File "/home/usr/miniconda3/envs/pymc-dev/lib/python3.10/site-packages/arviz/data/inference_data.py", line 253, in __getitem__
    raise KeyError(key)
KeyError: 'obs'

This is because inferencedata is returned by pm.sample_posterior_predictive. A possible fix could be

with model:
     pm.set_data({'x': [5., 6., 9.]})
     y_test = pm.sample_posterior_predictive(idata)
y_test.posterior_predictive['obs'].mean(('chain', 'draw'))  # <--  fix

Versions and main components

  • PyMC Version: 4.1.3
  • Aesara Version: 2.7.7
  • Python Version: 3.10
  • Operating system: Linux/Ubuntu
  • How did you install PyMC/PyMC3: conda

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions