From b9337e6b6bd063486fd971e12c75797271518b15 Mon Sep 17 00:00:00 2001 From: Manul Patel Date: Sat, 15 Apr 2023 13:52:13 +0530 Subject: [PATCH 1/3] Added alternative scale parameterization with unit tests to exponential --- pymc/distributions/continuous.py | 13 +++++++++-- tests/distributions/test_continuous.py | 32 ++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 2 deletions(-) diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index 91d2193504..a353e7bb2c 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -1347,13 +1347,22 @@ class Exponential(PositiveContinuous): ---------- lam : tensor_like of float Rate or inverse scale (``lam`` > 0). + scale: tensor_like of float + Alternative parameter (scale = 1/lam). """ rv_op = exponential @classmethod - def dist(cls, lam: DIST_PARAMETER_TYPES, *args, **kwargs): - lam = pt.as_tensor_variable(floatX(lam)) + def dist(cls, lam=None, scale=None, *args, **kwargs): + if lam is not None and scale is not None: + raise ValueError("Incompatible parametrization. Can't specify both lam and scale.") + elif lam is None and scale is None: + raise ValueError("Incompatible parametrization. Must specify either lam or scale.") + + if scale is not None: + lam = pt.reciprocal(scale) + lam = pt.as_tensor_variable(floatX(lam)) # PyTensor exponential op is parametrized in terms of mu (1/lam) return super().dist([pt.reciprocal(lam)], **kwargs) diff --git a/tests/distributions/test_continuous.py b/tests/distributions/test_continuous.py index 1f673eb285..24cbbeabb8 100644 --- a/tests/distributions/test_continuous.py +++ b/tests/distributions/test_continuous.py @@ -432,18 +432,43 @@ def test_exponential(self): {"lam": Rplus}, lambda value, lam: st.expon.logpdf(value, 0, 1 / lam), ) + check_logp( + pm.Exponential, + Rplus, + {"scale": Rplus}, + lambda value, scale: st.expon.logpdf(value, 0, scale), + ) check_logcdf( pm.Exponential, Rplus, {"lam": Rplus}, lambda value, lam: st.expon.logcdf(value, 0, 1 / lam), ) + check_logcdf( + pm.Exponential, + Rplus, + {"scale": Rplus}, + lambda value, scale: st.expon.logcdf(value, 0, scale), + ) check_icdf( pm.Exponential, {"lam": Rplus}, lambda q, lam: st.expon.ppf(q, loc=0, scale=1 / lam), ) + def test_exponential_wrong_arguments(self): + m = pm.Model() + + msg = "Incompatible parametrization. Can't specify both lam and scale" + with m: + with pytest.raises(ValueError, match=msg): + pm.Exponential("x", lam=0.5, scale=5) + + msg = "Incompatible parametrization. Must specify either lam or scale" + with m: + with pytest.raises(ValueError, match=msg): + pm.Exponential("x") + def test_laplace(self): check_logp( pm.Laplace, @@ -2091,6 +2116,13 @@ class TestExponential(BaseTestDistributionRandom): ] +class TestExponentialScale(BaseTestDistributionRandom): + pymc_dist = pm.Exponential + pymc_dist_params = {"scale": 5.0} + expected_rv_op_params = {"mu": pymc_dist_params["scale"]} + checks_to_run = ["check_pymc_params_match_rv_op"] + + class TestCauchy(BaseTestDistributionRandom): pymc_dist = pm.Cauchy pymc_dist_params = {"alpha": 2.0, "beta": 5.0} From f4bd7d41c1179dee5fc2f78aa6c303b6da1ec5c7 Mon Sep 17 00:00:00 2001 From: Manul Patel Date: Thu, 20 Apr 2023 17:33:46 +0530 Subject: [PATCH 2/3] Scale used in exponential and logp, logcdf checks not required. --- pymc/distributions/continuous.py | 8 ++++---- tests/distributions/test_continuous.py | 12 ------------ 2 files changed, 4 insertions(+), 16 deletions(-) diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index a353e7bb2c..fe881983e0 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -1359,12 +1359,12 @@ def dist(cls, lam=None, scale=None, *args, **kwargs): elif lam is None and scale is None: raise ValueError("Incompatible parametrization. Must specify either lam or scale.") - if scale is not None: - lam = pt.reciprocal(scale) + if scale is None: + scale = pt.reciprocal(lam) - lam = pt.as_tensor_variable(floatX(lam)) + scale = pt.as_tensor_variable(floatX(scale)) # PyTensor exponential op is parametrized in terms of mu (1/lam) - return super().dist([pt.reciprocal(lam)], **kwargs) + return super().dist([scale], **kwargs) def moment(rv, size, mu): if not rv_size_is_none(size): diff --git a/tests/distributions/test_continuous.py b/tests/distributions/test_continuous.py index 24cbbeabb8..2fe455898f 100644 --- a/tests/distributions/test_continuous.py +++ b/tests/distributions/test_continuous.py @@ -432,24 +432,12 @@ def test_exponential(self): {"lam": Rplus}, lambda value, lam: st.expon.logpdf(value, 0, 1 / lam), ) - check_logp( - pm.Exponential, - Rplus, - {"scale": Rplus}, - lambda value, scale: st.expon.logpdf(value, 0, scale), - ) check_logcdf( pm.Exponential, Rplus, {"lam": Rplus}, lambda value, lam: st.expon.logcdf(value, 0, 1 / lam), ) - check_logcdf( - pm.Exponential, - Rplus, - {"scale": Rplus}, - lambda value, scale: st.expon.logcdf(value, 0, scale), - ) check_icdf( pm.Exponential, {"lam": Rplus}, From 2e413cfb1976081a6774b891efbb27f5fc46a050 Mon Sep 17 00:00:00 2001 From: Manul Patel Date: Thu, 20 Apr 2023 23:22:20 +0530 Subject: [PATCH 3/3] Simplified test case --- tests/distributions/test_continuous.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/tests/distributions/test_continuous.py b/tests/distributions/test_continuous.py index 2fe455898f..7ca5eb53a0 100644 --- a/tests/distributions/test_continuous.py +++ b/tests/distributions/test_continuous.py @@ -445,17 +445,13 @@ def test_exponential(self): ) def test_exponential_wrong_arguments(self): - m = pm.Model() - msg = "Incompatible parametrization. Can't specify both lam and scale" - with m: - with pytest.raises(ValueError, match=msg): - pm.Exponential("x", lam=0.5, scale=5) + with pytest.raises(ValueError, match=msg): + pm.Exponential.dist(lam=0.5, scale=5) msg = "Incompatible parametrization. Must specify either lam or scale" - with m: - with pytest.raises(ValueError, match=msg): - pm.Exponential("x") + with pytest.raises(ValueError, match=msg): + pm.Exponential.dist() def test_laplace(self): check_logp(