-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Closed
Labels
Description
Broken example in docstring of pm.set_data
The example given in the docstring of pm.set_data
appears broken. When I run
>>> import pymc as pm
>>> with pm.Model() as model:
... x = pm.MutableData('x', [1., 2., 3.])
... y = pm.MutableData('y', [1., 2., 3.])
... beta = pm.Normal('beta', 0, 1)
... obs = pm.Normal('obs', x * beta, 1, observed=y)
... idata = pm.sample(1000, tune=1000)
>>> with model:
... pm.set_data({'x': [5., 6., 9.]})
... y_test = pm.sample_posterior_predictive(idata)
>>> y_test['obs'].mean(axis=0)
array([4.6088569 , 5.54128318, 8.32953844])
I get this
Traceback
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [beta]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 3 seconds.
Traceback (most recent call last):
File "/home/usr/pymc/check_docstring.py", line 14, in <module>
y_test['obs'].mean(axis=0)
File "/home/usr/miniconda3/envs/pymc-dev/lib/python3.10/site-packages/arviz/data/inference_data.py", line 253, in __getitem__
raise KeyError(key)
KeyError: 'obs'
This is because inferencedata is returned by pm.sample_posterior_predictive
. A possible fix could be
with model:
pm.set_data({'x': [5., 6., 9.]})
y_test = pm.sample_posterior_predictive(idata)
y_test.posterior_predictive['obs'].mean(('chain', 'draw')) # <-- fix
Versions and main components
- PyMC Version: 4.1.3
- Aesara Version: 2.7.7
- Python Version: 3.10
- Operating system: Linux/Ubuntu
- How did you install PyMC/PyMC3: conda
michaelosthege and OriolAbril