From c0a45d4d5b01423723d34f3e5c9278bf93e33a66 Mon Sep 17 00:00:00 2001 From: shreyas3156 Date: Thu, 20 Apr 2023 18:43:02 +0530 Subject: [PATCH 1/2] Make logprob inference for binary ops independent of order of inputs --- pymc/logprob/binary.py | 42 +++++++++++++++++++++++++++--------- tests/logprob/test_binary.py | 39 ++++++++++++++++++++++----------- 2 files changed, 58 insertions(+), 23 deletions(-) diff --git a/pymc/logprob/binary.py b/pymc/logprob/binary.py index 72224d394..9a2d73fcb 100644 --- a/pymc/logprob/binary.py +++ b/pymc/logprob/binary.py @@ -20,6 +20,7 @@ from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.basic import node_rewriter from pytensor.scalar.basic import GE, GT, LE, LT +from pytensor.tensor import TensorVariable from pytensor.tensor.math import ge, gt, le, lt from pymc.logprob.abstract import ( @@ -50,26 +51,47 @@ def find_measurable_comparisons( if isinstance(node.op, MeasurableComparison): return None # pragma: no cover - (compared_var,) = node.outputs - base_var, const = node.inputs + measurable_inputs = [ + (inp, idx) + for idx, inp in enumerate(node.inputs) + if inp.owner + and isinstance(inp.owner.op, MeasurableVariable) + and inp not in rv_map_feature.rv_values + ] - if not ( - base_var.owner - and isinstance(base_var.owner.op, MeasurableVariable) - and base_var not in rv_map_feature.rv_values - ): + if len(measurable_inputs) != 1: return None + base_var: TensorVariable = measurable_inputs[0][0] + + # Check that the other input is not potentially measurable, in which case this rewrite + # would be invalid + const = tuple(inp for inp in node.inputs if inp is not base_var) + # check for potential measurability of const - if not check_potential_measurability((const,), rv_map_feature): + if not check_potential_measurability(const, rv_map_feature): return None + const = const[0] + # Make base_var unmeasurable unmeasurable_base_var = ignore_logprob(base_var) - compared_op = MeasurableComparison(node.op.scalar_op) + node_scalar_op = node.op.scalar_op + + if measurable_inputs[0][1] == 1: + if isinstance(node_scalar_op, LT): + node_scalar_op = GT() + elif isinstance(node_scalar_op, GT): + node_scalar_op = LT() + elif isinstance(node_scalar_op, GE): + node_scalar_op = LE() + elif isinstance(node_scalar_op, LE): + node_scalar_op = GE() + + compared_op = MeasurableComparison(node_scalar_op) compared_rv = compared_op.make_node(unmeasurable_base_var, const).default_output() - compared_rv.name = compared_var.name + compared_rv.name = node.outputs[0].name return [compared_rv] diff --git a/tests/logprob/test_binary.py b/tests/logprob/test_binary.py index 0780dcf09..6f2248075 100644 --- a/tests/logprob/test_binary.py +++ b/tests/logprob/test_binary.py @@ -25,16 +25,17 @@ @pytest.mark.parametrize( - "comparison_op, exp_logp_true, exp_logp_false", + "comparison_op, exp_logp_true, exp_logp_false, inputs", [ - ((pt.lt, pt.le), "logcdf", "logsf"), - ((pt.gt, pt.ge), "logsf", "logcdf"), + ((pt.lt, pt.le), "logcdf", "logsf", (pt.random.normal(0, 1), 0.5)), + ((pt.gt, pt.ge), "logsf", "logcdf", (pt.random.normal(0, 1), 0.5)), + ((pt.lt, pt.le), "logsf", "logcdf", (0.5, pt.random.normal(0, 1))), + ((pt.gt, pt.ge), "logcdf", "logsf", (0.5, pt.random.normal(0, 1))), ], ) -def test_continuous_rv_comparison(comparison_op, exp_logp_true, exp_logp_false): - x_rv = pt.random.normal(0, 1) +def test_continuous_rv_comparison(comparison_op, exp_logp_true, exp_logp_false, inputs): for op in comparison_op: - comp_x_rv = op(x_rv, 0.5) + comp_x_rv = op(*inputs) comp_x_vv = comp_x_rv.clone() @@ -49,33 +50,45 @@ def test_continuous_rv_comparison(comparison_op, exp_logp_true, exp_logp_false): @pytest.mark.parametrize( - "comparison_op, exp_logp_true, exp_logp_false", + "comparison_op, exp_logp_true, exp_logp_false, inputs", [ ( pt.lt, lambda x: st.poisson(2).logcdf(x - 1), lambda x: np.logaddexp(st.poisson(2).logsf(x), st.poisson(2).logpmf(x)), + (pt.random.poisson(2), 3), ), ( pt.ge, lambda x: np.logaddexp(st.poisson(2).logsf(x), st.poisson(2).logpmf(x)), lambda x: st.poisson(2).logcdf(x - 1), + (pt.random.poisson(2), 3), ), + (pt.gt, st.poisson(2).logsf, st.poisson(2).logcdf, (pt.random.poisson(2), 3)), + (pt.le, st.poisson(2).logcdf, st.poisson(2).logsf, (pt.random.poisson(2), 3)), ( - pt.gt, + pt.lt, st.poisson(2).logsf, st.poisson(2).logcdf, + (3, pt.random.poisson(2)), + ), + (pt.ge, st.poisson(2).logcdf, st.poisson(2).logsf, (3, pt.random.poisson(2))), + ( + pt.gt, + lambda x: st.poisson(2).logcdf(x - 1), + lambda x: np.logaddexp(st.poisson(2).logsf(x), st.poisson(2).logpmf(x)), + (3, pt.random.poisson(2)), ), ( pt.le, - st.poisson(2).logcdf, - st.poisson(2).logsf, + lambda x: np.logaddexp(st.poisson(2).logsf(x), st.poisson(2).logpmf(x)), + lambda x: st.poisson(2).logcdf(x - 1), + (3, pt.random.poisson(2)), ), ], ) -def test_discrete_rv_comparison(comparison_op, exp_logp_true, exp_logp_false): - x_rv = pt.random.poisson(2) - cens_x_rv = comparison_op(x_rv, 3) +def test_discrete_rv_comparison(inputs, comparison_op, exp_logp_true, exp_logp_false): + cens_x_rv = comparison_op(*inputs) cens_x_vv = cens_x_rv.clone() From 3009cde5bccc5c0111d5209f73f81e0e82bde095 Mon Sep 17 00:00:00 2001 From: shreyas3156 Date: Fri, 21 Apr 2023 14:24:12 +0530 Subject: [PATCH 2/2] Add comment explaining change of op based on the order of inputs. --- pymc/logprob/binary.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pymc/logprob/binary.py b/pymc/logprob/binary.py index 9a2d73fcb..d9aeb3b57 100644 --- a/pymc/logprob/binary.py +++ b/pymc/logprob/binary.py @@ -62,6 +62,7 @@ def find_measurable_comparisons( if len(measurable_inputs) != 1: return None + # Make the measurable base_var always be the first input to the MeasurableComparison node base_var: TensorVariable = measurable_inputs[0][0] # Check that the other input is not potentially measurable, in which case this rewrite @@ -79,6 +80,7 @@ def find_measurable_comparisons( node_scalar_op = node.op.scalar_op + # Change the Op if the base_var is the second input in node.inputs. e.g. pt.lt(const, dist) -> pt.gt(dist, const) if measurable_inputs[0][1] == 1: if isinstance(node_scalar_op, LT): node_scalar_op = GT()