-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Fix Flaky Euler-Maruyama Tests #6287
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
0ba93b3
436d0f9
eadfd75
f2a70fd
627e49d
504f836
d0474ed
a0d68d1
ca7b5bd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -835,8 +835,11 @@ class TestEulerMaruyama: | |
| @pytest.mark.parametrize("batched_param", [1, 2]) | ||
| @pytest.mark.parametrize("explicit_shape", (True, False)) | ||
| def test_batched_size(self, explicit_shape, batched_param): | ||
| RANDOM_SEED = 42 | ||
| numpy_rng = np.random.default_rng(RANDOM_SEED) | ||
|
|
||
| steps, batch_size = 100, 5 | ||
| param_val = np.square(np.random.randn(batch_size)) | ||
| param_val = np.square(numpy_rng.standard_normal(batch_size)) | ||
| if explicit_shape: | ||
| kwargs = {"shape": (batch_size, steps)} | ||
| else: | ||
|
|
@@ -853,9 +856,9 @@ def sde_fn(x, k, d, s): | |
| "y", dt=0.02, sde_fn=sde_fn, sde_pars=sde_pars, init_dist=init_dist, **kwargs | ||
| ) | ||
|
|
||
| y_eval = draw(y, draws=2) | ||
| y_eval = draw(y, draws=2, random_seed=RANDOM_SEED) | ||
| assert y_eval[0].shape == (batch_size, steps) | ||
| assert not np.any(np.isclose(y_eval[0], y_eval[1])) | ||
| assert np.any(~np.isclose(y_eval[0], y_eval[1])) | ||
|
||
|
|
||
| if explicit_shape: | ||
| kwargs["shape"] = steps | ||
|
|
@@ -873,7 +876,7 @@ def sde_fn(x, k, d, s): | |
| **kwargs, | ||
| ) | ||
|
|
||
| t0_init = t0.initial_point() | ||
| t0_init = t0.initial_point(random_seed=RANDOM_SEED) | ||
| t1_init = {f"y_{i}": t0_init["y"][i] for i in range(batch_size)} | ||
| np.testing.assert_allclose( | ||
| t0.compile_logp()(t0_init), | ||
|
|
@@ -919,17 +922,20 @@ def test_linear_model(self): | |
| N = 300 | ||
| dt = 1e-1 | ||
|
|
||
| RANDOM_SEED = 42 | ||
| numpy_rng = np.random.default_rng(RANDOM_SEED) | ||
|
|
||
| def _gen_sde_path(sde, pars, dt, n, x0): | ||
| xs = [x0] | ||
| wt = np.random.normal(size=(n,) if isinstance(x0, float) else (n, x0.size)) | ||
| wt = numpy_rng.normal(size=(n,) if isinstance(x0, float) else (n, x0.size)) | ||
| for i in range(n): | ||
| f, g = sde(xs[-1], *pars) | ||
| xs.append(xs[-1] + f * dt + np.sqrt(dt) * g * wt[i]) | ||
| return np.array(xs) | ||
|
|
||
| sde = lambda x, lam: (lam * x, sig2) | ||
| x = floatX(_gen_sde_path(sde, (lam,), dt, N, 5.0)) | ||
| z = x + np.random.randn(x.size) * sig2 | ||
| z = x + numpy_rng.standard_normal(size=x.size) * sig2 | ||
| # build model | ||
| with Model() as model: | ||
| lamh = Flat("lamh") | ||
|
|
@@ -939,9 +945,9 @@ def _gen_sde_path(sde, pars, dt, n, x0): | |
| Normal("zh", mu=xh, sigma=sig2, observed=z) | ||
| # invert | ||
| with model: | ||
| trace = sample(chains=1) | ||
| trace = sample(chains=1, random_seed=RANDOM_SEED) | ||
|
|
||
| ppc = sample_posterior_predictive(trace, model=model) | ||
| ppc = sample_posterior_predictive(trace, model=model, random_seed=RANDOM_SEED) | ||
|
|
||
| p95 = [2.5, 97.5] | ||
| lo, hi = np.percentile(trace.posterior["lamh"], p95, axis=[0, 1]) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.