-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Closed
Description
Description
This function is not very helpful and only used in tests. We should be able to remove it, and use either logp
or factorized_joint_logprob
in the tests.
Lines 50 to 74 in 83cd926
def joint_logprob(*args, sum: bool = True, **kwargs) -> Optional[TensorVariable]: | |
"""Create a graph representing the joint log-probability/measure of a graph. | |
This function calls `factorized_joint_logprob` and returns the combined | |
log-probability factors as a single graph. | |
Parameters | |
---------- | |
sum: bool | |
If ``True`` each factor is collapsed to a scalar via ``sum`` before | |
being joined with the remaining factors. This may be necessary to | |
avoid incorrect broadcasting among independent factors. | |
""" | |
logprob = factorized_joint_logprob(*args, **kwargs) | |
if not logprob: | |
return None | |
if len(logprob) == 1: | |
logprob = tuple(logprob.values())[0] | |
if sum: | |
return pt.sum(logprob) | |
return logprob | |
if sum: | |
return pt.sum([pt.sum(factor) for factor in logprob.values()]) | |
return pt.add(*logprob.values()) |