Skip to content

Shape bug for observed truncated normal distribution initialized via pm.Truncated #6156

@larryshamalama

Description

@larryshamalama

When using a normal distribution within pm.Truncated with several observations, we encounter the following ShapeError

import pymc as pm

with pm.Model() as truncated_model:
    mu = pm.Normal("mu", 0., 5.)
    sigma = pm.HalfCauchy("sigma", 2.5)
    normal_dist = pm.Normal.dist(mu=mu, sigma=sigma)
    truncated_normal = pm.Truncated(
        "truncated_normal", 
        dist=normal_dist, 
        lower=-2., 
        upper=2.,
        observed=[-1, 0, 1]
    )

The problem should be somewhere here when we return a TruncatedNormalRV.

Traceback
---------------------------------------------------------------------------
ShapeError                                Traceback (most recent call last)
Input In [2], in <cell line: 1>()
      3     sigma = pm.HalfCauchy("sigma", 2.5)
      4     normal_dist = pm.Normal.dist(mu=mu, sigma=sigma)
----> 5     truncated_normal = pm.Truncated(
      6         "truncated_normal", 
      7         dist=normal_dist, 
      8         lower=-2., 
      9         upper=2.,
     10         observed=[-1, 0, 1]
     11     )
     13 pm.model_to_graphviz(truncated_model)

File ~/Documents/GitHub/pymc/pymc/distributions/distribution.py:292, in Distribution.__new__(cls, name, rng, dims, initval, observed, total_size, transform, *args, **kwargs)
    288         kwargs["shape"] = tuple(observed.shape)
    290 rv_out = cls.dist(*args, **kwargs)
--> 292 rv_out = model.register_rv(
    293     rv_out,
    294     name,
    295     observed,
    296     total_size,
    297     dims=dims,
    298     transform=transform,
    299     initval=initval,
    300 )
    302 # add in pretty-printing support
    303 rv_out.str_repr = types.MethodType(str_for_dist, rv_out)

File ~/Documents/GitHub/pymc/pymc/model.py:1386, in Model.register_rv(self, rv_var, name, data, total_size, dims, transform, initval)
   1379         raise TypeError(
   1380             "Variables that depend on other nodes cannot be used for observed data."
   1381             f"The data variable was: {data}"
   1382         )
   1384     # `rv_var` is potentially changed by `make_obs_var`,
   1385     # for example into a new graph for imputation of missing data.
-> 1386     rv_var = self.make_obs_var(rv_var, data, dims, transform)
   1388 return rv_var

File ~/Documents/GitHub/pymc/pymc/model.py:1412, in Model.make_obs_var(self, rv_var, data, dims, transform)
   1409 data = convert_observed_data(data).astype(rv_var.dtype)
   1411 if data.ndim != rv_var.ndim:
-> 1412     raise ShapeError(
   1413         "Dimensionality of data and RV don't match.", actual=data.ndim, expected=rv_var.ndim
   1414     )
   1416 if aesara.config.compute_test_value != "off":
   1417     test_value = getattr(rv_var.tag, "test_value", None)

ShapeError: Dimensionality of data and RV don't match. (actual 1 != expected 0)

cc @ricardoV94

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions