Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pymc/stats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,6 @@
setattr(sys.modules[__name__], attr, obj)

from pymc.stats.log_likelihood import compute_log_likelihood
from pymc.stats.log_prior import compute_log_prior

__all__ = ("compute_log_likelihood", *az.stats.__all__)
__all__ = ("compute_log_likelihood", "compute_log_prior", *az.stats.__all__)
159 changes: 159 additions & 0 deletions pymc/stats/log_density.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# Copyright 2024 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Sequence
from typing import Optional, cast

from arviz import InferenceData, dict_to_dataset
from fastprogress import progress_bar

import pymc

from pymc.backends.arviz import _DefaultTrace, coords_and_dims_for_inferencedata
from pymc.model import Model, modelcontext
from pymc.pytensorf import PointFunc
from pymc.util import dataset_to_point_list


def compute_log_prior(
idata: InferenceData,
var_names: Optional[Sequence[str]] = None,
extend_inferencedata: bool = True,
model: Optional[Model] = None,
sample_dims: Sequence[str] = ("chain", "draw"),
progressbar=True,
):
"""Compute elemwise log_prior of model given InferenceData with posterior group

Parameters
----------
idata : InferenceData
InferenceData with posterior group
var_names : sequence of str, optional
List of Observed variable names for which to compute log_prior.
Defaults to all all free variables.
extend_inferencedata : bool, default True
Whether to extend the original InferenceData or return a new one
model : Model, optional
sample_dims : sequence of str, default ("chain", "draw")
progressbar : bool, default True

Returns
-------
idata : InferenceData
InferenceData with log_prior group
"""
return compute_log_density(
idata=idata,
var_names=var_names,
extend_inferencedata=extend_inferencedata,
model=model,
kind="prior",
sample_dims=sample_dims,
progressbar=progressbar,
)


def compute_log_density(
idata: InferenceData,
*,
var_names: Optional[Sequence[str]] = None,
extend_inferencedata: bool = True,
model: Optional[Model] = None,
kind="likelihood",
sample_dims: Sequence[str] = ("chain", "draw"),
progressbar=True,
):
"""
Compute elemwise log_likelihood or log_prior of model given InferenceData with posterior group
"""

posterior = idata["posterior"]

model = modelcontext(model)

if kind not in ("likelihood", "prior"):
raise ValueError("kind must be either 'likelihood' or 'prior'")

if kind == "likelihood":
target_rvs = model.observed_RVs
target_str = "observed_RVs"
else:
target_rvs = model.unobserved_RVs
target_str = "free_RVs"

if var_names is None:
vars = target_rvs
var_names = tuple(rv.name for rv in vars)
else:
vars = [model.named_vars[name] for name in var_names]
if not set(vars).issubset(target_rvs):
raise ValueError(f"var_names must refer to {target_str} in the model. Got: {var_names}")

# We need to temporarily disable transforms, because the InferenceData only keeps the untransformed values
try:
original_rvs_to_values = model.rvs_to_values
original_rvs_to_transforms = model.rvs_to_transforms

model.rvs_to_values = {
rv: rv.clone() if rv not in model.observed_RVs else value
for rv, value in model.rvs_to_values.items()
}
model.rvs_to_transforms = {rv: None for rv in model.basic_RVs}

elemwise_logdens_fn = model.compile_fn(
inputs=model.value_vars,
outs=model.logp(vars=vars, sum=False),
on_unused_input="ignore",
)
elemwise_logdens_fn = cast(PointFunc, elemwise_logdens_fn)
finally:
model.rvs_to_values = original_rvs_to_values
model.rvs_to_transforms = original_rvs_to_transforms

# Ignore Deterministics
posterior_values = posterior[[rv.name for rv in model.free_RVs]]
posterior_pts, stacked_dims = dataset_to_point_list(posterior_values, sample_dims)

n_pts = len(posterior_pts)
logdens_dict = _DefaultTrace(n_pts)
indices = range(n_pts)
if progressbar:
indices = progress_bar(indices, total=n_pts, display=progressbar)

for idx in indices:
logdenss_pts = elemwise_logdens_fn(posterior_pts[idx])
for rv_name, rv_logdens in zip(var_names, logdenss_pts):
logdens_dict.insert(rv_name, rv_logdens, idx)

logdens_trace = logdens_dict.trace_dict
for key, array in logdens_trace.items():
logdens_trace[key] = array.reshape(
(*[len(coord) for coord in stacked_dims.values()], *array.shape[1:])
)

coords, dims = coords_and_dims_for_inferencedata(model)
logdens_dataset = dict_to_dataset(
logdens_trace,
library=pymc,
dims=dims,
coords=coords,
default_dims=list(sample_dims),
skip_event_dims=True,
)

if extend_inferencedata:
idata.add_groups({f"log_{kind}": logdens_dataset})
return idata
else:
return logdens_dataset
95 changes: 15 additions & 80 deletions pymc/stats/log_likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Sequence
from typing import Optional, cast
from typing import Optional

from arviz import InferenceData, dict_to_dataset
from fastprogress import progress_bar
from arviz import InferenceData

import pymc
from pymc.model import Model
from pymc.stats.log_density import compute_log_density

from pymc.backends.arviz import _DefaultTrace, coords_and_dims_for_inferencedata
from pymc.model import Model, modelcontext
from pymc.pytensorf import PointFunc
from pymc.util import dataset_to_point_list

__all__ = ("compute_log_likelihood",)
__all__ = "compute_log_likelihood"


def compute_log_likelihood(
Expand All @@ -43,7 +38,8 @@ def compute_log_likelihood(
idata : InferenceData
InferenceData with posterior group
var_names : sequence of str, optional
List of Observed variable names for which to compute log_likelihood. Defaults to all observed variables
List of Observed variable names for which to compute log_likelihood.
Defaults to all observed variables.
extend_inferencedata : bool, default True
Whether to extend the original InferenceData or return a new one
model : Model, optional
Expand All @@ -54,74 +50,13 @@ def compute_log_likelihood(
-------
idata : InferenceData
InferenceData with log_likelihood group

"""

posterior = idata["posterior"]

model = modelcontext(model)

if var_names is None:
observed_vars = model.observed_RVs
var_names = tuple(rv.name for rv in observed_vars)
else:
observed_vars = [model.named_vars[name] for name in var_names]
if not set(observed_vars).issubset(model.observed_RVs):
raise ValueError(f"var_names must refer to observed_RVs in the model. Got: {var_names}")

# We need to temporarily disable transforms, because the InferenceData only keeps the untransformed values
try:
original_rvs_to_values = model.rvs_to_values
original_rvs_to_transforms = model.rvs_to_transforms

model.rvs_to_values = {
rv: rv.clone() if rv not in model.observed_RVs else value
for rv, value in model.rvs_to_values.items()
}
model.rvs_to_transforms = {rv: None for rv in model.basic_RVs}

elemwise_loglike_fn = model.compile_fn(
inputs=model.value_vars,
outs=model.logp(vars=observed_vars, sum=False),
on_unused_input="ignore",
)
elemwise_loglike_fn = cast(PointFunc, elemwise_loglike_fn)
finally:
model.rvs_to_values = original_rvs_to_values
model.rvs_to_transforms = original_rvs_to_transforms

# Ignore Deterministics
posterior_values = posterior[[rv.name for rv in model.free_RVs]]
posterior_pts, stacked_dims = dataset_to_point_list(posterior_values, sample_dims)
n_pts = len(posterior_pts)
loglike_dict = _DefaultTrace(n_pts)
indices = range(n_pts)
if progressbar:
indices = progress_bar(indices, total=n_pts, display=progressbar)

for idx in indices:
loglikes_pts = elemwise_loglike_fn(posterior_pts[idx])
for rv_name, rv_loglike in zip(var_names, loglikes_pts):
loglike_dict.insert(rv_name, rv_loglike, idx)

loglike_trace = loglike_dict.trace_dict
for key, array in loglike_trace.items():
loglike_trace[key] = array.reshape(
(*[len(coord) for coord in stacked_dims.values()], *array.shape[1:])
)

coords, dims = coords_and_dims_for_inferencedata(model)
loglike_dataset = dict_to_dataset(
loglike_trace,
library=pymc,
dims=dims,
coords=coords,
default_dims=list(sample_dims),
skip_event_dims=True,
return compute_log_density(
idata=idata,
var_names=var_names,
extend_inferencedata=extend_inferencedata,
model=model,
kind="likelihood",
sample_dims=sample_dims,
progressbar=progressbar,
)

if extend_inferencedata:
idata.add_groups(dict(log_likelihood=loglike_dataset))
return idata
else:
return loglike_dataset
61 changes: 61 additions & 0 deletions pymc/stats/log_prior.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright 2024 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Sequence
from typing import Optional

from arviz import InferenceData

from pymc.model import Model
from pymc.stats.log_density import compute_log_density

__all__ = "compute_log_prior"


def compute_log_prior(
idata: InferenceData,
var_names: Optional[Sequence[str]] = None,
extend_inferencedata: bool = True,
model: Optional[Model] = None,
sample_dims: Sequence[str] = ("chain", "draw"),
progressbar=True,
):
"""Compute elemwise log_prior of model given InferenceData with posterior group

Parameters
----------
idata : InferenceData
InferenceData with posterior group
var_names : sequence of str, optional
List of Observed variable names for which to compute log_prior.
Defaults to all all free variables.
extend_inferencedata : bool, default True
Whether to extend the original InferenceData or return a new one
model : Model, optional
sample_dims : sequence of str, default ("chain", "draw")
progressbar : bool, default True

Returns
-------
idata : InferenceData
InferenceData with log_prior group
"""
return compute_log_density(
idata=idata,
var_names=var_names,
extend_inferencedata=extend_inferencedata,
model=model,
kind="prior",
sample_dims=sample_dims,
progressbar=progressbar,
)
48 changes: 48 additions & 0 deletions tests/stats/test_log_prior.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright 2024 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import pytest
import scipy.stats as st

from arviz import InferenceData, dict_to_dataset

from pymc.distributions import Normal
from pymc.distributions.transforms import log
from pymc.model import Model
from pymc.stats.log_prior import compute_log_prior


class TestComputeLogPrior:
@pytest.mark.parametrize("transform", (False, True))
def test_basic(self, transform):
transform = log if transform else None
with Model() as m:
x = Normal("x", transform=transform)
x_value_var = m.rvs_to_values[x]
Normal("y", x, observed=[0, 1, 2])

idata = InferenceData(posterior=dict_to_dataset({"x": np.arange(100).reshape(4, 25)}))
res = compute_log_prior(idata)

# Check we didn't erase the original mappings
assert m.rvs_to_values[x] is x_value_var
assert m.rvs_to_transforms[x] is transform

assert res is idata
assert res.log_prior.dims == {"chain": 4, "draw": 25}

np.testing.assert_allclose(
res.log_prior["x"].values,
st.norm(0, 1).logpdf(idata.posterior["x"].values),
)