Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,9 @@ class CustomSymbolicDistRV(SymbolicRandomVariable):

def update(self, node: Node):
op = node.op
inner_updates = collect_default_updates(op.inner_inputs, op.inner_outputs)
inner_updates = collect_default_updates(
op.inner_inputs, op.inner_outputs, must_be_shared=False
)

# Map inner updates to outer inputs/outputs
updates = {}
Expand Down
101 changes: 72 additions & 29 deletions pymc/pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
Variable,
clone_get_equiv,
graph_inputs,
vars_between,
walk,
)
from pytensor.graph.fg import FunctionGraph
Expand All @@ -51,6 +50,7 @@
from pytensor.tensor.basic import _as_tensor_variable
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.type import RandomType
from pytensor.tensor.random.var import (
RandomGeneratorSharedVariable,
RandomStateSharedVariable,
Expand Down Expand Up @@ -1000,42 +1000,85 @@ def reseed_rngs(


def collect_default_updates(
inputs: Sequence[Variable], outputs: Sequence[Variable]
inputs: Sequence[Variable],
outputs: Sequence[Variable],
must_be_shared: bool = True,
) -> Dict[Variable, Variable]:
"""Collect default update expression of RVs between inputs and outputs"""
"""Collect default update expression for shared-variable RNGs used by RVs between inputs and outputs.

If `must_be_shared` is False, update expressions will also be returned for non-shared input RNGs.
This can be useful to obtain the symbolic update expressions from inner graphs.
"""

# Avoid circular import
from pymc.distributions.distribution import SymbolicRandomVariable

def find_default_update(clients, rng: Variable) -> Union[None, Variable]:
rng_clients = clients.get(rng, None)

# Root case, RNG is not used elsewhere
if not rng_clients:
return rng

if len(rng_clients) > 1:
warnings.warn(
f"RNG Variable {rng} has multiple clients. This is likely an inconsistent random graph.",
UserWarning,
)
return None

[client, _] = rng_clients[0]

# RNG is an output of the function, this is not a problem
if client == "output":
return rng

# RNG is used by another operator, which should output an update for the RNG
if isinstance(client.op, RandomVariable):
# RandomVariable first output is always the update of the input RNG
next_rng = client.outputs[0]

elif isinstance(client.op, SymbolicRandomVariable):
# SymbolicRandomVariable have an explicit method that returns an
# update mapping for their RNG(s)
next_rng = client.op.update(client).get(rng)
if next_rng is None:
raise ValueError(
f"No update mapping found for RNG used in SymbolicRandomVariable Op {client.op}"
)
else:
# We don't know how this RNG should be updated (e.g., Scan).
# The user should provide an update manually
return None

# Recurse until we find final update for RNG
return find_default_update(clients, next_rng)

outputs = makeiter(outputs)
fg = FunctionGraph(outputs=outputs, clone=False)
clients = fg.clients

rng_updates = {}
output_to_list = outputs if isinstance(outputs, (list, tuple)) else [outputs]
for random_var in (
var
for var in vars_between(inputs, output_to_list)
if var.owner
and isinstance(var.owner.op, (RandomVariable, SymbolicRandomVariable))
and var not in inputs
# Iterate over input RNGs. Only consider shared RNGs if `must_be_shared==True`
for input_rng in (
inp
for inp in graph_inputs(outputs, blockers=inputs)
if (
(not must_be_shared or isinstance(inp, SharedVariable))
and isinstance(inp.type, RandomType)
)
):
# All nodes in `vars_between(inputs, outputs)` have owners.
# But mypy doesn't know, so we just assert it:
assert random_var.owner.op is not None
if isinstance(random_var.owner.op, RandomVariable):
rng = random_var.owner.inputs[0]
if getattr(rng, "default_update", None) is not None:
update_map = {rng: rng.default_update}
else:
update_map = {rng: random_var.owner.outputs[0]}
# Even if an explicit default update is provided, we call it to
# issue any warnings about invalid random graphs.
default_update = find_default_update(clients, input_rng)

# Respect default update if provided
if getattr(input_rng, "default_update", None):
rng_updates[input_rng] = input_rng.default_update
else:
update_map = random_var.owner.op.update(random_var.owner)
# Check that we are not setting different update expressions for the same variables
for rng, update in update_map.items():
if rng not in rng_updates:
rng_updates[rng] = update
# When a variable has multiple outputs, it will be called twice with the same
# update expression. We don't want to raise in that case, only if the update
# expression in different from the one already registered
elif rng_updates[rng] is not update:
raise ValueError(f"Multiple update expressions found for the variable {rng}")
if default_update is not None:
rng_updates[input_rng] = default_update

return rng_updates


Expand Down
127 changes: 106 additions & 21 deletions tests/test_pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings

from unittest import mock

import numpy as np
Expand Down Expand Up @@ -38,6 +40,7 @@
from pymc.exceptions import NotConstantValueError
from pymc.logprob.utils import ParameterValueError
from pymc.pytensorf import (
collect_default_updates,
compile_pymc,
constant_fold,
convert_observed_data,
Expand Down Expand Up @@ -406,28 +409,63 @@ def test_compile_pymc_updates_inputs(self):
# Each RV adds a shared output for its rng
assert len(fn_fgraph.outputs) == 1 + rvs_in_graph

# Disable `reseed_rngs` so that we can test with simpler update rule
@mock.patch("pymc.pytensorf.reseed_rngs")
def test_compile_pymc_custom_update_op(self, _):
"""Test that custom MeasurableVariable Op updates are used by compile_pymc"""
def test_compile_pymc_symbolic_rv_update(self):
"""Test that SymbolicRandomVariable Op update methods are used by compile_pymc"""

class NonSymbolicRV(OpFromGraph):
def update(self, node):
return {node.inputs[0]: node.inputs[0] + 1}
return {node.inputs[0]: node.outputs[0]}

dummy_inputs = [pt.scalar(), pt.scalar()]
dummy_outputs = [pt.add(*dummy_inputs)]
dummy_x = NonSymbolicRV(dummy_inputs, dummy_outputs)(pytensor.shared(1.0), 1.0)
rng = pytensor.shared(np.random.default_rng())
dummy_rng = rng.type()
dummy_next_rng, dummy_x = NonSymbolicRV(
[dummy_rng], pt.random.normal(rng=dummy_rng).owner.outputs
)(rng)

# Check that there are no updates at first
fn = compile_pymc(inputs=[], outputs=dummy_x)
assert fn() == fn() == 2.0
assert fn() == fn()

# And they are enabled once the Op is registered as a SymbolicRV
SymbolicRandomVariable.register(NonSymbolicRV)
fn = compile_pymc(inputs=[], outputs=dummy_x)
assert fn() == 2.0
assert fn() == 3.0
fn = compile_pymc(inputs=[], outputs=dummy_x, random_seed=431)
assert fn() != fn()

def test_compile_pymc_symbolic_rv_missing_update(self):
"""Test that error is raised if SymbolicRandomVariable Op does not
provide rule for updating RNG"""

class SymbolicRV(OpFromGraph):
def update(self, node):
# Update is provided for rng1 but not rng2
return {node.inputs[0]: node.outputs[0]}

SymbolicRandomVariable.register(SymbolicRV)

# No problems at first, as the one RNG is given the update rule
rng1 = pytensor.shared(np.random.default_rng())
dummy_rng1 = rng1.type()
dummy_next_rng1, dummy_x1 = SymbolicRV(
[dummy_rng1],
pt.random.normal(rng=dummy_rng1).owner.outputs,
)(rng1)
fn = compile_pymc(inputs=[], outputs=dummy_x1, random_seed=433)
assert fn() != fn()

# Now there's a problem as there is no update rule for rng2
rng2 = pytensor.shared(np.random.default_rng())
dummy_rng2 = rng2.type()
dummy_next_rng1, dummy_x1, dummy_next_rng2, dummy_x2 = SymbolicRV(
[dummy_rng1, dummy_rng2],
[
*pt.random.normal(rng=dummy_rng1).owner.outputs,
*pt.random.normal(rng=dummy_rng2).owner.outputs,
],
)(rng1, rng2)
with pytest.raises(
ValueError, match="No update mapping found for RNG used in SymbolicRandomVariable"
):
compile_pymc(inputs=[], outputs=[dummy_x1, dummy_x2])

def test_random_seed(self):
seedx = pytensor.shared(np.random.default_rng(1))
Expand Down Expand Up @@ -457,15 +495,62 @@ def test_random_seed(self):
assert y3_eval == y2_eval

def test_multiple_updates_same_variable(self):
rng = pytensor.shared(np.random.default_rng(), name="rng")
x = pt.random.normal(rng=rng)
y = pt.random.normal(rng=rng)

assert compile_pymc([], [x])
assert compile_pymc([], [y])
msg = "Multiple update expressions found for the variable rng"
with pytest.raises(ValueError, match=msg):
compile_pymc([], [x, y])
# Raise if unexpected warning is issued
with warnings.catch_warnings():
warnings.simplefilter("error")

rng = pytensor.shared(np.random.default_rng(), name="rng")
x = pt.random.normal(rng=rng)
y = pt.random.normal(rng=rng)

# No warnings if only one variable is used
assert compile_pymc([], [x])
assert compile_pymc([], [y])

user_warn_msg = "RNG Variable rng has multiple clients"
with pytest.warns(UserWarning, match=user_warn_msg):
f = compile_pymc([], [x, y], random_seed=456)
assert f() == f()

# The user can provide an explicit update, but we will still issue a warning
with pytest.warns(UserWarning, match=user_warn_msg):
f = compile_pymc([], [x, y], updates={rng: y.owner.outputs[0]}, random_seed=456)
assert f() != f()

# Same with default update
rng.default_update = x.owner.outputs[0]
with pytest.warns(UserWarning, match=user_warn_msg):
f = compile_pymc([], [x, y], updates={rng: y.owner.outputs[0]}, random_seed=456)
assert f() != f()

def test_nested_updates(self):
rng = pytensor.shared(np.random.default_rng())
next_rng1, x = pt.random.normal(rng=rng).owner.outputs
next_rng2, y = pt.random.normal(rng=next_rng1).owner.outputs
next_rng3, z = pt.random.normal(rng=next_rng2).owner.outputs

collect_default_updates([], [x, y, z]) == {rng: next_rng3}

fn = compile_pymc([], [x, y, z], random_seed=514)
assert not set(list(np.array(fn()))) & set(list(np.array(fn())))

# A local myopic rule (as PyMC used before, would not work properly)
fn = pytensor.function([], [x, y, z], updates={rng: next_rng1})
assert set(list(np.array(fn()))) & set(list(np.array(fn())))


def test_collect_default_updates_must_be_shared():
shared_rng = pytensor.shared(np.random.default_rng())
nonshared_rng = shared_rng.type()

next_rng_of_shared, x = pt.random.normal(rng=shared_rng).owner.outputs
next_rng_of_nonshared, y = pt.random.normal(rng=nonshared_rng).owner.outputs

res = collect_default_updates(inputs=[nonshared_rng], outputs=[x, y])
assert res == {shared_rng: next_rng_of_shared}

res = collect_default_updates(inputs=[nonshared_rng], outputs=[x, y], must_be_shared=False)
assert res == {shared_rng: next_rng_of_shared, nonshared_rng: next_rng_of_nonshared}


def test_replace_rng_nodes():
Expand Down