-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Closed
Description
Description of your problem
I seems the Dirichlet distribution does not work in the current beta, although it seems to be expected to work.
import numpy as np
import pymc as pm
with pm.Model():
a = pm.Dirichlet('a', np.ones(3))
pm.sample()
Complete error traceback
---------------------------------------------------------------------------
NotImplementedError Traceback (most recent call last)
File ~/.conda/envs/pymc_v1/lib/python3.9/site-packages/aesara/compile/function/types.py:964, in Function.__call__(self, *args, **kwargs)
962 try:
963 outputs = (
--> 964 self.fn()
965 if output_subset is None
966 else self.fn(output_subset=output_subset)
967 )
968 except Exception:
File ~/.conda/envs/pymc_v1/lib/python3.9/site-packages/aesara/graph/op.py:522, in Op.make_py_thunk.<locals>.rval(p, i, o, n, params)
518 @is_thunk_type
519 def rval(
520 p=p, i=node_input_storage, o=node_output_storage, n=node, params=None
521 ):
--> 522 r = p(n, [x[0] for x in i], o)
523 for o in node.outputs:
File ~/.conda/envs/pymc_v1/lib/python3.9/site-packages/aeppl/transforms.py:48, in TransformedVariable.perform(self, node, inputs, outputs)
47 def perform(self, node, inputs, outputs):
---> 48 raise NotImplementedError(
49 "These `Op`s should be removed from graphs used for computation."
50 )
NotImplementedError: These `Op`s should be removed from graphs used for computation.
During handling of the above exception, another exception occurred:
NotImplementedError Traceback (most recent call last)
Input In [9], in <cell line: 4>()
4 with pm.Model() as model:
5 a = pm.Dirichlet('a', np.ones(3))
----> 6 pm.sample()
File ~/.conda/envs/pymc_v1/lib/python3.9/site-packages/pymc/sampling.py:487, in sample(draws, step, init, n_init, initvals, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, jitter_max_retries, return_inferencedata, idata_kwargs, mp_ctx, **kwargs)
485 # One final check that shapes and logps at the starting points are okay.
486 for ip in initial_points:
--> 487 model.check_start_vals(ip)
488 _check_start_shape(model, ip)
490 sample_args = {
491 "draws": draws,
492 "step": step,
(...)
503 "discard_tuned_samples": discard_tuned_samples,
504 }
File ~/.conda/envs/pymc_v1/lib/python3.9/site-packages/pymc/model.py:1680, in Model.check_start_vals(self, start)
1674 valid_keys = ", ".join(self.named_vars.keys())
1675 raise KeyError(
1676 "Some start parameters do not appear in the model!\n"
1677 f"Valid keys are: {valid_keys}, but {extra_keys} was supplied"
1678 )
-> 1680 initial_eval = self.point_logps(point=elem)
1682 if not all(np.isfinite(v) for v in initial_eval.values()):
1683 raise SamplingError(
1684 "Initial evaluation of model at starting point failed!\n"
1685 f"Starting values:\n{elem}\n\n"
1686 f"Initial evaluation results:\n{initial_eval}"
1687 )
File ~/.conda/envs/pymc_v1/lib/python3.9/site-packages/pymc/model.py:1721, in Model.point_logps(self, point, round_vals)
1715 factors = self.basic_RVs + self.potentials
1716 factor_logps_fn = [at.sum(factor) for factor in self.logpt(factors, sum=False)]
1717 return {
1718 factor.name: np.round(np.asarray(factor_logp), round_vals)
1719 for factor, factor_logp in zip(
1720 factors,
-> 1721 self.compile_fn(factor_logps_fn)(point),
1722 )
1723 }
File ~/.conda/envs/pymc_v1/lib/python3.9/site-packages/pymc/model.py:1820, in PointFunc.__call__(self, state)
1819 def __call__(self, state):
-> 1820 return self.f(**state)
File ~/.conda/envs/pymc_v1/lib/python3.9/site-packages/aesara/compile/function/types.py:977, in Function.__call__(self, *args, **kwargs)
975 if hasattr(self.fn, "thunks"):
976 thunk = self.fn.thunks[self.fn.position_of_error]
--> 977 raise_with_op(
978 self.maker.fgraph,
979 node=self.fn.nodes[self.fn.position_of_error],
980 thunk=thunk,
981 storage_map=getattr(self.fn, "storage_map", None),
982 )
983 else:
984 # old-style linkers raise their own exceptions
985 raise
File ~/.conda/envs/pymc_v1/lib/python3.9/site-packages/aesara/link/utils.py:538, in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
533 warnings.warn(
534 f"{exc_type} error does not allow us to add an extra error message"
535 )
536 # Some exception need extra parameter in inputs. So forget the
537 # extra long error message in that case.
--> 538 raise exc_value.with_traceback(exc_trace)
File ~/.conda/envs/pymc_v1/lib/python3.9/site-packages/aesara/compile/function/types.py:964, in Function.__call__(self, *args, **kwargs)
961 t0_fn = time.time()
962 try:
963 outputs = (
--> 964 self.fn()
965 if output_subset is None
966 else self.fn(output_subset=output_subset)
967 )
968 except Exception:
969 restore_defaults()
File ~/.conda/envs/pymc_v1/lib/python3.9/site-packages/aesara/graph/op.py:522, in Op.make_py_thunk.<locals>.rval(p, i, o, n, params)
518 @is_thunk_type
519 def rval(
520 p=p, i=node_input_storage, o=node_output_storage, n=node, params=None
521 ):
--> 522 r = p(n, [x[0] for x in i], o)
523 for o in node.outputs:
524 compute_map[o][0] = True
File ~/.conda/envs/pymc_v1/lib/python3.9/site-packages/aeppl/transforms.py:48, in TransformedVariable.perform(self, node, inputs, outputs)
47 def perform(self, node, inputs, outputs):
---> 48 raise NotImplementedError(
49 "These `Op`s should be removed from graphs used for computation."
50 )
NotImplementedError: These `Op`s should be removed from graphs used for computation.
Apply node that caused the error: TransformedVariable(Softmax{axis=0}.0, a_simplex__)
Toposort index: 20
Inputs types: [TensorType(float64, (None,)), TensorType(float64, (None,))]
Inputs shapes: [(3,), (2,)]
Inputs strides: [(8,), (8,)]
Inputs values: [array([0.33333333, 0.33333333, 0.33333333]), array([0., 0.])]
Outputs clients: [[Elemwise{eq,no_inplace}(a_simplex___simplex, TensorConstant{(1,) of 0}), Elemwise{gt,no_inplace}(a_simplex___simplex, TensorConstant{(1,) of 1}), Elemwise{lt,no_inplace}(a_simplex___simplex, TensorConstant{(1,) of 0})]]
Backtrace when the node is created (use Aesara flag traceback__limit=N to make it longer):
File "/home/dotto/.conda/envs/pymc_v1/lib/python3.9/site-packages/aeppl/transforms.py", line 203, in apply
return self.default_transform_opt.optimize(fgraph)
File "/home/dotto/.conda/envs/pymc_v1/lib/python3.9/site-packages/aesara/graph/opt.py", line 103, in optimize
ret = self.apply(fgraph, *args, **kwargs)
File "/home/dotto/.conda/envs/pymc_v1/lib/python3.9/site-packages/aesara/graph/opt.py", line 1960, in apply
nb += self.process_node(fgraph, node)
File "/home/dotto/.conda/envs/pymc_v1/lib/python3.9/site-packages/aesara/graph/opt.py", line 1850, in process_node
replacements = lopt.transform(fgraph, node)
File "/home/dotto/.conda/envs/pymc_v1/lib/python3.9/site-packages/aesara/graph/opt.py", line 1055, in transform
return self.fn(fgraph, node)
File "/home/dotto/.conda/envs/pymc_v1/lib/python3.9/site-packages/aeppl/transforms.py", line 148, in transform_values
new_value_var = transformed_variable(
File "/home/dotto/.conda/envs/pymc_v1/lib/python3.9/site-packages/aesara/graph/op.py", line 294, in __call__
node = self.make_node(*inputs, **kwargs)
File "/home/dotto/.conda/envs/pymc_v1/lib/python3.9/site-packages/aeppl/transforms.py", line 45, in make_node
return Apply(self, [tran_value, value], [tran_value.type()])
HINT: Use the Aesara flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
The error above seems to indicate that Softmax is applied on the transformed RV of the Dirichlet distribution. However, the transformation currently used is the aeppl.transforms.Simplex
which does not explicitly use the Softmax function:
Versions and main components
- PyMC/PyMC3 Version: 4.0.0b4
- Aesara/Theano Version: 2.5.1
- aePPL Version: 0.0.27
- Python Version: 3.9.10
- Operating system: Ubuntu 18.04.5 LTS
- How did you install PyMC/PyMC3: conda