Skip to content

Random variables present in logp graph from Join dispatch #149

@larryshamalama

Description

@larryshamalama

With the logprob dispatch for the CumOp, there is ongoing work to redefine the Gaussian Random Walk distribution in PyMC as a cumulative sum of distributions and have AePPL automatically retrieve its logp graph (see PR 5814). A restriction in PyMC is that random variables must not be present in the logp graph as per lines here. However, the Join Op looks into tensor shapes which, as a consequence, does not replace the random variables in the logp graph with their value variable counterpart.

One way to address this issue would be to introduce a constant folding before passing variable shapes as an argument into splits_size:

aeppl/aeppl/tensor.py

Lines 112 to 117 in cc78f30

split_values = at.split(
value,
splits_size=[base_var.shape[axis] for base_var in base_vars],
n_splits=len(base_vars),
axis=axis,
)

Any thoughts about this approach?

CC @ricardoV94 @brandonwillard

Metadata

Metadata

Assignees

No one assigned

    Labels

    graph rewritingInvolves the implementation of rewrites to Aesara graphsquestionFurther information is requested

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions