In [1]:
import time
import numpy as np

In [2]:
def old_mttkrp(Ul, Y, Ur, szl, szr, szn, R):
    Y = np.reshape(Y, (-1, szr))
    Y = Y @ Ul
    Y = np.reshape(Y, (szl, szn, R))
    Ur = Ur.reshape((szl, 1, R))
    V = np.zeros((szn, R))
    for r in range(R):
        V[:, [r]] = Y[:, :, r].T @ Ur[:, :, r]
    return V

In [3]:
def new_mttkrp(Ul, Y, Ur, szl, szr, szn, R):
    Y = np.reshape(Y, (-1, szr))
    Y = Y.reshape(-1, szr) @ Ul
    Y = np.reshape(Y, (szl, szn, R))
    V = np.einsum('ijk, ik -> jk', Y, Ur)
    return V

In [4]:
def get_tens(n, R, shape):
    szl = np.prod(shape[0:n])
    szr = np.prod(shape[n + 1 :])
    szn = shape[n]
    
    Ul = np.random.rand(szr, R)
    Ur = np.random.rand(szl, R)
    Y = np.random.rand(*shape)
    return Ul, Ur, Y, szl, szr, szn

In [5]:
def time_mttkrp(n, r, shape=[20,30,40,3]):
    Ul, Ur, Y, szl, szr, szn = get_tens(n,r,shape)
    print("new mttkrp correct", np.all(old_mttkrp(Ul, Y, Ur, szl, szr, szn, r) == new_mttkrp(Ul, Y, Ur,szl, szr, szn, r)))
    old_time = []
    new_time = []
    for _ in range(10):
        old_start = time.time()
        old_mttkrp(Ul, Y, Ur, szl, szr, szn, r)
        old_end = new_start = time.time()
        new_mttkrp(Ul, Y, Ur, szl, szr, szn, r)
        new_end = time.time()
        old_time.append(old_end - old_start)
        new_time.append(new_end - new_start)
    return old_time, new_time

In [6]:
print("shape", [20,30,40,50])
print("---------------------")
for n in [1,2]:
    print(f"mode-{1} mttkrp")
    for r in [2, 10, 50, 100, 200]:
        print(f"rank-{r}")
        old_time, new_time = time_mttkrp(n,r)
        print("old", np.array(old_time)[1:].mean())
        print("new", np.array(new_time)[1:].mean())
        print("speedup", np.array(old_time)[1:].mean() / np.array(new_time)[1:].mean())
        print("-----------------------------------------------------------------------")

shape [20, 30, 40, 50]
---------------------
mode-1 mttkrp
rank-2
new mttkrp correct True
old 3.144476148817274e-05
new 2.8769175211588543e-05
speedup 1.0930018416206262
-----------------------------------------------------------------------
rank-10
new mttkrp correct True
old 7.865164015028212e-05
new 5.963113572862413e-05
speedup 1.3189693469569082
-----------------------------------------------------------------------
rank-50
new mttkrp correct True
old 0.00038048956129286025
new 9.00162590874566e-05
speedup 4.2268981753972925
-----------------------------------------------------------------------
rank-100
new mttkrp correct True
old 0.0007970068189832899
new 0.0008294317457411024
speedup 0.9609070584477802
-----------------------------------------------------------------------
rank-200
new mttkrp correct True
old 0.0012705855899386937
new 0.002744701173570421
speedup 0.46292310513565427
-----------------------------------------------------------------------
mode-1 mttkrp
rank-2
new