diff --git a/causalpy/experiments/interrupted_time_series.py b/causalpy/experiments/interrupted_time_series.py index 95c6d886..a84327e8 100644 --- a/causalpy/experiments/interrupted_time_series.py +++ b/causalpy/experiments/interrupted_time_series.py @@ -17,16 +17,15 @@ from typing import List, Union -import arviz as az import numpy as np import pandas as pd import xarray as xr from matplotlib import pyplot as plt -from patsy import build_design_matrices, dmatrices +from patsy import dmatrices from sklearn.base import RegressorMixin from causalpy.custom_exceptions import BadIndexException -from causalpy.plot_utils import get_hdi_to_df, plot_xY +from causalpy.plot_utils import plot_xY from causalpy.pymc_models import PyMCModel from causalpy.utils import round_num @@ -85,99 +84,138 @@ def __init__( **kwargs, ) -> None: super().__init__(model=model) - # rename the index to "obs_ind" data.index.name = "obs_ind" self.input_validation(data, treatment_time) self.treatment_time = treatment_time - # set experiment type - usually done in subclasses self.expt_type = "Pre-Post Fit" - # split data in to pre and post intervention - self.datapre = data[data.index < self.treatment_time] - self.datapost = data[data.index >= self.treatment_time] - self.formula = formula + self.data = self._build_data(data) + self.algorithm() - # set things up with pre-intervention data - y, X = dmatrices(formula, self.datapre) - self.outcome_variable_name = y.design_info.column_names[0] - self._y_design_info = y.design_info - self._x_design_info = X.design_info - self.labels = X.design_info.column_names - self.pre_y, self.pre_X = np.asarray(y), np.asarray(X) - # process post-intervention data - (new_y, new_x) = build_design_matrices( - [self._y_design_info, self._x_design_info], self.datapost - ) - self.post_X = np.asarray(new_x) - self.post_y = np.asarray(new_y) - # turn into xarray.DataArray's - self.pre_X = xr.DataArray( - self.pre_X, - dims=["obs_ind", "coeffs"], - coords={ - "obs_ind": self.datapre.index, - "coeffs": self.labels, - }, - ) - self.pre_y = xr.DataArray( - self.pre_y, # Keep 2D shape - dims=["obs_ind", "treated_units"], - coords={"obs_ind": self.datapre.index, "treated_units": ["unit_0"]}, - ) - self.post_X = xr.DataArray( - self.post_X, - dims=["obs_ind", "coeffs"], - coords={ - "obs_ind": self.datapost.index, - "coeffs": self.labels, - }, - ) - self.post_y = xr.DataArray( - self.post_y, # Keep 2D shape - dims=["obs_ind", "treated_units"], - coords={"obs_ind": self.datapost.index, "treated_units": ["unit_0"]}, - ) + def algorithm(self) -> None: + """Execute the core interrupted time series algorithm. - # fit the model to the observed (pre-intervention) data + This method implements the standard interrupted time series analysis workflow: + 1. Fit model on pre-intervention data + 2. Score model goodness of fit + 3. Generate predictions for pre and post periods + 4. Calculate causal impact and cumulative impact + """ + # 1. Fit the model to the observed (pre-intervention) data if isinstance(self.model, PyMCModel): COORDS = { "coeffs": self.labels, - "obs_ind": np.arange(self.pre_X.shape[0]), + "obs_ind": np.arange(self.data.X.sel(period="pre").shape[0]), "treated_units": ["unit_0"], } - self.model.fit(X=self.pre_X, y=self.pre_y, coords=COORDS) + self.model.fit( + X=self.data.X.sel(period="pre"), + y=self.data.y.sel(period="pre"), + coords=COORDS, + ) elif isinstance(self.model, RegressorMixin): # For OLS models, use 1D y data - self.model.fit(X=self.pre_X, y=self.pre_y.isel(treated_units=0)) + self.model.fit( + X=self.data.X.sel(period="pre"), + y=self.data.y.sel(period="pre").isel(treated_units=0), + ) else: raise ValueError("Model type not recognized") - # score the goodness of fit to the pre-intervention data - self.score = self.model.score(X=self.pre_X, y=self.pre_y) - - # get the model predictions of the observed (pre-intervention) data - self.pre_pred = self.model.predict(X=self.pre_X) - - # calculate the counterfactual - self.post_pred = self.model.predict(X=self.post_X) + # 2. Score the goodness of fit to the pre-intervention data + self.score = self.model.score( + X=self.data.X.sel(period="pre"), y=self.data.y.sel(period="pre") + ) - # calculate impact - use appropriate y data format for each model type + # 3. Generate predictions for the full dataset using unified approach + # This creates predictions aligned with our complete time series if isinstance(self.model, PyMCModel): - # PyMC models work with 2D data - self.pre_impact = self.model.calculate_impact(self.pre_y, self.pre_pred) - self.post_impact = self.model.calculate_impact(self.post_y, self.post_pred) - elif isinstance(self.model, RegressorMixin): - # SKL models work with 1D data - self.pre_impact = self.model.calculate_impact( - self.pre_y.isel(treated_units=0), self.pre_pred + # PyMC models expect xarray DataArrays + self.predictions = self.model.predict(X=self.data.X) + # Add period coordinate to predictions - InferenceData handles multiple data arrays + self.predictions = self.predictions.assign_coords( + period=("obs_ind", self.data.period.data) ) - self.post_impact = self.model.calculate_impact( - self.post_y.isel(treated_units=0), self.post_pred + else: + # Sklearn models expect numpy arrays + pred_array = self.model.predict(X=self.data.X.values) + # Create xarray DataArray with period coordinate + self.predictions = xr.DataArray( + pred_array, + dims=["obs_ind"], + coords={ + "obs_ind": self.data.obs_ind, + "period": ("obs_ind", self.data.period.data), + }, + ).set_xindex("period") + + # 4. Calculate impact + if isinstance(self.model, PyMCModel): + # Calculate impact for the entire time series at once + self.impact = self.model.calculate_impact(self.data.y, self.predictions) + # Assign period coordinate to unified impact and set index + self.impact = self.impact.assign_coords( + period=("obs_ind", self.data.period.data) + ).set_xindex("period") + else: + # For sklearn: calculate unified impact as DataArray + observed_values = self.data.y.isel(treated_units=0).values + predicted_values = self.predictions.values + impact_values = observed_values - predicted_values + + self.impact = xr.DataArray( + impact_values, + dims=["obs_ind"], + coords={ + "obs_ind": self.data.obs_ind, + "period": ("obs_ind", self.data.period.data), + }, + ).set_xindex("period") + + # 5. Calculate cumulative impact (only on post-intervention period) + post_impact = self.impact.sel(period="post") + if isinstance(self.model, PyMCModel): + self.post_impact_cumulative = self.model.calculate_cumulative_impact( + post_impact ) - - self.post_impact_cumulative = self.model.calculate_cumulative_impact( - self.post_impact - ) + else: + # For sklearn: simple cumulative sum + self.post_impact_cumulative = post_impact.cumsum() + + def _build_data(self, data: pd.DataFrame) -> xr.Dataset: + """Build the experiment dataset as unified time series with period coordinate.""" + # Build design matrices for the complete dataset directly + y_full, X_full = dmatrices(self.formula, data) + + # Store metadata from the design matrices + self.outcome_variable_name = y_full.design_info.column_names[0] + self._y_design_info = y_full.design_info + self._x_design_info = X_full.design_info + self.labels = X_full.design_info.column_names + + # Create period coordinate based on treatment time + period_coord = xr.where(data.index < self.treatment_time, "pre", "post") + + # Return as a xarray.Dataset + common_coords = { + "obs_ind": data.index, + "period": ("obs_ind", period_coord), + } + + return xr.Dataset( + { + "X": xr.DataArray( + np.asarray(X_full), + dims=["obs_ind", "coeffs"], + coords={**common_coords, "coeffs": self.labels}, + ), + "y": xr.DataArray( + np.asarray(y_full), + dims=["obs_ind", "treated_units"], + coords={**common_coords, "treated_units": ["unit_0"]}, + ), + } + ).set_xindex("period") def input_validation(self, data, treatment_time): """Validate the input data and model formula for correctness""" @@ -208,7 +246,7 @@ def _bayesian_plot( self, round_to=None, **kwargs ) -> tuple[plt.Figure, List[plt.Axes]]: """ - Plot the results + Plot the results using unified predictions with period coordinates. :param round_to: Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers. @@ -216,11 +254,21 @@ def _bayesian_plot( counterfactual_label = "Counterfactual" fig, ax = plt.subplots(3, 1, sharex=True, figsize=(7, 8)) + + # Extract pre/post predictions - InferenceData doesn't support .sel() with period + # but .where() works fine with coordinates + pre_pred = self.predictions["posterior_predictive"].where( + self.predictions["posterior_predictive"].period == "pre", drop=True + ) + post_pred = self.predictions["posterior_predictive"].where( + self.predictions["posterior_predictive"].period == "post", drop=True + ) + # TOP PLOT -------------------------------------------------- # pre-intervention period h_line, h_patch = plot_xY( - self.datapre.index, - self.pre_pred["posterior_predictive"].mu.isel(treated_units=0), + self.data.X.sel(period="pre").obs_ind, + pre_pred.mu.isel(treated_units=0), ax=ax[0], plot_hdi_kwargs={"color": "C0"}, ) @@ -228,10 +276,8 @@ def _bayesian_plot( labels = ["Pre-intervention period"] (h,) = ax[0].plot( - self.datapre.index, - self.pre_y.isel(treated_units=0) - if hasattr(self.pre_y, "isel") - else self.pre_y[:, 0], + self.data.X.sel(period="pre").obs_ind, + self.data.y.sel(period="pre").isel(treated_units=0), "k.", label="Observations", ) @@ -240,8 +286,8 @@ def _bayesian_plot( # post intervention period h_line, h_patch = plot_xY( - self.datapost.index, - self.post_pred["posterior_predictive"].mu.isel(treated_units=0), + self.data.X.sel(period="post").obs_ind, + post_pred.mu.isel(treated_units=0), ax=ax[0], plot_hdi_kwargs={"color": "C1"}, ) @@ -249,24 +295,17 @@ def _bayesian_plot( labels.append(counterfactual_label) ax[0].plot( - self.datapost.index, - self.post_y.isel(treated_units=0) - if hasattr(self.post_y, "isel") - else self.post_y[:, 0], + self.data.X.sel(period="post").obs_ind, + self.data.y.sel(period="post").isel(treated_units=0), "k.", ) - # Shaded causal effect - post_pred_mu = ( - az.extract(self.post_pred, group="posterior_predictive", var_names="mu") - .isel(treated_units=0) - .mean("sample") - ) # Add .mean("sample") to get 1D array + + # Shaded causal effect - use direct calculation + post_pred_mu = post_pred.mu.mean(dim=["chain", "draw"]).isel(treated_units=0) h = ax[0].fill_between( - self.datapost.index, + self.data.X.sel(period="post").obs_ind, y1=post_pred_mu, - y2=self.post_y.isel(treated_units=0) - if hasattr(self.post_y, "isel") - else self.post_y[:, 0], + y2=self.data.y.sel(period="post").isel(treated_units=0), color="C0", alpha=0.25, ) @@ -282,21 +321,23 @@ def _bayesian_plot( # MIDDLE PLOT ----------------------------------------------- plot_xY( - self.datapre.index, - self.pre_impact.isel(treated_units=0), + self.data.X.sel(period="pre").obs_ind, + self.impact.sel(period="pre").isel(treated_units=0), ax=ax[1], plot_hdi_kwargs={"color": "C0"}, ) plot_xY( - self.datapost.index, - self.post_impact.isel(treated_units=0), + self.data.X.sel(period="post").obs_ind, + self.impact.sel(period="post").isel(treated_units=0), ax=ax[1], plot_hdi_kwargs={"color": "C1"}, ) ax[1].axhline(y=0, c="k") ax[1].fill_between( - self.datapost.index, - y1=self.post_impact.mean(["chain", "draw"]).isel(treated_units=0), + self.data.X.sel(period="post").obs_ind, + y1=self.impact.sel(period="post") + .mean(["chain", "draw"]) + .isel(treated_units=0), color="C0", alpha=0.25, label="Causal impact", @@ -306,7 +347,7 @@ def _bayesian_plot( # BOTTOM PLOT ----------------------------------------------- ax[2].set(title="Cumulative Causal Impact") plot_xY( - self.datapost.index, + self.data.X.sel(period="post").obs_ind, self.post_impact_cumulative.isel(treated_units=0), ax=ax[2], plot_hdi_kwargs={"color": "C1"}, @@ -332,7 +373,7 @@ def _bayesian_plot( def _ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, List[plt.Axes]]: """ - Plot the results + Plot the results using unified predictions with period coordinates. :param round_to: Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers. @@ -341,13 +382,43 @@ def _ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, List[plt.Axes] fig, ax = plt.subplots(3, 1, sharex=True, figsize=(7, 8)) - ax[0].plot(self.datapre.index, self.pre_y, "k.") - ax[0].plot(self.datapost.index, self.post_y, "k.") + # Extract pre/post predictions - handle PyMC vs sklearn differently + if isinstance(self.model, PyMCModel): + # For PyMC models, predictions is InferenceData - use .where() with coordinates + pre_pred = ( + self.predictions["posterior_predictive"] + .where( + self.predictions["posterior_predictive"].period == "pre", drop=True + ) + .mu.mean(dim=["chain", "draw"]) + .isel(treated_units=0) + ) + post_pred = ( + self.predictions["posterior_predictive"] + .where( + self.predictions["posterior_predictive"].period == "post", drop=True + ) + .mu.mean(dim=["chain", "draw"]) + .isel(treated_units=0) + ) + else: + # For sklearn models, predictions is DataArray - use .sel() with indexed coordinates + pre_pred = self.predictions.sel(period="pre") + post_pred = self.predictions.sel(period="post") - ax[0].plot(self.datapre.index, self.pre_pred, c="k", label="model fit") ax[0].plot( - self.datapost.index, - self.post_pred, + self.data.X.sel(period="pre").obs_ind, self.data.y.sel(period="pre"), "k." + ) + ax[0].plot( + self.data.X.sel(period="post").obs_ind, self.data.y.sel(period="post"), "k." + ) + + ax[0].plot( + self.data.X.sel(period="pre").obs_ind, pre_pred, c="k", label="model fit" + ) + ax[0].plot( + self.data.X.sel(period="post").obs_ind, + post_pred, label=counterfactual_label, ls=":", c="k", @@ -356,32 +427,36 @@ def _ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, List[plt.Axes] title=f"$R^2$ on pre-intervention data = {round_num(self.score, round_to)}" ) - ax[1].plot(self.datapre.index, self.pre_impact, "k.") ax[1].plot( - self.datapost.index, - self.post_impact, + self.data.X.sel(period="pre").obs_ind, self.impact.sel(period="pre"), "k." + ) + ax[1].plot( + self.data.X.sel(period="post").obs_ind, + self.impact.sel(period="post"), "k.", label=counterfactual_label, ) ax[1].axhline(y=0, c="k") ax[1].set(title="Causal Impact") - ax[2].plot(self.datapost.index, self.post_impact_cumulative, c="k") + ax[2].plot( + self.data.X.sel(period="post").obs_ind, self.post_impact_cumulative, c="k" + ) ax[2].axhline(y=0, c="k") ax[2].set(title="Cumulative Causal Impact") # Shaded causal effect ax[0].fill_between( - self.datapost.index, - y1=np.squeeze(self.post_pred), - y2=np.squeeze(self.post_y), + self.data.X.sel(period="post").obs_ind, + y1=np.squeeze(post_pred), + y2=np.squeeze(self.data.y.sel(period="post")), color="C0", alpha=0.25, label="Causal impact", ) ax[1].fill_between( - self.datapost.index, - y1=np.squeeze(self.post_impact), + self.data.X.sel(period="post").obs_ind, + y1=np.squeeze(self.impact.sel(period="post")), color="C0", alpha=0.25, label="Causal impact", @@ -409,77 +484,60 @@ def get_plot_data_bayesian(self, hdi_prob: float = 0.94) -> pd.DataFrame: :param hdi_prob: Prob for which the highest density interval will be computed. The default value is defined as the default from the :func:`arviz.hdi` function. """ - if isinstance(self.model, PyMCModel): - hdi_pct = int(round(hdi_prob * 100)) + if not isinstance(self.model, PyMCModel): + raise ValueError("Unsupported model type") - pred_lower_col = f"pred_hdi_lower_{hdi_pct}" - pred_upper_col = f"pred_hdi_upper_{hdi_pct}" - impact_lower_col = f"impact_hdi_lower_{hdi_pct}" - impact_upper_col = f"impact_hdi_upper_{hdi_pct}" + hdi_pct = int(round(hdi_prob * 100)) - pre_data = self.datapre.copy() - post_data = self.datapost.copy() + # Start with the outcome data from our unified dataset + plot_data = pd.DataFrame( + {self.outcome_variable_name: self.data.y.isel(treated_units=0).values}, + index=self.data.y.obs_ind.values, + ) - pre_data["prediction"] = ( - az.extract(self.pre_pred, group="posterior_predictive", var_names="mu") - .mean("sample") - .isel(treated_units=0) - .values - ) - post_data["prediction"] = ( - az.extract(self.post_pred, group="posterior_predictive", var_names="mu") - .mean("sample") - .isel(treated_units=0) - .values - ) - hdi_pre_pred = get_hdi_to_df( - self.pre_pred["posterior_predictive"].mu, hdi_prob=hdi_prob - ) - hdi_post_pred = get_hdi_to_df( - self.post_pred["posterior_predictive"].mu, hdi_prob=hdi_prob - ) - # Select the single unit from the MultiIndex results - pre_data[[pred_lower_col, pred_upper_col]] = hdi_pre_pred.xs( - "unit_0", level="treated_units" - ).set_index(pre_data.index) - post_data[[pred_lower_col, pred_upper_col]] = hdi_post_pred.xs( - "unit_0", level="treated_units" - ).set_index(post_data.index) - - pre_data["impact"] = ( - self.pre_impact.mean(dim=["chain", "draw"]).isel(treated_units=0).values - ) - post_data["impact"] = ( - self.post_impact.mean(dim=["chain", "draw"]) - .isel(treated_units=0) - .values - ) - hdi_pre_impact = get_hdi_to_df(self.pre_impact, hdi_prob=hdi_prob) - hdi_post_impact = get_hdi_to_df(self.post_impact, hdi_prob=hdi_prob) - # Select the single unit from the MultiIndex results - pre_data[[impact_lower_col, impact_upper_col]] = hdi_pre_impact.xs( - "unit_0", level="treated_units" - ).set_index(pre_data.index) - post_data[[impact_lower_col, impact_upper_col]] = hdi_post_impact.xs( - "unit_0", level="treated_units" - ).set_index(post_data.index) - - self.plot_data = pd.concat([pre_data, post_data]) - - return self.plot_data - else: - raise ValueError("Unsupported model type") + # Extract predictions directly from unified predictions object + pred_mu = self.predictions["posterior_predictive"].mu.isel(treated_units=0) + plot_data["prediction"] = pred_mu.mean(dim=["chain", "draw"]).values + + # Extract impact directly from unified impact - no more calculation needed! + plot_data["impact"] = ( + self.impact.mean(dim=["chain", "draw"]).isel(treated_units=0).values + ) + + # Calculate HDI bounds directly using arviz + import arviz as az + + pred_hdi = az.hdi(pred_mu, hdi_prob=hdi_prob) + impact_hdi = az.hdi(self.impact.isel(treated_units=0), hdi_prob=hdi_prob) + + # Extract HDI bounds from xarray Dataset results + pred_var_name = list(pred_hdi.data_vars.keys())[0] + impact_var_name = list(impact_hdi.data_vars.keys())[0] + + pred_hdi_data = pred_hdi[pred_var_name] + impact_hdi_data = impact_hdi[impact_var_name] + + plot_data[f"pred_hdi_lower_{hdi_pct}"] = pred_hdi_data.isel(hdi=0).values + plot_data[f"pred_hdi_upper_{hdi_pct}"] = pred_hdi_data.isel(hdi=1).values + plot_data[f"impact_hdi_lower_{hdi_pct}"] = impact_hdi_data.isel(hdi=0).values + plot_data[f"impact_hdi_upper_{hdi_pct}"] = impact_hdi_data.isel(hdi=1).values + + self.plot_data = plot_data + return plot_data def get_plot_data_ols(self) -> pd.DataFrame: """ Recover the data of the experiment along with the prediction and causal impact information. """ - pre_data = self.datapre.copy() - post_data = self.datapost.copy() - pre_data["prediction"] = self.pre_pred - post_data["prediction"] = self.post_pred - pre_data["impact"] = self.pre_impact - post_data["impact"] = self.post_impact - self.plot_data = pd.concat([pre_data, post_data]) + # Create unified DataFrame from our xarray data + plot_data = pd.DataFrame( + {self.outcome_variable_name: self.data.y.isel(treated_units=0).values}, + index=self.data.y.obs_ind.values, + ) + + # Extract directly from unified data structures - ultimate simplification! + plot_data["prediction"] = self.predictions.values + plot_data["impact"] = self.impact.values + self.plot_data = plot_data return self.plot_data diff --git a/docs/source/_static/classes.png b/docs/source/_static/classes.png index 2dda20e6..ad3834f8 100644 Binary files a/docs/source/_static/classes.png and b/docs/source/_static/classes.png differ diff --git a/docs/source/_static/packages.png b/docs/source/_static/packages.png index 5a537cd0..65e70f8a 100644 Binary files a/docs/source/_static/packages.png and b/docs/source/_static/packages.png differ