diff --git a/pymc/logprob/tensor.py b/pymc/logprob/tensor.py index 9cbf456b7b..b227c13293 100644 --- a/pymc/logprob/tensor.py +++ b/pymc/logprob/tensor.py @@ -188,10 +188,12 @@ def logprob_join(op, values, axis, *base_rvs, **kwargs): # If the stacked variables depend on each other, we have to replace them by the respective values logps = replace_rvs_by_values(logps, rvs_to_values=base_rvs_to_split_values) - base_vars_ndim_supp = split_values[0].ndim - logps[0].ndim + # Make axis positive and adjust for multivariate logp fewer dimensions to the right + axis = pt.switch(axis >= 0, axis, value.ndim + axis) + axis = pt.minimum(axis, logps[0].ndim - 1) join_logprob = pt.concatenate( [pt.atleast_1d(logp) for logp in logps], - axis=axis - base_vars_ndim_supp, + axis=axis, ) return join_logprob diff --git a/tests/logprob/test_tensor.py b/tests/logprob/test_tensor.py index e61e0d1700..b37d42ee62 100644 --- a/tests/logprob/test_tensor.py +++ b/tests/logprob/test_tensor.py @@ -269,34 +269,23 @@ def test_measurable_join_univariate(size1, size2, axis, concatenate): @pytest.mark.parametrize( - "size1, supp_size1, size2, supp_size2, axis, concatenate", + "size1, supp_size1, size2, supp_size2, axis, concatenate, logp_axis", [ - (None, 2, None, 2, 0, True), - (None, 2, None, 2, -1, True), - ((5,), 2, (3,), 2, 0, True), - ((5,), 2, (3,), 2, -2, True), - ((2,), 5, (2,), 3, 1, True), - pytest.param( - (2,), - 5, - (2,), - 5, - 0, - False, - marks=pytest.mark.xfail(reason="cannot measure dimshuffled multivariate RVs"), - ), - pytest.param( - (2,), - 5, - (2,), - 5, - 1, - False, - marks=pytest.mark.xfail(reason="cannot measure dimshuffled multivariate RVs"), - ), + (None, 2, None, 2, 0, True, 0), + (None, 2, None, 2, -1, True, 0), + ((5,), 2, (3,), 2, 0, True, 0), + ((5,), 2, (3,), 2, -2, True, 0), + ((2,), 5, (2,), 3, 1, True, 0), + ((5, 6), 10, (5, 1), 10, 1, True, 1), + ((5, 6), 10, (5, 1), 10, -2, True, 1), + ((2,), 5, (2,), 5, 0, False, 0), + ((2,), 5, (2,), 5, 1, False, 1), + ((5, 6), 10, (5, 6), 10, 2, False, 2), ], ) -def test_measurable_join_multivariate(size1, supp_size1, size2, supp_size2, axis, concatenate): +def test_measurable_join_multivariate( + size1, supp_size1, size2, supp_size2, axis, concatenate, logp_axis +): base1_rv = pt.random.multivariate_normal( np.zeros(supp_size1), np.eye(supp_size1), size=size1, name="base1" ) @@ -310,19 +299,18 @@ def test_measurable_join_multivariate(size1, supp_size1, size2, supp_size2, axis base1_vv = base1_rv.clone() base2_vv = base2_rv.clone() y_vv = y_rv.clone() + + y_logp = logp(y_rv, y_vv) + assert_no_rvs(y_logp) + base_logps = [ pt.atleast_1d(logp) for logp in conditional_logp({base1_rv: base1_vv, base2_rv: base2_vv}).values() ] - if concatenate: - axis_norm = np.core.numeric.normalize_axis_index(axis, base1_rv.ndim) - base_logps = pt.concatenate(base_logps, axis=axis_norm - 1) + expected_logp = pt.concatenate(base_logps, axis=logp_axis) else: - axis_norm = np.core.numeric.normalize_axis_index(axis, base1_rv.ndim + 1) - base_logps = pt.stack(base_logps, axis=axis_norm - 1) - y_logp = y_logp = logp(y_rv, y_vv) - assert_no_rvs(y_logp) + expected_logp = pt.stack(base_logps, axis=logp_axis) base1_testval = base1_rv.eval() base2_testval = base2_rv.eval() @@ -331,7 +319,7 @@ def test_measurable_join_multivariate(size1, supp_size1, size2, supp_size2, axis else: y_testval = np.stack((base1_testval, base2_testval), axis=axis) np.testing.assert_allclose( - base_logps.eval({base1_vv: base1_testval, base2_vv: base2_testval}), + expected_logp.eval({base1_vv: base1_testval, base2_vv: base2_testval}), y_logp.eval({y_vv: y_testval}), )