From 5bc1cf23c4f5cb0da0d2e06c3e59ff0fe28cfb1f Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Thu, 2 Feb 2023 23:01:35 -0300 Subject: [PATCH] make random sample match variable shape --- pymc_bart/bart.py | 2 +- tests/test_bart.py | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/pymc_bart/bart.py b/pymc_bart/bart.py index 357dbe8..dad53d6 100644 --- a/pymc_bart/bart.py +++ b/pymc_bart/bart.py @@ -50,7 +50,7 @@ def rng_fn(cls, rng=None, X=None, Y=None, m=None, alpha=None, split_prior=None, else: return np.full(cls.Y.shape[0], cls.Y.mean()) else: - return _sample_posterior(cls.all_trees, cls.X, rng=rng).squeeze() + return _sample_posterior(cls.all_trees, cls.X, rng=rng).squeeze().T bart = BARTRV() diff --git a/tests/test_bart.py b/tests/test_bart.py index 571da06..d9d06d8 100644 --- a/tests/test_bart.py +++ b/tests/test_bart.py @@ -58,6 +58,19 @@ def test_shared_variable(): assert ppc2.posterior_predictive["y"].shape == (2, 100, 3) +def test_shape(): + X = np.random.normal(0, 1, size=(250, 3)) + Y = np.random.normal(0, 1, size=250) + + with pm.Model() as model: + w = pmb.BART("w", X, Y, m=2, shape=(2, 250)) + y = pm.Normal("y", w[0], pm.math.abs(w[1]), observed=Y) + idata = pm.sample(random_seed=3415) + + assert model.initial_point()["w"].shape == (2, 250) + assert idata.posterior.coords["w_dim_0"].data.size == 2 + assert idata.posterior.coords["w_dim_1"].data.size == 250 + class TestUtils: X_norm = np.random.normal(0, 1, size=(50, 2)) X_binom = np.random.binomial(1, 0.5, size=(50, 1))