-
-
Notifications
You must be signed in to change notification settings - Fork 20
Implement censored log-probabilities via the Clip Op
#22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov Report
@@ Coverage Diff @@
## main #22 +/- ##
==========================================
- Coverage 94.92% 94.84% -0.08%
==========================================
Files 9 8 -1
Lines 1260 1106 -154
Branches 164 133 -31
==========================================
- Hits 1196 1049 -147
+ Misses 31 27 -4
+ Partials 33 30 -3
Continue to review full report at Codecov.
|
|
To make broadcasting work, you should be able to use something like |
That's a very interesting idea! Can it be combined with other indices (e.g. Regardless, don't make this PR conditional on such extensions. Let's get |
|
Tip: if you assign names to your test Also, don't forget about test values! They will cause errors to arise during graph construction (i.e. where the symbolic objects themselves are defined). That combined with |
Still couldn't fix my lb_rv = at.random.uniform(0, 1, name="lb_rv")
x_rv = at.random.normal(0, 2, name="x_rv")
cens_x_rv = at.clip(x_rv, lb_rv, [1, 1])
cens_x_rv.name = "cens_x_rv"
lb = lb_rv.type()
lb.name = "lb"
cens_x = cens_x_rv.type()
cens_x.name = "cens_x"
logp = joint_logprob(cens_x_rv, {cens_x_rv: cens_x, lb_rv: lb})
assert_no_rvs(logp)These are the fgraphs printouts before and after the optimization phase in That lower_bound
# uniform_rv.out
lower_bound in rv_map_feature.rv_values
# False
lower_bound.owner.tag
# scratchpad{'imported_by': ['local_dimshuffle_rv_lift']}Also that opt seems to not propagate the variable name lower_bound.name
# None |
If I set tthe flag to "warn", even a simple unform logp raises a lot of "Cannot compute test value..." for every node in the logp. Is this something we need to address? @aesara.config.change_flags(compute_test_value='warn')
def test_compute_test_value():
x_rv = at.random.uniform(-1, 1)
x = x_rv.type()
logp = joint_logprob(x_rv, {x_rv: x})=============================== warnings summary ===============================
tests/test_truncation.py::test_compute_test_value
/home/ricardo/Documents/Projects/aeppl/venv/lib/python3.8/site-packages/aesara/graph/op.py:272: UserWarning: Warning, Cannot compute test value: input 0 (<TensorType(float64, scalar)>) of Op Elemwise{ge,no_inplace}(<TensorType(float64, scalar)>, TensorConstant{-1}) missing default value
compute_test_value(node)
tests/test_truncation.py::test_compute_test_value
/home/ricardo/Documents/Projects/aeppl/venv/lib/python3.8/site-packages/aesara/graph/op.py:272: UserWarning: Warning, Cannot compute test value: input 0 (<TensorType(float64, scalar)>) of Op Elemwise{le,no_inplace}(<TensorType(float64, scalar)>, TensorConstant{1}) missing default value
compute_test_value(node)
tests/test_truncation.py::test_compute_test_value
/home/ricardo/Documents/Projects/aeppl/venv/lib/python3.8/site-packages/aesara/graph/op.py:272: UserWarning: Warning, Cannot compute test value: input 0 (Elemwise{ge,no_inplace}.0) of Op Elemwise{and_,no_inplace}(Elemwise{ge,no_inplace}.0, Elemwise{le,no_inplace}.0) missing default value
compute_test_value(node)
tests/test_truncation.py::test_compute_test_value
/home/ricardo/Documents/Projects/aeppl/venv/lib/python3.8/site-packages/aesara/graph/op.py:272: UserWarning: Warning, Cannot compute test value: input 0 (<TensorType(float64, scalar)>) of Op Elemwise{second,no_inplace}(<TensorType(float64, scalar)>, Elemwise{neg,no_inplace}.0) missing default value
compute_test_value(node)
tests/test_truncation.py::test_compute_test_value
/home/ricardo/Documents/Projects/aeppl/venv/lib/python3.8/site-packages/aesara/graph/op.py:272: UserWarning: Warning, Cannot compute test value: input 0 (Elemwise{and_,no_inplace}.0) of Op Elemwise{switch,no_inplace}(Elemwise{and_,no_inplace}.0, Elemwise{second,no_inplace}.0, TensorConstant{-inf}) missing default value
compute_test_value(node)
tests/test_truncation.py::test_compute_test_value
/home/ricardo/Documents/Projects/aeppl/aeppl/joint_logprob.py:170: UserWarning: Warning, Cannot compute test value: input 0 (<TensorType(float64, scalar)>) of Op Elemwise{second,no_inplace}(<TensorType(float64, scalar)>, Elemwise{neg,no_inplace}.0) missing default value
compute_test_value(node)
tests/test_truncation.py::test_compute_test_value
/home/ricardo/Documents/Projects/aeppl/aeppl/joint_logprob.py:170: UserWarning: Warning, Cannot compute test value: input 0 (<TensorType(float64, scalar)>) of Op Elemwise{le,no_inplace}(<TensorType(float64, scalar)>, TensorConstant{1}) missing default value
compute_test_value(node)
tests/test_truncation.py::test_compute_test_value
/home/ricardo/Documents/Projects/aeppl/aeppl/joint_logprob.py:170: UserWarning: Warning, Cannot compute test value: input 0 (<TensorType(float64, scalar)>) of Op Elemwise{ge,no_inplace}(<TensorType(float64, scalar)>, TensorConstant{-1}) missing default value
compute_test_value(node)
tests/test_truncation.py::test_compute_test_value
/home/ricardo/Documents/Projects/aeppl/aeppl/joint_logprob.py:170: UserWarning: Warning, Cannot compute test value: input 0 (Elemwise{ge,no_inplace}.0) of Op Elemwise{and_,no_inplace}(Elemwise{ge,no_inplace}.0, Elemwise{le,no_inplace}.0) missing default value
compute_test_value(node)
tests/test_truncation.py::test_compute_test_value
/home/ricardo/Documents/Projects/aeppl/aeppl/joint_logprob.py:170: UserWarning: Warning, Cannot compute test value: input 0 (Elemwise{and_,no_inplace}.0) of Op Elemwise{switch,no_inplace}(Elemwise{and_,no_inplace}.0, Elemwise{second,no_inplace}.0, TensorConstant{-inf}) missing default value
compute_test_value(node)
-- Docs: https://docs.pytest.org/en/stable/warnings.html
======================== 1 passed, 10 warnings in 0.97s ========================
Process finished with exit code 0
PASSED [100%] |
It looks like you need to set s test value for |
|
I added tests for the logcdf methods and checked whether the new opts work with The only thing missing are the broadcasting / |
aeppl/truncation.py
Outdated
| def censor_logprob(op, value, *inputs, name=None, **kwargs): | ||
|
|
||
| *rv_params, lower_bound, upper_bound = inputs | ||
| logprob = _logprob(op.base_op, value, *rv_params, name=name, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could alternatively leave the original RandomVariable in the graph and make it an input of a CensoredRV, then this log-probability term would be taken care of automatically. The extra log-probability terms introduced by censoring can then be introduced by this dispatch function—like they currently are.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In that case I will have to direct the value var to the original variable in the rewrite phase (while keeping it pointing to the censor_logprob of course). I can give it a try, but I am not sure it will result in simpler code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After #19 is merged, you could also use .tag.ignore_logprob. With that, we can write a general RandomVariable wrapper/view Op that can be used as the base class for something like CensoredRV.
Ah, yeah, that would definitely cause such a problem; however, |
aeppl/truncation.py
Outdated
| logccdf = at.log(1 - at.exp(logcdf)) | ||
| # For right censored discrete RVs, we need to add an extra term | ||
| # corresponding to the pmf at the upper bound | ||
| if op.base_op.dtype == "int64": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should check against a collection of "discrete_dtypes".
Alternatively, could we add an attribute to RandomVariable to distinguish between discrete, continuous, and mixed variables? This could come handy in the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can create a metatype that does these kinds of checks. You can also use TensorType.numpy_dtype.kind as an easy way to check that something is a float, int, or bool.
In this case, I'm not sure such a check should be performed on the Op itself. Remember, RandomVariable Ops can have "indeterminate" dtypes: i.e. the dtype is determined within RandomVariable.make_node and is not a fixed class-level value. You might need to get the dtype from the RandomVariable's output instead.
|
Thinking a bit more about the generating models in the broadcasted test cases I wrote, I think it only makes sense to create a logprob graph when the base random variable has the same shape as the clipped output. # Scalar base, vector random bound - SUCCEEDS currently
lb_rv = at.random.normal(0, 1, size=2)
x_rv = at.random.normal(0, 2)
cens_x_rv = at.clip(x_rv, lb_rv, [1, 1])
# Scalar base, vector bound - FAILS currently
lb_rv = at.random.uniform(0, 1, name="lb_rv")
x_rv = at.random.normal(0, 2, name="x_rv")
cens_x_rv = at.clip(x_rv, lb_rv, [1, 1])
# Vector base RV, vector bound - FAILS currently
lb_rv = at.random.normal(0, 1, name="lb_rv")
x_rv = at.random.normal(0, 2, size=2, name="x_rv")
cens_x_rv = at.clip(x_rv, lb_rv, [1, 1])The logprob extraction for the first two graphs should fail because the two censored values in Is there a way to distinguish between these during the rewrite phase? |
Perhaps we could check the |
|
Is the broadcasting still the blocking issue/change here? |
No, I was trying an alternative that did not involve subclassing from RandomVariable. I'll try to get this back on board soon. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aside from the use of tag.ignore_logprob, this looks great.
Let's find a way to avoid using that feature, especially since nothing should be depending on it now and it's slated to be removed entirely.
Otherwise, if you want, submit the logcdf additions as a separate PR and we can push those through sooner.
aeppl/joint_logprob.py
Outdated
| # Filter out missing terms of variables with ignore_logprob | ||
| value_rvs = {v: k for k, v in updated_rv_values.items()} | ||
| for missing in tuple(missing_value_terms): | ||
| rv_of_missing = value_rvs.get(missing, None) | ||
| if rv_of_missing and getattr(rv_of_missing.tag, "ignore_logprob", False): | ||
| missing_value_terms.remove(missing) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If/when wee abandon the ignore_logprob this section can be removed. I had to add it for backwards compatibility with some tests in test_joint_logrob
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like you added two tests in a different commit that explicitly require this functionality: test_fail_multiple_censored_single_base and test_fail_base_and_censored_have_values. If you remove those, nothing will depend on this functionality and the commit can be removed.
This needs to be done before merging, if only because the changes are unrelated to censoring.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When I said something could be removed later, I meant the specific ignore_logprob flag logic, not the raising a RuntimeError if a variable is missing.
While developing the censored variables it would often fail silently and just return a graph with aesara clips unchanged and/or less terms than requested. This seems like a good way to catch such failures.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason why you don't want this type of check at the end of factorized_joint_logprob?
It's trivial to manually do the same check in those new tests I added, but this explicit check might be valuable enough to have as a default.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When I said something could be removed later, I meant the specific ignore_logprob flag logic, not the raising a RuntimeError if a variable is missing.
While developing the censored variables it would often fail silently and just return a graph with aesara clips unchanged and/or less terms than requested. This seems like a good way to catch such failures.
From a simple development and design perspective, it sounds like you're addressing a testing-specific issue within a feature implementation, and that's generally not good.
Otherwise, if something is failing silently the first question is "What's failing?". Is it the factorized_joint_logprob loop? If not, the failure should be addressed closer to where its primary logic/code resides, and that doesn't appear to be here.
Is there a reason why you don't want this type of check at the end of
factorized_joint_logprob?It's trivial to manually do the same check in those new tests I added, but this explicit check might be valuable enough to have as a default.
The reason why I don't want these kinds of unrelated changes is that their inclusion makes a PR contingent on additional review work and discussions.
It takes extra time and effort to go through logic like this and determine its relevance, risk, etc. These are things that need to be done within issues and/or at the outset of a PR (e.g. the premise/description of a PR) in order to avoid delaying the inclusion of any agreed upon changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From a simple development and design perspective, it sounds like you're addressing a testing-specific issue within a feature implementation, and that's generally not good.
What happened was that the test revealed that such a call to factorized_logprob would return with a missing logp term and zero complaints so I decided to add an explicit check there.
It was not for the sake of the test as that had already been solved and could be tested explicitly inside the test itself.
It was meant for further development when we introduce rewrites for Ops that are otherwise valid in logp graphs. It's also a conceptual obvious check for me: a user requested a dictionary of rv_values and we make sure we are returning a dictionary with a item for each original pair.
I don't mind splitting this into another PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed the commit but I am not very happy with the fact that this does not raise an error inside factorized_joint_logprob:
def test_fail_base_and_censored_have_values():
"""Test failure when both base_rv and clipped_rv are given value vars"""
x_rv = at.random.normal(0, 1)
cens_x_rv = at.clip(x_rv, x_rv, 1)
cens_x_rv.name = "cens_x"
x_vv = x_rv.clone()
cens_x_vv = cens_x_rv.clone()
logp_terms = factorized_joint_logprob({cens_x_rv: cens_x_vv, x_rv: x_vv})
assert cens_x_vv not in logp_termsThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again, why should the conditions be checked and the error be raised in factorized_joint_logprob specifically?
The two terms involved are very specific to the censored variable logic, and it looks like the error could've been initiated in find_censored_rvs—i.e. where all the relevant terms are identified and used. This approach could also short-circuit all the unnecessary down-stream logic, no?
You already have a warning there to that effect, so what do we gain by having an exception in factorized_joint_logprob?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it should be the rewrite responsibility to raise a failure. It may be there is another rewrite (e.g., added by users of the library) that can handle the conversion.
The bigger problem is that nothing happens if you ask for a graph that we don't know how to handle. That is not specific to censoredRVs, we just haven't tested it. For instance, this snippet does not complain at all:
import aesara.tensor as at
import aeppl
x_rv = at.random.normal(name='x')
y_rv = at.cos(x_rv)
x_vv = x_rv.clone()
y_vv = y_rv.clone()
logprob_dict = aeppl.factorized_joint_logprob({x_rv: x_vv, y_rv: y_vv})
logprob_dict
# {x: x_logprob}This snippet would be more realistic about what a user may try but is now failing for a different reason #87
logprob_dict = aeppl.factorized_joint_logprob({y_rv: y_vv})There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The bigger problem is that nothing happens if you ask for a graph that we don't know how to handle. That is not specific to censoredRVs, we just haven't tested it. For instance, this snippet does not complain at all:
If our rewrites don't know how to handle something, that's not necessarily a problem. As a matter of fact, we expect that they won't know how to handle more things than they do.
The assumption underlying your statements and example seems to be that you know what should be done relative to specific rewrites, and this is what makes it reasonable to handle rewrite-relevant errors in the rewrite logic.
In other words, if you know an error/warning should be raised because a value variable specification is redundant, you only really know that because you also know that there's a specific rewrite that determines which value variables are and aren't relevant.
Otherwise, a generic warning for "unused" variable/value mappings is simply an interface choice that might help inform people of issues elsewhere and/or bad assumptions (e.g. that the resulting graph will depend on certain terms), but that's all.
All this relates directly to #85.
aeppl/joint_logprob.py
Outdated
| # Filter out missing terms of variables with ignore_logprob | ||
| value_rvs = {v: k for k, v in updated_rv_values.items()} | ||
| for missing in tuple(missing_value_terms): | ||
| rv_of_missing = value_rvs.get(missing, None) | ||
| if rv_of_missing and getattr(rv_of_missing.tag, "ignore_logprob", False): | ||
| missing_value_terms.remove(missing) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like you added two tests in a different commit that explicitly require this functionality: test_fail_multiple_censored_single_base and test_fail_base_and_censored_have_values. If you remove those, nothing will depend on this functionality and the commit can be removed.
This needs to be done before merging, if only because the changes are unrelated to censoring.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use the same tests as before, but change the pytest.raises to look for the warnings instead.
958ec44 to
248d9f5
Compare
248d9f5 to
4830635
Compare
Clip Op
This PR implements logprob for censored (clipped) RVs.
I placed the new methods and tests inside
truncation.py, expecting this file wil also contain the methods for truncated RVs in the future.Some things are still not working well / missing:
Canonicalize set_subtensors to clipWill do in another PRx[x>ub] = ub -> clip(x, x, ub)logcdfmethodsCompute test values for new nodesSeems to not be necessary, tests pass withcompute_test_value="raise"Explore if CensoredRVs should be created even when they don't have a direct value variable (e.g, so that they can work as input to other derivedRVs)Postponed