Skip to content

AssertionError in aeppl.logp breaks sampling in pymc #84

@ferrine

Description

@ferrine

Description of your problem or feature request

Please provide a minimal, self-contained, and reproducible example.

res = Assert("sigma > 0")(res, at.all(at.gt(sigma, 0.0)))

Please provide the full traceback of any errors.

AssertionError: sigma > 0
Apply node that caused the error: Assert{msg='sigma > 0'}(Elemwise{Composite{((i0 + (i1 * sqr(i2))) - log(i3))}}[(0, 2)].0, All.0)

Please provide any additional information below.
Sampling with PyMC (master) fails with assertion errors and does not treat AssertionErorr as divergent sample

Expected Behaviour

return -inf

Possible Solution

  • provide a graph rewrite for fixing assertions
  • replace Assert with switch

This snippet solved my issue

aesara.assert_op.Assert = lambda name: (lambda res, *cond: aesara.tensor.switch(
    aesara.tensor.all(aesara.tensor.stack([c.all() for c in cond])), 
    res, 
    -np.inf
))

Versions and main components

  • Aesara version:
  • Aesara config (python -c "import aesara; print(aesara.config)")
  • Python version:
  • Operating system:
  • How did you install Aesara: (conda/pip)

Metadata

Metadata

Assignees

No one assigned

    Labels

    questionFurther information is requested

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions