- 
                Notifications
    You must be signed in to change notification settings 
- Fork 85
Support for custom priors via Prior class #488
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
b35001b
              b7300e7
              a60035e
              367c922
              a9f821c
              91aee00
              dc20e3e
              f51f994
              7565b7b
              4312dc9
              57ba733
              b57810a
              1a0b078
              787a10e
              bcba49f
              0650644
              3c659d3
              4be4cdd
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -22,6 +22,7 @@ | |
| import pytensor.tensor as pt | ||
| import xarray as xr | ||
| from arviz import r2_score | ||
| from pymc_extras.prior import Prior | ||
|  | ||
| from causalpy.utils import round_num | ||
|  | ||
|  | @@ -68,7 +69,15 @@ class PyMCModel(pm.Model): | |
| Inference data... | ||
| """ | ||
|  | ||
| def __init__(self, sample_kwargs: Optional[Dict[str, Any]] = None): | ||
| @property | ||
| def default_priors(self): | ||
| return {} | ||
|  | ||
| def __init__( | ||
| self, | ||
| sample_kwargs: Optional[Dict[str, Any]] = None, | ||
| priors: dict[str, Any] | None = None, | ||
| ): | ||
| """ | ||
| :param sample_kwargs: A dictionary of kwargs that get unpacked and passed to the | ||
| :func:`pymc.sample` function. Defaults to an empty dictionary. | ||
|  | @@ -77,6 +86,8 @@ def __init__(self, sample_kwargs: Optional[Dict[str, Any]] = None): | |
| self.idata = None | ||
| self.sample_kwargs = sample_kwargs if sample_kwargs is not None else {} | ||
|  | ||
| self.priors = {**self.default_priors, **(priors or {})} | ||
|  | ||
| def build_model(self, X, y, coords) -> None: | ||
| """Build the model, must be implemented by subclass.""" | ||
| raise NotImplementedError("This method must be implemented by a subclass") | ||
|  | @@ -237,6 +248,11 @@ class LinearRegression(PyMCModel): | |
| Inference data... | ||
| """ # noqa: W605 | ||
|  | ||
| default_priors = { | ||
| "beta": Prior("Normal", mu=0, sigma=50, dims="coeffs"), | ||
| "y_hat": Prior("Normal", sigma=Prior("HalfNormal", sigma=1), dims="obs_ind"), | ||
| } | ||
|  | ||
| def build_model(self, X, y, coords): | ||
| """ | ||
| Defines the PyMC model | ||
|  | @@ -245,10 +261,9 @@ def build_model(self, X, y, coords): | |
| self.add_coords(coords) | ||
| X = pm.Data("X", X, dims=["obs_ind", "coeffs"]) | ||
| y = pm.Data("y", y, dims="obs_ind") | ||
| beta = pm.Normal("beta", 0, 50, dims="coeffs") | ||
| sigma = pm.HalfNormal("sigma", 1) | ||
|         
                  williambdean marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
| beta = self.priors["beta"].create_variable("beta") | ||
| mu = pm.Deterministic("mu", pm.math.dot(X, beta), dims="obs_ind") | ||
| pm.Normal("y_hat", mu, sigma, observed=y, dims="obs_ind") | ||
| self.priors["y_hat"].create_likelihood_variable("y_hat", mu=mu, observed=y) | ||
|  | ||
|  | ||
| class WeightedSumFitter(PyMCModel): | ||
|  | @@ -276,6 +291,10 @@ class WeightedSumFitter(PyMCModel): | |
| Inference data... | ||
| """ # noqa: W605 | ||
|  | ||
| default_priors = { | ||
| "y_hat": Prior("Normal", sigma=Prior("HalfNormal", sigma=1), dims="obs_ind"), | ||
| } | ||
|  | ||
| def build_model(self, X, y, coords): | ||
| """ | ||
| Defines the PyMC model | ||
|  | @@ -286,9 +305,8 @@ def build_model(self, X, y, coords): | |
| X = pm.Data("X", X, dims=["obs_ind", "coeffs"]) | ||
| y = pm.Data("y", y[:, 0], dims="obs_ind") | ||
| beta = pm.Dirichlet("beta", a=np.ones(n_predictors), dims="coeffs") | ||
|          | ||
| sigma = pm.HalfNormal("sigma", 1) | ||
|         
                  williambdean marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
| mu = pm.Deterministic("mu", pm.math.dot(X, beta), dims="obs_ind") | ||
| pm.Normal("y_hat", mu, sigma, observed=y, dims="obs_ind") | ||
| self.priors["y_hat"].create_likelihood_variable("y_hat", mu=mu, observed=y) | ||
|  | ||
|  | ||
| class InstrumentalVariableRegression(PyMCModel): | ||
|  | @@ -477,13 +495,17 @@ class PropensityScore(PyMCModel): | |
| Inference... | ||
| """ # noqa: W605 | ||
|  | ||
| default_priors = { | ||
| "b": Prior("Normal", mu=0, sigma=1, dims="coeffs"), | ||
| } | ||
|  | ||
| def build_model(self, X, t, coords): | ||
| "Defines the PyMC propensity model" | ||
| with self: | ||
| self.add_coords(coords) | ||
| X_data = pm.Data("X", X, dims=["obs_ind", "coeffs"]) | ||
| t_data = pm.Data("t", t.flatten(), dims="obs_ind") | ||
| b = pm.Normal("b", mu=0, sigma=1, dims="coeffs") | ||
| b = self.priors["b"].create_variable("b") | ||
| mu = pm.math.dot(X_data, b) | ||
| p = pm.Deterministic("p", pm.math.invlogit(mu)) | ||
| pm.Bernoulli("t_pred", p=p, observed=t_data, dims="obs_ind") | ||
|  | ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -15,3 +15,4 @@ dependencies: | |
| - seaborn>=0.11.2 | ||
| - statsmodels | ||
| - xarray>=v2022.11.0 | ||
| - pymc-extras>=0.2.7 | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need to add
@propertydecorator here? Or is that remembered from it being done in thePyMCModelbase class?Getting an Pylance warning:
Type "dict[str, Prior]" is not assignable to declared type "property"There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What line of code bring that on? Maybe having a setter will help?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree, if you want a property, maybe we can have a setter method? (not a blocker for now and maybe create an issue?)