-
-
Notifications
You must be signed in to change notification settings - Fork 20
Implement chained transforms #26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov Report
@@ Coverage Diff @@
## main #26 +/- ##
==========================================
+ Coverage 94.36% 94.74% +0.38%
==========================================
Files 11 11
Lines 1614 1733 +119
Branches 230 251 +21
==========================================
+ Hits 1523 1642 +119
Misses 51 51
Partials 40 40
Continue to review full report at Codecov.
|
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a somewhat high-level description of the approach and design of this feature?
I'm trying to reconcile this with the convolution and scaling rewrites we've been planning, and I want to make sure that we can do both without conflict and/or complications, or—better yet—that we can focus on only one approach/framework.
To be clear, I'm talking about RandomVariable-specific rewrites like sums of normals and location/scale properties. They're limited in that they're not as general as a change-of-variables approach, but they have the big advantage that they result in known distributions and they're rather simple to implement.
It seems like we can always add such rewrites, but that they would need to be performed before these transforms. If that's all there is to it, then I think we're good; if not, then we don't want to dig ourselves into a hole.
Also, we might want to start focusing on those simple rewrites because of their ease and impact, so don't feel rushed to push this through before considering any of those.
The new local rewrite basically checks if there is a single chain of operations that link an "orphan" value variable to an usused base y_rv = at.log( at.random.beta(1, 1) * scale)
# becomes
y_rv = ChainTransformRV(beta(1, 1), [ScaleTransform(scale), LogTransform()])
The idea would be to apply the loc-scale reparametrization whenever possible as that results in more succinct graphs. Otherwise this rewrite is more general in that it applies to RVs that cannot be parametrized as such. In addition it is not limited to loc/scale transforms. The reason why I did not limit to this scenario, is that once we have the By the way, the convolutions are not addressed at all in the local rewrite (and I don't think they should). I actually have an explicit check to make sure the graph does not correspond to a potential convolution.
Yes this should not interfere with that at all. Actually those rewrites can equally easily be applied before or after this one, as the y_rv = ChainTransformRV[Normal(0, 1), [ScaleTransform(scale), LocTransform(loc)]]
# could be rewritten later (or before) as
y_rv = Normal(0 + loc, 1 * scale)In that sense, we could think of the smarter rewrites as a specialization of |
|
Okay, so my last roadblock is this edge case: # loc = at.scalar("loc") # works
# loc = at.vector("loc") # fails
loc = at.TensorType("floatX", (True,))("loc") # works
y_rv = loc + at.random.normal(0, 1, size=1)
y = y_rv.clone()
logp = joint_logprob(y_rv, {y_rv: y})That example only works when TypeError: Cannot convert Type TensorType(float64, vector) (of Variable y) into Type TensorType(float64, (True,)). You can try to manually convert y into a TensorType(float64, (True,)).Regardless of the broadcast question in #51 it seems the (1D) After all, this is valid: loc = at.vector("loc")
base_rv = at.random.normal(loc, 1, size=1)
base_rv.eval({loc: [1]})The rewrite fails similarly with row or column sizes such as loc = at.matrix("loc")
y_rv = loc + at.random.normal(0, 1, size=(1, 3))
y = y_rv.clone()
logp = joint_logprob(y_rv, {y_rv: y})Should this be handled by the new rewrite? |
|
My last commit provides a possible solution to the issue described in #26 (comment) I am not very happy with it, but couldn't come up with a more satisfying solution yet. Another option is to simply not support the cases where the original RV graph and the derived RV yield different |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I'm correctly understanding the purpose of MeasurableRebroadcast, then it would appear that we need another canonicalization that lifts Rebroadcast Ops.
Can you provide an explicit example in which this issue arises, and show the relevant aesara.dprint output?
It appears in Note that the Rebroadcast is something I had to introduce in the rewrite phase so that I could replace the old generic vector node by what would otherwise be a new length 1 vector TransformedRV. Relevant code: Line 435 in 02e6c13
|
Broadcasting is just a subset of shape inference and—just like that—an This first thing that needs to be done in a situation like this is to determine whether or not the broadcasting inference of the |
|
Also, for comparison, I set the similar (at.vector() + at.random.normal(size=1)).type
# TensorType(float64, vector)
(at.random.normal(loc=at.vector(), size=1)).type
# TensorType(float64, (True,))It feels like the first expression should be replaceable by the second, but they have different broadcastable patterns. aeppl/tests/test_transforms.py Lines 408 to 423 in 02e6c13
|
Yes, that definitely seems like the/a problem! We should create an issue for this in Aesara. Here's a little more background on the situation: import aesara
import aesara.tensor as at
at.vector().broadcastable
# (False,)
at.random.normal(size=1).broadcastable
# (True,)
(at.vector() + at.random.normal(size=1)).broadcastable
# (False,)
(at.random.normal(loc=at.vector(), size=1)).broadcastable
# (True,)Apparently, the last example is a bit contradictory, because the mean/ |
|
Actually, the last example is inconsistent; the The second-to-last example makes sense only because the broadcastable Simply put, in the last example, @ricardoV94, we need to address whatever it is that's giving rise to that situation. |
So the question in this PR is whether we want to support the derivation for this kind of graph: loc = at.vector()
rv = at.random.normal(0, 1, size=1) + loc
vv = rv.type()In theory If we want to support this, we would need to somehow make the derived rv have a vector type (one way could be through the If we don't want to support this, then there is no further roadblock to this PR, I would just revert the Users can still make this work by providing a valid |
No, the question is "Why are we getting inconsistent As far as solutions go, we could do multiple things that don't necessarily restrict the cases covered by AePPL: e.g. rebroadcast the parameters when |
|
Alright, I think I am gonna leave the graph |
|
This one is ready for review |
|
This one is ready for review. I reverted the Rebroadcast for now. In a subsequent PR we could add a rewrite that "lifts" a Rebroadcast to the RV inputs or changes the RV size, so that we can output rebroadcasted MeasureVariables from our rewrites (or apply it directly in the rewrites) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like this needs to be rebased.
|
With #129 this PR is now much more straightforward! Ready for review |
|
Hmm, does the IR "reverter" work in this case where nodes are removed instead of replaced? |
|
I had to change strategy away from the transform machinery, because I couldn't see a way to incorporate inputs that are not part of the base measurable variable, which is needed for the |
Closes #18.
TODO:
RVand the following ops:Add,Mul,Exp,Log1and generic loc / shift tensors (see Implement chained transforms #26 (comment))Deal properly with broadcasting as discussed in Implement censored log-probabilities via thedeferred to Decide how to deal with automatic broadcast in derived RVs #51ClipOp#22 (comment)Deal with discrete variables (either do not allow or do not add jacobian term)Not allowed for now.Jacobian of ChainTransform is wonky for non-scalar variablesLoc transform was returning a scalardet