From 7185852484f9f28b10c90c6f713b722557495719 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 8 May 2023 13:59:33 +0200 Subject: [PATCH 1/2] Return tune information from Slice sampler This avoids a crash when mixing slice sampling with another sampler that has `tune` stats --- pymc/step_methods/slicer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pymc/step_methods/slicer.py b/pymc/step_methods/slicer.py index 8d844187f..55df8931b 100644 --- a/pymc/step_methods/slicer.py +++ b/pymc/step_methods/slicer.py @@ -51,6 +51,7 @@ class Slice(ArrayStep): name = "slice" default_blocked = False stats_dtypes_shapes = { + "tune": (bool, []), "nstep_out": (int, []), "nstep_in": (int, []), } @@ -140,6 +141,7 @@ def astep(self, apoint: RaveledVars, *args) -> Tuple[RaveledVars, StatsType]: self.n_tunes += 1 stats = { + "tune": self.tune, "nstep_out": nstep_out, "nstep_in": nstep_in, } From eb4becfb1a033656f14b20b9baab0aede4a7c447 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 8 May 2023 13:39:58 +0200 Subject: [PATCH 2/2] Speedup Slice sampler --- pymc/step_methods/slicer.py | 45 ++++++++++++++++++++----------------- 1 file changed, 25 insertions(+), 20 deletions(-) diff --git a/pymc/step_methods/slicer.py b/pymc/step_methods/slicer.py index 55df8931b..83e241f1e 100644 --- a/pymc/step_methods/slicer.py +++ b/pymc/step_methods/slicer.py @@ -21,7 +21,8 @@ from pymc.blocking import RaveledVars, StatsType from pymc.model import modelcontext -from pymc.step_methods.arraystep import ArrayStep +from pymc.pytensorf import compile_pymc, join_nonshared_inputs, make_shared_replacements +from pymc.step_methods.arraystep import ArrayStepShared from pymc.step_methods.compound import Competence from pymc.util import get_value_vars_from_user_vars from pymc.vartypes import continuous_types @@ -31,7 +32,7 @@ LOOP_ERR_MSG = "max slicer iters %d exceeded" -class Slice(ArrayStep): +class Slice(ArrayStepShared): """ Univariate slice sampler step method. @@ -57,24 +58,33 @@ class Slice(ArrayStep): } def __init__(self, vars=None, w=1.0, tune=True, model=None, iter_limit=np.inf, **kwargs): - self.model = modelcontext(model) - self.w = w + model = modelcontext(model) + self.w = np.asarray(w).copy() self.tune = tune self.n_tunes = 0.0 self.iter_limit = iter_limit if vars is None: - vars = self.model.continuous_value_vars + vars = model.continuous_value_vars else: - vars = get_value_vars_from_user_vars(vars, self.model) + vars = get_value_vars_from_user_vars(vars, model) - super().__init__(vars, [self.model.compile_logp()], **kwargs) + point = model.initial_point() + shared = make_shared_replacements(point, vars, model) + [logp], raveled_inp = join_nonshared_inputs( + point=point, outputs=[model.logp()], inputs=vars, shared_inputs=shared + ) + self.logp = compile_pymc([raveled_inp], logp) + self.logp.trust_input = True - def astep(self, apoint: RaveledVars, *args) -> Tuple[RaveledVars, StatsType]: + super().__init__(vars, shared) + + def astep(self, apoint: RaveledVars) -> Tuple[RaveledVars, StatsType]: # The arguments are determined by the list passed via `super().__init__(..., fs, ...)` - logp = args[0] q0_val = apoint.data - self.w = np.resize(self.w, len(q0_val)) # this is a repmat + + if q0_val.shape != self.w.shape: + self.w = np.resize(self.w, len(q0_val)) # this is a repmat nstep_out = nstep_in = 0 @@ -82,15 +92,10 @@ def astep(self, apoint: RaveledVars, *args) -> Tuple[RaveledVars, StatsType]: ql = np.copy(q0_val) # l for left boundary qr = np.copy(q0_val) # r for right boundary - # The points are not copied, so it's fine to update them inplace in the - # loop below - q_ra = RaveledVars(q, apoint.point_map_info) - ql_ra = RaveledVars(ql, apoint.point_map_info) - qr_ra = RaveledVars(qr, apoint.point_map_info) - + logp = self.logp for i, wi in enumerate(self.w): # uniformly sample from 0 to p(q), but in log space - y = logp(q_ra) - nr.standard_exponential() + y = logp(q) - nr.standard_exponential() # Create initial interval ql[i] = q[i] - nr.uniform() * wi # q[i] + r * w @@ -98,7 +103,7 @@ def astep(self, apoint: RaveledVars, *args) -> Tuple[RaveledVars, StatsType]: # Stepping out procedure cnt = 0 - while y <= logp(ql_ra): # changed lt to leq for locally uniform posteriors + while y <= logp(ql): # changed lt to leq for locally uniform posteriors ql[i] -= wi cnt += 1 if cnt > self.iter_limit: @@ -106,7 +111,7 @@ def astep(self, apoint: RaveledVars, *args) -> Tuple[RaveledVars, StatsType]: nstep_out += cnt cnt = 0 - while y <= logp(qr_ra): + while y <= logp(qr): qr[i] += wi cnt += 1 if cnt > self.iter_limit: @@ -115,7 +120,7 @@ def astep(self, apoint: RaveledVars, *args) -> Tuple[RaveledVars, StatsType]: cnt = 0 q[i] = nr.uniform(ql[i], qr[i]) - while y > logp(q_ra): # Changed leq to lt, to accommodate for locally flat posteriors + while y > logp(q): # Changed leq to lt, to accommodate for locally flat posteriors # Sample uniformly from slice if q[i] > q0_val[i]: qr[i] = q[i]