Skip to content

Conversation

michaelosthege
Copy link
Member

@michaelosthege michaelosthege commented Jan 14, 2023

The goal was to uncouple sampling functions from MultiTrace and SamplerReport.

Some calls to SamplerReport._log_summary() were unnecessary because MultiTrace._add_warnings() was never called inbetween instantiation and _log_summary(), therefore the traces never contained warnings.

Running convergence checks and logging the warnings can also be done without needing MultiTrace or SamplerReport instances/methods.

Checklist

Minor changes

  • The "The number of samples is too small to check convergence reliably." warning is now an INFO level log message instead of a Warning.
  • SamplerReport._log_summary() and SamplerReport._run_convergence_checks methods were removed.

Maintenance

  • More type hints in SMC code
  • SMC and MCMC sampling functions no longer rely on instantiating a MultiTrace or SamplerReport to compute/log warnings.

@michaelosthege michaelosthege added the trace-backend Traces and ArviZ stuff label Jan 14, 2023
@michaelosthege michaelosthege self-assigned this Jan 14, 2023
@codecov
Copy link

codecov bot commented Jan 14, 2023

Codecov Report

Merging #6453 (ab128bc) into main (6ab0c03) will increase coverage by 0.01%.
The diff coverage is 100.00%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #6453      +/-   ##
==========================================
+ Coverage   94.78%   94.79%   +0.01%     
==========================================
  Files         148      148              
  Lines       27678    27678              
==========================================
+ Hits        26234    26238       +4     
+ Misses       1444     1440       -4     
Impacted Files Coverage Δ
pymc/backends/base.py 84.82% <100.00%> (+0.06%) ⬆️
pymc/backends/report.py 78.84% <100.00%> (-1.16%) ⬇️
pymc/sampling/mcmc.py 93.06% <100.00%> (-0.07%) ⬇️
pymc/smc/kernels.py 97.44% <100.00%> (+0.02%) ⬆️
pymc/smc/sampling.py 86.61% <100.00%> (+0.19%) ⬆️
pymc/stats/convergence.py 95.61% <100.00%> (+2.88%) ⬆️
pymc/step_methods/compound.py 97.45% <100.00%> (ø)
pymc/tests/smc/test_smc.py 100.00% <100.00%> (ø)

@michaelosthege michaelosthege force-pushed the decouple-convergence-checking branch from c6dbdbb to f419ed3 Compare January 14, 2023 18:14
@michaelosthege michaelosthege marked this pull request as ready for review January 14, 2023 18:14
@michaelosthege michaelosthege added the SMC Sequential Monte Carlo label Jan 14, 2023
@michaelosthege michaelosthege force-pushed the decouple-convergence-checking branch from f419ed3 to 6c2f7f2 Compare January 14, 2023 18:42
The goal was to uncouple sampling functions
from `MultiTrace` and `SamplerReport`.

Some calls to `SamplerReport._log_summary()` were unnecessary because
`MultiTrace._add_warnings()` was never called inbetween instantiation
and `_log_summary()`, therefore the traces never contained warnings.

Running convergence checks and logging the warnings can also be done
without needing `MultiTrace` or `SamplerReport` instances/methods.
@michaelosthege michaelosthege force-pushed the decouple-convergence-checking branch from 6c2f7f2 to 49f5263 Compare January 14, 2023 18:46
* Specify covariant input types in `StatsBijection`.
* Annotate `_choose_chains` to be independent of `BaseTrace` type.
@OriolAbril
Copy link
Member

I don't think I am qualified to review this

@michaelosthege
Copy link
Member Author

I don't think I am qualified to review this

I should have added comments to the diff earlier..

GitHub suggested you because you edited the SMC code? Who else is familiar with it?

Comment on lines +524 to +527
S = TypeVar("S", bound=Sized)


def _choose_chains(traces: Sequence[S], tune: int) -> Tuple[List[S], int]:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This annotates it as returning a list of the same type of items as given in the input, but with the constraint that these items must be Sized.

f"({n_tune*n_chains:_d} + {n_draws*n_chains:_d} draws total) "
f"took {t_sampling:.0f} seconds."
)
mtrace.report._log_summary()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inbetween the line 574 mtrace = MultiTrace(traces)[:length] where the MultiTrace was created, no warnings were added to mtrace.
Therefore, there are no warnings to log and the _log_summary() call can safely be removed.

Comment on lines -616 to -619
warnings.warn(
"The number of samples is too small to check convergence reliably.",
stacklevel=2,
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is now checked by run_convergence_checks, just like it already checked for a minimum number of chains

Comment on lines -929 to -930
multitrace = MultiTrace(traces)
multitrace._report._log_summary()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here too: The multitrace can not have warnings that would be printed by _log_summary() because none were added here or in its __init__

Comment on lines +241 to +245
if idata is None:
idata = to_inference_data(trace, log_likelihood=False)
warns = run_convergence_checks(idata, model)
trace.report._add_warnings(warns)
log_warnings(warns)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This replaces the _compute_convergence_checks function and makes the trace.report be a dead end that can easily be removed in the future

Remember from other changes:

  • "number of samples is too small" warning now done by run_convergence_checks
  • report._add_warnings was done inside report._run_convergence_checks
  • trace.report._log_summary() internally called log_warnings()

"""Map between a `list` of stats to `dict` of stats."""

def __init__(self, sampler_stats_dtypes: Sequence[Dict[str, type]]) -> None:
def __init__(self, sampler_stats_dtypes: Sequence[Mapping[str, type]]) -> None:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typing rule of thumb: Generic input types, exact output types.

@OriolAbril
Copy link
Member

GitHub suggested you because you edited the SMC code? Who else is familiar with it?

I have only modified some docstrings 😅. @aloctavodia is the best choice I think

Copy link
Member

@aloctavodia aloctavodia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@aloctavodia aloctavodia merged commit 5802f12 into pymc-devs:main Jan 20, 2023
@michaelosthege michaelosthege deleted the decouple-convergence-checking branch January 20, 2023 14:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
SMC Sequential Monte Carlo trace-backend Traces and ArviZ stuff
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants