Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 43 additions & 7 deletions pymc/logprob/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,11 @@
from pytensor.graph.op import Op
from pytensor.graph.replace import clone_replace
from pytensor.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter
from pytensor.scalar import Add, Exp, Log, Mul, Pow, Sqr, Sqrt
from pytensor.scalar import Abs, Add, Exp, Log, Mul, Pow, Sqr, Sqrt
from pytensor.scan.op import Scan
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import (
abs,
add,
exp,
log,
Expand Down Expand Up @@ -336,7 +337,7 @@ def apply(self, fgraph: FunctionGraph):
class MeasurableTransform(MeasurableElemwise):
"""A placeholder used to specify a log-likelihood for a transformed measurable variable"""

valid_scalar_types = (Exp, Log, Add, Mul, Pow)
valid_scalar_types = (Exp, Log, Add, Mul, Pow, Abs)

# Cannot use `transform` as name because it would clash with the property added by
# the `TransformValuesRewrite`
Expand Down Expand Up @@ -374,9 +375,14 @@ def measurable_transform_logprob(op: MeasurableTransform, values, *inputs, **kwa
else:
input_logprob = logprob(measurable_input, backward_value)

if input_logprob.ndim < value.ndim:
# Do we just need to sum the jacobian terms across the support dims?
raise NotImplementedError("Transform of multivariate RVs not implemented")

jacobian = op.transform_elemwise.log_jac_det(value, *other_inputs)

return input_logprob + jacobian
# The jacobian is used to ensure a value in the supported domain was provided
return at.switch(at.isnan(jacobian), -np.inf, input_logprob + jacobian)


@node_rewriter([reciprocal])
Expand Down Expand Up @@ -493,7 +499,7 @@ def measurable_sub_to_neg(fgraph, node):
return [at.add(minuend, at.neg(subtrahend))]


@node_rewriter([exp, log, add, mul, pow])
@node_rewriter([exp, log, add, mul, pow, abs])
def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]:
"""Find measurable transformations from Elemwise operators."""

Expand Down Expand Up @@ -553,6 +559,8 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
transform = ExpTransform()
elif isinstance(scalar_op, Log):
transform = LogTransform()
elif isinstance(scalar_op, Abs):
transform = AbsTransform()
elif isinstance(scalar_op, Pow):
# We only allow for the base to be measurable
if measurable_input_idx != 0:
Expand Down Expand Up @@ -696,6 +704,20 @@ def log_jac_det(self, value, *inputs):
return -at.log(value)


class AbsTransform(RVTransform):
name = "abs"

def forward(self, value, *inputs):
return at.abs(value)

def backward(self, value, *inputs):
value = at.switch(value >= 0, value, np.nan)
return -value, value

def log_jac_det(self, value, *inputs):
return at.switch(value >= 0, 0, np.nan)


class PowerTransform(RVTransform):
name = "power"

Expand All @@ -711,18 +733,32 @@ def forward(self, value, *inputs):
at.power(value, self.power)

def backward(self, value, *inputs):
backward_value = at.power(value, (1 / self.power))
inv_power = 1 / self.power

# Powers that don't admit negative values
if (np.abs(self.power) < 1) or (self.power % 2 == 0):
backward_value = at.switch(value >= 0, at.power(value, inv_power), np.nan)
# Powers that admit negative values require special logic, because (-1)**(1/3) returns `nan` in PyTensor
else:
backward_value = at.power(at.abs(value), inv_power) * at.switch(value >= 0, 1, -1)

# In this case the transform is not 1-to-1
if (self.power > 1) and (self.power % 2 == 0):
if self.power % 2 == 0:
return -backward_value, backward_value
else:
return backward_value

def log_jac_det(self, value, *inputs):
inv_power = 1 / self.power

# Note: This fails for value==0
return np.log(np.abs(inv_power)) + (inv_power - 1) * at.log(value)
res = np.log(np.abs(inv_power)) + (inv_power - 1) * at.log(at.abs(value))

# Powers that don't admit negative values
if (np.abs(self.power) < 1) or (self.power % 2 == 0):
res = at.switch(value >= 0, res, np.nan)

return res


class IntervalTransform(RVTransform):
Expand Down
72 changes: 62 additions & 10 deletions pymc/tests/logprob/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,15 +632,15 @@ def test_chained_transform():


def test_exp_transform_rv():
base_rv = at.random.normal(0, 1, size=2, name="base_rv")
base_rv = at.random.normal(0, 1, size=3, name="base_rv")
y_rv = at.exp(base_rv)
y_rv.name = "y"

y_vv = y_rv.clone()
logp = joint_logprob({y_rv: y_vv}, sum=False)
logp_fn = pytensor.function([y_vv], logp)

y_val = [0.1, 0.3]
y_val = [-2.0, 0.1, 0.3]
np.testing.assert_allclose(
logp_fn(y_val),
sp.stats.lognorm(s=1).logpdf(y_val),
Expand Down Expand Up @@ -794,28 +794,28 @@ def test_invalid_broadcasted_transform_rv_fails():
def test_reciprocal_rv_transform(numerator):
shape = 3
scale = 5
x_rv = numerator / at.random.gamma(shape, scale)
x_rv = numerator / at.random.gamma(shape, scale, size=(2,))
x_rv.name = "x"

x_vv = x_rv.clone()
x_logp_fn = pytensor.function([x_vv], joint_logprob({x_rv: x_vv}))
x_logp_fn = pytensor.function([x_vv], joint_logprob({x_rv: x_vv}, sum=False))

x_test_val = 1.5
assert np.isclose(
x_test_val = np.r_[-0.5, 1.5]
assert np.allclose(
x_logp_fn(x_test_val),
sp.stats.invgamma(shape, scale=scale * numerator).logpdf(x_test_val),
)


def test_sqr_transform():
# The square of a unit normal is a chi-square with 1 df
x_rv = at.random.normal(0, 1, size=(3,)) ** 2
x_rv = at.random.normal(0, 1, size=(4,)) ** 2
x_rv.name = "x"

x_vv = x_rv.clone()
x_logp_fn = pytensor.function([x_vv], joint_logprob({x_rv: x_vv}, sum=False))

x_test_val = np.r_[0.5, 1, 2.5]
x_test_val = np.r_[-0.5, 0.5, 1, 2.5]
assert np.allclose(
x_logp_fn(x_test_val),
sp.stats.chi2(df=1).logpdf(x_test_val),
Expand All @@ -824,19 +824,71 @@ def test_sqr_transform():

def test_sqrt_transform():
# The sqrt of a chisquare with n df is a chi distribution with n df
x_rv = at.sqrt(at.random.chisquare(df=3, size=(3,)))
x_rv = at.sqrt(at.random.chisquare(df=3, size=(4,)))
x_rv.name = "x"

x_vv = x_rv.clone()
x_logp_fn = pytensor.function([x_vv], joint_logprob({x_rv: x_vv}, sum=False))

x_test_val = np.r_[0.5, 1, 2.5]
x_test_val = np.r_[-2.5, 0.5, 1, 2.5]
assert np.allclose(
x_logp_fn(x_test_val),
sp.stats.chi(df=3).logpdf(x_test_val),
)


@pytest.mark.parametrize("power", (-3, -1, 1, 5, 7))
def test_negative_value_odd_power_transform(power):
# check that negative values and odd powers evaluate to a finite logp
x_rv = at.random.normal() ** power
x_rv.name = "x"

x_vv = x_rv.clone()
x_logp_fn = pytensor.function([x_vv], joint_logprob({x_rv: x_vv}, sum=False))

assert np.isfinite(x_logp_fn(1))
assert np.isfinite(x_logp_fn(-1))


@pytest.mark.parametrize("power", (-2, 2, 4, 6, 8))
def test_negative_value_even_power_transform(power):
# check that negative values and odd powers evaluate to -inf logp
x_rv = at.random.normal() ** power
x_rv.name = "x"

x_vv = x_rv.clone()
x_logp_fn = pytensor.function([x_vv], joint_logprob({x_rv: x_vv}, sum=False))

assert np.isfinite(x_logp_fn(1))
assert np.isneginf(x_logp_fn(-1))


@pytest.mark.parametrize("power", (-1 / 3, -1 / 2, 1 / 2, 1 / 3))
def test_negative_value_frac_power_transform(power):
# check that negative values and fractional powers evaluate to -inf logp
x_rv = at.random.normal() ** power
x_rv.name = "x"

x_vv = x_rv.clone()
x_logp_fn = pytensor.function([x_vv], joint_logprob({x_rv: x_vv}, sum=False))

assert np.isfinite(x_logp_fn(2.5))
assert np.isneginf(x_logp_fn(-2.5))


@pytest.mark.parametrize("test_val", (2.5, -2.5))
def test_absolute_transform(test_val):
x_rv = at.abs(at.random.normal())
y_rv = at.random.halfnormal()

x_vv = x_rv.clone()
y_vv = y_rv.clone()
x_logp_fn = pytensor.function([x_vv], joint_logprob({x_rv: x_vv}, sum=False))
y_logp_fn = pytensor.function([y_vv], joint_logprob({y_rv: y_vv}, sum=False))

assert np.allclose(x_logp_fn(test_val), y_logp_fn(test_val))


def test_negated_rv_transform():
x_rv = -at.random.halfnormal()
x_rv.name = "x"
Expand Down