Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions pymc/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@
DensityDist,
Discrete,
Distribution,
NoDistribution,
SymbolicRandomVariable,
)
from pymc.distributions.mixture import Mixture, NormalMixture
Expand Down Expand Up @@ -159,7 +158,6 @@
"SymbolicRandomVariable",
"Continuous",
"Discrete",
"NoDistribution",
"MvNormal",
"MatrixNormal",
"KroneckerNormal",
Expand Down
277 changes: 150 additions & 127 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from functools import singledispatch
from typing import Callable, Optional, Sequence, Tuple, Union

import aesara
import numpy as np

from aeppl.abstract import MeasurableVariable, _get_measurable_outputs
Expand Down Expand Up @@ -62,7 +61,6 @@
"Distribution",
"Continuous",
"Discrete",
"NoDistribution",
"SymbolicRandomVariable",
]

Expand Down Expand Up @@ -470,13 +468,6 @@ class Continuous(Distribution):
"""Base class for continuous distributions"""


class NoDistribution(Distribution):
"""Base class for artifical distributions

RandomVariables that share this type are allowed in logprob graphs
"""


class DensityDistRV(RandomVariable):
"""
Base class for DensityDistRV
Expand All @@ -494,18 +485,130 @@ def rng_fn(cls, rng, *args):
return cls._random_fn(*args, rng=rng, size=size)


class DensityDist(NoDistribution):
class DensityDist(Distribution):
"""A distribution that can be used to wrap black-box log density functions.

Creates a Distribution and registers the supplied log density function to be used
for inference. It is also possible to supply a `random` method in order to be able
to sample from the prior or posterior predictive distributions.


Parameters
----------
name : str
dist_params : Tuple
A sequence of the distribution's parameter. These will be converted into
Aesara tensors internally. These parameters could be other ``TensorVariable``
instances created from , optionally created via ``RandomVariable`` ``Op``s.
class_name : str
Name for the RandomVariable class which will wrap the DensityDist methods.
When not specified, it will be given the name of the variable.

.. warning:: New DensityDists created with the same class_name will override the
methods dispatched onto the previous classes. If using DensityDists with
different methods across separate models, be sure to use distinct
class_names.

logp : Optional[Callable]
A callable that calculates the log density of some given observed ``value``
conditioned on certain distribution parameter values. It must have the
following signature: ``logp(value, *dist_params)``, where ``value`` is
an Aesara tensor that represents the observed value, and ``dist_params``
are the tensors that hold the values of the distribution parameters.
This function must return an Aesara tensor. If ``None``, a ``NotImplemented``
error will be raised when trying to compute the distribution's logp.
logcdf : Optional[Callable]
A callable that calculates the log cummulative probability of some given observed
``value`` conditioned on certain distribution parameter values. It must have the
following signature: ``logcdf(value, *dist_params)``, where ``value`` is
an Aesara tensor that represents the observed value, and ``dist_params``
are the tensors that hold the values of the distribution parameters.
This function must return an Aesara tensor. If ``None``, a ``NotImplemented``
error will be raised when trying to compute the distribution's logcdf.
random : Optional[Callable]
A callable that can be used to generate random draws from the distribution.
It must have the following signature: ``random(*dist_params, rng=None, size=None)``.
The distribution parameters are passed as positional arguments in the
same order as they are supplied when the ``DensityDist`` is constructed.
The keyword arguments are ``rnd``, which will provide the random variable's
associated :py:class:`~numpy.random.Generator`, and ``size``, that will represent
the desired size of the random draw. If ``None``, a ``NotImplemented``
error will be raised when trying to draw random samples from the distribution's
prior or posterior predictive.
moment : Optional[Callable]
A callable that can be used to compute the moments of the distribution.
It must have the following signature: ``moment(rv, size, *rv_inputs)``.
The distribution's :class:`~aesara.tensor.random.op.RandomVariable` is passed
as the first argument ``rv``. ``size`` is the random variable's size implied
by the ``dims``, ``size`` and parameters supplied to the distribution. Finally,
``rv_inputs`` is the sequence of the distribution parameters, in the same order
as they were supplied when the DensityDist was created. If ``None``, a default
``moment`` function will be assigned that will always return 0, or an array
of zeros.
ndim_supp : int
The number of dimensions in the support of the distribution. Defaults to assuming
a scalar distribution, i.e. ``ndim_supp = 0``.
ndims_params : Optional[Sequence[int]]
The list of number of dimensions in the support of each of the distribution's
parameters. If ``None``, it is assumed that all parameters are scalars, hence
the number of dimensions of their support will be 0.
dtype : str
The dtype of the distribution. All draws and observations passed into the distribution
will be casted onto this dtype.
kwargs :
Extra keyword arguments are passed to the parent's class ``__new__`` method.

Examples
--------
.. code-block:: python

def logp(value, mu):
return -(value - mu)**2

with pm.Model():
mu = pm.Normal('mu',0,1)
pm.DensityDist(
'density_dist',
mu,
logp=logp,
observed=np.random.randn(100),
)
idata = pm.sample(100)

.. code-block:: python

def logp(value, mu):
return -(value - mu)**2

def random(mu, rng=None, size=None):
return rng.normal(loc=mu, scale=1, size=size)

with pm.Model():
mu = pm.Normal('mu', 0 , 1)
dens = pm.DensityDist(
'density_dist',
mu,
logp=logp,
random=random,
observed=np.random.randn(100, 3),
size=(100, 3),
)
prior = pm.sample_prior_predictive(10).prior_predictive['density_dist']
assert prior.shape == (1, 10, 100, 3)

"""

def __new__(
rv_type = DensityDistRV

def __new__(cls, name, *args, **kwargs):
kwargs.setdefault("class_name", name)
return super().__new__(cls, name, *args, **kwargs)

@classmethod
def dist(
cls,
name: str,
*dist_params,
class_name: str,
logp: Optional[Callable] = None,
logcdf: Optional[Callable] = None,
random: Optional[Callable] = None,
Expand All @@ -515,102 +618,6 @@ def __new__(
dtype: str = "floatX",
**kwargs,
):
"""
Parameters
----------
name : str
dist_params : Tuple
A sequence of the distribution's parameter. These will be converted into
Aesara tensors internally. These parameters could be other ``TensorVariable``
instances created from , optionally created via ``RandomVariable`` ``Op``s.
logp : Optional[Callable]
A callable that calculates the log density of some given observed ``value``
conditioned on certain distribution parameter values. It must have the
following signature: ``logp(value, *dist_params)``, where ``value`` is
an Aesara tensor that represents the observed value, and ``dist_params``
are the tensors that hold the values of the distribution parameters.
This function must return an Aesara tensor. If ``None``, a ``NotImplemented``
error will be raised when trying to compute the distribution's logp.
logcdf : Optional[Callable]
A callable that calculates the log cummulative probability of some given observed
``value`` conditioned on certain distribution parameter values. It must have the
following signature: ``logcdf(value, *dist_params)``, where ``value`` is
an Aesara tensor that represents the observed value, and ``dist_params``
are the tensors that hold the values of the distribution parameters.
This function must return an Aesara tensor. If ``None``, a ``NotImplemented``
error will be raised when trying to compute the distribution's logcdf.
random : Optional[Callable]
A callable that can be used to generate random draws from the distribution.
It must have the following signature: ``random(*dist_params, rng=None, size=None)``.
The distribution parameters are passed as positional arguments in the
same order as they are supplied when the ``DensityDist`` is constructed.
The keyword arguments are ``rnd``, which will provide the random variable's
associated :py:class:`~numpy.random.Generator`, and ``size``, that will represent
the desired size of the random draw. If ``None``, a ``NotImplemented``
error will be raised when trying to draw random samples from the distribution's
prior or posterior predictive.
moment : Optional[Callable]
A callable that can be used to compute the moments of the distribution.
It must have the following signature: ``moment(rv, size, *rv_inputs)``.
The distribution's :class:`~aesara.tensor.random.op.RandomVariable` is passed
as the first argument ``rv``. ``size`` is the random variable's size implied
by the ``dims``, ``size`` and parameters supplied to the distribution. Finally,
``rv_inputs`` is the sequence of the distribution parameters, in the same order
as they were supplied when the DensityDist was created. If ``None``, a default
``moment`` function will be assigned that will always return 0, or an array
of zeros.
ndim_supp : int
The number of dimensions in the support of the distribution. Defaults to assuming
a scalar distribution, i.e. ``ndim_supp = 0``.
ndims_params : Optional[Sequence[int]]
The list of number of dimensions in the support of each of the distribution's
parameters. If ``None``, it is assumed that all parameters are scalars, hence
the number of dimensions of their support will be 0.
dtype : str
The dtype of the distribution. All draws and observations passed into the distribution
will be casted onto this dtype.
kwargs :
Extra keyword arguments are passed to the parent's class ``__new__`` method.

Examples
--------
.. code-block:: python

def logp(value, mu):
return -(value - mu)**2

with pm.Model():
mu = pm.Normal('mu',0,1)
pm.DensityDist(
'density_dist',
mu,
logp=logp,
observed=np.random.randn(100),
)
idata = pm.sample(100)

.. code-block:: python

def logp(value, mu):
return -(value - mu)**2

def random(mu, rng=None, size=None):
return rng.normal(loc=mu, scale=1, size=size)

with pm.Model():
mu = pm.Normal('mu', 0 , 1)
dens = pm.DensityDist(
'density_dist',
mu,
logp=logp,
random=random,
observed=np.random.randn(100, 3),
size=(100, 3),
)
prior = pm.sample_prior_predictive(10).prior_predictive['density_dist']
assert prior.shape == (1, 10, 100, 3)

"""

if dist_params is None:
dist_params = []
Expand All @@ -622,34 +629,61 @@ def random(mu, rng=None, size=None):
"to the API documentation for more information on how to use the "
"new DensityDist API."
)
dist_params = [as_tensor_variable(param) for param in dist_params]
dist_params = [as_tensor_variable(param) for param in dist_params]

# Assume scalar ndims_params
if ndims_params is None:
ndims_params = [0] * len(dist_params)

if logp is None:
logp = default_not_implemented(name, "logp")
logp = default_not_implemented(class_name, "logp")

if logcdf is None:
logcdf = default_not_implemented(name, "logcdf")
logcdf = default_not_implemented(class_name, "logcdf")

if moment is None:
moment = functools.partial(
default_moment,
rv_name=name,
rv_name=class_name,
has_fallback=random is not None,
ndim_supp=ndim_supp,
)

if random is None:
random = default_not_implemented(name, "random")
random = default_not_implemented(class_name, "random")

return super().dist(
dist_params,
class_name=class_name,
logp=logp,
logcdf=logcdf,
random=random,
moment=moment,
ndim_supp=ndim_supp,
ndims_params=ndims_params,
dtype=dtype,
**kwargs,
)

@classmethod
def rv_op(
cls,
*dist_params,
class_name: str,
logp: Optional[Callable],
logcdf: Optional[Callable],
random: Optional[Callable],
moment: Optional[Callable],
ndim_supp: int,
ndims_params: Optional[Sequence[int]],
dtype: str,
**kwargs,
):
rv_op = type(
f"DensityDist_{name}",
f"DensityDist_{class_name}",
(DensityDistRV,),
dict(
name=f"DensityDist_{name}",
name=f"DensityDist_{class_name}",
inplace=False,
ndim_supp=ndim_supp,
ndims_params=ndims_params,
Expand Down Expand Up @@ -677,18 +711,7 @@ def density_dist_logcdf(op, var, rvs_to_values, *dist_params, **kwargs):
def density_dist_get_moment(op, rv, rng, size, dtype, *dist_params):
return moment(rv, size, *dist_params)

cls.rv_op = rv_op
return super().__new__(cls, name, *dist_params, **kwargs)

@classmethod
def dist(cls, *args, **kwargs):
output = super().dist(args, **kwargs)
if cls.rv_op.dtype == "floatX":
dtype = aesara.config.floatX
else:
dtype = cls.rv_op.dtype
ndim_supp = cls.rv_op.ndim_supp
return output
return rv_op(*dist_params, **kwargs)


def default_not_implemented(rv_name, method_name):
Expand Down
Loading