Skip to content

Commit 7f1e031

Browse files
committed
Fix tests and add doc for new statistics
1 parent f713ea7 commit 7f1e031

File tree

3 files changed

+21
-1
lines changed

3 files changed

+21
-1
lines changed

pymc/step_methods/hmc/hmc.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ class HamiltonianMC(BaseHMC):
5151
"process_time_diff": np.float64,
5252
"perf_counter_diff": np.float64,
5353
"perf_counter_start": np.float64,
54+
"largest_eigval": np.float64,
55+
"smallest_eigval": np.float64,
5456
}
5557
]
5658

@@ -162,7 +164,16 @@ def _hamiltonian_step(self, start, p0, step_size):
162164
"model_logp": state.model_logp,
163165
}
164166
# Retrieve State q and p data from respective RaveledVars
165-
end = State(end.q.data, end.p.data, end.v, end.q_grad, end.energy, end.model_logp)
167+
end = State(
168+
end.q.data,
169+
end.p.data,
170+
end.v,
171+
end.q_grad,
172+
end.energy,
173+
end.model_logp,
174+
end.index_in_trajectory,
175+
)
176+
stats.update(self.potential.stats())
166177
return HMCStepData(end, accept_stat, div_info, stats)
167178

168179
@staticmethod

pymc/step_methods/hmc/nuts.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,12 @@ class NUTS(BaseHMC):
7878
by the python standard library `time.perf_counter` (wall time).
7979
- `perf_counter_start`: The value of `time.perf_counter` at the beginning
8080
of the computation of the draw.
81+
- `index_in_trajectory`: This is usually only interesting for debugging
82+
purposes. This indicates the position of the posterior draw in the
83+
trajectory. Eg a -4 would indicate that the draw was the result of the
84+
fourth leapfrog step in negative direction.
85+
- `largest_eigval` and `smallest_eigval`: Experimental statistics for
86+
some mass matrix adaptation algorithms. This is nan if it is not used.
8187
8288
References
8389
----------

pymc/tests/test_step.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -595,6 +595,9 @@ def test_sampler_stats(self):
595595
"perf_counter_diff",
596596
"perf_counter_start",
597597
"process_time_diff",
598+
"index_in_trajectory",
599+
"largest_eigval",
600+
"smallest_eigval",
598601
}
599602
assert trace.stat_names == expected_stat_names
600603
for varname in trace.stat_names:

0 commit comments

Comments
 (0)