Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from pymc.logprob.abstract import MeasurableVariable, _icdf, _logcdf, _logprob
from pymc.logprob.basic import logp
from pymc.logprob.rewriting import logprob_rewrites_db
from pymc.model import BlockModelAccess
from pymc.model import new_or_existing_block_model_access
from pymc.printing import str_for_dist
from pymc.pytensorf import collect_default_updates, convert_observed_data, floatX
from pymc.util import UNSET, _add_future_warning_tag
Expand Down Expand Up @@ -645,7 +645,7 @@ def rv_op(
size = normalize_size_param(size)
dummy_size_param = size.type()
dummy_dist_params = [dist_param.type() for dist_param in dist_params]
with BlockModelAccess(
with new_or_existing_block_model_access(
error_msg_on_access="Model variables cannot be created in the dist function. Use the `.dist` API"
):
dummy_rv = dist(*dummy_dist_params, dummy_size_param)
Expand Down Expand Up @@ -1048,7 +1048,7 @@ def is_symbolic_random(self, random, dist_params):
# Try calling random with symbolic inputs
try:
size = normalize_size_param(None)
with BlockModelAccess(
with new_or_existing_block_model_access(
error_msg_on_access="Model variables cannot be created in the random function. Use the `.dist` API to create such variables."
):
out = random(*dist_params, size)
Expand Down
12 changes: 10 additions & 2 deletions pymc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def __init__(
cls._context_class = context_class
super().__init__(name, bases, nmspc)

def get_context(cls, error_if_none=True) -> Optional[T]:
def get_context(cls, error_if_none=True, allow_block_model_access=False) -> Optional[T]:
"""Return the most recently pushed context object of type ``cls``
on the stack, or ``None``. If ``error_if_none`` is True (default),
raise a ``TypeError`` instead of returning ``None``."""
Expand All @@ -155,7 +155,7 @@ def get_context(cls, error_if_none=True) -> Optional[T]:
if error_if_none:
raise TypeError(f"No {cls} on context stack")
return None
if isinstance(candidate, BlockModelAccess):
if isinstance(candidate, BlockModelAccess) and not allow_block_model_access:
raise BlockModelAccessError(candidate.error_msg_on_access)
return candidate

Expand Down Expand Up @@ -1889,6 +1889,14 @@ def __init__(self, *args, error_msg_on_access="Model access is blocked", **kwarg
self.error_msg_on_access = error_msg_on_access


def new_or_existing_block_model_access(*args, **kwargs):
"""Return a BlockModelAccess in the stack or create a new one if none is found."""
model = Model.get_context(error_if_none=False, allow_block_model_access=True)
if isinstance(model, BlockModelAccess):
return model
return BlockModelAccess(*args, **kwargs)


def set_data(new_data, model=None, *, coords=None):
"""Sets the value of one or more data container variables. Note that the shape is also
dynamic, it is updated when the value is changed. See the examples below for two common
Expand Down
16 changes: 16 additions & 0 deletions tests/distributions/test_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,22 @@ def dist(size):

assert pm.CustomDist.dist(dist=dist)

def test_nested_custom_dist(self):
"""Test we can create CustomDist that creates another CustomDist"""

def dist(size=None):
def inner_dist(size=None):
return pm.Normal.dist(size=size)

inner_dist = pm.CustomDist.dist(dist=inner_dist, size=size)
return pt.exp(inner_dist)

rv = pm.CustomDist.dist(dist=dist)
np.testing.assert_allclose(
pm.logp(rv, 1.0).eval(),
pm.logp(pm.LogNormal.dist(), 1.0).eval(),
)


class TestSymbolicRandomVariable:
def test_inline(self):
Expand Down