Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion causalpy/experiments/prepostnegd.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class PrePostNEGD(BaseExperiment):
Intercept -0.5, 94% HDI [-1, 0.2]
C(group)[T.1] 2, 94% HDI [2, 2]
pre 1, 94% HDI [1, 1]
sigma 0.5, 94% HDI [0.5, 0.6]
y_hat_sigma 0.5, 94% HDI [0.5, 0.6]
"""

supports_ols = False
Expand Down
57 changes: 45 additions & 12 deletions causalpy/pymc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import pytensor.tensor as pt
import xarray as xr
from arviz import r2_score
from pymc_extras.prior import Prior

from causalpy.utils import round_num

Expand Down Expand Up @@ -68,7 +69,18 @@ class PyMCModel(pm.Model):
Inference data...
"""

def __init__(self, sample_kwargs: Optional[Dict[str, Any]] = None):
@property
def default_priors(self):
return {}

def priors_from_data(self, X, y) -> Dict[str, Any]:
return {}

def __init__(
self,
sample_kwargs: Optional[Dict[str, Any]] = None,
priors: dict[str, Any] | None = None,
):
"""
:param sample_kwargs: A dictionary of kwargs that get unpacked and passed to the
:func:`pymc.sample` function. Defaults to an empty dictionary.
Expand All @@ -77,6 +89,8 @@ def __init__(self, sample_kwargs: Optional[Dict[str, Any]] = None):
self.idata = None
self.sample_kwargs = sample_kwargs if sample_kwargs is not None else {}

self.priors = {**self.default_priors, **(priors or {})}

def build_model(self, X, y, coords) -> None:
"""Build the model, must be implemented by subclass."""
raise NotImplementedError("This method must be implemented by a subclass")
Expand Down Expand Up @@ -111,6 +125,8 @@ def fit(self, X, y, coords: Optional[Dict[str, Any]] = None) -> None:
# sample_posterior_predictive() if provided in sample_kwargs.
random_seed = self.sample_kwargs.get("random_seed", None)

self.priors = {**self.priors_from_data(X, y), **self.priors}

self.build_model(X, y, coords)
with self:
self.idata = pm.sample(**self.sample_kwargs)
Expand Down Expand Up @@ -188,15 +204,15 @@ def print_row(
coeffs = az.extract(self.idata.posterior, var_names="beta")

# Determine the width of the longest label
max_label_length = max(len(name) for name in labels + ["sigma"])
max_label_length = max(len(name) for name in labels + ["y_hat_sigma"])

for name in labels:
coeff_samples = coeffs.sel(coeffs=name)
print_row(max_label_length, name, coeff_samples, round_to)

# Add coefficient for measurement std
coeff_samples = az.extract(self.idata.posterior, var_names="sigma")
name = "sigma"
coeff_samples = az.extract(self.idata.posterior, var_names="y_hat_sigma")
name = "y_hat_sigma"
print_row(max_label_length, name, coeff_samples, round_to)


Expand Down Expand Up @@ -237,6 +253,11 @@ class LinearRegression(PyMCModel):
Inference data...
""" # noqa: W605

default_priors = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to add @property decorator here? Or is that remembered from it being done in the PyMCModel base class?

Getting an Pylance warning: Type "dict[str, Prior]" is not assignable to declared type "property"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What line of code bring that on? Maybe having a setter will help?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, if you want a property, maybe we can have a setter method? (not a blocker for now and maybe create an issue?)

"beta": Prior("Normal", mu=0, sigma=50, dims="coeffs"),
"y_hat": Prior("Normal", sigma=Prior("HalfNormal", sigma=1), dims="obs_ind"),
}

def build_model(self, X, y, coords):
"""
Defines the PyMC model
Expand All @@ -245,10 +266,9 @@ def build_model(self, X, y, coords):
self.add_coords(coords)
X = pm.Data("X", X, dims=["obs_ind", "coeffs"])
y = pm.Data("y", y, dims="obs_ind")
beta = pm.Normal("beta", 0, 50, dims="coeffs")
sigma = pm.HalfNormal("sigma", 1)
beta = self.priors["beta"].create_variable("beta")
mu = pm.Deterministic("mu", pm.math.dot(X, beta), dims="obs_ind")
pm.Normal("y_hat", mu, sigma, observed=y, dims="obs_ind")
self.priors["y_hat"].create_likelihood_variable("y_hat", mu=mu, observed=y)


class WeightedSumFitter(PyMCModel):
Expand Down Expand Up @@ -276,19 +296,28 @@ class WeightedSumFitter(PyMCModel):
Inference data...
""" # noqa: W605

default_priors = {
"y_hat": Prior("Normal", sigma=Prior("HalfNormal", sigma=1), dims="obs_ind"),
}

def priors_from_data(self, X, y) -> Dict[str, Any]:
n_predictors = X.shape[1]

return {
"beta": Prior("Dirichlet", a=np.ones(n_predictors), dims="coeffs"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just realised that n_predictors will equal length of the "coeffs" dim. Does that have to be the case in fact? If so, do we need priors_from_data?

Not saying I don't want it, it could be really cool. But just wondering more generally if it's needed or not. Will try to think more with a fresh head in the morning, but does this spark off any thoughts?

}

def build_model(self, X, y, coords):
"""
Defines the PyMC model
"""
with self:
self.add_coords(coords)
n_predictors = X.shape[1]
X = pm.Data("X", X, dims=["obs_ind", "coeffs"])
y = pm.Data("y", y[:, 0], dims="obs_ind")
beta = pm.Dirichlet("beta", a=np.ones(n_predictors), dims="coeffs")
sigma = pm.HalfNormal("sigma", 1)
beta = self.priors["beta"].create_variable("beta")
mu = pm.Deterministic("mu", pm.math.dot(X, beta), dims="obs_ind")
pm.Normal("y_hat", mu, sigma, observed=y, dims="obs_ind")
self.priors["y_hat"].create_likelihood_variable("y_hat", mu=mu, observed=y)


class InstrumentalVariableRegression(PyMCModel):
Expand Down Expand Up @@ -477,13 +506,17 @@ class PropensityScore(PyMCModel):
Inference...
""" # noqa: W605

default_priors = {
"b": Prior("Normal", mu=0, sigma=1, dims="coeffs"),
}

def build_model(self, X, t, coords):
"Defines the PyMC propensity model"
with self:
self.add_coords(coords)
X_data = pm.Data("X", X, dims=["obs_ind", "coeffs"])
t_data = pm.Data("t", t.flatten(), dims="obs_ind")
b = pm.Normal("b", mu=0, sigma=1, dims="coeffs")
b = self.priors["b"].create_variable("b")
mu = pm.math.dot(X_data, b)
p = pm.Deterministic("p", pm.math.invlogit(mu))
pm.Bernoulli("t_pred", p=p, observed=t_data, dims="obs_ind")
Expand Down
6 changes: 3 additions & 3 deletions docs/source/_static/interrogate_badge.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ dependencies:
- seaborn>=0.11.2
- statsmodels
- xarray>=v2022.11.0
- pymc-extras>=0.2.7
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ dependencies = [
"seaborn>=0.11.2",
"statsmodels",
"xarray>=v2022.11.0",
"pymc-extras>=0.2.7",
]

# List additional groups of dependencies here (e.g. development dependencies). Users
Expand Down