Skip to content
4 changes: 2 additions & 2 deletions pymc/backends/arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,8 +659,8 @@ def apply_function_over_dataset(
out_dict = _DefaultTrace(n_pts)
indices = range(n_pts)

with Progress(console=Console(theme=progressbar_theme)) as progress:
task = progress.add_task("Computing ...", total=n_pts, visible=progressbar)
with Progress(console=Console(theme=progressbar_theme), disable=not progressbar) as progress:
task = progress.add_task("Computinng ...", total=n_pts, visible=progressbar)
for idx in indices:
out = fn(posterior_pts[idx])
fn.f.trust_input = True # If we arrive here the dtypes are valid
Expand Down
4 changes: 3 additions & 1 deletion pymc/sampling/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,7 +829,9 @@ def sample_posterior_predictive(
_log.info(f"Sampling: {list(sorted(volatile_basic_rvs, key=lambda var: var.name))}") # type: ignore
ppc_trace_t = _DefaultTrace(samples)
try:
with Progress(console=Console(theme=progressbar_theme)) as progress:
with Progress(
console=Console(theme=progressbar_theme), disable=not progressbar
) as progress:
task = progress.add_task("Sampling ...", total=samples, visible=progressbar)
for idx in np.arange(samples):
if nchain > 1:
Expand Down
4 changes: 2 additions & 2 deletions pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,8 +1041,8 @@ def _sample(
for it, diverging in enumerate(sampling_gen):
if it >= skip_first and diverging:
_pbar_data["divergences"] += 1
progress.update(task, advance=1)
progress.update(task, advance=1, completed=True)
progress.update(task, refresh=True, advance=1)
progress.update(task, refresh=True, advance=1, completed=True)
except KeyboardInterrupt:
pass

Expand Down
6 changes: 5 additions & 1 deletion pymc/sampling/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import numpy as np

from rich.console import Console
from rich.progress import BarColumn, Progress, TimeRemainingColumn
from rich.progress import BarColumn, Progress, TextColumn, TimeElapsedColumn, TimeRemainingColumn
from rich.theme import Theme

from pymc.blocking import DictToArrayBijection
Expand Down Expand Up @@ -428,7 +428,10 @@ def __init__(
BarColumn(),
"[progress.percentage]{task.percentage:>3.0f}%",
TimeRemainingColumn(),
TextColumn("/"),
TimeElapsedColumn(),
console=Console(theme=progressbar_theme),
disable=not progressbar,
)
self._show_progress = progressbar
self._divergences = 0
Expand Down Expand Up @@ -465,6 +468,7 @@ def __iter__(self):
self._divergences += 1
progress.update(
task,
refresh=True,
completed=self._completed_draws,
total=self._total_draws,
description=self._desc.format(self),
Expand Down
6 changes: 4 additions & 2 deletions pymc/sampling/population.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import cloudpickle
import numpy as np

from rich.progress import BarColumn, Progress, TimeRemainingColumn
from rich.progress import BarColumn, Progress, TextColumn, TimeElapsedColumn, TimeRemainingColumn

from pymc.backends.base import BaseTrace
from pymc.initial_point import PointType
Expand Down Expand Up @@ -104,7 +104,7 @@ def _sample_population(
task = progress.add_task("[red]Sampling...", total=draws, visible=progressbar)

for _ in sampling:
progress.update(task, advance=1)
progress.update(task, advance=1, refresh=True)

return

Expand Down Expand Up @@ -180,6 +180,8 @@ def __init__(self, steppers, parallelize: bool, progressbar: bool = True):
BarColumn(),
"[progress.percentage]{task.percentage:>3.0f}%",
TimeRemainingColumn(),
TextColumn("/"),
TimeElapsedColumn(),
) as self._progress:
for c, stepper in enumerate(steppers):
# enumerate(progress_bar(steppers)) if progressbar else enumerate(steppers)
Expand Down
14 changes: 12 additions & 2 deletions pymc/smc/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,13 @@
import numpy as np

from arviz import InferenceData
from rich.progress import Progress, SpinnerColumn, TextColumn, TimeElapsedColumn
from rich.progress import (
Progress,
SpinnerColumn,
TextColumn,
TimeElapsedColumn,
TimeRemainingColumn,
)

import pymc

Expand Down Expand Up @@ -366,6 +372,8 @@ def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores):
with Progress(
TextColumn("{task.description}"),
SpinnerColumn(),
TimeRemainingColumn(),
TextColumn("/"),
TimeElapsedColumn(),
TextColumn("{task.fields[status]}"),
) as progress:
Expand Down Expand Up @@ -403,6 +411,8 @@ def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores):
stage = update_data["stage"]
beta = update_data["beta"]
# update the progress bar for this task:
progress.update(status=f"Stage: {stage} Beta: {beta:.3f}", task_id=task_id)
progress.update(
status=f"Stage: {stage} Beta: {beta:.3f}", task_id=task_id, refresh=True
)

return tuple(cloudpickle.loads(r.result()) for r in futures)
3 changes: 2 additions & 1 deletion pymc/tuning/starting.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def find_MAP(
if isinstance(e, StopIteration):
pm._log.info(e)
finally:
cost_func.progress.update(cost_func.task, completed=cost_func.n_eval)
cost_func.progress.update(cost_func.task, completed=cost_func.n_eval, refresh=True)
print(file=sys.stdout)

mx0 = RaveledVars(mx0, x0.point_map_info)
Expand Down Expand Up @@ -223,6 +223,7 @@ def __init__(
*Progress.get_default_columns(),
TextColumn("{task.fields[loss]}"),
console=Console(theme=progressbar_theme),
disable=not progressbar,
)
self.task = self.progress.add_task("MAP", total=maxeval, visible=progressbar, loss="")

Expand Down
4 changes: 3 additions & 1 deletion pymc/variational/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,9 @@ def fit(
def _iterate_without_loss(self, s, n, step_func, progressbar, progressbar_theme, callbacks):
i = 0
try:
with Progress(console=Console(theme=progressbar_theme)) as progress:
with Progress(
console=Console(theme=progressbar_theme), disable=not progressbar
) as progress:
task = progress.add_task("Fitting", total=n, visible=progressbar)
for i in range(n):
step_func()
Expand Down