Skip to content

Commit 540893b

Browse files
Support for custom priors via Prior class (#488)
* add pymc-extras to environment * add default_priors and support for custom priors * get pymc_models tests to pass * add dim to y_hat * fix for sigma -> y_hat_sigma * fix failing doctest * add support for priors from data * Add regenerated interrogate badge with updated coverage * update pymc-extras version pin in attempt to fix failing remote tests * Add pragma no cover to exception branches Added '# pragma: no cover' to NotImplementedError and ValueError branches in PyMCModel to exclude them from test coverage reporting. * update pymc-extras version pin to match that in pyproject.toml * add docstrings to the priors_from_data methods * add tests * Convert default_priors property to class attribute --------- Co-authored-by: Benjamin T. Vincent <[email protected]>
1 parent 77f8d59 commit 540893b

File tree

6 files changed

+420
-34
lines changed

6 files changed

+420
-34
lines changed

causalpy/experiments/prepostnegd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ class PrePostNEGD(BaseExperiment):
8282
Intercept -0.5, 94% HDI [-1, 0.2]
8383
C(group)[T.1] 2, 94% HDI [2, 2]
8484
pre 1, 94% HDI [1, 1]
85-
sigma 0.5, 94% HDI [0.5, 0.6]
85+
y_hat_sigma 0.5, 94% HDI [0.5, 0.6]
8686
"""
8787

8888
supports_ols = False

causalpy/pymc_models.py

Lines changed: 162 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import xarray as xr
2424
from arviz import r2_score
2525
from patsy import dmatrix
26+
from pymc_extras.prior import Prior
2627

2728
from causalpy.utils import round_num
2829

@@ -90,7 +91,87 @@ class PyMCModel(pm.Model):
9091
Inference data...
9192
"""
9293

93-
def __init__(self, sample_kwargs: Optional[Dict[str, Any]] = None):
94+
default_priors = {}
95+
96+
def priors_from_data(self, X, y) -> Dict[str, Any]:
97+
"""
98+
Generate priors dynamically based on the input data.
99+
100+
This method allows models to set sensible priors that adapt to the scale
101+
and characteristics of the actual data being analyzed. It's called during
102+
the `fit()` method before model building, allowing data-driven prior
103+
specification that can improve model performance and convergence.
104+
105+
The priors returned by this method are merged with any user-specified
106+
priors (passed via the `priors` parameter in `__init__`), with
107+
user-specified priors taking precedence in case of conflicts.
108+
109+
Parameters
110+
----------
111+
X : xarray.DataArray
112+
Input features/covariates with dimensions ["obs_ind", "coeffs"].
113+
Used to understand the scale and structure of predictors.
114+
y : xarray.DataArray
115+
Target variable with dimensions ["obs_ind", "treated_units"].
116+
Used to understand the scale and structure of the outcome.
117+
118+
Returns
119+
-------
120+
Dict[str, Prior]
121+
Dictionary mapping parameter names to Prior objects. The keys should
122+
match parameter names used in the model's `build_model()` method.
123+
124+
Notes
125+
-----
126+
The base implementation returns an empty dictionary, meaning no
127+
data-driven priors are set by default. Subclasses should override
128+
this method to implement data-adaptive prior specification.
129+
130+
**Priority Order for Priors:**
131+
1. User-specified priors (passed to `__init__`)
132+
2. Data-driven priors (from this method)
133+
3. Default priors (from `default_priors` property)
134+
135+
Examples
136+
--------
137+
A typical implementation might scale priors based on data variance:
138+
139+
>>> def priors_from_data(self, X, y):
140+
... y_std = float(y.std())
141+
... return {
142+
... "sigma": Prior("HalfNormal", sigma=y_std, dims="treated_units"),
143+
... "beta": Prior(
144+
... "Normal",
145+
... mu=0,
146+
... sigma=2 * y_std,
147+
... dims=["treated_units", "coeffs"],
148+
... ),
149+
... }
150+
151+
Or set shape parameters based on data dimensions:
152+
153+
>>> def priors_from_data(self, X, y):
154+
... n_predictors = X.shape[1]
155+
... return {
156+
... "beta": Prior(
157+
... "Dirichlet",
158+
... a=np.ones(n_predictors),
159+
... dims=["treated_units", "coeffs"],
160+
... )
161+
... }
162+
163+
See Also
164+
--------
165+
WeightedSumFitter.priors_from_data : Example implementation that sets
166+
Dirichlet prior shape based on number of control units.
167+
"""
168+
return {}
169+
170+
def __init__(
171+
self,
172+
sample_kwargs: Optional[Dict[str, Any]] = None,
173+
priors: dict[str, Any] | None = None,
174+
):
94175
"""
95176
:param sample_kwargs: A dictionary of kwargs that get unpacked and passed to the
96177
:func:`pymc.sample` function. Defaults to an empty dictionary.
@@ -99,9 +180,13 @@ def __init__(self, sample_kwargs: Optional[Dict[str, Any]] = None):
99180
self.idata = None
100181
self.sample_kwargs = sample_kwargs if sample_kwargs is not None else {}
101182

183+
self.priors = {**self.default_priors, **(priors or {})}
184+
102185
def build_model(self, X, y, coords) -> None:
103186
"""Build the model, must be implemented by subclass."""
104-
raise NotImplementedError("This method must be implemented by a subclass")
187+
raise NotImplementedError(
188+
"This method must be implemented by a subclass"
189+
) # pragma: no cover
105190

106191
def _data_setter(self, X: xr.DataArray) -> None:
107192
"""
@@ -144,6 +229,10 @@ def fit(self, X, y, coords: Optional[Dict[str, Any]] = None) -> None:
144229
# sample_posterior_predictive() if provided in sample_kwargs.
145230
random_seed = self.sample_kwargs.get("random_seed", None)
146231

232+
# Merge priors with precedence: user-specified > data-driven > defaults
233+
# Data-driven priors are computed first, then user-specified priors override them
234+
self.priors = {**self.priors_from_data(X, y), **self.priors}
235+
147236
self.build_model(X, y, coords)
148237
with self:
149238
self.idata = pm.sample(**self.sample_kwargs)
@@ -239,26 +328,36 @@ def print_coefficients_for_unit(
239328
) -> None:
240329
"""Print coefficients for a single unit"""
241330
# Determine the width of the longest label
242-
max_label_length = max(len(name) for name in labels + ["sigma"])
331+
max_label_length = max(len(name) for name in labels + ["y_hat_sigma"])
243332

244333
for name in labels:
245334
coeff_samples = unit_coeffs.sel(coeffs=name)
246335
print_row(max_label_length, name, coeff_samples, round_to)
247336

248337
# Add coefficient for measurement std
249-
print_row(max_label_length, "sigma", unit_sigma, round_to)
338+
print_row(max_label_length, "y_hat_sigma", unit_sigma, round_to)
250339

251340
print("Model coefficients:")
252341
coeffs = az.extract(self.idata.posterior, var_names="beta")
253342

254-
# Always has treated_units dimension - no branching needed!
343+
# Check if sigma or y_hat_sigma variable exists
344+
sigma_var_name = None
345+
if "sigma" in self.idata.posterior:
346+
sigma_var_name = "sigma"
347+
elif "y_hat_sigma" in self.idata.posterior:
348+
sigma_var_name = "y_hat_sigma"
349+
else:
350+
raise ValueError(
351+
"Neither 'sigma' nor 'y_hat_sigma' found in posterior"
352+
) # pragma: no cover
353+
255354
treated_units = coeffs.coords["treated_units"].values
256355
for unit in treated_units:
257356
if len(treated_units) > 1:
258357
print(f"\nTreated unit: {unit}")
259358

260359
unit_coeffs = coeffs.sel(treated_units=unit)
261-
unit_sigma = az.extract(self.idata.posterior, var_names="sigma").sel(
360+
unit_sigma = az.extract(self.idata.posterior, var_names=sigma_var_name).sel(
262361
treated_units=unit
263362
)
264363
print_coefficients_for_unit(unit_coeffs, unit_sigma, labels, round_to or 2)
@@ -301,6 +400,15 @@ class LinearRegression(PyMCModel):
301400
Inference data...
302401
""" # noqa: W605
303402

403+
default_priors = {
404+
"beta": Prior("Normal", mu=0, sigma=50, dims=["treated_units", "coeffs"]),
405+
"y_hat": Prior(
406+
"Normal",
407+
sigma=Prior("HalfNormal", sigma=1, dims=["treated_units"]),
408+
dims=["obs_ind", "treated_units"],
409+
),
410+
}
411+
304412
def build_model(self, X, y, coords):
305413
"""
306414
Defines the PyMC model
@@ -314,12 +422,11 @@ def build_model(self, X, y, coords):
314422
self.add_coords(coords)
315423
X = pm.Data("X", X, dims=["obs_ind", "coeffs"])
316424
y = pm.Data("y", y, dims=["obs_ind", "treated_units"])
317-
beta = pm.Normal("beta", 0, 50, dims=["treated_units", "coeffs"])
318-
sigma = pm.HalfNormal("sigma", 1, dims="treated_units")
425+
beta = self.priors["beta"].create_variable("beta")
319426
mu = pm.Deterministic(
320427
"mu", pt.dot(X, beta.T), dims=["obs_ind", "treated_units"]
321428
)
322-
pm.Normal("y_hat", mu, sigma, observed=y, dims=["obs_ind", "treated_units"])
429+
self.priors["y_hat"].create_likelihood_variable("y_hat", mu=mu, observed=y)
323430

324431

325432
class WeightedSumFitter(PyMCModel):
@@ -362,23 +469,56 @@ class WeightedSumFitter(PyMCModel):
362469
Inference data...
363470
""" # noqa: W605
364471

472+
default_priors = {
473+
"y_hat": Prior(
474+
"Normal",
475+
sigma=Prior("HalfNormal", sigma=1, dims=["treated_units"]),
476+
dims=["obs_ind", "treated_units"],
477+
),
478+
}
479+
480+
def priors_from_data(self, X, y) -> Dict[str, Any]:
481+
"""
482+
Set Dirichlet prior for weights based on number of control units.
483+
484+
For synthetic control models, this method sets the shape parameter of the
485+
Dirichlet prior on the control unit weights (`beta`) to be uniform across
486+
all available control units. This ensures that all control units have
487+
equal prior probability of contributing to the synthetic control.
488+
489+
Parameters
490+
----------
491+
X : xarray.DataArray
492+
Control unit data with shape (n_obs, n_control_units).
493+
y : xarray.DataArray
494+
Treated unit outcome data.
495+
496+
Returns
497+
-------
498+
Dict[str, Prior]
499+
Dictionary containing:
500+
- "beta": Dirichlet prior with shape=(1,...,1) for n_control_units
501+
"""
502+
n_predictors = X.shape[1]
503+
return {
504+
"beta": Prior(
505+
"Dirichlet", a=np.ones(n_predictors), dims=["treated_units", "coeffs"]
506+
),
507+
}
508+
365509
def build_model(self, X, y, coords):
366510
"""
367511
Defines the PyMC model
368512
"""
369513
with self:
370514
self.add_coords(coords)
371-
n_predictors = X.sizes["coeffs"]
372515
X = pm.Data("X", X, dims=["obs_ind", "coeffs"])
373516
y = pm.Data("y", y, dims=["obs_ind", "treated_units"])
374-
beta = pm.Dirichlet(
375-
"beta", a=np.ones(n_predictors), dims=["treated_units", "coeffs"]
376-
)
377-
sigma = pm.HalfNormal("sigma", 1, dims="treated_units")
517+
beta = self.priors["beta"].create_variable("beta")
378518
mu = pm.Deterministic(
379519
"mu", pt.dot(X, beta.T), dims=["obs_ind", "treated_units"]
380520
)
381-
pm.Normal("y_hat", mu, sigma, observed=y, dims=["obs_ind", "treated_units"])
521+
self.priors["y_hat"].create_likelihood_variable("y_hat", mu=mu, observed=y)
382522

383523

384524
class InstrumentalVariableRegression(PyMCModel):
@@ -568,21 +708,18 @@ class PropensityScore(PyMCModel):
568708
Inference...
569709
""" # noqa: W605
570710

571-
def build_model(self, X, t, coords, prior, noncentred):
711+
default_priors = {
712+
"b": Prior("Normal", mu=0, sigma=1, dims="coeffs"),
713+
}
714+
715+
def build_model(self, X, t, coords, prior=None, noncentred=True):
572716
"Defines the PyMC propensity model"
573717
with self:
574718
self.add_coords(coords)
575719
X_data = pm.Data("X", X, dims=["obs_ind", "coeffs"])
576720
t_data = pm.Data("t", t.flatten(), dims="obs_ind")
577-
if noncentred:
578-
mu_beta, sigma_beta = prior["b"]
579-
beta_std = pm.Normal("beta_std", 0, 1, dims="coeffs")
580-
b = pm.Deterministic(
581-
"beta_", mu_beta + sigma_beta * beta_std, dims="coeffs"
582-
)
583-
else:
584-
b = pm.Normal("b", mu=prior["b"][0], sigma=prior["b"][1], dims="coeffs")
585-
mu = pm.math.dot(X_data, b)
721+
b = self.priors["b"].create_variable("b")
722+
mu = pt.dot(X_data, b)
586723
p = pm.Deterministic("p", pm.math.invlogit(mu))
587724
pm.Bernoulli("t_pred", p=p, observed=t_data, dims="obs_ind")
588725

0 commit comments

Comments
 (0)