2323import  xarray  as  xr 
2424from  arviz  import  r2_score 
2525from  patsy  import  dmatrix 
26+ from  pymc_extras .prior  import  Prior 
2627
2728from  causalpy .utils  import  round_num 
2829
@@ -90,7 +91,87 @@ class PyMCModel(pm.Model):
9091    Inference data... 
9192    """ 
9293
93-     def  __init__ (self , sample_kwargs : Optional [Dict [str , Any ]] =  None ):
94+     default_priors  =  {}
95+ 
96+     def  priors_from_data (self , X , y ) ->  Dict [str , Any ]:
97+         """ 
98+         Generate priors dynamically based on the input data. 
99+ 
100+         This method allows models to set sensible priors that adapt to the scale 
101+         and characteristics of the actual data being analyzed. It's called during 
102+         the `fit()` method before model building, allowing data-driven prior 
103+         specification that can improve model performance and convergence. 
104+ 
105+         The priors returned by this method are merged with any user-specified 
106+         priors (passed via the `priors` parameter in `__init__`), with 
107+         user-specified priors taking precedence in case of conflicts. 
108+ 
109+         Parameters 
110+         ---------- 
111+         X : xarray.DataArray 
112+             Input features/covariates with dimensions ["obs_ind", "coeffs"]. 
113+             Used to understand the scale and structure of predictors. 
114+         y : xarray.DataArray 
115+             Target variable with dimensions ["obs_ind", "treated_units"]. 
116+             Used to understand the scale and structure of the outcome. 
117+ 
118+         Returns 
119+         ------- 
120+         Dict[str, Prior] 
121+             Dictionary mapping parameter names to Prior objects. The keys should 
122+             match parameter names used in the model's `build_model()` method. 
123+ 
124+         Notes 
125+         ----- 
126+         The base implementation returns an empty dictionary, meaning no 
127+         data-driven priors are set by default. Subclasses should override 
128+         this method to implement data-adaptive prior specification. 
129+ 
130+         **Priority Order for Priors:** 
131+         1. User-specified priors (passed to `__init__`) 
132+         2. Data-driven priors (from this method) 
133+         3. Default priors (from `default_priors` property) 
134+ 
135+         Examples 
136+         -------- 
137+         A typical implementation might scale priors based on data variance: 
138+ 
139+         >>> def priors_from_data(self, X, y): 
140+         ...     y_std = float(y.std()) 
141+         ...     return { 
142+         ...         "sigma": Prior("HalfNormal", sigma=y_std, dims="treated_units"), 
143+         ...         "beta": Prior( 
144+         ...             "Normal", 
145+         ...             mu=0, 
146+         ...             sigma=2 * y_std, 
147+         ...             dims=["treated_units", "coeffs"], 
148+         ...         ), 
149+         ...     } 
150+ 
151+         Or set shape parameters based on data dimensions: 
152+ 
153+         >>> def priors_from_data(self, X, y): 
154+         ...     n_predictors = X.shape[1] 
155+         ...     return { 
156+         ...         "beta": Prior( 
157+         ...             "Dirichlet", 
158+         ...             a=np.ones(n_predictors), 
159+         ...             dims=["treated_units", "coeffs"], 
160+         ...         ) 
161+         ...     } 
162+ 
163+         See Also 
164+         -------- 
165+         WeightedSumFitter.priors_from_data : Example implementation that sets 
166+             Dirichlet prior shape based on number of control units. 
167+         """ 
168+         return  {}
169+ 
170+     def  __init__ (
171+         self ,
172+         sample_kwargs : Optional [Dict [str , Any ]] =  None ,
173+         priors : dict [str , Any ] |  None  =  None ,
174+     ):
94175        """ 
95176        :param sample_kwargs: A dictionary of kwargs that get unpacked and passed to the 
96177            :func:`pymc.sample` function. Defaults to an empty dictionary. 
@@ -99,9 +180,13 @@ def __init__(self, sample_kwargs: Optional[Dict[str, Any]] = None):
99180        self .idata  =  None 
100181        self .sample_kwargs  =  sample_kwargs  if  sample_kwargs  is  not None  else  {}
101182
183+         self .priors  =  {** self .default_priors , ** (priors  or  {})}
184+ 
102185    def  build_model (self , X , y , coords ) ->  None :
103186        """Build the model, must be implemented by subclass.""" 
104-         raise  NotImplementedError ("This method must be implemented by a subclass" )
187+         raise  NotImplementedError (
188+             "This method must be implemented by a subclass" 
189+         )  # pragma: no cover 
105190
106191    def  _data_setter (self , X : xr .DataArray ) ->  None :
107192        """ 
@@ -144,6 +229,10 @@ def fit(self, X, y, coords: Optional[Dict[str, Any]] = None) -> None:
144229        # sample_posterior_predictive() if provided in sample_kwargs. 
145230        random_seed  =  self .sample_kwargs .get ("random_seed" , None )
146231
232+         # Merge priors with precedence: user-specified > data-driven > defaults 
233+         # Data-driven priors are computed first, then user-specified priors override them 
234+         self .priors  =  {** self .priors_from_data (X , y ), ** self .priors }
235+ 
147236        self .build_model (X , y , coords )
148237        with  self :
149238            self .idata  =  pm .sample (** self .sample_kwargs )
@@ -239,26 +328,36 @@ def print_coefficients_for_unit(
239328        ) ->  None :
240329            """Print coefficients for a single unit""" 
241330            # Determine the width of the longest label 
242-             max_label_length  =  max (len (name ) for  name  in  labels  +  ["sigma " ])
331+             max_label_length  =  max (len (name ) for  name  in  labels  +  ["y_hat_sigma " ])
243332
244333            for  name  in  labels :
245334                coeff_samples  =  unit_coeffs .sel (coeffs = name )
246335                print_row (max_label_length , name , coeff_samples , round_to )
247336
248337            # Add coefficient for measurement std 
249-             print_row (max_label_length , "sigma " , unit_sigma , round_to )
338+             print_row (max_label_length , "y_hat_sigma " , unit_sigma , round_to )
250339
251340        print ("Model coefficients:" )
252341        coeffs  =  az .extract (self .idata .posterior , var_names = "beta" )
253342
254-         # Always has treated_units dimension - no branching needed! 
343+         # Check if sigma or y_hat_sigma variable exists 
344+         sigma_var_name  =  None 
345+         if  "sigma"  in  self .idata .posterior :
346+             sigma_var_name  =  "sigma" 
347+         elif  "y_hat_sigma"  in  self .idata .posterior :
348+             sigma_var_name  =  "y_hat_sigma" 
349+         else :
350+             raise  ValueError (
351+                 "Neither 'sigma' nor 'y_hat_sigma' found in posterior" 
352+             )  # pragma: no cover 
353+ 
255354        treated_units  =  coeffs .coords ["treated_units" ].values 
256355        for  unit  in  treated_units :
257356            if  len (treated_units ) >  1 :
258357                print (f"\n Treated unit: { unit }  )
259358
260359            unit_coeffs  =  coeffs .sel (treated_units = unit )
261-             unit_sigma  =  az .extract (self .idata .posterior , var_names = "sigma" ).sel (
360+             unit_sigma  =  az .extract (self .idata .posterior , var_names = sigma_var_name ).sel (
262361                treated_units = unit 
263362            )
264363            print_coefficients_for_unit (unit_coeffs , unit_sigma , labels , round_to  or  2 )
@@ -301,6 +400,15 @@ class LinearRegression(PyMCModel):
301400    Inference data... 
302401    """   # noqa: W605 
303402
403+     default_priors  =  {
404+         "beta" : Prior ("Normal" , mu = 0 , sigma = 50 , dims = ["treated_units" , "coeffs" ]),
405+         "y_hat" : Prior (
406+             "Normal" ,
407+             sigma = Prior ("HalfNormal" , sigma = 1 , dims = ["treated_units" ]),
408+             dims = ["obs_ind" , "treated_units" ],
409+         ),
410+     }
411+ 
304412    def  build_model (self , X , y , coords ):
305413        """ 
306414        Defines the PyMC model 
@@ -314,12 +422,11 @@ def build_model(self, X, y, coords):
314422            self .add_coords (coords )
315423            X  =  pm .Data ("X" , X , dims = ["obs_ind" , "coeffs" ])
316424            y  =  pm .Data ("y" , y , dims = ["obs_ind" , "treated_units" ])
317-             beta  =  pm .Normal ("beta" , 0 , 50 , dims = ["treated_units" , "coeffs" ])
318-             sigma  =  pm .HalfNormal ("sigma" , 1 , dims = "treated_units" )
425+             beta  =  self .priors ["beta" ].create_variable ("beta" )
319426            mu  =  pm .Deterministic (
320427                "mu" , pt .dot (X , beta .T ), dims = ["obs_ind" , "treated_units" ]
321428            )
322-             pm . Normal ("y_hat" , mu ,  sigma ,  observed = y ,  dims = [ "obs_ind" ,  "treated_units" ] )
429+             self . priors [ "y_hat" ]. create_likelihood_variable ("y_hat" , mu = mu ,  observed = y )
323430
324431
325432class  WeightedSumFitter (PyMCModel ):
@@ -362,23 +469,56 @@ class WeightedSumFitter(PyMCModel):
362469    Inference data... 
363470    """   # noqa: W605 
364471
472+     default_priors  =  {
473+         "y_hat" : Prior (
474+             "Normal" ,
475+             sigma = Prior ("HalfNormal" , sigma = 1 , dims = ["treated_units" ]),
476+             dims = ["obs_ind" , "treated_units" ],
477+         ),
478+     }
479+ 
480+     def  priors_from_data (self , X , y ) ->  Dict [str , Any ]:
481+         """ 
482+         Set Dirichlet prior for weights based on number of control units. 
483+ 
484+         For synthetic control models, this method sets the shape parameter of the 
485+         Dirichlet prior on the control unit weights (`beta`) to be uniform across 
486+         all available control units. This ensures that all control units have 
487+         equal prior probability of contributing to the synthetic control. 
488+ 
489+         Parameters 
490+         ---------- 
491+         X : xarray.DataArray 
492+             Control unit data with shape (n_obs, n_control_units). 
493+         y : xarray.DataArray 
494+             Treated unit outcome data. 
495+ 
496+         Returns 
497+         ------- 
498+         Dict[str, Prior] 
499+             Dictionary containing: 
500+             - "beta": Dirichlet prior with shape=(1,...,1) for n_control_units 
501+         """ 
502+         n_predictors  =  X .shape [1 ]
503+         return  {
504+             "beta" : Prior (
505+                 "Dirichlet" , a = np .ones (n_predictors ), dims = ["treated_units" , "coeffs" ]
506+             ),
507+         }
508+ 
365509    def  build_model (self , X , y , coords ):
366510        """ 
367511        Defines the PyMC model 
368512        """ 
369513        with  self :
370514            self .add_coords (coords )
371-             n_predictors  =  X .sizes ["coeffs" ]
372515            X  =  pm .Data ("X" , X , dims = ["obs_ind" , "coeffs" ])
373516            y  =  pm .Data ("y" , y , dims = ["obs_ind" , "treated_units" ])
374-             beta  =  pm .Dirichlet (
375-                 "beta" , a = np .ones (n_predictors ), dims = ["treated_units" , "coeffs" ]
376-             )
377-             sigma  =  pm .HalfNormal ("sigma" , 1 , dims = "treated_units" )
517+             beta  =  self .priors ["beta" ].create_variable ("beta" )
378518            mu  =  pm .Deterministic (
379519                "mu" , pt .dot (X , beta .T ), dims = ["obs_ind" , "treated_units" ]
380520            )
381-             pm . Normal ("y_hat" , mu ,  sigma ,  observed = y ,  dims = [ "obs_ind" ,  "treated_units" ] )
521+             self . priors [ "y_hat" ]. create_likelihood_variable ("y_hat" , mu = mu ,  observed = y )
382522
383523
384524class  InstrumentalVariableRegression (PyMCModel ):
@@ -568,21 +708,18 @@ class PropensityScore(PyMCModel):
568708    Inference... 
569709    """   # noqa: W605 
570710
571-     def  build_model (self , X , t , coords , prior , noncentred ):
711+     default_priors  =  {
712+         "b" : Prior ("Normal" , mu = 0 , sigma = 1 , dims = "coeffs" ),
713+     }
714+ 
715+     def  build_model (self , X , t , coords , prior = None , noncentred = True ):
572716        "Defines the PyMC propensity model" 
573717        with  self :
574718            self .add_coords (coords )
575719            X_data  =  pm .Data ("X" , X , dims = ["obs_ind" , "coeffs" ])
576720            t_data  =  pm .Data ("t" , t .flatten (), dims = "obs_ind" )
577-             if  noncentred :
578-                 mu_beta , sigma_beta  =  prior ["b" ]
579-                 beta_std  =  pm .Normal ("beta_std" , 0 , 1 , dims = "coeffs" )
580-                 b  =  pm .Deterministic (
581-                     "beta_" , mu_beta  +  sigma_beta  *  beta_std , dims = "coeffs" 
582-                 )
583-             else :
584-                 b  =  pm .Normal ("b" , mu = prior ["b" ][0 ], sigma = prior ["b" ][1 ], dims = "coeffs" )
585-             mu  =  pm .math .dot (X_data , b )
721+             b  =  self .priors ["b" ].create_variable ("b" )
722+             mu  =  pt .dot (X_data , b )
586723            p  =  pm .Deterministic ("p" , pm .math .invlogit (mu ))
587724            pm .Bernoulli ("t_pred" , p = p , observed = t_data , dims = "obs_ind" )
588725
0 commit comments