Skip to content

Commit 8360f86

Browse files
authored
Merge branch 'pymc-devs:main' into aesaraf_join_shared_input_doc
2 parents a50be1f + e57d1d7 commit 8360f86

File tree

5 files changed

+69
-198
lines changed

5 files changed

+69
-198
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ repos:
1919
- id: isort
2020
name: isort
2121
- repo: https://github.com/asottile/pyupgrade
22-
rev: v3.1.0
22+
rev: v3.2.0
2323
hooks:
2424
- id: pyupgrade
2525
args: [--py37-plus]

pymc/printing.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def str_for_dist(
5858

5959
if "latex" in formatting:
6060
if print_name is not None:
61-
print_name = r"\text{" + _latex_escape(dist.name) + "}"
61+
print_name = r"\text{" + _latex_escape(dist.name.strip("$")) + "}"
6262

6363
op_name = (
6464
dist.owner.op._print_name[1]
@@ -67,9 +67,11 @@ def str_for_dist(
6767
)
6868
if include_params:
6969
if print_name:
70-
return r"${} \sim {}({})$".format(print_name, op_name, ",~".join(dist_args))
70+
return r"${} \sim {}({})$".format(
71+
print_name, op_name, ",~".join([d.strip("$") for d in dist_args])
72+
)
7173
else:
72-
return r"${}({})$".format(op_name, ",~".join(dist_args))
74+
return r"${}({})$".format(op_name, ",~".join([d.strip("$") for d in dist_args]))
7375

7476
else:
7577
if print_name:
@@ -138,7 +140,7 @@ def str_for_potential_or_deterministic(
138140
LaTeX or plain, optionally with distribution parameter values included."""
139141
print_name = var.name if var.name is not None else "<unnamed>"
140142
if "latex" in formatting:
141-
print_name = r"\text{" + _latex_escape(print_name) + "}"
143+
print_name = r"\text{" + _latex_escape(print_name.strip("$")) + "}"
142144
if include_params:
143145
return rf"${print_name} \sim \operatorname{{{dist_name}}}({_str_for_expression(var, formatting=formatting)})$"
144146
else:
@@ -182,7 +184,7 @@ def _str_for_input_rv(var: Variable, formatting: str) -> str:
182184
else str_for_dist(var, formatting=formatting, include_params=True)
183185
)
184186
if "latex" in formatting:
185-
return r"\text{" + _latex_escape(_str) + "}"
187+
return _latex_text_format(_latex_escape(_str.strip("$")))
186188
else:
187189
return _str
188190

@@ -215,9 +217,20 @@ def _expand(x):
215217
names = [x.name for x in parents]
216218

217219
if "latex" in formatting:
218-
return r"f(" + ",~".join([r"\text{" + _latex_escape(n) + "}" for n in names]) + ")"
220+
return (
221+
r"f("
222+
+ ",~".join([_latex_text_format(_latex_escape(n.strip("$"))) for n in names])
223+
+ ")"
224+
)
225+
else:
226+
return r"f(" + ", ".join([n.strip("$") for n in names]) + ")"
227+
228+
229+
def _latex_text_format(text: str) -> str:
230+
if r"\operatorname{" in text:
231+
return text
219232
else:
220-
return r"f(" + ", ".join(names) + ")"
233+
return r"\text{" + text + "}"
221234

222235

223236
def _latex_escape(text: str) -> str:

pymc/sampling.py

Lines changed: 4 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -2065,127 +2065,10 @@ def sample_posterior_predictive_w(
20652065
weighted models (default), or a dictionary with variable names as keys, and samples as
20662066
numpy arrays.
20672067
"""
2068-
raise NotImplementedError(f"sample_posterior_predictive_w has not yet been ported to PyMC 4.0.")
2069-
2070-
if isinstance(traces[0], InferenceData):
2071-
n_samples = [
2072-
trace.posterior.sizes["chain"] * trace.posterior.sizes["draw"] for trace in traces
2073-
]
2074-
traces = [dataset_to_point_list(trace.posterior) for trace in traces]
2075-
elif isinstance(traces[0], xarray.Dataset):
2076-
n_samples = [trace.sizes["chain"] * trace.sizes["draw"] for trace in traces]
2077-
traces = [dataset_to_point_list(trace) for trace in traces]
2078-
else:
2079-
n_samples = [len(i) * i.nchains for i in traces]
2080-
2081-
if models is None:
2082-
models = [modelcontext(models)] * len(traces)
2083-
2084-
if random_seed is not None:
2085-
(random_seed,) = _get_seeds_per_chain(random_seed, 1)
2086-
2087-
for model in models:
2088-
if model.potentials:
2089-
warnings.warn(
2090-
"The effect of Potentials on other parameters is ignored during posterior predictive sampling. "
2091-
"This is likely to lead to invalid or biased predictive samples.",
2092-
UserWarning,
2093-
stacklevel=2,
2094-
)
2095-
break
2096-
2097-
if weights is None:
2098-
weights = [1] * len(traces)
2099-
2100-
if len(traces) != len(weights):
2101-
raise ValueError("The number of traces and weights should be the same")
2102-
2103-
if len(models) != len(weights):
2104-
raise ValueError("The number of models and weights should be the same")
2105-
2106-
length_morv = len(models[0].observed_RVs)
2107-
if any(len(i.observed_RVs) != length_morv for i in models):
2108-
raise ValueError("The number of observed RVs should be the same for all models")
2109-
2110-
weights = np.asarray(weights)
2111-
p = weights / np.sum(weights)
2112-
2113-
min_tr = min(n_samples)
2114-
2115-
n = (min_tr * p).astype("int")
2116-
# ensure n sum up to min_tr
2117-
idx = np.argmax(n)
2118-
n[idx] = n[idx] + min_tr - np.sum(n)
2119-
trace = []
2120-
for i, j in enumerate(n):
2121-
tr = traces[i]
2122-
len_trace = len(tr)
2123-
try:
2124-
nchain = tr.nchains
2125-
except AttributeError:
2126-
nchain = 1
2127-
2128-
indices = np.random.randint(0, nchain * len_trace, j)
2129-
if nchain > 1:
2130-
chain_idx, point_idx = np.divmod(indices, len_trace)
2131-
for cidx, pidx in zip(chain_idx, point_idx):
2132-
trace.append(tr._straces[cidx].point(pidx))
2133-
else:
2134-
for idx in indices:
2135-
trace.append(tr[idx])
2136-
2137-
obs = [x for m in models for x in m.observed_RVs]
2138-
variables = np.repeat(obs, n)
2139-
2140-
lengths = list({np.atleast_1d(observed).shape for observed in obs})
2141-
2142-
size: List[Optional[Tuple[int, ...]]] = []
2143-
if len(lengths) == 1:
2144-
size = [None] * len(variables)
2145-
elif len(lengths) > 2:
2146-
raise ValueError("Observed variables could not be broadcast together")
2147-
else:
2148-
x = np.zeros(shape=lengths[0])
2149-
y = np.zeros(shape=lengths[1])
2150-
b = np.broadcast(x, y)
2151-
for var in variables:
2152-
# XXX: This needs to be refactored
2153-
shape = None # np.shape(np.atleast_1d(var.distribution.default()))
2154-
if shape != b.shape:
2155-
size.append(b.shape)
2156-
else:
2157-
size.append(None)
2158-
len_trace = len(trace)
2159-
2160-
if samples is None:
2161-
samples = len_trace
2162-
2163-
indices = np.random.randint(0, len_trace, samples)
2164-
2165-
if progressbar:
2166-
indices = progress_bar(indices, total=samples, display=progressbar)
2167-
2168-
try:
2169-
ppcl: Dict[str, list] = defaultdict(list)
2170-
for idx in indices:
2171-
param = trace[idx]
2172-
var = variables[idx]
2173-
# TODO sample_posterior_predictive_w is currently only work for model with
2174-
# one observed.
2175-
# XXX: This needs to be refactored
2176-
# ppc[var.name].append(draw_values([var], point=param, size=size[idx])[0])
2177-
raise NotImplementedError()
2178-
2179-
except KeyboardInterrupt:
2180-
pass
2181-
else:
2182-
ppcd = {k: np.asarray(v) for k, v in ppcl.items()}
2183-
if not return_inferencedata:
2184-
return ppcd
2185-
ikwargs: Dict[str, Any] = dict(model=models)
2186-
if idata_kwargs:
2187-
ikwargs.update(idata_kwargs)
2188-
return pm.to_inference_data(posterior_predictive=ppcd, **ikwargs)
2068+
raise FutureWarning(
2069+
"The function `sample_posterior_predictive_w` has been removed in PyMC 4.3.0. "
2070+
"Switch to `arviz.stats.weight_predictions`"
2071+
)
21892072

21902073

21912074
def sample_prior_predictive(

pymc/tests/test_printing.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
22

3-
from pymc import Bernoulli, Censored, Mixture
3+
from pymc import Bernoulli, Censored, HalfCauchy, Mixture, StudentT
44
from pymc.aesaraf import floatX
55
from pymc.distributions import (
66
Dirichlet,
@@ -130,12 +130,12 @@ def setup_class(self):
130130
r"$\text{beta} \sim \operatorname{N}(0,~10)$",
131131
r"$\text{Z} \sim \operatorname{N}(f(),~f())$",
132132
r"$\text{nb_with_p_n} \sim \operatorname{NB}(10,~\text{nbp})$",
133-
r"$\text{zip} \sim \operatorname{MarginalMixture}(f(),~\text{\$\operatorname{DiracDelta}(0)\$},~\text{\$\operatorname{Pois}(5)\$})$",
133+
r"$\text{zip} \sim \operatorname{MarginalMixture}(f(),~\operatorname{DiracDelta}(0),~\operatorname{Pois}(5))$",
134134
r"$\text{w} \sim \operatorname{Dir}(\text{<constant>})$",
135135
(
136136
r"$\text{nested_mix} \sim \operatorname{MarginalMixture}(\text{w},"
137-
r"~\text{\$\operatorname{MarginalMixture}(f(),~\text{\\$\operatorname{DiracDelta}(0)\\$},~\text{\\$\operatorname{Pois}(5)\\$})\$},"
138-
r"~\text{\$\operatorname{Censored}(\text{\\$\operatorname{Bern}(0.5)\\$},~-1,~1)\$})$"
137+
r"~\operatorname{MarginalMixture}(f(),~\operatorname{DiracDelta}(0),~\operatorname{Pois}(5)),"
138+
r"~\operatorname{Censored}(\operatorname{Bern}(0.5),~-1,~1))$"
139139
),
140140
r"$\text{Y_obs} \sim \operatorname{N}(\text{mu},~\text{sigma})$",
141141
r"$\text{pot} \sim \operatorname{Potential}(f(\text{beta},~\text{alpha}))$",
@@ -178,3 +178,43 @@ def test_str_repr(self):
178178
assert segment in model_text
179179
else:
180180
assert text in model_text
181+
182+
183+
def test_model_latex_repr_three_levels_model():
184+
with Model() as censored_model:
185+
mu = Normal("mu", 0.0, 5.0)
186+
sigma = HalfCauchy("sigma", 2.5)
187+
normal_dist = Normal.dist(mu=mu, sigma=sigma)
188+
censored_normal = Censored(
189+
"censored_normal", normal_dist, lower=-2.0, upper=2.0, observed=[1, 0, 0.5]
190+
)
191+
192+
latex_repr = censored_model.str_repr(formatting="latex")
193+
expected = [
194+
"$$",
195+
"\\begin{array}{rcl}",
196+
"\\text{mu} &\\sim & \\operatorname{N}(0,~5)\\\\\\text{sigma} &\\sim & "
197+
"\\operatorname{C^{+}}(0,~2.5)\\\\\\text{censored_normal} &\\sim & "
198+
"\\operatorname{Censored}(\\operatorname{N}(\\text{mu},~\\text{sigma}),~-2,~2)",
199+
"\\end{array}",
200+
"$$",
201+
]
202+
assert [line.strip() for line in latex_repr.split("\n")] == expected
203+
204+
205+
def test_model_latex_repr_mixture_model():
206+
with Model() as mix_model:
207+
w = Dirichlet("w", [1, 1])
208+
mix = Mixture("mix", w=w, comp_dists=[Normal.dist(0.0, 5.0), StudentT.dist(7.0)])
209+
210+
latex_repr = mix_model.str_repr(formatting="latex")
211+
expected = [
212+
"$$",
213+
"\\begin{array}{rcl}",
214+
"\\text{w} &\\sim & "
215+
"\\operatorname{Dir}(\\text{<constant>})\\\\\\text{mix} &\\sim & "
216+
"\\operatorname{MarginalMixture}(\\text{w},~\\operatorname{N}(0,~5),~\\operatorname{StudentT}(7,~0,~1))",
217+
"\\end{array}",
218+
"$$",
219+
]
220+
assert [line.strip() for line in latex_repr.split("\n")] == expected

pymc/tests/test_sampling.py

Lines changed: 0 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1177,71 +1177,6 @@ def test_logging_sampled_basic_rvs_posterior_mutable(self, mock_sample_results,
11771177
caplog.clear()
11781178

11791179

1180-
@pytest.mark.xfail(
1181-
reason="sample_posterior_predictive_w not refactored for v4", raises=NotImplementedError
1182-
)
1183-
class TestSamplePPCW(SeededTest):
1184-
def test_sample_posterior_predictive_w(self):
1185-
data0 = np.random.normal(0, 1, size=50)
1186-
warning_msg = "The number of samples is too small to check convergence reliably"
1187-
1188-
with pm.Model() as model_0:
1189-
mu = pm.Normal("mu", mu=0, sigma=1)
1190-
y = pm.Normal("y", mu=mu, sigma=1, observed=data0)
1191-
with pytest.warns(UserWarning, match=warning_msg):
1192-
trace_0 = pm.sample(10, tune=0, chains=2, return_inferencedata=False)
1193-
idata_0 = pm.to_inference_data(trace_0, log_likelihood=False)
1194-
1195-
with pm.Model() as model_1:
1196-
mu = pm.Normal("mu", mu=0, sigma=1, size=len(data0))
1197-
y = pm.Normal("y", mu=mu, sigma=1, observed=data0)
1198-
with pytest.warns(UserWarning, match=warning_msg):
1199-
trace_1 = pm.sample(10, tune=0, chains=2, return_inferencedata=False)
1200-
idata_1 = pm.to_inference_data(trace_1, log_likelihood=False)
1201-
1202-
with pm.Model() as model_2:
1203-
# Model with no observed RVs.
1204-
mu = pm.Normal("mu", mu=0, sigma=1)
1205-
with pytest.warns(UserWarning, match=warning_msg):
1206-
trace_2 = pm.sample(10, tune=0, return_inferencedata=False)
1207-
1208-
traces = [trace_0, trace_1]
1209-
idatas = [idata_0, idata_1]
1210-
models = [model_0, model_1]
1211-
1212-
ppc = pm.sample_posterior_predictive_w(traces, 100, models)
1213-
assert ppc["y"].shape == (100, 50)
1214-
1215-
ppc = pm.sample_posterior_predictive_w(idatas, 100, models)
1216-
assert ppc["y"].shape == (100, 50)
1217-
1218-
with model_0:
1219-
ppc = pm.sample_posterior_predictive_w([idata_0.posterior], None)
1220-
assert ppc["y"].shape == (20, 50)
1221-
1222-
with pytest.raises(ValueError, match="The number of traces and weights should be the same"):
1223-
pm.sample_posterior_predictive_w([idata_0.posterior], 100, models, weights=[0.5, 0.5])
1224-
1225-
with pytest.raises(ValueError, match="The number of models and weights should be the same"):
1226-
pm.sample_posterior_predictive_w([idata_0.posterior], 100, models)
1227-
1228-
with pytest.raises(
1229-
ValueError, match="The number of observed RVs should be the same for all models"
1230-
):
1231-
pm.sample_posterior_predictive_w([trace_0, trace_2], 100, [model_0, model_2])
1232-
1233-
def test_potentials_warning(self):
1234-
warning_msg = "The effect of Potentials on other parameters is ignored during"
1235-
with pm.Model() as m:
1236-
a = pm.Normal("a", 0, 1)
1237-
p = pm.Potential("p", a + 1)
1238-
obs = pm.Normal("obs", a, 1, observed=5)
1239-
1240-
trace = az_from_dict({"a": np.random.rand(10)})
1241-
with pytest.warns(UserWarning, match=warning_msg):
1242-
pm.sample_posterior_predictive_w(samples=5, traces=[trace, trace], models=[m, m])
1243-
1244-
12451180
def check_exec_nuts_init(method):
12461181
with pm.Model() as model:
12471182
pm.Normal("a", mu=0, sigma=1, size=2)

0 commit comments

Comments
 (0)