From efaf16cbb3f39e97618402c44c73a3999b358953 Mon Sep 17 00:00:00 2001 From: Jeremy Myers Date: Sat, 16 Sep 2023 08:08:39 -0700 Subject: [PATCH 1/2] =?UTF-8?q?Default=20behavior=20of=20divmod()=20is=20d?= =?UTF-8?q?ifferent=20from=20MATLAB=20mod(),=20on=20which=E2=80=A6=20(#237?= =?UTF-8?q?)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Default behavior of divmod() is different from MATLAB mod(), on which the original logic was based. The prior logic threw a ZeroDivisionError if printitn == 0. Instead, this fix avoids this error by testing that printitn > 0. Tests added for various values of printitn. * Update tests/test_cp_als.py Co-authored-by: Nick <24689722+ntjohnson1@users.noreply.github.com> * Update tests/test_cp_als.py Co-authored-by: Nick <24689722+ntjohnson1@users.noreply.github.com> * Applying Nick's suggesetions to remove mark, change maxiters to save CPU cycles. Also removing output since nothing is actually checked. * Removes mark * Closes #235 --------- Co-authored-by: Jeremy Myers Co-authored-by: Nick <24689722+ntjohnson1@users.noreply.github.com> --- pyttb/cp_als.py | 2 +- tests/test_cp_als.py | 20 ++++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/pyttb/cp_als.py b/pyttb/cp_als.py index b1254235..cd8f2cc7 100644 --- a/pyttb/cp_als.py +++ b/pyttb/cp_als.py @@ -238,7 +238,7 @@ def cp_als( # noqa: PLR0912,PLR0913,PLR0915 else: flag = 1 - if (divmod(iteration, printitn)[1] == 0) or (printitn > 0 and flag == 0): + if (printitn > 0) and ((divmod(iteration, printitn)[1] == 0) or (flag == 0)): print(f" Iter {iteration}: f = {fit:e} f-delta = {fitchange:7.1e}") # Check for convergence diff --git a/tests/test_cp_als.py b/tests/test_cp_als.py index a591790c..f351922a 100644 --- a/tests/test_cp_als.py +++ b/tests/test_cp_als.py @@ -206,3 +206,23 @@ def test_cp_als_sptensor_zeros(capsys): capsys.readouterr() assert pytest.approx(output3["fit"], 1) == 0 assert output3["normresidual"] == 0 + + +def test_cp_als_tensor_printitn(capsys, sample_tensor): + _, T = sample_tensor + + # default printitn + ttb.cp_als(T, 2, printitn=1, maxiters=2) + capsys.readouterr() + + # zero printitn + ttb.cp_als(T, 2, printitn=0, maxiters=2) + capsys.readouterr() + + # negative printitn + ttb.cp_als(T, 2, printitn=-1, maxiters=2) + capsys.readouterr() + + # float printitn + ttb.cp_als(T, 2, printitn=1.5, maxiters=2) + capsys.readouterr() From 2f24969d2e1817aa51851c8b764784c40cc62b5a Mon Sep 17 00:00:00 2001 From: Danny Dunlavy Date: Sat, 16 Sep 2023 09:27:33 -0600 Subject: [PATCH 2/2] black: fixing black formatting --- tests/test_cp_als.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_cp_als.py b/tests/test_cp_als.py index f351922a..672b9871 100644 --- a/tests/test_cp_als.py +++ b/tests/test_cp_als.py @@ -214,11 +214,11 @@ def test_cp_als_tensor_printitn(capsys, sample_tensor): # default printitn ttb.cp_als(T, 2, printitn=1, maxiters=2) capsys.readouterr() - + # zero printitn ttb.cp_als(T, 2, printitn=0, maxiters=2) capsys.readouterr() - + # negative printitn ttb.cp_als(T, 2, printitn=-1, maxiters=2) capsys.readouterr()