Skip to content

Conversation

markusschmaus
Copy link
Contributor

@markusschmaus markusschmaus commented Aug 31, 2022

What is this PR about?
For VI the mean and std of approximations is currently only available as an unstructured flat Aesara Variable. This leads to frequent questions on how to extract these properties from the posterior. This PR creates two new properties which evaluate the Aesara Variables and transforms them into a structured xarray Dataset using the available coords.

See also:
https://discourse.pymc.io/t/quality-of-life-improvements-to-advi/10254

Checklist

Major / Breaking Changes

  • None

Bugfixes / New features

  • new feature: get mean and std as xarray data set

Docs / Maintenance

  • Included doc strings for existing mean, std, and cov properties

@ricardoV94 ricardoV94 added the VI Variational Inference label Aug 31, 2022
@markusschmaus
Copy link
Contributor Author

The error message of "Read the Docs build" don't look like they have anything to do with this PR. Is there something wrong with the doc strings?

@canyon289
Copy link
Member

@markusschmaus Thanks for your contribution! For the read the docs error yes please ignore that, sorry for the false alarm.

For your code submission, I'll review it now at a "code level" but will defer to my more VI colleagues here for the math and user questions. Which brings me to my next question, right now the PR is marked draft, did you want a review now or were you still planning on working this some more?

@codecov
Copy link

codecov bot commented Sep 1, 2022

Codecov Report

Merging #6086 (3c6af3a) into main (0b191ad) will increase coverage by 0.06%.
The diff coverage is 94.44%.

❗ Current head 3c6af3a differs from pull request most recent head 671cb9c. Consider uploading reports for the commit 671cb9c to get more accurate results

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #6086      +/-   ##
==========================================
+ Coverage   89.54%   89.60%   +0.06%     
==========================================
  Files          72       72              
  Lines       12929    12947      +18     
==========================================
+ Hits        11577    11601      +24     
+ Misses       1352     1346       -6     
Impacted Files Coverage Δ
pymc/variational/opvi.py 87.24% <94.44%> (+0.20%) ⬆️
pymc/step_methods/hmc/base_hmc.py 90.55% <0.00%> (+0.78%) ⬆️
pymc/variational/approximations.py 90.14% <0.00%> (+2.81%) ⬆️

@markusschmaus
Copy link
Contributor Author

Thanks, I left it at draft level as I was still trying to debug the docs issue. I'm polishing up a few things and then I will change the status.

@markusschmaus markusschmaus marked this pull request as ready for review September 1, 2022 13:35
@fonnesbeck
Copy link
Member

Should an InferenceData object be used to store output rather than a raw xarray? Just thinking in terms of consistency with the rest of PyMC, and a natural place to put samples from the approximate posterior.

@markusschmaus
Copy link
Contributor Author

I thought about using an InferenceData object, but it doesn't really fit, as it is meant for storing samples and not the parameters of the posterior approximation.

It's already possible to get a true InferenceData object by calling sample on the approximation, including samples from the approximate posterior.

@fonnesbeck
Copy link
Member

I was thinking something along the lines of the "sample_stats" or "observed_data" entries in the InferenceData object which are not samples, but related variables from the model.

@markusschmaus
Copy link
Contributor Author

Let's go through the options:

https://python.arviz.org/en/latest/schema/schema.html#schema

  • posterior, posterior_predictive, prior, prior_predictive, predictions: No, since all of these are supposed to be samples
  • sample_stats_prior: No, the approximation isn't related to any prior samples
  • log_likelihood: No
  • observed_data: No, these is supposed to be data the posterior is conditional on
  • predictions_constant_data: No, since the approximation has nothing to do with any predictions
  • constant_data: No, since the approximations are not data included in the model
  • sample_stats: Probably not, since this is supposed to relate to the samples in the posterior group, which we don't have

So a straight forward xarray looks best to me.

@ricardoV94
Copy link
Member

ricardoV94 commented Sep 2, 2022

Can we just get a dictionary? Nevermind you are passing dims around as well.

@markusschmaus
Copy link
Contributor Author

Yeah, I find the coords just too useful not to use them. I considered returning a dict of numpy arrays when no coords are given, but this would result in an inconsistent return type, which I always find a pain to deal with when a library does this.

@fonnesbeck
Copy link
Member

Was hoping you'd be able to add custom attributes to InferenceData but I guess you can't. 😢

@ricardoV94
Copy link
Member

Was hoping you'd be able to add custom attributes to InferenceData but I guess you can't. 😢

What do you mean? You can last time I checked

@markusschmaus
Copy link
Contributor Author

markusschmaus commented Sep 12, 2022

The spec is only enforced with a warning:

for key in kwargs:
    if key not in SUPPORTED_GROUPS_ALL:
        key_list.append(key)
        warnings.warn(
            f"{key} group is not defined in the InferenceData scheme", UserWarning
        )

https://github.com/arviz-devs/arviz/blob/2a7bf0f2cb26bfe273e800406249547507d4fdd4/arviz/data/inference_data.py#L147

So I could ignore this warning and wrap the xarrays in an Inference data object which doesn't conform to the spec, though I still don't see any benefits for doing so. It wouldn't make sense to start sampling just to be able to fill any of the other fields, since the whole point of this PR is to give the user the ability to extract mean and std without sampling.

If it's about the syntax and you prefer approx.params_data["mean"] to approx.mean_data, it would be an option to wrap them them in a dictionary.

@ricardoV94
Copy link
Member

I meant you can add attributes to one of the "allowed" groups. In this case I was thinking you could add it to the posterior group.

@markusschmaus
Copy link
Contributor Author

The point of the PR is to avoid sampling just for extracting the mean and std, so there are no posterior samples and no posterior group. I could create an empty group, but I still don't see the benefit.

@ricardoV94
Copy link
Member

The point of the PR is to avoid sampling just for extracting the mean and std, so there are no posterior samples and no posterior group. I could create an empty group, but I still don't see the benefit.

Fair enough

@markusschmaus
Copy link
Contributor Author

@fonnesbeck @ricardoV94 What's the future of this PR?

@OriolAbril
Copy link
Member

Can we just get a dictionary? Nevermind you are passing dims around as well.

General PSA, datasets have a dict-like interface so to most ends you can simply ignore the fact you have a dataset and treat it as a dictionary.

@ghost
Copy link

ghost commented Nov 29, 2022

Let's rebase and merge this.

@fonnesbeck
Copy link
Member

@markusschmaus are you good with us merging this as-is, or are there any additional changes you'd like to make? Sorry its taken so long--kind of fell off the radar!

@fonnesbeck
Copy link
Member

Merge conflicts fixed by #6387

@fonnesbeck fonnesbeck closed this Dec 12, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

VI Variational Inference

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants