Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ repos:
- types-filelock
- types-setuptools
- arviz
- aesara==2.7.5
- aesara==2.7.7
- aeppl==0.0.32
always_run: true
require_serial: true
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ channels:
dependencies:
# Base dependencies
- aeppl=0.0.32
- aesara=2.7.5
- aesara=2.7.7
- arviz>=0.12.0
- blas
- cachetools>=4.2.1
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ channels:
dependencies:
# Base dependencies
- aeppl=0.0.32
- aesara=2.7.5
- aesara=2.7.7
- arviz>=0.12.0
- blas
- cachetools>=4.2.1
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/windows-environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ channels:
dependencies:
# Base dependencies (see install guide for Windows)
- aeppl=0.0.32
- aesara=2.7.5
- aesara=2.7.7
- arviz>=0.12.0
- blas
- cachetools>=4.2.1
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/windows-environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ channels:
dependencies:
# Base dependencies (see install guide for Windows)
- aeppl=0.0.32
- aesara=2.7.5
- aesara=2.7.7
- arviz>=0.12.0
- blas
- cachetools>=4.2.1
Expand Down
13 changes: 10 additions & 3 deletions pymc/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ def __init__(
batch_size=128,
dtype=None,
broadcastable=None,
shape=None,
name="Minibatch",
random_seed=42,
update_shared_f=None,
Expand All @@ -324,9 +325,15 @@ def __init__(
self.update_shared_f = update_shared_f
self.random_slc = self.make_random_slices(self.shared.shape, batch_size, random_seed)
minibatch = self.shared[self.random_slc]
if broadcastable is None:
broadcastable = (False,) * minibatch.ndim
minibatch = at.patternbroadcast(minibatch, broadcastable)
if broadcastable is not None:
warnings.warn(
"Minibatch `broadcastable` argument is deprecated. Use `shape` instead",
FutureWarning,
)
assert shape is None
shape = [1 if b else None for b in broadcastable]
if shape is not None:
minibatch = at.specify_shape(minibatch, shape)
self.minibatch = minibatch
super().__init__(self.minibatch.type, None, None, name=name)
Apply(aesara.compile.view_op, inputs=[self.minibatch], outputs=[self])
Expand Down
6 changes: 3 additions & 3 deletions pymc/ode/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,11 +158,11 @@ def __call__(self, y0, theta, return_sens=False, **kwargs):
)

# convert inputs to tensors (and check their types)
y0 = at.cast(at.unbroadcast(at.as_tensor_variable(y0), 0), floatX)
theta = at.cast(at.unbroadcast(at.as_tensor_variable(theta), 0), floatX)
y0 = at.cast(at.as_tensor_variable(y0), floatX)
theta = at.cast(at.as_tensor_variable(theta), floatX)
inputs = [y0, theta]
for i, (input_val, itype) in enumerate(zip(inputs, self._itypes)):
if not input_val.type.in_same_class(itype):
if not itype.is_super(input_val.type):
raise ValueError(
f"Input {i} of type {input_val.type} does not have the expected type of {itype}"
)
Expand Down
4 changes: 2 additions & 2 deletions pymc/variational/opvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,7 +1030,7 @@ def symbolic_sample_over_posterior(self, node):
"""
node = self.to_flat_input(node)
random = self.symbolic_random.astype(self.symbolic_initial.dtype)
random = at.patternbroadcast(random, self.symbolic_initial.broadcastable)
random = at.specify_shape(random, self.symbolic_initial.type.shape)

def sample(post, node):
return aesara.clone_replace(node, {self.input: post})
Expand Down Expand Up @@ -1065,7 +1065,7 @@ def make_size_and_deterministic_replacements(self, s, d, more_replacements=None)
dict with replacements for initial
"""
initial = self._new_initial(s, d, more_replacements)
initial = at.patternbroadcast(initial, self.symbolic_initial.broadcastable)
initial = at.specify_shape(initial, self.symbolic_initial.type.shape)
if more_replacements:
initial = aesara.clone_replace(initial, more_replacements)
return {self.symbolic_initial: initial}
Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# See that file for comments about the need/usage of each dependency.

aeppl==0.0.32
aesara==2.7.5
aesara==2.7.7
arviz>=0.12.0
cachetools>=4.2.1
cloudpickle
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
aeppl==0.0.32
aesara==2.7.5
aesara==2.7.7
arviz>=0.12.0
cachetools>=4.2.1
cloudpickle
Expand Down