Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 27 additions & 20 deletions pymc/step_methods/slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@

from pymc.blocking import RaveledVars, StatsType
from pymc.model import modelcontext
from pymc.step_methods.arraystep import ArrayStep
from pymc.pytensorf import compile_pymc, join_nonshared_inputs, make_shared_replacements
from pymc.step_methods.arraystep import ArrayStepShared
from pymc.step_methods.compound import Competence
from pymc.util import get_value_vars_from_user_vars
from pymc.vartypes import continuous_types
Expand All @@ -31,7 +32,7 @@
LOOP_ERR_MSG = "max slicer iters %d exceeded"


class Slice(ArrayStep):
class Slice(ArrayStepShared):
"""
Univariate slice sampler step method.

Expand All @@ -51,61 +52,66 @@ class Slice(ArrayStep):
name = "slice"
default_blocked = False
stats_dtypes_shapes = {
"tune": (bool, []),
"nstep_out": (int, []),
"nstep_in": (int, []),
}

def __init__(self, vars=None, w=1.0, tune=True, model=None, iter_limit=np.inf, **kwargs):
self.model = modelcontext(model)
self.w = w
model = modelcontext(model)
self.w = np.asarray(w).copy()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does asarray not make a copy, or is that array?

Copy link
Member Author

@ricardoV94 ricardoV94 May 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

asarray returns the same object if it's already a numpy array

self.tune = tune
self.n_tunes = 0.0
self.iter_limit = iter_limit

if vars is None:
vars = self.model.continuous_value_vars
vars = model.continuous_value_vars
else:
vars = get_value_vars_from_user_vars(vars, self.model)
vars = get_value_vars_from_user_vars(vars, model)

super().__init__(vars, [self.model.compile_logp()], **kwargs)
point = model.initial_point()
shared = make_shared_replacements(point, vars, model)
[logp], raveled_inp = join_nonshared_inputs(
point=point, outputs=[model.logp()], inputs=vars, shared_inputs=shared
)
self.logp = compile_pymc([raveled_inp], logp)
self.logp.trust_input = True

def astep(self, apoint: RaveledVars, *args) -> Tuple[RaveledVars, StatsType]:
super().__init__(vars, shared)

def astep(self, apoint: RaveledVars) -> Tuple[RaveledVars, StatsType]:
# The arguments are determined by the list passed via `super().__init__(..., fs, ...)`
logp = args[0]
q0_val = apoint.data
self.w = np.resize(self.w, len(q0_val)) # this is a repmat

if q0_val.shape != self.w.shape:
self.w = np.resize(self.w, len(q0_val)) # this is a repmat

nstep_out = nstep_in = 0

q = np.copy(q0_val)
ql = np.copy(q0_val) # l for left boundary
qr = np.copy(q0_val) # r for right boundary

# The points are not copied, so it's fine to update them inplace in the
# loop below
q_ra = RaveledVars(q, apoint.point_map_info)
ql_ra = RaveledVars(ql, apoint.point_map_info)
qr_ra = RaveledVars(qr, apoint.point_map_info)

logp = self.logp
for i, wi in enumerate(self.w):
# uniformly sample from 0 to p(q), but in log space
y = logp(q_ra) - nr.standard_exponential()
y = logp(q) - nr.standard_exponential()

# Create initial interval
ql[i] = q[i] - nr.uniform() * wi # q[i] + r * w
qr[i] = ql[i] + wi # Equivalent to q[i] + (1-r) * w

# Stepping out procedure
cnt = 0
while y <= logp(ql_ra): # changed lt to leq for locally uniform posteriors
while y <= logp(ql): # changed lt to leq for locally uniform posteriors
ql[i] -= wi
cnt += 1
if cnt > self.iter_limit:
raise RuntimeError(LOOP_ERR_MSG % self.iter_limit)
nstep_out += cnt

cnt = 0
while y <= logp(qr_ra):
while y <= logp(qr):
qr[i] += wi
cnt += 1
if cnt > self.iter_limit:
Expand All @@ -114,7 +120,7 @@ def astep(self, apoint: RaveledVars, *args) -> Tuple[RaveledVars, StatsType]:

cnt = 0
q[i] = nr.uniform(ql[i], qr[i])
while y > logp(q_ra): # Changed leq to lt, to accommodate for locally flat posteriors
while y > logp(q): # Changed leq to lt, to accommodate for locally flat posteriors
# Sample uniformly from slice
if q[i] > q0_val[i]:
qr[i] = q[i]
Expand All @@ -140,6 +146,7 @@ def astep(self, apoint: RaveledVars, *args) -> Tuple[RaveledVars, StatsType]:
self.n_tunes += 1

stats = {
"tune": self.tune,
"nstep_out": nstep_out,
"nstep_in": nstep_in,
}
Expand Down