Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions pymc/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def str_for_dist(

if "latex" in formatting:
if print_name is not None:
print_name = r"\text{" + _latex_escape(dist.name) + "}"
print_name = r"\text{" + _latex_escape(dist.name.strip("$")) + "}"

op_name = (
dist.owner.op._print_name[1]
Expand All @@ -67,9 +67,11 @@ def str_for_dist(
)
if include_params:
if print_name:
return r"${} \sim {}({})$".format(print_name, op_name, ",~".join(dist_args))
return r"${} \sim {}({})$".format(
print_name, op_name, ",~".join([d.strip("$") for d in dist_args])
)
else:
return r"${}({})$".format(op_name, ",~".join(dist_args))
return r"${}({})$".format(op_name, ",~".join([d.strip("$") for d in dist_args]))

else:
if print_name:
Expand Down Expand Up @@ -138,7 +140,7 @@ def str_for_potential_or_deterministic(
LaTeX or plain, optionally with distribution parameter values included."""
print_name = var.name if var.name is not None else "<unnamed>"
if "latex" in formatting:
print_name = r"\text{" + _latex_escape(print_name) + "}"
print_name = r"\text{" + _latex_escape(print_name.strip("$")) + "}"
if include_params:
return rf"${print_name} \sim \operatorname{{{dist_name}}}({_str_for_expression(var, formatting=formatting)})$"
else:
Expand Down Expand Up @@ -182,7 +184,7 @@ def _str_for_input_rv(var: Variable, formatting: str) -> str:
else str_for_dist(var, formatting=formatting, include_params=True)
)
if "latex" in formatting:
return r"\text{" + _latex_escape(_str) + "}"
return _latex_text_format(_latex_escape(_str.strip("$")))
else:
return _str

Expand Down Expand Up @@ -215,9 +217,20 @@ def _expand(x):
names = [x.name for x in parents]

if "latex" in formatting:
return r"f(" + ",~".join([r"\text{" + _latex_escape(n) + "}" for n in names]) + ")"
return (
r"f("
+ ",~".join([_latex_text_format(_latex_escape(n.strip("$"))) for n in names])
+ ")"
)
else:
return r"f(" + ", ".join([n.strip("$") for n in names]) + ")"


def _latex_text_format(text: str) -> str:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay for now, but I have a sense that this can be simplified since you're always using it on an output of _latex_escape

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can add the call to _latex_escape(text.strip("$") inside _latex_text_format and call this also in the other 2 occurrences of _latex_escape. Then, we can also remove the naked \text{ in there.

if r"\operatorname{" in text:
return text
else:
return r"f(" + ", ".join(names) + ")"
return r"\text{" + text + "}"


def _latex_escape(text: str) -> str:
Expand Down
48 changes: 44 additions & 4 deletions pymc/tests/test_printing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np

from pymc import Bernoulli, Censored, Mixture
from pymc import Bernoulli, Censored, HalfCauchy, Mixture, StudentT
from pymc.aesaraf import floatX
from pymc.distributions import (
Dirichlet,
Expand Down Expand Up @@ -130,12 +130,12 @@ def setup_class(self):
r"$\text{beta} \sim \operatorname{N}(0,~10)$",
r"$\text{Z} \sim \operatorname{N}(f(),~f())$",
r"$\text{nb_with_p_n} \sim \operatorname{NB}(10,~\text{nbp})$",
r"$\text{zip} \sim \operatorname{MarginalMixture}(f(),~\text{\$\operatorname{DiracDelta}(0)\$},~\text{\$\operatorname{Pois}(5)\$})$",
r"$\text{zip} \sim \operatorname{MarginalMixture}(f(),~\operatorname{DiracDelta}(0),~\operatorname{Pois}(5))$",
r"$\text{w} \sim \operatorname{Dir}(\text{<constant>})$",
(
r"$\text{nested_mix} \sim \operatorname{MarginalMixture}(\text{w},"
r"~\text{\$\operatorname{MarginalMixture}(f(),~\text{\\$\operatorname{DiracDelta}(0)\\$},~\text{\\$\operatorname{Pois}(5)\\$})\$},"
r"~\text{\$\operatorname{Censored}(\text{\\$\operatorname{Bern}(0.5)\\$},~-1,~1)\$})$"
r"~\operatorname{MarginalMixture}(f(),~\operatorname{DiracDelta}(0),~\operatorname{Pois}(5)),"
r"~\operatorname{Censored}(\operatorname{Bern}(0.5),~-1,~1))$"
),
r"$\text{Y_obs} \sim \operatorname{N}(\text{mu},~\text{sigma})$",
r"$\text{pot} \sim \operatorname{Potential}(f(\text{beta},~\text{alpha}))$",
Expand Down Expand Up @@ -178,3 +178,43 @@ def test_str_repr(self):
assert segment in model_text
else:
assert text in model_text


def test_model_latex_repr_three_levels_model():
with Model() as censored_model:
mu = Normal("mu", 0.0, 5.0)
sigma = HalfCauchy("sigma", 2.5)
normal_dist = Normal.dist(mu=mu, sigma=sigma)
censored_normal = Censored(
"censored_normal", normal_dist, lower=-2.0, upper=2.0, observed=[1, 0, 0.5]
)

latex_repr = censored_model.str_repr(formatting="latex")
expected = [
"$$",
"\\begin{array}{rcl}",
"\\text{mu} &\\sim & \\operatorname{N}(0,~5)\\\\\\text{sigma} &\\sim & "
"\\operatorname{C^{+}}(0,~2.5)\\\\\\text{censored_normal} &\\sim & "
"\\operatorname{Censored}(\\operatorname{N}(\\text{mu},~\\text{sigma}),~-2,~2)",
"\\end{array}",
"$$",
]
assert [line.strip() for line in latex_repr.split("\n")] == expected


def test_model_latex_repr_mixture_model():
with Model() as mix_model:
w = Dirichlet("w", [1, 1])
mix = Mixture("mix", w=w, comp_dists=[Normal.dist(0.0, 5.0), StudentT.dist(7.0)])

latex_repr = mix_model.str_repr(formatting="latex")
expected = [
"$$",
"\\begin{array}{rcl}",
"\\text{w} &\\sim & "
"\\operatorname{Dir}(\\text{<constant>})\\\\\\text{mix} &\\sim & "
"\\operatorname{MarginalMixture}(\\text{w},~\\operatorname{N}(0,~5),~\\operatorname{StudentT}(7,~0,~1))",
"\\end{array}",
"$$",
]
assert [line.strip() for line in latex_repr.split("\n")] == expected