From 214d0c678eeb33e9d71506f981611e121bdd4793 Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Mon, 12 Feb 2024 11:27:50 -0300 Subject: [PATCH 1/4] add compute_log_prior --- pymc/stats/__init__.py | 3 +- pymc/stats/log_density.py | 159 ++++++++++++++++++++++++++++++++++ pymc/stats/log_likelihood.py | 95 ++++---------------- pymc/stats/log_prior.py | 61 +++++++++++++ tests/stats/test_log_prior.py | 48 ++++++++++ 5 files changed, 285 insertions(+), 81 deletions(-) create mode 100644 pymc/stats/log_density.py create mode 100644 pymc/stats/log_prior.py create mode 100644 tests/stats/test_log_prior.py diff --git a/pymc/stats/__init__.py b/pymc/stats/__init__.py index e27b5265a0..7609142a24 100644 --- a/pymc/stats/__init__.py +++ b/pymc/stats/__init__.py @@ -28,5 +28,6 @@ setattr(sys.modules[__name__], attr, obj) from pymc.stats.log_likelihood import compute_log_likelihood +from pymc.stats.log_prior import compute_log_prior -__all__ = ("compute_log_likelihood", *az.stats.__all__) +__all__ = ("compute_log_likelihood", "compute_log_prior", *az.stats.__all__) diff --git a/pymc/stats/log_density.py b/pymc/stats/log_density.py new file mode 100644 index 0000000000..ee03be3d73 --- /dev/null +++ b/pymc/stats/log_density.py @@ -0,0 +1,159 @@ +# Copyright 2024 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections.abc import Sequence +from typing import Optional, cast + +from arviz import InferenceData, dict_to_dataset +from fastprogress import progress_bar + +import pymc + +from pymc.backends.arviz import _DefaultTrace, coords_and_dims_for_inferencedata +from pymc.model import Model, modelcontext +from pymc.pytensorf import PointFunc +from pymc.util import dataset_to_point_list + + +def compute_log_prior( + idata: InferenceData, + var_names: Optional[Sequence[str]] = None, + extend_inferencedata: bool = True, + model: Optional[Model] = None, + sample_dims: Sequence[str] = ("chain", "draw"), + progressbar=True, +): + """Compute elemwise log_prior of model given InferenceData with posterior group + + Parameters + ---------- + idata : InferenceData + InferenceData with posterior group + var_names : sequence of str, optional + List of Observed variable names for which to compute log_prior. + Defaults to all all free variables. + extend_inferencedata : bool, default True + Whether to extend the original InferenceData or return a new one + model : Model, optional + sample_dims : sequence of str, default ("chain", "draw") + progressbar : bool, default True + + Returns + ------- + idata : InferenceData + InferenceData with log_prior group + """ + return compute_log_density( + idata=idata, + var_names=var_names, + extend_inferencedata=extend_inferencedata, + model=model, + kind="prior", + sample_dims=sample_dims, + progressbar=progressbar, + ) + + +def compute_log_density( + idata: InferenceData, + *, + var_names: Optional[Sequence[str]] = None, + extend_inferencedata: bool = True, + model: Optional[Model] = None, + kind="likelihood", + sample_dims: Sequence[str] = ("chain", "draw"), + progressbar=True, +): + """ + Compute elemwise log_likelihood or log_prior of model given InferenceData with posterior group + """ + + posterior = idata["posterior"] + + model = modelcontext(model) + + if kind not in ("likelihood", "prior"): + raise ValueError("kind must be either 'likelihood' or 'prior'") + + if kind == "likelihood": + target_rvs = model.observed_RVs + target_str = "observed_RVs" + else: + target_rvs = model.unobserved_RVs + target_str = "free_RVs" + + if var_names is None: + vars = target_rvs + var_names = tuple(rv.name for rv in vars) + else: + vars = [model.named_vars[name] for name in var_names] + if not set(vars).issubset(target_rvs): + raise ValueError(f"var_names must refer to {target_str} in the model. Got: {var_names}") + + # We need to temporarily disable transforms, because the InferenceData only keeps the untransformed values + try: + original_rvs_to_values = model.rvs_to_values + original_rvs_to_transforms = model.rvs_to_transforms + + model.rvs_to_values = { + rv: rv.clone() if rv not in model.observed_RVs else value + for rv, value in model.rvs_to_values.items() + } + model.rvs_to_transforms = {rv: None for rv in model.basic_RVs} + + elemwise_logdens_fn = model.compile_fn( + inputs=model.value_vars, + outs=model.logp(vars=vars, sum=False), + on_unused_input="ignore", + ) + elemwise_logdens_fn = cast(PointFunc, elemwise_logdens_fn) + finally: + model.rvs_to_values = original_rvs_to_values + model.rvs_to_transforms = original_rvs_to_transforms + + # Ignore Deterministics + posterior_values = posterior[[rv.name for rv in model.free_RVs]] + posterior_pts, stacked_dims = dataset_to_point_list(posterior_values, sample_dims) + + n_pts = len(posterior_pts) + logdens_dict = _DefaultTrace(n_pts) + indices = range(n_pts) + if progressbar: + indices = progress_bar(indices, total=n_pts, display=progressbar) + + for idx in indices: + logdenss_pts = elemwise_logdens_fn(posterior_pts[idx]) + for rv_name, rv_logdens in zip(var_names, logdenss_pts): + logdens_dict.insert(rv_name, rv_logdens, idx) + + logdens_trace = logdens_dict.trace_dict + for key, array in logdens_trace.items(): + logdens_trace[key] = array.reshape( + (*[len(coord) for coord in stacked_dims.values()], *array.shape[1:]) + ) + + coords, dims = coords_and_dims_for_inferencedata(model) + logdens_dataset = dict_to_dataset( + logdens_trace, + library=pymc, + dims=dims, + coords=coords, + default_dims=list(sample_dims), + skip_event_dims=True, + ) + + if extend_inferencedata: + idata.add_groups({f"log_{kind}": logdens_dataset}) + return idata + else: + return logdens_dataset diff --git a/pymc/stats/log_likelihood.py b/pymc/stats/log_likelihood.py index 26874d5a4e..a267b97847 100644 --- a/pymc/stats/log_likelihood.py +++ b/pymc/stats/log_likelihood.py @@ -12,19 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Optional, cast +from typing import Optional -from arviz import InferenceData, dict_to_dataset -from fastprogress import progress_bar +from arviz import InferenceData -import pymc +from pymc.model import Model +from pymc.stats.log_density import compute_log_density -from pymc.backends.arviz import _DefaultTrace, coords_and_dims_for_inferencedata -from pymc.model import Model, modelcontext -from pymc.pytensorf import PointFunc -from pymc.util import dataset_to_point_list - -__all__ = ("compute_log_likelihood",) +__all__ = "compute_log_likelihood" def compute_log_likelihood( @@ -43,7 +38,8 @@ def compute_log_likelihood( idata : InferenceData InferenceData with posterior group var_names : sequence of str, optional - List of Observed variable names for which to compute log_likelihood. Defaults to all observed variables + List of Observed variable names for which to compute log_likelihood. + Defaults to all observed variables. extend_inferencedata : bool, default True Whether to extend the original InferenceData or return a new one model : Model, optional @@ -54,74 +50,13 @@ def compute_log_likelihood( ------- idata : InferenceData InferenceData with log_likelihood group - """ - - posterior = idata["posterior"] - - model = modelcontext(model) - - if var_names is None: - observed_vars = model.observed_RVs - var_names = tuple(rv.name for rv in observed_vars) - else: - observed_vars = [model.named_vars[name] for name in var_names] - if not set(observed_vars).issubset(model.observed_RVs): - raise ValueError(f"var_names must refer to observed_RVs in the model. Got: {var_names}") - - # We need to temporarily disable transforms, because the InferenceData only keeps the untransformed values - try: - original_rvs_to_values = model.rvs_to_values - original_rvs_to_transforms = model.rvs_to_transforms - - model.rvs_to_values = { - rv: rv.clone() if rv not in model.observed_RVs else value - for rv, value in model.rvs_to_values.items() - } - model.rvs_to_transforms = {rv: None for rv in model.basic_RVs} - - elemwise_loglike_fn = model.compile_fn( - inputs=model.value_vars, - outs=model.logp(vars=observed_vars, sum=False), - on_unused_input="ignore", - ) - elemwise_loglike_fn = cast(PointFunc, elemwise_loglike_fn) - finally: - model.rvs_to_values = original_rvs_to_values - model.rvs_to_transforms = original_rvs_to_transforms - - # Ignore Deterministics - posterior_values = posterior[[rv.name for rv in model.free_RVs]] - posterior_pts, stacked_dims = dataset_to_point_list(posterior_values, sample_dims) - n_pts = len(posterior_pts) - loglike_dict = _DefaultTrace(n_pts) - indices = range(n_pts) - if progressbar: - indices = progress_bar(indices, total=n_pts, display=progressbar) - - for idx in indices: - loglikes_pts = elemwise_loglike_fn(posterior_pts[idx]) - for rv_name, rv_loglike in zip(var_names, loglikes_pts): - loglike_dict.insert(rv_name, rv_loglike, idx) - - loglike_trace = loglike_dict.trace_dict - for key, array in loglike_trace.items(): - loglike_trace[key] = array.reshape( - (*[len(coord) for coord in stacked_dims.values()], *array.shape[1:]) - ) - - coords, dims = coords_and_dims_for_inferencedata(model) - loglike_dataset = dict_to_dataset( - loglike_trace, - library=pymc, - dims=dims, - coords=coords, - default_dims=list(sample_dims), - skip_event_dims=True, + return compute_log_density( + idata=idata, + var_names=var_names, + extend_inferencedata=extend_inferencedata, + model=model, + kind="likelihood", + sample_dims=sample_dims, + progressbar=progressbar, ) - - if extend_inferencedata: - idata.add_groups(dict(log_likelihood=loglike_dataset)) - return idata - else: - return loglike_dataset diff --git a/pymc/stats/log_prior.py b/pymc/stats/log_prior.py new file mode 100644 index 0000000000..c2d9245b52 --- /dev/null +++ b/pymc/stats/log_prior.py @@ -0,0 +1,61 @@ +# Copyright 2024 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections.abc import Sequence +from typing import Optional + +from arviz import InferenceData + +from pymc.model import Model +from pymc.stats.log_density import compute_log_density + +__all__ = "compute_log_prior" + + +def compute_log_prior( + idata: InferenceData, + var_names: Optional[Sequence[str]] = None, + extend_inferencedata: bool = True, + model: Optional[Model] = None, + sample_dims: Sequence[str] = ("chain", "draw"), + progressbar=True, +): + """Compute elemwise log_prior of model given InferenceData with posterior group + + Parameters + ---------- + idata : InferenceData + InferenceData with posterior group + var_names : sequence of str, optional + List of Observed variable names for which to compute log_prior. + Defaults to all all free variables. + extend_inferencedata : bool, default True + Whether to extend the original InferenceData or return a new one + model : Model, optional + sample_dims : sequence of str, default ("chain", "draw") + progressbar : bool, default True + + Returns + ------- + idata : InferenceData + InferenceData with log_prior group + """ + return compute_log_density( + idata=idata, + var_names=var_names, + extend_inferencedata=extend_inferencedata, + model=model, + kind="prior", + sample_dims=sample_dims, + progressbar=progressbar, + ) diff --git a/tests/stats/test_log_prior.py b/tests/stats/test_log_prior.py new file mode 100644 index 0000000000..aad48187ae --- /dev/null +++ b/tests/stats/test_log_prior.py @@ -0,0 +1,48 @@ +# Copyright 2024 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytest +import scipy.stats as st + +from arviz import InferenceData, dict_to_dataset + +from pymc.distributions import Normal +from pymc.distributions.transforms import log +from pymc.model import Model +from pymc.stats.log_prior import compute_log_prior + + +class TestComputeLogPrior: + @pytest.mark.parametrize("transform", (False, True)) + def test_basic(self, transform): + transform = log if transform else None + with Model() as m: + x = Normal("x", transform=transform) + x_value_var = m.rvs_to_values[x] + Normal("y", x, observed=[0, 1, 2]) + + idata = InferenceData(posterior=dict_to_dataset({"x": np.arange(100).reshape(4, 25)})) + res = compute_log_prior(idata) + + # Check we didn't erase the original mappings + assert m.rvs_to_values[x] is x_value_var + assert m.rvs_to_transforms[x] is transform + + assert res is idata + assert res.log_prior.dims == {"chain": 4, "draw": 25} + + np.testing.assert_allclose( + res.log_prior["x"].values, + st.norm(0, 1).logpdf(idata.posterior["x"].values), + ) From a95354eaf5e4844a508f80fd8bcff0353c4de667 Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Mon, 12 Feb 2024 12:00:43 -0300 Subject: [PATCH 2/4] unify --- docs/source/api/misc.rst | 1 + pymc/backends/arviz.py | 2 +- pymc/stats/__init__.py | 3 +- pymc/stats/log_density.py | 42 +++++++++++++ pymc/stats/log_likelihood.py | 62 ------------------- pymc/stats/log_prior.py | 61 ------------------ ..._log_likelihood.py => test_log_density.py} | 25 +++++++- tests/stats/test_log_prior.py | 48 -------------- 8 files changed, 69 insertions(+), 175 deletions(-) delete mode 100644 pymc/stats/log_likelihood.py delete mode 100644 pymc/stats/log_prior.py rename tests/stats/{test_log_likelihood.py => test_log_density.py} (84%) delete mode 100644 tests/stats/test_log_prior.py diff --git a/docs/source/api/misc.rst b/docs/source/api/misc.rst index c40c7c4aa1..095d262179 100644 --- a/docs/source/api/misc.rst +++ b/docs/source/api/misc.rst @@ -7,5 +7,6 @@ Other utils :toctree: generated/ compute_log_likelihood + compute_log_prior find_constrained_prior DictToArrayBijection diff --git a/pymc/backends/arviz.py b/pymc/backends/arviz.py index 874be7bd09..24c618ecf5 100644 --- a/pymc/backends/arviz.py +++ b/pymc/backends/arviz.py @@ -436,7 +436,7 @@ def to_inference_data(self): id_dict["constant_data"] = self.constant_data_to_xarray() idata = InferenceData(save_warmup=self.save_warmup, **id_dict) if self.log_likelihood: - from pymc.stats.log_likelihood import compute_log_likelihood + from pymc.stats.log_density import compute_log_likelihood idata = compute_log_likelihood( idata, diff --git a/pymc/stats/__init__.py b/pymc/stats/__init__.py index 7609142a24..23e5ba7abc 100644 --- a/pymc/stats/__init__.py +++ b/pymc/stats/__init__.py @@ -27,7 +27,6 @@ if not attr.startswith("__"): setattr(sys.modules[__name__], attr, obj) -from pymc.stats.log_likelihood import compute_log_likelihood -from pymc.stats.log_prior import compute_log_prior +from pymc.stats.log_density import compute_log_likelihood, compute_log_prior __all__ = ("compute_log_likelihood", "compute_log_prior", *az.stats.__all__) diff --git a/pymc/stats/log_density.py b/pymc/stats/log_density.py index ee03be3d73..973e53a289 100644 --- a/pymc/stats/log_density.py +++ b/pymc/stats/log_density.py @@ -24,6 +24,48 @@ from pymc.pytensorf import PointFunc from pymc.util import dataset_to_point_list +__all__ = ("compute_log_likelihood", "compute_log_prior") + + +def compute_log_likelihood( + idata: InferenceData, + *, + var_names: Optional[Sequence[str]] = None, + extend_inferencedata: bool = True, + model: Optional[Model] = None, + sample_dims: Sequence[str] = ("chain", "draw"), + progressbar=True, +): + """Compute elemwise log_likelihood of model given InferenceData with posterior group + + Parameters + ---------- + idata : InferenceData + InferenceData with posterior group + var_names : sequence of str, optional + List of Observed variable names for which to compute log_likelihood. + Defaults to all observed variables. + extend_inferencedata : bool, default True + Whether to extend the original InferenceData or return a new one + model : Model, optional + sample_dims : sequence of str, default ("chain", "draw") + progressbar : bool, default True + + Returns + ------- + idata : InferenceData + InferenceData with log_likelihood group + """ + return compute_log_density( + idata=idata, + var_names=var_names, + extend_inferencedata=extend_inferencedata, + model=model, + kind="likelihood", + sample_dims=sample_dims, + progressbar=progressbar, + ) + def compute_log_prior( idata: InferenceData, diff --git a/pymc/stats/log_likelihood.py b/pymc/stats/log_likelihood.py deleted file mode 100644 index a267b97847..0000000000 --- a/pymc/stats/log_likelihood.py +++ /dev/null @@ -1,62 +0,0 @@ -# Copyright 2024 The PyMC Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from collections.abc import Sequence -from typing import Optional - -from arviz import InferenceData - -from pymc.model import Model -from pymc.stats.log_density import compute_log_density - -__all__ = "compute_log_likelihood" - - -def compute_log_likelihood( - idata: InferenceData, - *, - var_names: Optional[Sequence[str]] = None, - extend_inferencedata: bool = True, - model: Optional[Model] = None, - sample_dims: Sequence[str] = ("chain", "draw"), - progressbar=True, -): - """Compute elemwise log_likelihood of model given InferenceData with posterior group - - Parameters - ---------- - idata : InferenceData - InferenceData with posterior group - var_names : sequence of str, optional - List of Observed variable names for which to compute log_likelihood. - Defaults to all observed variables. - extend_inferencedata : bool, default True - Whether to extend the original InferenceData or return a new one - model : Model, optional - sample_dims : sequence of str, default ("chain", "draw") - progressbar : bool, default True - - Returns - ------- - idata : InferenceData - InferenceData with log_likelihood group - """ - return compute_log_density( - idata=idata, - var_names=var_names, - extend_inferencedata=extend_inferencedata, - model=model, - kind="likelihood", - sample_dims=sample_dims, - progressbar=progressbar, - ) diff --git a/pymc/stats/log_prior.py b/pymc/stats/log_prior.py deleted file mode 100644 index c2d9245b52..0000000000 --- a/pymc/stats/log_prior.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright 2024 The PyMC Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from collections.abc import Sequence -from typing import Optional - -from arviz import InferenceData - -from pymc.model import Model -from pymc.stats.log_density import compute_log_density - -__all__ = "compute_log_prior" - - -def compute_log_prior( - idata: InferenceData, - var_names: Optional[Sequence[str]] = None, - extend_inferencedata: bool = True, - model: Optional[Model] = None, - sample_dims: Sequence[str] = ("chain", "draw"), - progressbar=True, -): - """Compute elemwise log_prior of model given InferenceData with posterior group - - Parameters - ---------- - idata : InferenceData - InferenceData with posterior group - var_names : sequence of str, optional - List of Observed variable names for which to compute log_prior. - Defaults to all all free variables. - extend_inferencedata : bool, default True - Whether to extend the original InferenceData or return a new one - model : Model, optional - sample_dims : sequence of str, default ("chain", "draw") - progressbar : bool, default True - - Returns - ------- - idata : InferenceData - InferenceData with log_prior group - """ - return compute_log_density( - idata=idata, - var_names=var_names, - extend_inferencedata=extend_inferencedata, - model=model, - kind="prior", - sample_dims=sample_dims, - progressbar=progressbar, - ) diff --git a/tests/stats/test_log_likelihood.py b/tests/stats/test_log_density.py similarity index 84% rename from tests/stats/test_log_likelihood.py rename to tests/stats/test_log_density.py index a62da5c5e2..3de1fa5ec8 100644 --- a/tests/stats/test_log_likelihood.py +++ b/tests/stats/test_log_density.py @@ -20,7 +20,7 @@ from pymc.distributions import Dirichlet, Normal from pymc.distributions.transforms import log from pymc.model import Model -from pymc.stats.log_likelihood import compute_log_likelihood +from pymc.stats.log_density import compute_log_likelihood, compute_log_prior from tests.distributions.test_multivariate import dirichlet_logpdf @@ -132,3 +132,26 @@ def test_dims_without_coords(self): llike.log_likelihood["y"].values, st.norm.logpdf([[[0, 0, 0], [1, 1, 1]]]), ) + + @pytest.mark.parametrize("transform", (False, True)) + def test_basic_log_prior(self, transform): + transform = log if transform else None + with Model() as m: + x = Normal("x", transform=transform) + x_value_var = m.rvs_to_values[x] + Normal("y", x, observed=[0, 1, 2]) + + idata = InferenceData(posterior=dict_to_dataset({"x": np.arange(100).reshape(4, 25)})) + res = compute_log_prior(idata) + + # Check we didn't erase the original mappings + assert m.rvs_to_values[x] is x_value_var + assert m.rvs_to_transforms[x] is transform + + assert res is idata + assert res.log_prior.dims == {"chain": 4, "draw": 25} + + np.testing.assert_allclose( + res.log_prior["x"].values, + st.norm(0, 1).logpdf(idata.posterior["x"].values), + ) diff --git a/tests/stats/test_log_prior.py b/tests/stats/test_log_prior.py deleted file mode 100644 index aad48187ae..0000000000 --- a/tests/stats/test_log_prior.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright 2024 The PyMC Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import numpy as np -import pytest -import scipy.stats as st - -from arviz import InferenceData, dict_to_dataset - -from pymc.distributions import Normal -from pymc.distributions.transforms import log -from pymc.model import Model -from pymc.stats.log_prior import compute_log_prior - - -class TestComputeLogPrior: - @pytest.mark.parametrize("transform", (False, True)) - def test_basic(self, transform): - transform = log if transform else None - with Model() as m: - x = Normal("x", transform=transform) - x_value_var = m.rvs_to_values[x] - Normal("y", x, observed=[0, 1, 2]) - - idata = InferenceData(posterior=dict_to_dataset({"x": np.arange(100).reshape(4, 25)})) - res = compute_log_prior(idata) - - # Check we didn't erase the original mappings - assert m.rvs_to_values[x] is x_value_var - assert m.rvs_to_transforms[x] is transform - - assert res is idata - assert res.log_prior.dims == {"chain": 4, "draw": 25} - - np.testing.assert_allclose( - res.log_prior["x"].values, - st.norm(0, 1).logpdf(idata.posterior["x"].values), - ) From 3bc1c3361785d16a59b4a26b597b33e05f0bdd32 Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Mon, 12 Feb 2024 12:18:19 -0300 Subject: [PATCH 3/4] update name workflow --- .github/workflows/tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 1e9c942ae8..1eaf91bd30 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -73,7 +73,7 @@ jobs: tests/sampling/test_forward.py tests/sampling/test_population.py tests/stats/test_convergence.py - tests/stats/test_log_likelihood.py + tests/stats/test_log_density.py tests/distributions/test_distribution.py tests/distributions/test_discrete.py From 47025f0457fb49e7d19f0b5086036d9d5508b44f Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Mon, 12 Feb 2024 13:43:57 -0300 Subject: [PATCH 4/4] add log_prior to arviz converter --- pymc/backends/arviz.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/pymc/backends/arviz.py b/pymc/backends/arviz.py index 24c618ecf5..f555c882cf 100644 --- a/pymc/backends/arviz.py +++ b/pymc/backends/arviz.py @@ -174,6 +174,7 @@ def __init__( prior=None, posterior_predictive=None, log_likelihood=False, + log_prior=False, predictions=None, coords: Optional[CoordSpec] = None, dims: Optional[DimSpec] = None, @@ -215,6 +216,7 @@ def __init__( self.prior = prior self.posterior_predictive = posterior_predictive self.log_likelihood = log_likelihood + self.log_prior = log_prior self.predictions = predictions if all(elem is None for elem in (trace, predictions, posterior_predictive, prior)): @@ -446,6 +448,17 @@ def to_inference_data(self): sample_dims=self.sample_dims, progressbar=False, ) + if self.log_prior: + from pymc.stats.log_density import compute_log_prior + + idata = compute_log_prior( + idata, + var_names=None if self.log_prior is True else self.log_prior, + extend_inferencedata=True, + model=self.model, + sample_dims=self.sample_dims, + progressbar=False, + ) return idata @@ -455,6 +468,7 @@ def to_inference_data( prior: Optional[Mapping[str, Any]] = None, posterior_predictive: Optional[Mapping[str, Any]] = None, log_likelihood: Union[bool, Iterable[str]] = False, + log_prior: Union[bool, Iterable[str]] = False, coords: Optional[CoordSpec] = None, dims: Optional[DimSpec] = None, sample_dims: Optional[list] = None, @@ -481,8 +495,11 @@ def to_inference_data( Dictionary with the variable names as keys, and values numpy arrays containing posterior predictive samples. log_likelihood : bool or array_like of str, optional - List of variables to calculate `log_likelihood`. Defaults to True which calculates - `log_likelihood` for all observed variables. If set to False, log_likelihood is skipped. + List of variables to calculate `log_likelihood`. Defaults to False. + If set to True, computes `log_likelihood` for all observed variables. + log_prior : bool or array_like of str, optional + List of variables to calculate `log_prior`. Defaults to False. + If set to True, computes `log_prior` for all unobserved variables. coords : dict of {str: array-like}, optional Map of coordinate names to coordinate values dims : dict of {str: list of str}, optional @@ -509,6 +526,7 @@ def to_inference_data( prior=prior, posterior_predictive=posterior_predictive, log_likelihood=log_likelihood, + log_prior=log_prior, coords=coords, dims=dims, sample_dims=sample_dims,