Skip to content

Commit c8bffcb

Browse files
authored
Rand generators (#100)
* Non-functional change: * Fix numpy deprecation warning, logic should be equivalent * Tenrand initial implementation * Sptenrand initial implementation
1 parent 05336d1 commit c8bffcb

File tree

5 files changed

+123
-5
lines changed

5 files changed

+123
-5
lines changed

pyttb/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@
1515
from pyttb.ktensor import ktensor
1616
from pyttb.pyttb_utils import *
1717
from pyttb.sptenmat import sptenmat
18-
from pyttb.sptensor import sptendiag, sptensor
18+
from pyttb.sptensor import sptendiag, sptenrand, sptensor
1919
from pyttb.sptensor3 import sptensor3
2020
from pyttb.sumtensor import sumtensor
2121
from pyttb.symktensor import symktensor
2222
from pyttb.symtensor import symtensor
2323
from pyttb.tenmat import tenmat
24-
from pyttb.tensor import tendiag, tenones, tensor, tenzeros
24+
from pyttb.tensor import tendiag, tenones, tenrand, tensor, tenzeros
2525
from pyttb.ttensor import ttensor
2626
from pyttb.tucker_als import tucker_als
2727

pyttb/sptensor.py

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def from_tensor_type(
196196
@classmethod
197197
def from_function(
198198
cls,
199-
function_handle: Callable[[Tuple[float, float]], np.ndarray],
199+
function_handle: Callable[[Tuple[int, ...]], np.ndarray],
200200
shape: Tuple[int, ...],
201201
nonzeros: float,
202202
) -> sptensor:
@@ -1049,7 +1049,7 @@ def scale(self, factor: np.ndarray, dims: Union[float, np.ndarray]) -> sptensor:
10491049

10501050
if isinstance(factor, ttb.tensor):
10511051
shapeArray = np.array(self.shape)
1052-
if np.any(factor.shape != shapeArray[dims]):
1052+
if not np.array_equal(factor.shape, shapeArray[dims]):
10531053
assert False, "Size mismatch in scale"
10541054
return ttb.sptensor.from_data(
10551055
self.subs,
@@ -1058,7 +1058,7 @@ def scale(self, factor: np.ndarray, dims: Union[float, np.ndarray]) -> sptensor:
10581058
)
10591059
if isinstance(factor, ttb.sptensor):
10601060
shapeArray = np.array(self.shape)
1061-
if np.any(factor.shape != shapeArray[dims]):
1061+
if not np.array_equal(factor.shape, shapeArray[dims]):
10621062
assert False, "Size mismatch in scale"
10631063
return ttb.sptensor.from_data(
10641064
self.subs, self.vals * factor.extract(self.subs[:, dims]), self.shape
@@ -2585,6 +2585,55 @@ def ttm(
25852585
return ttb.tensor.from_tensor_type(Ynt)
25862586

25872587

2588+
def sptenrand(
2589+
shape: Tuple[int, ...],
2590+
density: Optional[float] = None,
2591+
nonzeros: Optional[float] = None,
2592+
) -> sptensor:
2593+
"""
2594+
Create sptensor with entries drawn from a uniform distribution on the unit interval
2595+
2596+
Parameters
2597+
----------
2598+
shape: Shape of resulting tensor
2599+
density: Density of resulting sparse tensor
2600+
nonzeros: Number of nonzero entries in resulting sparse tensor
2601+
2602+
Returns
2603+
-------
2604+
Constructed tensor
2605+
2606+
Example
2607+
-------
2608+
>>> X = ttb.sptenrand((2,2), nonzeros=1)
2609+
>>> Y = ttb.sptenrand((2,2), density=0.25)
2610+
"""
2611+
if density is None and nonzeros is None:
2612+
raise ValueError("Must set either density or nonzeros")
2613+
2614+
if density is not None and nonzeros is not None:
2615+
raise ValueError("Must set either density or nonzeros but not both")
2616+
2617+
if density is not None and not 0 < density <= 1:
2618+
raise ValueError(f"Density must be a fraction (0, 1] but received {density}")
2619+
2620+
if isinstance(density, float):
2621+
valid_nonzeros = float(np.prod(shape) * density)
2622+
elif isinstance(nonzeros, (int, float)):
2623+
valid_nonzeros = nonzeros
2624+
else: # pragma: no cover
2625+
raise ValueError(
2626+
f"Incorrect types for density:{density} and nonzeros:{nonzeros}"
2627+
)
2628+
2629+
# Typing doesn't play nice with partial
2630+
# mypy issue: 1484
2631+
def unit_uniform(pass_through_shape: Tuple[int, ...]) -> np.ndarray:
2632+
return np.random.uniform(low=0, high=1, size=pass_through_shape)
2633+
2634+
return ttb.sptensor.from_function(unit_uniform, shape, valid_nonzeros)
2635+
2636+
25882637
def sptendiag(
25892638
elements: np.ndarray, shape: Optional[Tuple[int, ...]] = None
25902639
) -> sptensor:

pyttb/tensor.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1873,6 +1873,31 @@ def tenzeros(shape: Tuple[int, ...]) -> tensor:
18731873
return tensor.from_function(np.zeros, shape)
18741874

18751875

1876+
def tenrand(shape: Tuple[int, ...]) -> tensor:
1877+
"""
1878+
Creates a tensor with entries drawn from a uniform distribution on the unit interval
1879+
1880+
Parameters
1881+
----------
1882+
shape: Shape of resulting tensor
1883+
1884+
Returns
1885+
-------
1886+
Constructed tensor
1887+
1888+
Example
1889+
-------
1890+
>>> X = ttb.tenrand((2,2))
1891+
"""
1892+
1893+
# Typing doesn't play nice with partial
1894+
# mypy issue: 1484
1895+
def unit_uniform(pass_through_shape: Tuple[int, ...]) -> np.ndarray:
1896+
return np.random.uniform(low=0, high=1, size=pass_through_shape)
1897+
1898+
return tensor.from_function(unit_uniform, shape)
1899+
1900+
18761901
def tendiag(elements: np.ndarray, shape: Optional[Tuple[int, ...]] = None) -> tensor:
18771902
"""
18781903
Creates a tensor with elements along super diagonal

tests/test_sptensor.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1789,3 +1789,38 @@ def test_sptendiag():
17891789
for i in range(N):
17901790
diag_index = (i,) * N
17911791
assert X[diag_index] == i
1792+
1793+
1794+
def test_sptenrand():
1795+
arbitrary_shape = (3, 3, 3)
1796+
rand_tensor = ttb.sptenrand(arbitrary_shape, nonzeros=1)
1797+
in_unit_interval = np.all(0 <= rand_tensor.vals <= 1)
1798+
assert (
1799+
in_unit_interval
1800+
and rand_tensor.shape == arbitrary_shape
1801+
and rand_tensor.nnz == 1
1802+
)
1803+
1804+
rand_tensor = ttb.sptenrand(arbitrary_shape, density=1 / np.prod(arbitrary_shape))
1805+
in_unit_interval = np.all(0 <= rand_tensor.vals <= 1)
1806+
assert (
1807+
in_unit_interval
1808+
and rand_tensor.shape == arbitrary_shape
1809+
and rand_tensor.nnz == 1
1810+
)
1811+
1812+
# Negative tests
1813+
# Bad density
1814+
with pytest.raises(ValueError):
1815+
ttb.sptenrand(arbitrary_shape, density=-1)
1816+
ttb.sptenrand(arbitrary_shape, density=2)
1817+
1818+
# Missing args
1819+
# Bad density
1820+
with pytest.raises(ValueError):
1821+
ttb.sptenrand(arbitrary_shape)
1822+
1823+
# Redundant/contradicting args
1824+
# Bad density
1825+
with pytest.raises(ValueError):
1826+
ttb.sptenrand(arbitrary_shape, density=0.5, nonzeros=2)

tests/test_tensor.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1643,6 +1643,15 @@ def test_tenzeros():
16431643
assert np.equal(zeros_tensor, data_tensor), "Tenzeros should match all zeros tensor"
16441644

16451645

1646+
def test_tenrand():
1647+
arbitrary_shape = (3, 3, 3)
1648+
rand_tensor = ttb.tenrand(arbitrary_shape)
1649+
in_unit_interval = np.all((rand_tensor >= 0).data) and np.all(
1650+
(rand_tensor <= 1).data
1651+
)
1652+
assert in_unit_interval and rand_tensor.shape == arbitrary_shape
1653+
1654+
16461655
def test_tendiag():
16471656
N = 4
16481657
elements = np.arange(0, N)

0 commit comments

Comments
 (0)