Skip to content

Commit 95fc403

Browse files
ntjohnson1dmdunla
andauthored
Teneye: Add preliminary implementation and tests (#222)
* Teneye: Add preliminary implementation and tests * Black: Re-run black * Added to local test suite in other PR --------- Co-authored-by: Danny Dunlavy <[email protected]>
1 parent 12b8829 commit 95fc403

File tree

3 files changed

+61
-2
lines changed

3 files changed

+61
-2
lines changed

pyttb/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from pyttb.symktensor import symktensor
2525
from pyttb.symtensor import symtensor
2626
from pyttb.tenmat import tenmat
27-
from pyttb.tensor import tendiag, tenones, tenrand, tensor, tenzeros
27+
from pyttb.tensor import tendiag, teneye, tenones, tenrand, tensor, tenzeros
2828
from pyttb.ttensor import ttensor
2929
from pyttb.tucker_als import tucker_als
3030

@@ -57,6 +57,7 @@ def ignore_warnings(ignore=True):
5757
sumtensor.__name__,
5858
symktensor.__name__,
5959
symtensor.__name__,
60+
teneye.__name__,
6061
tenmat.__name__,
6162
tendiag.__name__,
6263
tenones.__name__,

pyttb/tensor.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import logging
99
from collections.abc import Iterable
10-
from itertools import permutations
10+
from itertools import combinations_with_replacement, permutations
1111
from math import factorial
1212
from typing import Any, Callable, List, Optional, Tuple, Union
1313

@@ -2553,6 +2553,53 @@ def tendiag(elements: np.ndarray, shape: Optional[Tuple[int, ...]] = None) -> te
25532553
return X
25542554

25552555

2556+
def teneye(order: int, size: int) -> tensor:
2557+
"""Create identity tensor of specified shape.
2558+
2559+
T is an "identity tensor if T.ttsv(x, skip_dim=0) = x for all x such that
2560+
norm(x) == 1.
2561+
2562+
An identity tensor only exists if order is even.
2563+
This method is resource intensive
2564+
for even moderate orders or sizes (>=6).
2565+
2566+
Parameters
2567+
----------
2568+
order: Number of dimensions of tensor.
2569+
size: Number of elements in any dimension of the tensor.
2570+
2571+
Examples
2572+
--------
2573+
>>> ttb.teneye(2, 3)
2574+
tensor of shape (3, 3)
2575+
data[:, :] =
2576+
[[1. 0. 0.]
2577+
[0. 1. 0.]
2578+
[0. 0. 1.]]
2579+
>>> x = np.ones((5,))
2580+
>>> x /= np.linalg.norm(x)
2581+
>>> T = ttb.teneye(4, 5)
2582+
>>> np.allclose(T.ttsv(x, 0), x)
2583+
True
2584+
2585+
Returns
2586+
-------
2587+
Identity tensor.
2588+
"""
2589+
if order % 2 != 0:
2590+
raise ValueError(f"Order must be even but received {order}")
2591+
idx_iterator = combinations_with_replacement(range(size), order)
2592+
A = tenzeros((size,) * order)
2593+
s = np.zeros((factorial(order), order // 2))
2594+
for _i, indices in enumerate(idx_iterator):
2595+
p = np.array(list(permutations(indices)))
2596+
for j in range(order // 2):
2597+
s[:, j] = p[:, 2 * j - 1] == p[:, 2 * j]
2598+
v = np.sum(np.sum(s, axis=1) == order // 2)
2599+
A[tuple(zip(*p))] = v / factorial(order)
2600+
return A
2601+
2602+
25562603
def mttv_left(W_in: np.ndarray, U1: np.ndarray) -> np.ndarray:
25572604
"""
25582605
Contract leading mode in partial MTTKRP W_in using factor matrix U1.

tests/test_tensor.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1708,6 +1708,17 @@ def test_tendiag():
17081708
assert X[diag_index] == i
17091709

17101710

1711+
def test_teneye():
1712+
with pytest.raises(ValueError):
1713+
ttb.teneye(1, 0)
1714+
size = 5
1715+
order = 4
1716+
T = ttb.teneye(order, size)
1717+
x = np.random.random((size,))
1718+
x = x / np.linalg.norm(x)
1719+
np.testing.assert_almost_equal(T.ttsv(x, 0), x)
1720+
1721+
17111722
def test_mttv_left():
17121723
m1 = 2
17131724
mi = [range(1, 4)]

0 commit comments

Comments
 (0)