From ecba5bb262baa079d13ac4e474d15730cee046f7 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 14 Feb 2023 16:35:48 +0100 Subject: [PATCH] Replace rvs_to_total_sizes mapping by ManibatchRandomVariables --- .github/workflows/tests.yml | 1 + pymc/logprob/joint_logprob.py | 87 +------------- pymc/model.py | 31 +++-- pymc/util.py | 1 - pymc/variational/minibatch_rv.py | 113 ++++++++++++++++++ pymc/variational/opvi.py | 13 +- tests/distributions/test_transform.py | 4 - tests/distributions/util.py | 1 - tests/logprob/test_joint_logprob.py | 54 +-------- tests/logprob/test_utils.py | 3 - tests/logprob/utils.py | 2 - tests/test_data.py | 140 +--------------------- tests/test_model.py | 34 +++--- tests/variational/test_minibatch_rv.py | 157 +++++++++++++++++++++++++ tests/variational/test_opvi.py | 2 +- 15 files changed, 320 insertions(+), 323 deletions(-) create mode 100644 pymc/variational/minibatch_rv.py create mode 100644 tests/variational/test_minibatch_rv.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 10ae3c7866..332a7eacd6 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -240,6 +240,7 @@ jobs: - | tests/sampling/test_parallel.py tests/test_data.py + tests/variational/test_minibatch_rv.py tests/test_model.py - | diff --git a/pymc/logprob/joint_logprob.py b/pymc/logprob/joint_logprob.py index f8481de7b6..d5548aa930 100644 --- a/pymc/logprob/joint_logprob.py +++ b/pymc/logprob/joint_logprob.py @@ -39,7 +39,6 @@ from collections import deque from typing import Dict, List, Optional, Sequence, Union -import numpy as np import pytensor import pytensor.tensor as pt @@ -55,7 +54,6 @@ from pymc.logprob.rewriting import construct_ir_fgraph from pymc.logprob.transforms import RVTransform, TransformValuesRewrite from pymc.logprob.utils import rvs_to_value_vars -from pymc.pytensorf import floatX def logp(rv: TensorVariable, value) -> TensorVariable: @@ -248,77 +246,6 @@ def factorized_joint_logprob( return logprob_vars -TOTAL_SIZE = Union[int, Sequence[int], None] - - -def _get_scaling(total_size: TOTAL_SIZE, shape, ndim: int) -> TensorVariable: - """ - Gets scaling constant for logp. - - Parameters - ---------- - total_size: Optional[int|List[int]] - size of a fully observed data without minibatching, - `None` means data is fully observed - shape: shape - shape of an observed data - ndim: int - ndim hint - - Returns - ------- - scalar - """ - if total_size is None: - coef = 1.0 - elif isinstance(total_size, int): - if ndim >= 1: - denom = shape[0] - else: - denom = 1 - coef = floatX(total_size) / floatX(denom) - elif isinstance(total_size, (list, tuple)): - if not all(isinstance(i, int) for i in total_size if (i is not Ellipsis and i is not None)): - raise TypeError( - "Unrecognized `total_size` type, expected " - "int or list of ints, got %r" % total_size - ) - if Ellipsis in total_size: - sep = total_size.index(Ellipsis) - begin = total_size[:sep] - end = total_size[sep + 1 :] - if Ellipsis in end: - raise ValueError( - "Double Ellipsis in `total_size` is restricted, got %r" % total_size - ) - else: - begin = total_size - end = [] - if (len(begin) + len(end)) > ndim: - raise ValueError( - "Length of `total_size` is too big, " - "number of scalings is bigger that ndim, got %r" % total_size - ) - elif (len(begin) + len(end)) == 0: - coef = 1.0 - if len(end) > 0: - shp_end = shape[-len(end) :] - else: - shp_end = np.asarray([]) - shp_begin = shape[: len(begin)] - begin_coef = [ - floatX(t) / floatX(shp_begin[i]) for i, t in enumerate(begin) if t is not None - ] - end_coef = [floatX(t) / floatX(shp_end[i]) for i, t in enumerate(end) if t is not None] - coefs = begin_coef + end_coef - coef = pt.prod(coefs) - else: - raise TypeError( - "Unrecognized `total_size` type, expected int or list of ints, got %r" % total_size - ) - return pt.as_tensor(coef, dtype=pytensor.config.floatX) - - def _check_no_rvs(logp_terms: Sequence[TensorVariable]): # Raise if there are unexpected RandomVariables in the logp graph # Only SimulatorRVs MinibatchIndexRVs are allowed @@ -348,7 +275,6 @@ def joint_logp( rvs_to_values: Dict[TensorVariable, TensorVariable], rvs_to_transforms: Dict[TensorVariable, RVTransform], jacobian: bool = True, - rvs_to_total_sizes: Dict[TensorVariable, TOTAL_SIZE], **kwargs, ) -> List[TensorVariable]: """Thin wrapper around pymc.logprob.factorized_joint_logprob, extended with Model @@ -371,18 +297,13 @@ def joint_logp( **kwargs, ) - # The function returns the logp for every single value term we provided to it. This - # includes the extra values we plugged in above, so we filter those we actually - # wanted in the same order they were given in. + # The function returns the logp for every single value term we provided to it. + # This includes the extra values we plugged in above, so we filter those we + # actually wanted in the same order they were given in. logp_terms = {} for rv in rvs: value_var = rvs_to_values[rv] - logp_term = temp_logp_terms[value_var] - total_size = rvs_to_total_sizes.get(rv, None) - if total_size is not None: - scaling = _get_scaling(total_size, value_var.shape, value_var.ndim) - logp_term *= scaling - logp_terms[value_var] = logp_term + logp_terms[value_var] = temp_logp_terms[value_var] _check_no_rvs(list(logp_terms.values())) return list(logp_terms.values()) diff --git a/pymc/model.py b/pymc/model.py index f10c3bf22d..3f18609871 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -564,7 +564,6 @@ def __init__( self.values_to_rvs = treedict(parent=self.parent.values_to_rvs) self.rvs_to_values = treedict(parent=self.parent.rvs_to_values) self.rvs_to_transforms = treedict(parent=self.parent.rvs_to_transforms) - self.rvs_to_total_sizes = treedict(parent=self.parent.rvs_to_total_sizes) self.rvs_to_initial_values = treedict(parent=self.parent.rvs_to_initial_values) self.free_RVs = treelist(parent=self.parent.free_RVs) self.observed_RVs = treelist(parent=self.parent.observed_RVs) @@ -578,7 +577,6 @@ def __init__( self.values_to_rvs = treedict() self.rvs_to_values = treedict() self.rvs_to_transforms = treedict() - self.rvs_to_total_sizes = treedict() self.rvs_to_initial_values = treedict() self.free_RVs = treelist() self.observed_RVs = treelist() @@ -762,7 +760,6 @@ def logp( rvs=rvs, rvs_to_values=self.rvs_to_values, rvs_to_transforms=self.rvs_to_transforms, - rvs_to_total_sizes=self.rvs_to_total_sizes, jacobian=jacobian, ) assert isinstance(rv_logps, list) @@ -1314,8 +1311,6 @@ def register_rv( name = self.name_for(name) rv_var.name = name _add_future_warning_tag(rv_var) - rv_var.tag.total_size = total_size - self.rvs_to_total_sizes[rv_var] = total_size # Associate previously unknown dimension names with # the length of the corresponding RV dimension. @@ -1327,6 +1322,8 @@ def register_rv( self.add_coord(dname, values=None, length=rv_var.shape[d]) if observed is None: + if total_size is not None: + raise ValueError("total_size can only be passed to observed RVs") self.free_RVs.append(rv_var) self.create_value_var(rv_var, transform) self.add_named_variable(rv_var, dims) @@ -1351,12 +1348,17 @@ def register_rv( # `rv_var` is potentially changed by `make_obs_var`, # for example into a new graph for imputation of missing data. - rv_var = self.make_obs_var(rv_var, observed, dims, transform) + rv_var = self.make_obs_var(rv_var, observed, dims, transform, total_size) return rv_var def make_obs_var( - self, rv_var: TensorVariable, data: np.ndarray, dims, transform: Optional[Any] + self, + rv_var: TensorVariable, + data: np.ndarray, + dims, + transform: Union[Any, None], + total_size: Union[int, None], ) -> TensorVariable: """Create a `TensorVariable` for an observed random variable. @@ -1392,11 +1394,6 @@ def make_obs_var( mask = getattr(data, "mask", None) if mask is not None: - if mask.all(): - # If there are no observed values, this variable isn't really - # observed. - return rv_var - impute_message = ( f"Data in {rv_var} contains missing values and" " will be automatically imputed from the" @@ -1404,6 +1401,9 @@ def make_obs_var( ) warnings.warn(impute_message, ImputationWarning) + if total_size is not None: + raise ValueError("total_size is not compatible with imputed variables") + if not isinstance(rv_var.owner.op, RandomVariable): raise NotImplementedError( "Automatic inputation is only supported for univariate RandomVariables." @@ -1471,6 +1471,13 @@ def make_obs_var( data = sparse.basic.as_sparse(data, name=name) else: data = at.as_tensor_variable(data, name=name) + + if total_size: + from pymc.variational.minibatch_rv import create_minibatch_rv + + rv_var = create_minibatch_rv(rv_var, total_size) + rv_var.name = name + rv_var.tag.observations = data self.create_value_var(rv_var, transform=None, value_var=data) self.add_named_variable(rv_var, dims) diff --git a/pymc/util.py b/pymc/util.py index 95daad664f..31d1e49899 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -489,7 +489,6 @@ def __getattribute__(self, name): for deprecated_names, alternative in ( (("value_var", "observations"), "model.rvs_to_values[rv]"), (("transform",), "model.rvs_to_transforms[rv]"), - (("total_size",), "model.rvs_to_total_sizes[rv]"), ): if name in deprecated_names: try: diff --git a/pymc/variational/minibatch_rv.py b/pymc/variational/minibatch_rv.py new file mode 100644 index 0000000000..c5d2a85aca --- /dev/null +++ b/pymc/variational/minibatch_rv.py @@ -0,0 +1,113 @@ +# Copyright 2023 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Sequence, Union, cast + +import pytensor.tensor as pt + +from pytensor import Variable, config +from pytensor.graph import Apply, Op +from pytensor.tensor import NoneConst, TensorVariable, as_tensor_variable + +from pymc.logprob.abstract import MeasurableVariable, _get_measurable_outputs, _logprob +from pymc.logprob.abstract import logprob as logprob_logprob +from pymc.logprob.utils import ignore_logprob + + +class MinibatchRandomVariable(Op): + """RV whose logprob should be rescaled to match total_size""" + + __props__ = () + view_map = {0: [0]} + + def make_node(self, rv, *total_size): + rv = as_tensor_variable(rv) + total_size = [ + as_tensor_variable(t, dtype="int64", ndim=0) if t is not None else NoneConst + for t in total_size + ] + assert len(total_size) == rv.ndim + out = rv.type() + return Apply(self, [rv, *total_size], [out]) + + def perform(self, node, inputs, output_storage): + output_storage[0][0] = inputs[0] + + +minibatch_rv = MinibatchRandomVariable() + + +EllipsisType = Any # EllipsisType is not present in Python 3.8 yet + + +def create_minibatch_rv( + rv: TensorVariable, + total_size: Union[int, None, Sequence[Union[int, EllipsisType, None]]], +) -> TensorVariable: + """Create variable whose logp is rescaled by total_size.""" + if isinstance(total_size, int): + if rv.ndim <= 1: + total_size = [total_size] + else: + missing_ndims = rv.ndim - 1 + total_size = [total_size] + [None] * missing_ndims + elif isinstance(total_size, (list, tuple)): + total_size = list(total_size) + if Ellipsis in total_size: + # Replace Ellipsis by None + if total_size.count(Ellipsis) > 1: + raise ValueError("Only one Ellipsis can be present in total_size") + sep = total_size.index(Ellipsis) + begin = total_size[:sep] + end = total_size[sep + 1 :] + missing_ndims = max((rv.ndim - len(begin) - len(end), 0)) + total_size = begin + [None] * missing_ndims + end + if len(total_size) > rv.ndim: + raise ValueError(f"Length of total_size {total_size} is langer than RV ndim {rv.ndim}") + else: + raise TypeError(f"Invalid type for total_size: {total_size}") + + rv = ignore_logprob(rv) + + return cast(TensorVariable, minibatch_rv(rv, *total_size)) + + +def get_scaling(total_size: Sequence[Variable], shape: TensorVariable) -> TensorVariable: + """Gets scaling constant for logp.""" + + # mypy doesn't understand we can convert a shape TensorVariable into a tuple + shape = tuple(shape) # type: ignore + + # Scalar RV + if len(shape) == 0: # type: ignore + coef = total_size[0] if not NoneConst.equals(total_size[0]) else 1.0 + else: + coefs = [t / shape[i] for i, t in enumerate(total_size) if not NoneConst.equals(t)] + coef = pt.prod(coefs) + + return pt.cast(coef, dtype=config.floatX) + + +MeasurableVariable.register(MinibatchRandomVariable) + + +@_get_measurable_outputs.register(MinibatchRandomVariable) +def _get_measurable_outputs_minibatch_random_variable(op, node): + return [node.outputs[0]] + + +@_logprob.register(MinibatchRandomVariable) +def minibatch_rv_logprob(op, values, *inputs, **kwargs): + [value] = values + rv, *total_size = inputs + return logprob_logprob(rv, value, **kwargs) * get_scaling(total_size, value.shape) diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index a2e371c1ab..7bc4325f5e 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -66,7 +66,6 @@ from pymc.backends.ndarray import NDArray from pymc.blocking import DictToArrayBijection from pymc.initial_point import make_initial_point_fn -from pymc.logprob.joint_logprob import _get_scaling from pymc.model import modelcontext from pymc.pytensorf import ( SeedSequenceSeed, @@ -82,6 +81,7 @@ _get_seeds_per_chain, locally_cachedmethod, ) +from pymc.variational.minibatch_rv import MinibatchRandomVariable, get_scaling from pymc.variational.updates import adagrad_window from pymc.vartypes import discrete_types @@ -1069,9 +1069,11 @@ def symbolic_normalizing_constant(self): t = self.to_flat_input( at.max( [ - _get_scaling(self.model.rvs_to_total_sizes.get(v, None), v.shape, v.ndim) + get_scaling(v.owner.inputs[1:], v.shape) for v in self.group + if isinstance(v.owner.op, MinibatchRandomVariable) ] + + [1.0] # To avoid empty max ) ) t = self.symbolic_single_sample(t) @@ -1237,12 +1239,9 @@ def symbolic_normalizing_constant(self): t = at.max( self.collect("symbolic_normalizing_constant") + [ - _get_scaling( - self.model.rvs_to_total_sizes.get(obs, None), - obs.shape, - obs.ndim, - ) + get_scaling(obs.owner.inputs[1:], obs.shape) for obs in self.model.observed_RVs + if isinstance(obs.owner.op, MinibatchRandomVariable) ] ) t = at.switch(self._scale_cost_to_minibatch, t, at.constant(1, dtype=t.dtype)) diff --git a/tests/distributions/test_transform.py b/tests/distributions/test_transform.py index 6b6d71fe6d..9d47a59e10 100644 --- a/tests/distributions/test_transform.py +++ b/tests/distributions/test_transform.py @@ -312,7 +312,6 @@ def check_transform_elementwise_logp(self, model): (x,), rvs_to_values={x: x_val_transf}, rvs_to_transforms={x: transform}, - rvs_to_total_sizes={}, jacobian=False, )[0] .sum() @@ -323,7 +322,6 @@ def check_transform_elementwise_logp(self, model): (x,), rvs_to_values={x: x_val_untransf}, rvs_to_transforms={}, - rvs_to_total_sizes={}, )[0] .sum() .eval({x_val_untransf: test_array_untransf}) @@ -362,7 +360,6 @@ def check_vectortransform_elementwise_logp(self, model): (x,), rvs_to_values={x: x_val_transf}, rvs_to_transforms={x: transform}, - rvs_to_total_sizes={}, jacobian=False, )[0] .sum() @@ -373,7 +370,6 @@ def check_vectortransform_elementwise_logp(self, model): (x,), rvs_to_values={x: x_val_untransf}, rvs_to_transforms={}, - rvs_to_total_sizes={}, )[0] .sum() .eval({x_val_untransf: test_array_untransf}) diff --git a/tests/distributions/util.py b/tests/distributions/util.py index 5586b4572a..9563d1e4f5 100644 --- a/tests/distributions/util.py +++ b/tests/distributions/util.py @@ -600,7 +600,6 @@ def assert_moment_is_expected(model, expected, check_finite_logp=True): (model["x"],), rvs_to_values={model["x"]: at.constant(moment)}, rvs_to_transforms={}, - rvs_to_total_sizes={}, )[0] .sum() .eval() diff --git a/tests/logprob/test_joint_logprob.py b/tests/logprob/test_joint_logprob.py index fb2fc5aeab..f0336821b7 100644 --- a/tests/logprob/test_joint_logprob.py +++ b/tests/logprob/test_joint_logprob.py @@ -56,11 +56,7 @@ import pymc as pm from pymc.logprob.abstract import logprob -from pymc.logprob.joint_logprob import ( - _get_scaling, - factorized_joint_logprob, - joint_logp, -) +from pymc.logprob.joint_logprob import factorized_joint_logprob, joint_logp from pymc.logprob.utils import rvs_to_value_vars, walk_model from tests.helpers import assert_no_rvs from tests.logprob.utils import joint_logprob @@ -281,52 +277,6 @@ def test_multiple_rvs_to_same_value_raises(): joint_logprob({x_rv1: x, x_rv2: x}) -def test_get_scaling(): - assert _get_scaling(None, (2, 3), 2).eval() == 1 - # ndim >=1 & ndim<1 - assert _get_scaling(45, (2, 3), 1).eval() == 22.5 - assert _get_scaling(45, (2, 3), 0).eval() == 45 - - # list or tuple tests - # total_size contains other than Ellipsis, None and Int - with pytest.raises(TypeError, match="Unrecognized `total_size` type"): - _get_scaling([2, 4, 5, 9, 11.5], (2, 3), 2) - # check with Ellipsis - with pytest.raises(ValueError, match="Double Ellipsis in `total_size` is restricted"): - _get_scaling([1, 2, 5, Ellipsis, Ellipsis], (2, 3), 2) - with pytest.raises( - ValueError, - match="Length of `total_size` is too big, number of scalings is bigger that ndim", - ): - _get_scaling([1, 2, 5, Ellipsis], (2, 3), 2) - - assert _get_scaling([Ellipsis], (2, 3), 2).eval() == 1 - - assert _get_scaling([4, 5, 9, Ellipsis, 32, 12], (2, 3, 2), 5).eval() == 960 - assert _get_scaling([4, 5, 9, Ellipsis], (2, 3, 2), 5).eval() == 15 - # total_size with no Ellipsis (end = [ ]) - with pytest.raises( - ValueError, - match="Length of `total_size` is too big, number of scalings is bigger that ndim", - ): - _get_scaling([1, 2, 5], (2, 3), 2) - - assert _get_scaling([], (2, 3), 2).eval() == 1 - assert _get_scaling((), (2, 3), 2).eval() == 1 - # total_size invalid type - with pytest.raises( - TypeError, - match="Unrecognized `total_size` type, expected int or list of ints, got {1, 2, 5}", - ): - _get_scaling({1, 2, 5}, (2, 3), 2) - - # test with rvar from model graph - with pm.Model() as m2: - rv_var = pm.Uniform("a", 0.0, 1.0) - total_size = [] - assert _get_scaling(total_size, shape=rv_var.shape, ndim=rv_var.ndim).eval() == 1.0 - - def test_joint_logp_basic(): """Make sure we can compute a log-likelihood for a hierarchical model with transforms.""" @@ -348,7 +298,6 @@ def test_joint_logp_basic(): (b,), rvs_to_values=m.rvs_to_values, rvs_to_transforms=m.rvs_to_transforms, - rvs_to_total_sizes={}, ) # There shouldn't be any `RandomVariable`s in the resulting graph @@ -394,7 +343,6 @@ def test_joint_logp_incsubtensor(indices, size): (a_idx,), rvs_to_values={a_idx: a_value_var}, rvs_to_transforms={}, - rvs_to_total_sizes={}, ) logp_vals = a_idx_logp[0].eval({a_value_var: a_val}) diff --git a/tests/logprob/test_utils.py b/tests/logprob/test_utils.py index 7bb4780d9a..8fe398195c 100644 --- a/tests/logprob/test_utils.py +++ b/tests/logprob/test_utils.py @@ -233,7 +233,6 @@ def custom_logp(value, x): [x], rvs_to_values={x: value}, rvs_to_transforms={}, - rvs_to_total_sizes={}, ) with pm.Model(): @@ -248,7 +247,6 @@ def custom_logp(value, x): [y], rvs_to_values={y: y.type()}, rvs_to_transforms={}, - rvs_to_total_sizes={}, ) # The above warning should go away with ignore_logprob. @@ -261,5 +259,4 @@ def custom_logp(value, x): [y], rvs_to_values={y: y.type()}, rvs_to_transforms={}, - rvs_to_total_sizes={}, ) diff --git a/tests/logprob/utils.py b/tests/logprob/utils.py index 54c238340f..644d2a83db 100644 --- a/tests/logprob/utils.py +++ b/tests/logprob/utils.py @@ -246,7 +246,6 @@ def logp(value, x): [y], rvs_to_values={y: y.type()}, rvs_to_transforms={}, - rvs_to_total_sizes={}, ) # The above warning should go away with ignore_logprob. @@ -259,5 +258,4 @@ def logp(value, x): [y], rvs_to_values={y: y.type()}, rvs_to_transforms={}, - rvs_to_total_sizes={}, ) diff --git a/tests/test_data.py b/tests/test_data.py index 7d5759ffd3..09d175de4b 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -20,7 +20,6 @@ import pytensor import pytensor.tensor as at import pytest -import scipy.stats as st from pytensor import shared from pytensor.tensor.var import TensorVariable @@ -29,7 +28,7 @@ from pymc.data import is_minibatch from pymc.pytensorf import GeneratorOp, floatX -from tests.helpers import SeededTest, select_by_precision +from tests.helpers import SeededTest class TestData(SeededTest): @@ -588,143 +587,6 @@ def gen2(): i += 1 -class TestScaling: - """ - Related to minibatch training - """ - - def test_density_scaling(self): - with pm.Model() as model1: - pm.Normal("n", observed=[[1]], total_size=1) - p1 = pytensor.function([], model1.logp()) - - with pm.Model() as model2: - pm.Normal("n", observed=[[1]], total_size=2) - p2 = pytensor.function([], model2.logp()) - assert p1() * 2 == p2() - - def test_density_scaling_with_generator(self): - # We have different size generators - - def true_dens(): - g = gen1() - for i, point in enumerate(g): - yield st.norm.logpdf(point).sum() * 10 - - t = true_dens() - # We have same size models - with pm.Model() as model1: - pm.Normal("n", observed=gen1(), total_size=100) - p1 = pytensor.function([], model1.logp()) - - with pm.Model() as model2: - gen_var = pm.generator(gen2()) - pm.Normal("n", observed=gen_var, total_size=100) - p2 = pytensor.function([], model2.logp()) - - for i in range(10): - _1, _2, _t = p1(), p2(), next(t) - decimals = select_by_precision(float64=7, float32=1) - np.testing.assert_almost_equal(_1, _t, decimal=decimals) # Value O(-50,000) - np.testing.assert_almost_equal(_1, _2) - # Done - - def test_gradient_with_scaling(self): - with pm.Model() as model1: - genvar = pm.generator(gen1()) - m = pm.Normal("m") - pm.Normal("n", observed=genvar, total_size=1000) - grad1 = model1.compile_fn(model1.dlogp(vars=m), point_fn=False) - with pm.Model() as model2: - m = pm.Normal("m") - shavar = pytensor.shared(np.ones((1000, 100))) - pm.Normal("n", observed=shavar) - grad2 = model2.compile_fn(model2.dlogp(vars=m), point_fn=False) - - for i in range(10): - shavar.set_value(np.ones((100, 100)) * i) - g1 = grad1(1) - g2 = grad2(1) - np.testing.assert_almost_equal(g1, g2) - - def test_multidim_scaling(self): - with pm.Model() as model0: - pm.Normal("n", observed=[[1, 1], [1, 1]], total_size=[]) - p0 = pytensor.function([], model0.logp()) - - with pm.Model() as model1: - pm.Normal("n", observed=[[1, 1], [1, 1]], total_size=[2, 2]) - p1 = pytensor.function([], model1.logp()) - - with pm.Model() as model2: - pm.Normal("n", observed=[[1], [1]], total_size=[2, 2]) - p2 = pytensor.function([], model2.logp()) - - with pm.Model() as model3: - pm.Normal("n", observed=[[1, 1]], total_size=[2, 2]) - p3 = pytensor.function([], model3.logp()) - - with pm.Model() as model4: - pm.Normal("n", observed=[[1]], total_size=[2, 2]) - p4 = pytensor.function([], model4.logp()) - - with pm.Model() as model5: - pm.Normal("n", observed=[[1]], total_size=[2, Ellipsis, 2]) - p5 = pytensor.function([], model5.logp()) - _p0 = p0() - assert ( - np.allclose(_p0, p1()) - and np.allclose(_p0, p2()) - and np.allclose(_p0, p3()) - and np.allclose(_p0, p4()) - and np.allclose(_p0, p5()) - ) - - def test_common_errors(self): - with pytest.raises(ValueError) as e: - with pm.Model() as m: - pm.Normal("n", observed=[[1]], total_size=[2, Ellipsis, 2, 2]) - m.logp() - assert "Length of" in str(e.value) - with pytest.raises(ValueError) as e: - with pm.Model() as m: - pm.Normal("n", observed=[[1]], total_size=[2, 2, 2]) - m.logp() - assert "Length of" in str(e.value) - with pytest.raises(TypeError) as e: - with pm.Model() as m: - pm.Normal("n", observed=[[1]], total_size="foo") - m.logp() - assert "Unrecognized" in str(e.value) - with pytest.raises(TypeError) as e: - with pm.Model() as m: - pm.Normal("n", observed=[[1]], total_size=["foo"]) - m.logp() - assert "Unrecognized" in str(e.value) - with pytest.raises(ValueError) as e: - with pm.Model() as m: - pm.Normal("n", observed=[[1]], total_size=[Ellipsis, Ellipsis]) - m.logp() - assert "Double Ellipsis" in str(e.value) - - def test_mixed1(self): - with pm.Model(): - data = np.random.rand(10, 20) - mb = pm.Minibatch(data, batch_size=5) - v = pm.Normal("n", observed=mb, total_size=10) - assert pm.logp(v, 1) is not None, "Check index is allowed in graph" - - def test_free_rv(self): - with pm.Model() as model4: - pm.Normal("n", observed=[[1, 1], [1, 1]], total_size=[2, 2]) - p4 = model4.compile_fn(model4.logp(), point_fn=False) - - with pm.Model() as model5: - n = pm.Normal("n", total_size=[2, Ellipsis, 2], size=(2, 2)) - p5 = model5.compile_fn(model5.logp(), point_fn=False) - assert p4() == p5(pm.floatX([[1, 1], [1, 1]])) - - @pytest.mark.usefixtures("strict_float32") class TestMinibatch: data = np.random.rand(30, 10) diff --git a/tests/test_model.py b/tests/test_model.py index ba26e5f8c1..2589e23c4c 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -46,6 +46,7 @@ from pymc.logprob.transforms import IntervalTransform from pymc.model import Point, ValueGradFunction, modelcontext from pymc.util import _FutureWarningValidatingScratchpad +from pymc.variational.minibatch_rv import MinibatchRandomVariable from tests.helpers import SeededTest from tests.models import simple_model @@ -503,7 +504,7 @@ def test_model_value_vars(): def test_model_var_maps(): with pm.Model() as model: a = pm.Uniform("a") - x = pm.Normal("x", a, total_size=5) + x = pm.Normal("x", a) assert set(model.rvs_to_values.keys()) == {a, x} a_value = model.rvs_to_values[a] @@ -516,10 +517,6 @@ def test_model_var_maps(): assert isinstance(model.rvs_to_transforms[a], IntervalTransform) assert model.rvs_to_transforms[x] is None - assert set(model.rvs_to_total_sizes.keys()) == {a, x} - assert model.rvs_to_total_sizes[a] is None - assert model.rvs_to_total_sizes[x] == 5 - def test_make_obs_var(): """ @@ -543,25 +540,27 @@ def test_make_obs_var(): # The function requires data and RV dimensionality to be compatible with pytest.raises(ShapeError, match="Dimensionality of data and RV don't match."): - fake_model.make_obs_var(fake_distribution, np.ones((3, 3, 1)), None, None) + fake_model.make_obs_var(fake_distribution, np.ones((3, 3, 1)), None, None, None) # Check function behavior using the various inputs # dense, sparse: Ensure that the missing values are appropriately set to None # masked: a deterministic variable is returned - dense_output = fake_model.make_obs_var(fake_distribution, dense_input, None, None) + dense_output = fake_model.make_obs_var(fake_distribution, dense_input, None, None, None) assert dense_output == fake_distribution assert isinstance(fake_model.rvs_to_values[dense_output], TensorConstant) del fake_model.named_vars[fake_distribution.name] - sparse_output = fake_model.make_obs_var(fake_distribution, sparse_input, None, None) + sparse_output = fake_model.make_obs_var(fake_distribution, sparse_input, None, None, None) assert sparse_output == fake_distribution assert sparse.basic._is_sparse_variable(fake_model.rvs_to_values[sparse_output]) del fake_model.named_vars[fake_distribution.name] # Here the RandomVariable is split into observed/imputed and a Deterministic is returned with pytest.warns(ImputationWarning): - masked_output = fake_model.make_obs_var(fake_distribution, masked_array_input, None, None) + masked_output = fake_model.make_obs_var( + fake_distribution, masked_array_input, None, None, None + ) assert masked_output != fake_distribution assert not isinstance(masked_output, RandomVariable) # Ensure it has missing values @@ -569,6 +568,15 @@ def test_make_obs_var(): assert {"testing_inputs", "testing_inputs_observed"} == { v.name for v in fake_model.observed_RVs } + del fake_model.named_vars[fake_distribution.name] + + # Test that setting total_size returns a MinibatchRandomVariable + scaled_outputs = fake_model.make_obs_var( + fake_distribution, dense_input, None, None, total_size=100 + ) + assert scaled_outputs != fake_distribution + assert isinstance(scaled_outputs.owner.op, MinibatchRandomVariable) + del fake_model.named_vars[fake_distribution.name] def test_initial_point(): @@ -1436,7 +1444,6 @@ def test_missing_symmetric(self): [x_obs_rv, x_unobs_rv], rvs_to_values={x_obs_rv: x_obs_vv, x_unobs_rv: x_unobs_vv}, rvs_to_transforms={}, - rvs_to_total_sizes={}, ) logp_inputs = list(graph_inputs(logp)) assert x_obs_vv in logp_inputs @@ -1509,10 +1516,6 @@ def test_tag_future_warning_model(): with pytest.raises(AttributeError): x.tag.observations - with pytest.warns(FutureWarning, match="model.rvs_to_total_sizes"): - total_size = x.tag.total_size - assert total_size is None - # Cloning a node will keep the same tag type and contents y = x.owner.clone().default_output() assert y is not x @@ -1530,6 +1533,3 @@ def test_tag_future_warning_model(): assert y_value.eval() == 5 assert isinstance(y_value.tag, _FutureWarningValidatingScratchpad) - with pytest.warns(FutureWarning, match="model.rvs_to_total_sizes"): - total_size = y.tag.total_size - assert total_size is None diff --git a/tests/variational/test_minibatch_rv.py b/tests/variational/test_minibatch_rv.py new file mode 100644 index 0000000000..7f0a1d4dc4 --- /dev/null +++ b/tests/variational/test_minibatch_rv.py @@ -0,0 +1,157 @@ +# Copyright 2023 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytensor +import pytest + +from scipy import stats as st + +import pymc as pm + +from pymc import Normal, draw +from pymc.variational.minibatch_rv import create_minibatch_rv +from tests.helpers import select_by_precision +from tests.test_data import gen1, gen2 + + +class TestMinibatchRandomVariable: + """ + Related to minibatch training + """ + + def test_density_scaling(self): + with pm.Model() as model1: + pm.Normal("n", observed=[[1]], total_size=1) + p1 = pytensor.function([], model1.logp()) + + with pm.Model() as model2: + pm.Normal("n", observed=[[1]], total_size=2) + p2 = pytensor.function([], model2.logp()) + assert p1() * 2 == p2() + + def test_density_scaling_with_generator(self): + # We have different size generators + + def true_dens(): + g = gen1() + for i, point in enumerate(g): + yield st.norm.logpdf(point).sum() * 10 + + t = true_dens() + # We have same size models + with pm.Model() as model1: + pm.Normal("n", observed=gen1(), total_size=100) + p1 = pytensor.function([], model1.logp()) + + with pm.Model() as model2: + gen_var = pm.generator(gen2()) + pm.Normal("n", observed=gen_var, total_size=100) + p2 = pytensor.function([], model2.logp()) + + for i in range(10): + _1, _2, _t = p1(), p2(), next(t) + decimals = select_by_precision(float64=7, float32=1) + np.testing.assert_almost_equal(_1, _t, decimal=decimals) # Value O(-50,000) + np.testing.assert_almost_equal(_1, _2) + # Done + + def test_gradient_with_scaling(self): + with pm.Model() as model1: + genvar = pm.generator(gen1()) + m = pm.Normal("m") + pm.Normal("n", observed=genvar, total_size=1000) + grad1 = model1.compile_fn(model1.dlogp(vars=m), point_fn=False) + with pm.Model() as model2: + m = pm.Normal("m") + shavar = pytensor.shared(np.ones((1000, 100))) + pm.Normal("n", observed=shavar) + grad2 = model2.compile_fn(model2.dlogp(vars=m), point_fn=False) + + for i in range(10): + shavar.set_value(np.ones((100, 100)) * i) + g1 = grad1(1) + g2 = grad2(1) + np.testing.assert_almost_equal(g1, g2) + + def test_multidim_scaling(self): + with pm.Model() as model0: + pm.Normal("n", observed=[[1, 1], [1, 1]], total_size=[]) + p0 = pytensor.function([], model0.logp()) + + with pm.Model() as model1: + pm.Normal("n", observed=[[1, 1], [1, 1]], total_size=[2, 2]) + p1 = pytensor.function([], model1.logp()) + + with pm.Model() as model2: + pm.Normal("n", observed=[[1], [1]], total_size=[2, 2]) + p2 = pytensor.function([], model2.logp()) + + with pm.Model() as model3: + pm.Normal("n", observed=[[1, 1]], total_size=[2, 2]) + p3 = pytensor.function([], model3.logp()) + + with pm.Model() as model4: + pm.Normal("n", observed=[[1]], total_size=[2, 2]) + p4 = pytensor.function([], model4.logp()) + + with pm.Model() as model5: + pm.Normal("n", observed=[[1]], total_size=[2, Ellipsis, 2]) + p5 = pytensor.function([], model5.logp()) + _p0 = p0() + assert ( + np.allclose(_p0, p1()) + and np.allclose(_p0, p2()) + and np.allclose(_p0, p3()) + and np.allclose(_p0, p4()) + and np.allclose(_p0, p5()) + ) + + def test_common_errors(self): + with pytest.raises(ValueError, match="Length of"): + with pm.Model() as m: + pm.Normal("n", observed=[[1]], total_size=[2, Ellipsis, 2, 2]) + m.logp() + with pytest.raises(ValueError, match="Length of"): + with pm.Model() as m: + pm.Normal("n", observed=[[1]], total_size=[2, 2, 2]) + m.logp() + with pytest.raises(TypeError, match="Invalid type for total_size"): + with pm.Model() as m: + pm.Normal("n", observed=[[1]], total_size="foo") + m.logp() + with pytest.raises(NotImplementedError, match="Cannot convert"): + with pm.Model() as m: + pm.Normal("n", observed=[[1]], total_size=["foo"]) + m.logp() + with pytest.raises(ValueError, match="Only one Ellipsis"): + with pm.Model() as m: + pm.Normal("n", observed=[[1]], total_size=[Ellipsis, Ellipsis]) + m.logp() + + with pm.Model() as model4: + with pytest.raises(ValueError, match="only be passed to observed RVs"): + pm.Normal("n", shape=(1, 1), total_size=[2, 2]) + + def test_mixed1(self): + with pm.Model(): + data = np.random.rand(10, 20) + mb = pm.Minibatch(data, batch_size=5) + v = pm.Normal("n", observed=mb, total_size=10) + assert pm.logp(v, 1) is not None, "Check index is allowed in graph" + + def test_random(self): + x = Normal.dist(size=(5,)) + mx = create_minibatch_rv(x, total_size=(10,)) + assert mx is not x + np.testing.assert_array_equal(draw(mx, random_seed=1), draw(x, random_seed=1)) diff --git a/tests/variational/test_opvi.py b/tests/variational/test_opvi.py index 12296fc211..af75a21a8f 100644 --- a/tests/variational/test_opvi.py +++ b/tests/variational/test_opvi.py @@ -48,7 +48,7 @@ def test_discrete_not_allowed(): @pytest.fixture(scope="module") def three_var_model(): with pm.Model() as model: - pm.HalfNormal("one", size=(10, 2), total_size=100) + pm.HalfNormal("one", size=(10, 2)) pm.Normal("two", size=(10,)) pm.Normal("three", size=(10, 1, 2)) return model