diff --git a/aeppl/transforms.py b/aeppl/transforms.py index fa8929ad..eb6ef7dc 100644 --- a/aeppl/transforms.py +++ b/aeppl/transforms.py @@ -10,8 +10,21 @@ from aesara.graph.fg import FunctionGraph from aesara.graph.op import Op from aesara.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter -from aesara.scalar import Add, Exp, Log, Mul +from aesara.scalar import ( + Add, + Exp, + Log, + Mul, + Neg, + Pow, + Reciprocal, + Sqr, + Sqrt, + Sub, + TrueDiv, +) from aesara.tensor.elemwise import Elemwise +from aesara.tensor.exceptions import NotScalarConstantError from aesara.tensor.rewriting.basic import ( register_specialize, register_stabilize, @@ -86,8 +99,11 @@ def forward(self, value: TensorVariable, *inputs: Variable) -> TensorVariable: """Apply the transformation.""" @abc.abstractmethod - def backward(self, value: TensorVariable, *inputs: Variable) -> TensorVariable: - """Invert the transformation.""" + def backward( + self, value: TensorVariable, *inputs: Variable + ) -> Union[TensorVariable, Tuple[TensorVariable, ...]]: + """Invert the transformation. Multible values bay be returned when the + transformation is not 1-to-1""" def log_jac_det(self, value: TensorVariable, *inputs) -> TensorVariable: """Construct the log of the absolute value of the Jacobian determinant.""" @@ -217,7 +233,7 @@ def apply(self, fgraph: FunctionGraph): class MeasurableTransform(MeasurableElemwise): """A placeholder used to specify a log-likelihood for a transformed measurable variable""" - valid_scalar_types = (Exp, Log, Add, Mul) + valid_scalar_types = (Exp, Log, Add, Mul, Pow) # Cannot use `transform` as name because it would clash with the property added by # the `TransformValuesRewrite` @@ -248,13 +264,157 @@ def measurable_transform_logprob(op: MeasurableTransform, values, *inputs, **kwa # The value variable must still be back-transformed to be on the natural support of # the respective measurable input. backward_value = op.transform_elemwise.backward(value, *other_inputs) - input_logprob = logprob(measurable_input, backward_value, **kwargs) + + # Some transformations, like squaring may produce multiple backward values + if isinstance(backward_value, tuple): + input_logprob = at.logaddexp( + *( + logprob(measurable_input, backward_val, **kwargs) + for backward_val in backward_value + ) + ) + else: + input_logprob = logprob(measurable_input, backward_value) jacobian = op.transform_elemwise.log_jac_det(value, *other_inputs) return input_logprob + jacobian +@node_rewriter([Elemwise]) +def measurable_reciprocal_to_product(fgraph, node): + """Convert reciprocal of `MeasurableVariable`s to power.""" + if isinstance(node.op.scalar_op, Reciprocal): + inp = node.inputs[0] + if not (inp.owner and isinstance(inp.owner.op, MeasurableVariable)): + return None + + rv_map_feature: Optional[PreserveRVMappings] = getattr( + fgraph, "preserve_rv_mappings", None + ) + if rv_map_feature is None: + return None # pragma: no cover + + # Only apply this rewrite if the variable is unvalued + if inp in rv_map_feature.rv_values: + return None # pragma: no cover + + return [at.pow(inp, -1.0)] + + +@node_rewriter([Elemwise]) +def measurable_div_to_power_product(fgraph, node): + """Convert divisions involving `MeasurableVariable`s to power product.""" + if isinstance(node.op.scalar_op, TrueDiv): + measurable_vars = [ + var + for var in node.inputs + if (var.owner and isinstance(var.owner.op, MeasurableVariable)) + ] + if not measurable_vars: + return None # pragma: no cover + + rv_map_feature: Optional[PreserveRVMappings] = getattr( + fgraph, "preserve_rv_mappings", None + ) + if rv_map_feature is None: + return None # pragma: no cover + + # Only apply this rewrite if there is one unvalued MeasurableVariable involved + if all( + measurable_var in rv_map_feature.rv_values + for measurable_var in measurable_vars + ): + return None # pragma: no cover + + numerator, denominator = node.inputs + + # Check if numerator is 1 + try: + if at.get_scalar_constant_value(numerator) == 1: + # We convert the denominator directly to a power transform as this + # must be the measurable input + return [at.pow(denominator, -1)] + except NotScalarConstantError: + pass + return [at.mul(numerator, at.reciprocal(denominator))] + + +@node_rewriter([Elemwise]) +def measurable_sqrt_sqr_to_power(fgraph, node): + """Convert square root or square of `MeasurableVariable`s to power form.""" + if isinstance(node.op.scalar_op, (Sqr, Sqrt)): + inp = node.inputs[0] + if not (inp.owner and isinstance(inp.owner.op, MeasurableVariable)): + return None + + rv_map_feature: Optional[PreserveRVMappings] = getattr( + fgraph, "preserve_rv_mappings", None + ) + if rv_map_feature is None: + return None # pragma: no cover + + # Only apply this rewrite if the variable is unvalued + if inp in rv_map_feature.rv_values: + return None # pragma: no cover + + if isinstance(node.op.scalar_op, Sqr): + return [at.pow(inp, 2)] + + if isinstance(node.op.scalar_op, Sqrt): + return [at.pow(inp, 1 / 2)] + + +@node_rewriter([Elemwise]) +def measurable_neg_to_product(fgraph, node): + """Convert negation of `MeasurableVariable`s to product with `-1`.""" + if isinstance(node.op.scalar_op, Neg): + inp = node.inputs[0] + if not (inp.owner and isinstance(inp.owner.op, MeasurableVariable)): + return None + + rv_map_feature: Optional[PreserveRVMappings] = getattr( + fgraph, "preserve_rv_mappings", None + ) + if rv_map_feature is None: + return None # pragma: no cover + + # Only apply this rewrite if the variable is unvalued + if inp in rv_map_feature.rv_values: + return None # pragma: no cover + + return [at.mul(inp, -1.0)] + + +@node_rewriter([Elemwise]) +def measurable_sub_to_neg(fgraph, node): + """Convert subtraction involving `MeasurableVariable`s to addition with neg""" + if isinstance(node.op.scalar_op, Sub): + measurable_vars = [ + var + for var in node.inputs + if (var.owner and isinstance(var.owner.op, MeasurableVariable)) + ] + if not measurable_vars: + return None # pragma: no cover + + rv_map_feature: Optional[PreserveRVMappings] = getattr( + fgraph, "preserve_rv_mappings", None + ) + if rv_map_feature is None: + return None # pragma: no cover + + # Only apply this rewrite if there is one unvalued MeasurableVariable involved + if all( + measurable_var in rv_map_feature.rv_values + for measurable_var in measurable_vars + ): + return None # pragma: no cover + + minuend, subtrahend = node.inputs + return [at.add(minuend, at.neg(subtrahend))] + + @node_rewriter([Elemwise]) def find_measurable_transforms( fgraph: FunctionGraph, node: Node @@ -319,6 +479,18 @@ def find_measurable_transforms( transform = ExpTransform() elif isinstance(scalar_op, Log): transform = LogTransform() + elif isinstance(scalar_op, Pow): + # We only allow for the base to be measurable + if measurable_input_idx != 0: + return None + try: + (power,) = other_inputs + power = at.get_scalar_constant_value(power) + # Power needs to be a constant + except NotScalarConstantError: + return None + transform_inputs = (measurable_input, power) + transform = PowerTransform(power=power) elif isinstance(scalar_op, Add): transform_inputs = (measurable_input, at.add(*other_inputs)) transform = LocTransform( @@ -341,6 +513,38 @@ def find_measurable_transforms( return [transform_out] +measurable_ir_rewrites_db.register( + "measurable_div_to_power_product", + measurable_div_to_power_product, + -5, + "basic", + "transform", +) + +measurable_ir_rewrites_db.register( + "measurable_sqrt_sqr_to_power", + measurable_sqrt_sqr_to_power, + -5, + "basic", + "transform", +) + +measurable_ir_rewrites_db.register( + "measurable_neg_to_product", + measurable_neg_to_product, + -5, + "basic", + "transform", +) + +measurable_ir_rewrites_db.register( + "measurable_sub_to_neg", + measurable_sub_to_neg, + -5, + "basic", + "transform", +) + measurable_ir_rewrites_db.register( "find_measurable_transforms", find_measurable_transforms, @@ -413,6 +617,30 @@ def log_jac_det(self, value, *inputs): return -at.log(value) +class PowerTransform(RVTransform): + name = "power" + + def __init__(self, power=None): + self.power = power + super().__init__() + + def forward(self, value, *inputs): + at.power(value, self.power) + + def backward(self, value, *inputs): + backward_value = at.power(value, at.reciprocal(self.power)) + + # In this case the transform is not 1-to-1 + if (self.power > 1) and (self.power % 2 == 0): + return -backward_value, backward_value + + return backward_value + + def log_jac_det(self, value, *inputs): + inv_power = at.reciprocal(self.power) + return at.log(at.abs(inv_power)) + (inv_power - 1) * at.log(value) + + class IntervalTransform(RVTransform): name = "interval" diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 30ece7af..f36b5380 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -591,17 +591,20 @@ def test_log_transform_rv(): @pytest.mark.parametrize( - "rv_size, loc_type", + "rv_size, loc_type, addition", [ - (None, at.scalar), - (2, at.vector), - ((2, 1), at.col), + (None, at.scalar, True), + (2, at.vector, False), + ((2, 1), at.col, True), ], ) -def test_loc_transform_rv(rv_size, loc_type): +def test_loc_transform_rv(rv_size, loc_type, addition): loc = loc_type("loc") - y_rv = loc + at.random.normal(0, 1, size=rv_size, name="base_rv") + if addition: + y_rv = loc + at.random.normal(0, 1, size=rv_size, name="base_rv") + else: + y_rv = at.random.normal(0, 1, size=rv_size, name="base_rv") - at.neg(loc) y_rv.name = "y" y_vv = y_rv.clone() @@ -619,17 +622,22 @@ def test_loc_transform_rv(rv_size, loc_type): @pytest.mark.parametrize( - "rv_size, scale_type", + "rv_size, scale_type, product", [ - (None, at.scalar), - (1, at.TensorType("floatX", (True,))), - ((2, 3), at.matrix), + (None, at.scalar, True), + (1, at.TensorType("floatX", (True,)), True), + ((2, 3), at.matrix, False), ], ) -def test_scale_transform_rv(rv_size, scale_type): +def test_scale_transform_rv(rv_size, scale_type, product): scale = scale_type("scale") - y_rv = at.random.normal(0, 1, size=rv_size, name="base_rv") * scale + if product: + y_rv = at.random.normal(0, 1, size=rv_size, name="base_rv") * scale + else: + y_rv = at.random.normal(0, 1, size=rv_size, name="base_rv") / at.reciprocal( + scale + ) y_rv.name = "y" y_vv = y_rv.clone() @@ -709,3 +717,69 @@ def test_invalid_broadcasted_transform_rv_fails(): logp = joint_logprob({y_rv: y_vv}) logp.eval({y_vv: [0, 0, 0, 0], loc: [0, 0, 0, 0]}) assert False, "Should have failed before" + + +@pytest.mark.parametrize("numerator", (1.0, 2.0)) +def test_reciprocal_rv_transform(numerator): + shape = 3 + scale = 5 + x_rv = numerator / at.random.gamma(shape, scale) + x_rv.name = "x" + + x_vv = x_rv.clone() + x_logp_fn = aesara.function([x_vv], joint_logprob({x_rv: x_vv})) + + x_test_val = 1.5 + assert np.isclose( + x_logp_fn(x_test_val), + sp.stats.invgamma(shape, scale=scale * numerator).logpdf(x_test_val), + ) + + +def test_sqr_transform(): + x_rv = at.random.normal(0, 1, size=(3,)) ** 2 + x_rv.name = "x" + + x_vv = x_rv.clone() + x_logp_fn = aesara.function([x_vv], joint_logprob({x_rv: x_vv}, sum=False)) + + x_test_val = np.r_[0.5, 1, 2.5] + assert np.allclose( + x_logp_fn(x_test_val), + sp.stats.chi2(df=1).logpdf(x_test_val), + ) + + +def test_sqrt_transform(): + x_rv = at.sqrt(at.random.chisquare(df=3, size=(3,))) + x_rv.name = "x" + + x_vv = x_rv.clone() + x_logp_fn = aesara.function([x_vv], joint_logprob({x_rv: x_vv}, sum=False)) + + x_test_val = np.r_[0.5, 1, 2.5] + assert np.allclose( + x_logp_fn(x_test_val), + sp.stats.chi(df=3).logpdf(x_test_val), + ) + + +def test_negated_rv_transform(): + x_rv = -at.random.halfnormal() + x_rv.name = "x" + + x_vv = x_rv.clone() + x_logp_fn = aesara.function([x_vv], joint_logprob({x_rv: x_vv})) + + assert np.isclose(x_logp_fn(-1.5), sp.stats.halfnorm.logpdf(1.5)) + + +def test_subtracted_rv_transform(): + # Choose base RV that is assymetric around zero + x_rv = 5.0 - at.random.normal(1.0) + x_rv.name = "x" + + x_vv = x_rv.clone() + x_logp_fn = aesara.function([x_vv], joint_logprob({x_rv: x_vv})) + + assert np.isclose(x_logp_fn(7.3), sp.stats.norm.logpdf(5.0 - 7.3, 1.0))