Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 48 additions & 34 deletions pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2394,7 +2394,7 @@ class ZeroSumNormal(Distribution):
ZeroSumNormal distribution, i.e Normal distribution where one or
several axes are constrained to sum to zero.
By default, the last axis is constrained to sum to zero.
See `zerosum_axes` kwarg for more details.
See `n_zerosum_axes` kwarg for more details.

.. math::

Expand All @@ -2411,9 +2411,10 @@ class ZeroSumNormal(Distribution):
It's actually the standard deviation of the underlying, unconstrained Normal distribution.
Defaults to 1 if not specified.
For now, ``sigma`` has to be a scalar, to ensure the zero-sum constraint.
zerosum_axes: int, defaults to 1
n_zerosum_axes: int, defaults to 1
Number of axes along which the zero-sum constraint is enforced, starting from the rightmost position.
Defaults to 1, i.e the rightmost axis.
zerosum_axes: int, deprecated please use n_zerosum_axes as its successor
dims: sequence of strings, optional
Dimension names of the distribution. Works the same as for other PyMC distributions.
Necessary if ``shape`` is not passed.
Expand Down Expand Up @@ -2452,25 +2453,38 @@ class ZeroSumNormal(Distribution):
"""
rv_type = ZeroSumNormalRV

def __new__(cls, *args, zerosum_axes=None, support_shape=None, dims=None, **kwargs):
def __new__(
cls, *args, zerosum_axes=None, n_zerosum_axes=None, support_shape=None, dims=None, **kwargs
):
if zerosum_axes is not None:
n_nezosum_axes = zerosum_axes
warnings.warn(
"The 'zerosum_axes' parameter is deprecated. Use 'n_zerosum_axes' instead.",
DeprecationWarning,
)
if dims is not None or kwargs.get("observed") is not None:
zerosum_axes = cls.check_zerosum_axes(zerosum_axes)
n_zerosum_axes = cls.check_zerosum_axes(n_zerosum_axes)

support_shape = get_support_shape(
support_shape=support_shape,
shape=None, # Shape will be checked in `cls.dist`
dims=dims,
observed=kwargs.get("observed", None),
ndim_supp=zerosum_axes,
ndim_supp=n_zerosum_axes,
)

return super().__new__(
cls, *args, zerosum_axes=zerosum_axes, support_shape=support_shape, dims=dims, **kwargs
cls,
*args,
n_zerosum_axes=n_zerosum_axes,
support_shape=support_shape,
dims=dims,
**kwargs,
)

@classmethod
def dist(cls, sigma=1, zerosum_axes=None, support_shape=None, **kwargs):
zerosum_axes = cls.check_zerosum_axes(zerosum_axes)
def dist(cls, sigma=1, n_zerosum_axes=None, support_shape=None, **kwargs):
n_zerosum_axes = cls.check_zerosum_axes(n_zerosum_axes)

sigma = at.as_tensor_variable(floatX(sigma))
if sigma.ndim > 0:
Expand All @@ -2479,41 +2493,41 @@ def dist(cls, sigma=1, zerosum_axes=None, support_shape=None, **kwargs):
support_shape = get_support_shape(
support_shape=support_shape,
shape=kwargs.get("shape"),
ndim_supp=zerosum_axes,
ndim_supp=n_zerosum_axes,
)

if support_shape is None:
if zerosum_axes > 0:
if n_zerosum_axes > 0:
raise ValueError("You must specify dims, shape or support_shape parameter")
# TODO: edge-case doesn't work for now, because at.stack in get_support_shape fails
# else:
# support_shape = () # because it's just a Normal in that case
support_shape = at.as_tensor_variable(intX(support_shape))

assert zerosum_axes == at.get_vector_length(
assert n_zerosum_axes == at.get_vector_length(
support_shape
), "support_shape has to be as long as zerosum_axes"
), "support_shape has to be as long as n_zerosum_axes"

return super().dist(
[sigma], zerosum_axes=zerosum_axes, support_shape=support_shape, **kwargs
[sigma], n_zerosum_axes=n_zerosum_axes, support_shape=support_shape, **kwargs
)

@classmethod
def check_zerosum_axes(cls, zerosum_axes: Optional[int]) -> int:
if zerosum_axes is None:
zerosum_axes = 1
if not isinstance(zerosum_axes, int):
raise TypeError("zerosum_axes has to be an integer")
if not zerosum_axes > 0:
raise ValueError("zerosum_axes has to be > 0")
return zerosum_axes
def check_zerosum_axes(cls, n_zerosum_axes: Optional[int]) -> int:
if n_zerosum_axes is None:
n_zerosum_axes = 1
if not isinstance(n_zerosum_axes, int):
raise TypeError("n_zerosum_axes has to be an integer")
if not n_zerosum_axes > 0:
raise ValueError("n_zerosum_axes has to be > 0")
return n_zerosum_axes

@classmethod
def rv_op(cls, sigma, zerosum_axes, support_shape, size=None):
def rv_op(cls, sigma, n_zerosum_axes, support_shape, size=None):
shape = to_tuple(size) + tuple(support_shape)
normal_dist = ignore_logprob(pm.Normal.dist(sigma=sigma, shape=shape))

if zerosum_axes > normal_dist.ndim:
if n_zerosum_axes > normal_dist.ndim:
raise ValueError("Shape of distribution is too small for the number of zerosum axes")

normal_dist_, sigma_, support_shape_ = (
Expand All @@ -2522,15 +2536,15 @@ def rv_op(cls, sigma, zerosum_axes, support_shape, size=None):
support_shape.type(),
)

# Zerosum-normaling is achieved by subtracting the mean along the given zerosum_axes
# Zerosum-normaling is achieved by subtracting the mean along the given n_zerosum_axes
zerosum_rv_ = normal_dist_
for axis in range(zerosum_axes):
for axis in range(n_zerosum_axes):
zerosum_rv_ -= zerosum_rv_.mean(axis=-axis - 1, keepdims=True)

return ZeroSumNormalRV(
inputs=[normal_dist_, sigma_, support_shape_],
outputs=[zerosum_rv_, support_shape_],
ndim_supp=zerosum_axes,
ndim_supp=n_zerosum_axes,
)(normal_dist, sigma, support_shape)


Expand All @@ -2544,7 +2558,7 @@ def change_zerosum_size(op, normal_dist, new_size, expand=False):
new_size = tuple(new_size) + old_size

return ZeroSumNormal.rv_op(
sigma=sigma, zerosum_axes=op.ndim_supp, support_shape=support_shape, size=new_size
sigma=sigma, n_zerosum_axes=op.ndim_supp, support_shape=support_shape, size=new_size
)


Expand All @@ -2555,28 +2569,28 @@ def zerosumnormal_moment(op, rv, *rv_inputs):

@_default_transform.register(ZeroSumNormalRV)
def zerosum_default_transform(op, rv):
zerosum_axes = tuple(np.arange(-op.ndim_supp, 0))
return ZeroSumTransform(zerosum_axes)
n_zerosum_axes = tuple(np.arange(-op.ndim_supp, 0))
return ZeroSumTransform(n_zerosum_axes)


@_logprob.register(ZeroSumNormalRV)
def zerosumnormal_logp(op, values, normal_dist, sigma, support_shape, **kwargs):
(value,) = values
shape = value.shape
zerosum_axes = op.ndim_supp
n_zerosum_axes = op.ndim_supp

_deg_free_support_shape = at.inc_subtensor(shape[-zerosum_axes:], -1)
_deg_free_support_shape = at.inc_subtensor(shape[-n_zerosum_axes:], -1)
_full_size = at.prod(shape)
_degrees_of_freedom = at.prod(_deg_free_support_shape)

zerosums = [
at.all(at.isclose(at.mean(value, axis=-axis - 1), 0, atol=1e-9))
for axis in range(zerosum_axes)
for axis in range(n_zerosum_axes)
]

out = at.sum(
pm.logp(normal_dist, value) * _degrees_of_freedom / _full_size,
axis=tuple(np.arange(-zerosum_axes, 0)),
axis=tuple(np.arange(-n_zerosum_axes, 0)),
)

return check_parameters(out, *zerosums, msg="mean(value, axis=zerosum_axes) = 0")
return check_parameters(out, *zerosums, msg="mean(value, axis=n_zerosum_axes) = 0")