Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,8 @@ def rv_op(
):
dummy_rv = dist(*dummy_dist_params, dummy_size_param)
dummy_params = [dummy_size_param, *dummy_dist_params]
# RNGs are not passed as explicit inputs (because we usually don't know how many are needed)
# We retrieve them here. This will also raise if the user forgot to specify some update in a Scan Op
dummy_updates_dict = collect_default_updates(inputs=dummy_params, outputs=(dummy_rv,))

rv_type = type(
Expand Down Expand Up @@ -1001,12 +1003,7 @@ def change_custom_dist_size(op, rv, new_size, expand):

return new_rv

# RNGs are not passed as explicit inputs (because we usually don't know how many are needed)
# We retrieve them here
updates_dict = collect_default_updates(inputs=dummy_params, outputs=(dummy_rv,))
rngs = updates_dict.keys()
rngs_updates = updates_dict.values()

rngs, rngs_updates = zip(*dummy_updates_dict.items())
inputs = [*dummy_params, *rngs]
outputs = [dummy_rv, *rngs_updates]
signature = cls._infer_final_signature(
Expand Down
54 changes: 30 additions & 24 deletions pymc/distributions/truncated.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,28 +106,34 @@ def rv_op(cls, dist, lower, upper, max_n_steps, *, size=None):
]
graph_inputs = [*rv_inputs, lower, upper]

rv = dist.owner.op.make_node(*rv_inputs).default_output()
# Variables with `_` suffix identify dummy inputs for the OpFromGraph
graph_inputs_ = [
inp.type() if not isinstance(inp.type, RandomType) else inp for inp in graph_inputs
]
*rv_inputs_, lower_, upper_ = graph_inputs_

rv_ = dist.owner.op.make_node(*rv_inputs_).default_output()

# Try to use inverted cdf sampling
# truncated_rv = icdf(rv, draw(uniform(cdf(lower), cdf(upper))))
try:
logcdf_lower, logcdf_upper = cls._create_logcdf_exprs(rv, rv, lower, upper)
logcdf_lower_, logcdf_upper_ = TruncatedRV._create_logcdf_exprs(
rv_, rv_, lower_, upper_
)
# We use the first RNG from the base RV, so we don't have to introduce a new one
# This is not problematic because the RNG won't be used in the RV logcdf graph
uniform_rng = next(inp for inp in rv_inputs if isinstance(inp.type, RandomType))
uniform_next_rng, uniform = pt.random.uniform(
pt.exp(logcdf_lower),
pt.exp(logcdf_upper),
rng=uniform_rng,
size=rv.shape,
uniform_rng_ = next(inp_ for inp_ in rv_inputs_ if isinstance(inp_.type, RandomType))
uniform_next_rng_, uniform_ = pt.random.uniform(
pt.exp(logcdf_lower_),
pt.exp(logcdf_upper_),
rng=uniform_rng_,
size=rv_.shape,
).owner.outputs
# So icdf does not see the random graph of uniform
uniform_type = uniform.type()
truncated_rv = graph_replace(icdf(rv, uniform_type), {uniform_type: uniform})
truncated_rv_ = icdf(rv_, uniform_, warn_rvs=False)
return TruncatedRV(
base_rv_op=dist.owner.op,
inputs=graph_inputs,
outputs=[truncated_rv, uniform_next_rng],
inputs=graph_inputs_,
outputs=[truncated_rv_, uniform_next_rng_],
ndim_supp=0,
max_n_steps=max_n_steps,
)(*graph_inputs)
Expand All @@ -154,25 +160,25 @@ def loop_fn(truncated_rv, reject_draws, lower, upper, *rv_inputs):

return (
(truncated_rv, reject_draws),
collect_default_updates(new_truncated_rv, inputs=rv_inputs),
collect_default_updates(new_truncated_rv),
until(~pt.any(reject_draws)),
)

(truncated_rv, reject_draws_), updates = scan(
(truncated_rv_, reject_draws_), updates = scan(
loop_fn,
outputs_info=[
pt.zeros_like(rv),
pt.ones_like(rv, dtype=bool),
pt.zeros_like(rv_),
pt.ones_like(rv_, dtype=bool),
],
non_sequences=[lower, upper, *rv_inputs],
non_sequences=[lower_, upper_, *rv_inputs_],
n_steps=max_n_steps,
strict=True,
)

truncated_rv = truncated_rv[-1]
convergence = ~pt.any(reject_draws_[-1])
truncated_rv = TruncationCheck(f"Truncation did not converge in {max_n_steps} steps")(
truncated_rv, convergence
truncated_rv_ = truncated_rv_[-1]
convergence_ = ~pt.any(reject_draws_[-1])
truncated_rv_ = TruncationCheck(f"Truncation did not converge in {max_n_steps} steps")(
truncated_rv_, convergence_
)

# Sort updates of each RNG so that they show in the same order as the input RNGs
Expand All @@ -184,8 +190,8 @@ def sort_updates(update):

return TruncatedRV(
base_rv_op=dist.owner.op,
inputs=graph_inputs,
outputs=[truncated_rv, *next_rngs],
inputs=graph_inputs_,
outputs=[truncated_rv_, *next_rngs],
ndim_supp=0,
max_n_steps=max_n_steps,
)(*graph_inputs)
Expand Down
14 changes: 14 additions & 0 deletions tests/distributions/test_truncated.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,3 +585,17 @@ def test_truncated_identity_input(dist_op):

rv_out = Truncated.dist(dist=dist_op(mu_identity, 5), lower=0, upper=1)
assert np.ptp(draw(rv_out, draws=500)) < 1


@pytest.mark.parametrize("rv_op", [icdf_normal, rejection_normal])
def test_truncated_custom_dist_indexed_argument(rv_op):
# Regression test for https://github.com/pymc-devs/pymc/issues/7312

def dist(scale, size):
return pt.exp(rv_op(scale=scale, size=size))

scale = Exponential.dist(scale=[1, 2, 3])
latent = CustomDist.dist(scale[[0, 0, 1, 1, 2, 2]], dist=dist)
rv_out = Truncated.dist(latent, upper=7)

assert np.ptp(draw(rv_out, draws=100)) < 7