From 512cb01c57fad1aeea2212f9d88d7593a42a1578 Mon Sep 17 00:00:00 2001 From: Tomas Capretto Date: Sun, 28 Apr 2024 11:06:41 -0300 Subject: [PATCH 1/3] Add test to compare draws when var_names is used in pm.sample --- tests/sampling/test_mcmc.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/tests/sampling/test_mcmc.py b/tests/sampling/test_mcmc.py index 31b48250f..0ecdc8210 100644 --- a/tests/sampling/test_mcmc.py +++ b/tests/sampling/test_mcmc.py @@ -703,6 +703,42 @@ def test_sample_var_names(): assert "b" not in idata.posterior +def test_sample_var_names_draws(): + # Generate data + seed = 1234 + rng = np.random.default_rng(seed) + + group = rng.choice(list("ABCD"), size=100) + x = rng.normal(size=100) + y = rng.normal(size=100) + + group_values, group_idx = np.unique(group, return_inverse=True) + + coords = {"group": group_values} + + # Create model + with pm.Model(coords=coords) as model: + b_group = pm.Normal("b_group", dims="group") + b_x = pm.Normal("b_x") + mu = pm.Deterministic("mu", b_group[group_idx] + b_x * x) + sigma = pm.HalfNormal("sigma") + pm.Normal("y", mu=mu, sigma=sigma, observed=y) + + # Sample with and without var_names, but always with the same seed + with model: + idata_1 = pm.sample(tune=100, draws=100, random_seed=seed) + idata_2 = pm.sample( + tune=100, draws=100, var_names=["b_group", "b_x", "sigma"], random_seed=seed + ) + + assert "mu" in idata_1.posterior + assert "mu" not in idata_2.posterior + + assert np.all(idata_1.posterior["b_group"] == idata_2.posterior["b_group"]).item() + assert np.all(idata_1.posterior["b_x"] == idata_2.posterior["b_x"]).item() + assert np.all(idata_1.posterior["sigma"] == idata_2.posterior["sigma"]).item() + + class TestAssignStepMethods: def test_bernoulli(self): """Test bernoulli distribution is assigned binary gibbs metropolis method""" From 294aa6480454e8b6691586e1c24340f935a1435e Mon Sep 17 00:00:00 2001 From: Tomas Capretto Date: Sun, 28 Apr 2024 11:13:53 -0300 Subject: [PATCH 2/3] Delete whitespace in test --- tests/sampling/test_mcmc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/sampling/test_mcmc.py b/tests/sampling/test_mcmc.py index 0ecdc8210..1127cb0df 100644 --- a/tests/sampling/test_mcmc.py +++ b/tests/sampling/test_mcmc.py @@ -712,7 +712,7 @@ def test_sample_var_names_draws(): x = rng.normal(size=100) y = rng.normal(size=100) - group_values, group_idx = np.unique(group, return_inverse=True) + group_values, group_idx = np.unique(group, return_inverse=True) coords = {"group": group_values} @@ -721,7 +721,7 @@ def test_sample_var_names_draws(): b_group = pm.Normal("b_group", dims="group") b_x = pm.Normal("b_x") mu = pm.Deterministic("mu", b_group[group_idx] + b_x * x) - sigma = pm.HalfNormal("sigma") + sigma = pm.HalfNormal("sigma") pm.Normal("y", mu=mu, sigma=sigma, observed=y) # Sample with and without var_names, but always with the same seed From 538c3a598abeb1552d8e97f0ec9afd9af4efcaec Mon Sep 17 00:00:00 2001 From: Tomas Capretto Date: Sun, 28 Apr 2024 11:38:11 -0300 Subject: [PATCH 3/3] remove redundant test --- tests/sampling/test_mcmc.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/tests/sampling/test_mcmc.py b/tests/sampling/test_mcmc.py index 1127cb0df..647e8ada7 100644 --- a/tests/sampling/test_mcmc.py +++ b/tests/sampling/test_mcmc.py @@ -695,15 +695,6 @@ def test_no_init_nuts_compound(caplog): def test_sample_var_names(): - with pm.Model() as model: - a = pm.Normal("a") - b = pm.Deterministic("b", a**2) - idata = pm.sample(10, tune=10, var_names=["a"]) - assert "a" in idata.posterior - assert "b" not in idata.posterior - - -def test_sample_var_names_draws(): # Generate data seed = 1234 rng = np.random.default_rng(seed)