Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pymc/step_methods/hmc/base_hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ def astep(self, q0):

stats.update(hmc_step.stats)
stats.update(self.step_adapt.stats())
stats.update(self.potential.stats())

return hmc_step.end.q, [stats]

Expand Down
13 changes: 12 additions & 1 deletion pymc/step_methods/hmc/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ class HamiltonianMC(BaseHMC):
"process_time_diff": np.float64,
"perf_counter_diff": np.float64,
"perf_counter_start": np.float64,
"largest_eigval": np.float64,
"smallest_eigval": np.float64,
}
]

Expand Down Expand Up @@ -162,7 +164,16 @@ def _hamiltonian_step(self, start, p0, step_size):
"model_logp": state.model_logp,
}
# Retrieve State q and p data from respective RaveledVars
end = State(end.q.data, end.p.data, end.v, end.q_grad, end.energy, end.model_logp)
end = State(
end.q.data,
end.p.data,
end.v,
end.q_grad,
end.energy,
end.model_logp,
end.index_in_trajectory,
)
stats.update(self.potential.stats())
return HMCStepData(end, accept_stat, div_info, stats)

@staticmethod
Expand Down
14 changes: 11 additions & 3 deletions pymc/step_methods/hmc/integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from pymc.blocking import RaveledVars

State = namedtuple("State", "q, p, v, q_grad, energy, model_logp")
State = namedtuple("State", "q, p, v, q_grad, energy, model_logp, index_in_trajectory")


class IntegrationError(RuntimeError):
Expand Down Expand Up @@ -49,7 +49,7 @@ def compute_state(self, q, p):
v = self._potential.velocity(p.data)
kinetic = self._potential.energy(p.data, velocity=v)
energy = kinetic - logp
return State(q, p, v, dlogp, energy, logp)
return State(q, p, v, dlogp, energy, logp, 0)

def step(self, epsilon, state):
"""Leapfrog integrator step.
Expand Down Expand Up @@ -114,4 +114,12 @@ def _step(self, epsilon, state):
kinetic = pot.velocity_energy(p_new.data, v_new)
energy = kinetic - logp

return State(q_new, p_new, v_new, q_new_grad, energy, logp)
return State(
q_new,
p_new,
v_new,
q_new_grad,
energy,
logp,
state.index_in_trajectory + int(np.sign(epsilon)),
)
59 changes: 27 additions & 32 deletions pymc/step_methods/hmc/nuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from pymc.aesaraf import floatX
from pymc.backends.report import SamplerWarning, WarningType
from pymc.math import logbern, logdiffexp_numpy
from pymc.math import logbern
from pymc.step_methods.arraystep import Competence
from pymc.step_methods.hmc.base_hmc import BaseHMC, DivergenceInfo, HMCStepData
from pymc.step_methods.hmc.integration import IntegrationError
Expand Down Expand Up @@ -78,6 +78,12 @@ class NUTS(BaseHMC):
by the python standard library `time.perf_counter` (wall time).
- `perf_counter_start`: The value of `time.perf_counter` at the beginning
of the computation of the draw.
- `index_in_trajectory`: This is usually only interesting for debugging
purposes. This indicates the position of the posterior draw in the
trajectory. Eg a -4 would indicate that the draw was the result of the
fourth leapfrog step in negative direction.
- `largest_eigval` and `smallest_eigval`: Experimental statistics for
some mass matrix adaptation algorithms. This is nan if it is not used.

References
----------
Expand Down Expand Up @@ -105,6 +111,9 @@ class NUTS(BaseHMC):
"process_time_diff": np.float64,
"perf_counter_diff": np.float64,
"perf_counter_start": np.float64,
"largest_eigval": np.float64,
"smallest_eigval": np.float64,
"index_in_trajectory": np.int64,
}
]

Expand Down Expand Up @@ -219,12 +228,12 @@ def warnings(self):


# A proposal for the next position
Proposal = namedtuple("Proposal", "q, q_grad, energy, log_p_accept_weighted, logp")
Proposal = namedtuple("Proposal", "q, q_grad, energy, logp, index_in_trajectory")

# A subtree of the binary tree built by nuts.
Subtree = namedtuple(
"Subtree",
"left, right, p_sum, proposal, log_size, log_weighted_accept_sum, n_proposals",
"left, right, p_sum, proposal, log_size",
)


Expand Down Expand Up @@ -252,10 +261,10 @@ def __init__(self, ndim, integrator, start, step_size, Emax):
self.start_energy = np.array(start.energy)

self.left = self.right = start
self.proposal = Proposal(start.q.data, start.q_grad, start.energy, 1.0, start.model_logp)
self.proposal = Proposal(start.q.data, start.q_grad, start.energy, start.model_logp, 0)
self.depth = 0
self.log_size = 0
self.log_weighted_accept_sum = -np.inf
self.log_accept_sum = -np.inf
self.mean_tree_accept = 0.0
self.n_proposals = 0
self.p_sum = start.p.data.copy()
Expand All @@ -279,7 +288,7 @@ def extend(self, direction):
)
leftmost_begin, leftmost_end = self.left, self.right
rightmost_begin, rightmost_end = tree.left, tree.right
leftmost_p_sum = self.p_sum
leftmost_p_sum = self.p_sum.copy()
rightmost_p_sum = tree.p_sum
self.right = tree.right
else:
Expand All @@ -289,11 +298,10 @@ def extend(self, direction):
leftmost_begin, leftmost_end = tree.right, tree.left
rightmost_begin, rightmost_end = self.left, self.right
leftmost_p_sum = tree.p_sum
rightmost_p_sum = self.p_sum
rightmost_p_sum = self.p_sum.copy()
self.left = tree.right

self.depth += 1
self.n_proposals += tree.n_proposals

if diverging or turning:
return diverging, turning
Expand All @@ -303,9 +311,6 @@ def extend(self, direction):
self.proposal = tree.proposal

self.log_size = np.logaddexp(self.log_size, tree.log_size)
self.log_weighted_accept_sum = np.logaddexp(
self.log_weighted_accept_sum, tree.log_weighted_accept_sum
)
self.p_sum[:] += tree.p_sum

# Additional turning check only when tree depth > 0 to avoid redundant work
Expand Down Expand Up @@ -336,30 +341,30 @@ def _single_step(self, left, epsilon):
if np.isnan(energy_change):
energy_change = np.inf

self.log_accept_sum = np.logaddexp(self.log_accept_sum, min(0, -energy_change))

if np.abs(energy_change) > np.abs(self.max_energy_change):
self.max_energy_change = energy_change
if np.abs(energy_change) < self.Emax:
if energy_change < self.Emax:
# Acceptance statistic
# e^{H(q_0, p_0) - H(q_n, p_n)} max(1, e^{H(q_0, p_0) - H(q_n, p_n)})
# Saturated Metropolis accept probability with Boltzmann weight
# if h - H0 < 0
log_p_accept_weighted = -energy_change + min(0.0, -energy_change)
log_size = -energy_change
proposal = Proposal(
right.q.data,
right.q_grad,
right.energy,
log_p_accept_weighted,
right.model_logp,
right.index_in_trajectory,
)
tree = Subtree(
right, right, right.p.data, proposal, log_size, log_p_accept_weighted, 1
)
tree = Subtree(right, right, right.p.data, proposal, log_size)
return tree, None, False
else:
error_msg = f"Energy change in leapfrog step is too large: {energy_change}."
error = None
tree = Subtree(None, None, None, None, -np.inf, -np.inf, 1)
finally:
self.n_proposals += 1
tree = Subtree(None, None, None, None, -np.inf)
divergance_info = DivergenceInfo(error_msg, error, left, right)
return tree, divergance_info, False

Expand Down Expand Up @@ -387,31 +392,20 @@ def _build_subtree(self, left, depth, epsilon):
turning = turning | turning1 | turning2

log_size = np.logaddexp(tree1.log_size, tree2.log_size)
log_weighted_accept_sum = np.logaddexp(
tree1.log_weighted_accept_sum, tree2.log_weighted_accept_sum
)
if logbern(tree2.log_size - log_size):
proposal = tree2.proposal
else:
proposal = tree1.proposal
else:
p_sum = tree1.p_sum
log_size = tree1.log_size
log_weighted_accept_sum = tree1.log_weighted_accept_sum
proposal = tree1.proposal

n_proposals = tree1.n_proposals + tree2.n_proposals

tree = Subtree(left, right, p_sum, proposal, log_size, log_weighted_accept_sum, n_proposals)
tree = Subtree(left, right, p_sum, proposal, log_size)
return tree, diverging, turning

def stats(self):
# Update accept stat if any subtrees were accepted
if self.log_size > 0:
# Remove contribution from initial state which is always a perfect
# accept
log_sum_weight = logdiffexp_numpy(self.log_size, 0.0)
self.mean_tree_accept = np.exp(self.log_weighted_accept_sum - log_sum_weight)
self.mean_tree_accept = np.exp(self.log_accept_sum) / self.n_proposals
return {
"depth": self.depth,
"mean_tree_accept": self.mean_tree_accept,
Expand All @@ -420,4 +414,5 @@ def stats(self):
"tree_size": self.n_proposals,
"max_energy_error": self.max_energy_change,
"model_logp": self.proposal.logp,
"index_in_trajectory": self.proposal.index_in_trajectory,
}
4 changes: 4 additions & 0 deletions pymc/step_methods/hmc/quadpotential.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ def raise_ok(self, map_info=None):
def reset(self):
pass

def stats(self):
return {"largest_eigval": np.nan, "smallest_eigval": np.nan}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this for the future?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That got here my accident...
I can take it out, it is useful though for the new code in covadapt (which I think is pretty nice actually, and I'll try to get into pymc at some point).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want to leave it in, I'm ok with it (so long as it has a comment!) I wonder what you think of algorithm 2 in https://proceedings.mlr.press/v151/hoffman22a/hoffman22a.pdf as a means of estimating scales?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thomas sent me a link to that paper just this morning. I also wonder what I'll think about algorithm 2, looks fascinating though. ;-)
If you like to compare ideas with what I did in covadapt, I just wrote a sketch of an intro in the readme: https://github.com/aseyboldt/covadapt
If you don't understand what I'm talking about over there, that's my fault. I'll try to improve it soon. :-)



def isquadpotential(value):
"""Check whether an object might be a QuadPotential object."""
Expand Down Expand Up @@ -254,6 +257,7 @@ def random(self):

def _update_from_weightvar(self, weightvar):
weightvar.current_variance(out=self._var)
self._var = np.clip(self._var, 1e-12, 1e12)
np.sqrt(self._var, out=self._stds)
np.divide(1, self._stds, out=self._inv_stds)
self._var_aesara.set_value(self._var)
Expand Down
3 changes: 3 additions & 0 deletions pymc/tests/test_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,9 @@ def test_sampler_stats(self):
"perf_counter_diff",
"perf_counter_start",
"process_time_diff",
"index_in_trajectory",
"largest_eigval",
"smallest_eigval",
}
assert trace.stat_names == expected_stat_names
for varname in trace.stat_names:
Expand Down