diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9f72ba4544..ff90fc3b3c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -26,7 +26,7 @@ repos: - types-filelock - types-setuptools - arviz - - aesara==2.7.5 + - aesara==2.7.7 - aeppl==0.0.32 always_run: true require_serial: true diff --git a/conda-envs/environment-dev.yml b/conda-envs/environment-dev.yml index 7ee5eda730..4bd145aace 100644 --- a/conda-envs/environment-dev.yml +++ b/conda-envs/environment-dev.yml @@ -6,7 +6,7 @@ channels: dependencies: # Base dependencies - aeppl=0.0.32 -- aesara=2.7.5 +- aesara=2.7.7 - arviz>=0.12.0 - blas - cachetools>=4.2.1 diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index c400e2c3c1..ccac4652f9 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -6,7 +6,7 @@ channels: dependencies: # Base dependencies - aeppl=0.0.32 -- aesara=2.7.5 +- aesara=2.7.7 - arviz>=0.12.0 - blas - cachetools>=4.2.1 diff --git a/conda-envs/windows-environment-dev.yml b/conda-envs/windows-environment-dev.yml index 6477a25caa..198321e479 100644 --- a/conda-envs/windows-environment-dev.yml +++ b/conda-envs/windows-environment-dev.yml @@ -6,7 +6,7 @@ channels: dependencies: # Base dependencies (see install guide for Windows) - aeppl=0.0.32 -- aesara=2.7.5 +- aesara=2.7.7 - arviz>=0.12.0 - blas - cachetools>=4.2.1 diff --git a/conda-envs/windows-environment-test.yml b/conda-envs/windows-environment-test.yml index 6cc829a3ad..a373ec6073 100644 --- a/conda-envs/windows-environment-test.yml +++ b/conda-envs/windows-environment-test.yml @@ -6,7 +6,7 @@ channels: dependencies: # Base dependencies (see install guide for Windows) - aeppl=0.0.32 -- aesara=2.7.5 +- aesara=2.7.7 - arviz>=0.12.0 - blas - cachetools>=4.2.1 diff --git a/pymc/data.py b/pymc/data.py index bda9d777c8..a04d5c30ae 100644 --- a/pymc/data.py +++ b/pymc/data.py @@ -310,6 +310,7 @@ def __init__( batch_size=128, dtype=None, broadcastable=None, + shape=None, name="Minibatch", random_seed=42, update_shared_f=None, @@ -324,9 +325,15 @@ def __init__( self.update_shared_f = update_shared_f self.random_slc = self.make_random_slices(self.shared.shape, batch_size, random_seed) minibatch = self.shared[self.random_slc] - if broadcastable is None: - broadcastable = (False,) * minibatch.ndim - minibatch = at.patternbroadcast(minibatch, broadcastable) + if broadcastable is not None: + warnings.warn( + "Minibatch `broadcastable` argument is deprecated. Use `shape` instead", + FutureWarning, + ) + assert shape is None + shape = [1 if b else None for b in broadcastable] + if shape is not None: + minibatch = at.specify_shape(minibatch, shape) self.minibatch = minibatch super().__init__(self.minibatch.type, None, None, name=name) Apply(aesara.compile.view_op, inputs=[self.minibatch], outputs=[self]) diff --git a/pymc/ode/ode.py b/pymc/ode/ode.py index 6dfd927234..f400cb7c00 100644 --- a/pymc/ode/ode.py +++ b/pymc/ode/ode.py @@ -158,11 +158,11 @@ def __call__(self, y0, theta, return_sens=False, **kwargs): ) # convert inputs to tensors (and check their types) - y0 = at.cast(at.unbroadcast(at.as_tensor_variable(y0), 0), floatX) - theta = at.cast(at.unbroadcast(at.as_tensor_variable(theta), 0), floatX) + y0 = at.cast(at.as_tensor_variable(y0), floatX) + theta = at.cast(at.as_tensor_variable(theta), floatX) inputs = [y0, theta] for i, (input_val, itype) in enumerate(zip(inputs, self._itypes)): - if not input_val.type.in_same_class(itype): + if not itype.is_super(input_val.type): raise ValueError( f"Input {i} of type {input_val.type} does not have the expected type of {itype}" ) diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index 10af05d8fa..0e6cc7f65b 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -1030,7 +1030,7 @@ def symbolic_sample_over_posterior(self, node): """ node = self.to_flat_input(node) random = self.symbolic_random.astype(self.symbolic_initial.dtype) - random = at.patternbroadcast(random, self.symbolic_initial.broadcastable) + random = at.specify_shape(random, self.symbolic_initial.type.shape) def sample(post, node): return aesara.clone_replace(node, {self.input: post}) @@ -1065,7 +1065,7 @@ def make_size_and_deterministic_replacements(self, s, d, more_replacements=None) dict with replacements for initial """ initial = self._new_initial(s, d, more_replacements) - initial = at.patternbroadcast(initial, self.symbolic_initial.broadcastable) + initial = at.specify_shape(initial, self.symbolic_initial.type.shape) if more_replacements: initial = aesara.clone_replace(initial, more_replacements) return {self.symbolic_initial: initial} diff --git a/requirements-dev.txt b/requirements-dev.txt index cd0c647a91..b488309618 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,7 +2,7 @@ # See that file for comments about the need/usage of each dependency. aeppl==0.0.32 -aesara==2.7.5 +aesara==2.7.7 arviz>=0.12.0 cachetools>=4.2.1 cloudpickle diff --git a/requirements.txt b/requirements.txt index 6d52a9c242..479ca4449e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ aeppl==0.0.32 -aesara==2.7.5 +aesara==2.7.7 arviz>=0.12.0 cachetools>=4.2.1 cloudpickle