-
Notifications
You must be signed in to change notification settings - Fork 248
Sampling with likelihoods #79
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
mhavasi
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey, thanks for the PR. Can you also add a test that compares the output of the forward compute_likelihood and the backward compute_likelihood to make sure they match?
1172ef7 to
0e04528
Compare
| torch.allclose(forward_log_likelihood, backward_log_likelihood, atol=1e-2), | ||
| ) | ||
|
|
||
| def test_forward_backward_likelihoods(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey, thanks for the PR. Can you also add a test that compares the output of the forward compute_likelihood and the backward compute_likelihood to make sure they match?
Is this sufficient? It computes the forward likelihood, while generating samples. Then using those samples to compute the backward likelihood and compare them.
The odesolver.compute_likelihood does not allow a forward computation, as it also would't make sense without x1 samples?!
Maybe I am misunderstanding sth. here ...
|
Hey @timonpalm , thanks for the reply. Can you rebase your branch to main so we can run the unit tests? We just added the option to run unit tests on external PRs recently. |
1323874 to
5d778c9
Compare
5d778c9 to
4404acb
Compare
Issue #62
I expanded the
ODESolver.sample()method to compute likelihoods along the way when sampling, alleviating the two step integration.You just have to pass the
log_p0parameter to the function.