-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Closed
Description
When using a normal distribution within pm.Truncated
with several observations, we encounter the following ShapeError
import pymc as pm
with pm.Model() as truncated_model:
mu = pm.Normal("mu", 0., 5.)
sigma = pm.HalfCauchy("sigma", 2.5)
normal_dist = pm.Normal.dist(mu=mu, sigma=sigma)
truncated_normal = pm.Truncated(
"truncated_normal",
dist=normal_dist,
lower=-2.,
upper=2.,
observed=[-1, 0, 1]
)
The problem should be somewhere here when we return a TruncatedNormalRV
.
Traceback
---------------------------------------------------------------------------
ShapeError Traceback (most recent call last)
Input In [2], in <cell line: 1>()
3 sigma = pm.HalfCauchy("sigma", 2.5)
4 normal_dist = pm.Normal.dist(mu=mu, sigma=sigma)
----> 5 truncated_normal = pm.Truncated(
6 "truncated_normal",
7 dist=normal_dist,
8 lower=-2.,
9 upper=2.,
10 observed=[-1, 0, 1]
11 )
13 pm.model_to_graphviz(truncated_model)
File ~/Documents/GitHub/pymc/pymc/distributions/distribution.py:292, in Distribution.__new__(cls, name, rng, dims, initval, observed, total_size, transform, *args, **kwargs)
288 kwargs["shape"] = tuple(observed.shape)
290 rv_out = cls.dist(*args, **kwargs)
--> 292 rv_out = model.register_rv(
293 rv_out,
294 name,
295 observed,
296 total_size,
297 dims=dims,
298 transform=transform,
299 initval=initval,
300 )
302 # add in pretty-printing support
303 rv_out.str_repr = types.MethodType(str_for_dist, rv_out)
File ~/Documents/GitHub/pymc/pymc/model.py:1386, in Model.register_rv(self, rv_var, name, data, total_size, dims, transform, initval)
1379 raise TypeError(
1380 "Variables that depend on other nodes cannot be used for observed data."
1381 f"The data variable was: {data}"
1382 )
1384 # `rv_var` is potentially changed by `make_obs_var`,
1385 # for example into a new graph for imputation of missing data.
-> 1386 rv_var = self.make_obs_var(rv_var, data, dims, transform)
1388 return rv_var
File ~/Documents/GitHub/pymc/pymc/model.py:1412, in Model.make_obs_var(self, rv_var, data, dims, transform)
1409 data = convert_observed_data(data).astype(rv_var.dtype)
1411 if data.ndim != rv_var.ndim:
-> 1412 raise ShapeError(
1413 "Dimensionality of data and RV don't match.", actual=data.ndim, expected=rv_var.ndim
1414 )
1416 if aesara.config.compute_test_value != "off":
1417 test_value = getattr(rv_var.tag, "test_value", None)
ShapeError: Dimensionality of data and RV don't match. (actual 1 != expected 0)
cc @ricardoV94
Metadata
Metadata
Assignees
Labels
No labels