Skip to content

Commit b57a19b

Browse files
ntjohnson1dmdunla
andauthored
Sparse logical (#269)
* TENSOR: Logical operations should yield indictator tensors of original dtype: * Also use our utility functions for sample data * SPTENSOR: Logical operations should yield indictator tensors of original dtype * SPTENSOR: Fix exact quality logical comparison bug --------- Co-authored-by: Danny Dunlavy <[email protected]>
1 parent 1fb71c8 commit b57a19b

File tree

4 files changed

+79
-55
lines changed

4 files changed

+79
-55
lines changed

pyttb/sptensor.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -703,7 +703,7 @@ def logical_and(self, B: Union[float, sptensor, ttb.tensor]) -> sptensor:
703703
if B == 0:
704704
C = sptensor(shape=self.shape)
705705
else:
706-
newvals = self.vals == B
706+
newvals = np.ones_like(self.vals)
707707
C = sptensor(self.subs, newvals, self.shape)
708708
return C
709709
# Case 2: Argument is a tensor of some sort
@@ -718,6 +718,7 @@ def logical_and(self, B: Union[float, sptensor, ttb.tensor]) -> sptensor:
718718
self.shape,
719719
lambda x: len(x) == 2,
720720
)
721+
C.vals = C.vals.astype(self.vals.dtype)
721722

722723
return C
723724

@@ -741,7 +742,7 @@ def logical_not(self) -> sptensor:
741742
allsubs = self.allsubs()
742743
subsIdx = tt_setdiff_rows(allsubs, self.subs)
743744
subs = allsubs[subsIdx]
744-
trueVector = np.ones(shape=(subs.shape[0], 1), dtype=bool)
745+
trueVector = np.ones(shape=(subs.shape[0], 1), dtype=self.vals.dtype)
745746
return sptensor(subs, trueVector, self.shape)
746747

747748
@overload
@@ -771,12 +772,14 @@ def logical_or(
771772
assert False, "Logical Or requires tensors of the same size"
772773

773774
if isinstance(B, ttb.sptensor):
774-
return sptensor.from_aggregator(
775+
C = sptensor.from_aggregator(
775776
np.vstack((self.subs, B.subs)),
776777
np.ones((self.subs.shape[0] + B.subs.shape[0], 1)),
777778
self.shape,
778779
lambda x: len(x) >= 1,
779780
)
781+
C.vals = C.vals.astype(self.vals.dtype)
782+
return C
780783

781784
assert False, "Sptensor Logical Or argument must be scalar or sptensor"
782785

@@ -814,9 +817,11 @@ def logical_xor(
814817
assert False, "Logical XOR requires tensors of the same size"
815818

816819
subs = np.vstack((self.subs, other.subs))
817-
return ttb.sptensor.from_aggregator(
820+
result = ttb.sptensor.from_aggregator(
818821
subs, np.ones((len(subs), 1)), self.shape, lambda x: len(x) == 1
819822
)
823+
result.vals = result.vals.astype(self.vals.dtype)
824+
return result
820825

821826
assert False, "The argument must be an sptensor, tensor or scalar"
822827

pyttb/tensor.py

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -643,13 +643,13 @@ def logical_and(self, other: Union[float, tensor]) -> tensor:
643643
644644
Examples
645645
--------
646-
>>> T = ttb.tensor(np.ones((2,2), dtype=bool))
646+
>>> T = ttb.tenones((2,2))
647647
>>> T.logical_and(T).collapse() # All true
648-
4
648+
4.0
649649
"""
650650

651651
def logical_and(x, y):
652-
return np.logical_and(x, y)
652+
return np.logical_and(x, y).astype(dtype=x.dtype)
653653

654654
return tt_tenfun(logical_and, self, other)
655655

@@ -659,11 +659,12 @@ def logical_not(self) -> tensor:
659659
660660
Examples
661661
--------
662-
>>> T = ttb.tensor(np.ones((2,2), dtype=bool))
662+
>>> T = ttb.tenones((2,2))
663663
>>> T.logical_not().collapse() # All false
664-
0
664+
0.0
665665
"""
666-
return ttb.tensor(np.logical_not(self.data), copy=False)
666+
# Np logical not dtype argument seems to not work here
667+
return ttb.tensor(np.logical_not(self.data).astype(self.data.dtype), copy=False)
667668

668669
def logical_or(self, other: Union[float, tensor]) -> tensor:
669670
"""
@@ -676,13 +677,13 @@ def logical_or(self, other: Union[float, tensor]) -> tensor:
676677
677678
Examples
678679
--------
679-
>>> T = ttb.tensor(np.ones((2,2), dtype=bool))
680+
>>> T = ttb.tenones((2,2))
680681
>>> T.logical_or(T.logical_not()).collapse() # All true
681-
4
682+
4.0
682683
"""
683684

684685
def tensor_or(x, y):
685-
return np.logical_or(x, y)
686+
return np.logical_or(x, y).astype(x.dtype)
686687

687688
return tt_tenfun(tensor_or, self, other)
688689

@@ -697,13 +698,13 @@ def logical_xor(self, other: Union[float, tensor]) -> tensor:
697698
698699
Examples
699700
--------
700-
>>> T = ttb.tensor(np.ones((2,2), dtype=bool))
701+
>>> T = ttb.tenones((2,2))
701702
>>> T.logical_xor(T.logical_not()).collapse() # All true
702-
4
703+
4.0
703704
"""
704705

705706
def tensor_xor(x, y):
706-
return np.logical_xor(x, y)
707+
return np.logical_xor(x, y).astype(dtype=x.dtype)
707708

708709
return tt_tenfun(tensor_xor, self, other)
709710

@@ -723,7 +724,7 @@ def mask(self, W: tensor) -> np.ndarray:
723724
Examples
724725
--------
725726
>>> T = ttb.tensor(np.array([[1, 2], [3, 4]]))
726-
>>> W = ttb.tensor(np.ones((2,2)))
727+
>>> W = ttb.tenones((2,2))
727728
>>> T.mask(W)
728729
array([1, 3, 2, 4])
729730
"""
@@ -758,7 +759,7 @@ def mttkrp( # noqa: PLR0912
758759
759760
Examples
760761
--------
761-
>>> T = ttb.tensor(np.ones((2,2,2)))
762+
>>> T = ttb.tenones((2,2,2))
762763
>>> U = [np.ones((2,2))] * 3
763764
>>> T.mttkrp(U, 2)
764765
array([[4., 4.],
@@ -841,7 +842,7 @@ def mttkrps(self, U: Union[ttb.ktensor, List[np.ndarray]]) -> List[np.ndarray]:
841842
842843
Examples
843844
--------
844-
>>> T = ttb.tensor(np.ones((2,2,2)))
845+
>>> T = ttb.tenones((2,2,2))
845846
>>> U = [np.ones((2,2))] * 3
846847
>>> T.mttkrps(U)
847848
[array([[4., 4.],
@@ -876,7 +877,7 @@ def ndims(self) -> int:
876877
877878
Examples
878879
--------
879-
>>> T = ttb.tensor(np.ones((2,2)))
880+
>>> T = ttb.tenones((2,2))
880881
>>> T.ndims
881882
2
882883
"""
@@ -891,7 +892,7 @@ def nnz(self) -> int:
891892
892893
Examples
893894
--------
894-
>>> T = ttb.tensor(np.ones((2,2,2)))
895+
>>> T = ttb.tenones((2,2,2))
895896
>>> T.nnz
896897
8
897898
"""
@@ -904,7 +905,7 @@ def norm(self) -> float:
904905
905906
Examples
906907
--------
907-
>>> T = ttb.tensor(np.ones((2,2,2,2)))
908+
>>> T = ttb.tenones((2,2,2,2))
908909
>>> T.norm()
909910
4.0
910911
"""
@@ -1025,7 +1026,7 @@ def reshape(self, shape: Tuple[int, ...]) -> tensor:
10251026
10261027
Examples
10271028
--------
1028-
>>> T1 = ttb.tensor(np.ones((2,2)))
1029+
>>> T1 = ttb.tenones((2,2))
10291030
>>> T1.shape
10301031
(2, 2)
10311032
>>> T2 = T1.reshape((4,1))
@@ -1152,7 +1153,7 @@ def symmetrize( # noqa: PLR0912,PLR0915
11521153
11531154
Examples
11541155
--------
1155-
>>> T = ttb.tensor(np.ones((2,2,2)))
1156+
>>> T = ttb.tenones((2,2,2))
11561157
>>> T.symmetrize(np.array([0,2]))
11571158
tensor of shape (2, 2, 2)
11581159
data[0, :, :] =
@@ -1317,7 +1318,7 @@ def ttm(
13171318
13181319
Examples
13191320
--------
1320-
>>> T = ttb.tensor(np.ones((2,2,2,2)))
1321+
>>> T = ttb.tenones((2,2,2,2))
13211322
>>> A = 2*np.ones((2,1))
13221323
>>> T.ttm([A,A], dims=[0,1], transpose=True)
13231324
tensor of shape (1, 1, 2, 2)
@@ -1665,7 +1666,7 @@ def __setitem__(self, key, value):
16651666
16661667
Examples
16671668
--------
1668-
>>> T = tensor(np.ones((3,4,2)))
1669+
>>> T = tenones((3,4,2))
16691670
>>> # replaces subtensor
16701671
>>> T[0:2,0:2,0] = np.ones((2,2))
16711672
>>> # replaces two elements
@@ -1810,7 +1811,7 @@ def __getitem__(self, item): # noqa: PLR0912
18101811
18111812
Examples
18121813
--------
1813-
>>> T = tensor(np.ones((3,4,2,1)))
1814+
>>> T = tenones((3,4,2,1))
18141815
>>> T[0,0,0,0] # produces a scalar
18151816
1.0
18161817
>>> # produces a tensor of order 1 and size 1

tests/test_sptensor.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,11 +174,14 @@ def test_sptensor_and_scalar(sample_sptensor):
174174
assert b.subs.size == 0
175175
assert b.vals.size == 0
176176
assert b.shape == data["shape"]
177+
assert b.vals.dtype == sptensorInstance.vals.dtype
177178

179+
# Sparsity pattern check not exact value equality
178180
b = sptensorInstance.logical_and(0.5)
179181
assert np.array_equal(b.subs, data["subs"])
180-
assert np.array_equal(b.vals, np.array([[True], [False], [False], [False]]))
182+
assert np.array_equal(b.vals, np.array([[True], [True], [True], [True]]))
181183
assert b.shape == data["shape"]
184+
assert b.vals.dtype == sptensorInstance.vals.dtype
182185

183186

184187
def test_sptensor_and_sptensor(sample_sptensor):
@@ -188,6 +191,7 @@ def test_sptensor_and_sptensor(sample_sptensor):
188191
assert np.array_equal(b.subs, data["subs"])
189192
assert np.array_equal(b.vals, np.array([[True], [True], [True], [True]]))
190193
assert b.shape == data["shape"]
194+
assert b.vals.dtype == sptensorInstance.vals.dtype
191195

192196
with pytest.raises(AssertionError) as excinfo:
193197
sptensorInstance.logical_and(
@@ -207,6 +211,7 @@ def test_sptensor_and_tensor(sample_sptensor):
207211
b = sptensorInstance.logical_and(sptensorInstance.to_tensor())
208212
assert np.array_equal(b.subs, data["subs"])
209213
assert np.array_equal(b.vals, np.ones(data["vals"].shape))
214+
assert b.vals.dtype == sptensorInstance.vals.dtype
210215

211216

212217
def test_sptensor_full(sample_sptensor):
@@ -685,6 +690,7 @@ def test_sptensor_logical_not(sample_sptensor):
685690
assert all(notSptensorInstance.vals == 1)
686691
assert np.array_equal(notSptensorInstance.subs, np.array(result))
687692
assert notSptensorInstance.shape == data["shape"]
693+
assert notSptensorInstance.vals.dtype == sptensorInstance.vals.dtype
688694

689695

690696
def test_sptensor_logical_or(sample_sptensor):
@@ -695,20 +701,24 @@ def test_sptensor_logical_or(sample_sptensor):
695701
assert sptensorOr.shape == data["shape"]
696702
assert np.array_equal(sptensorOr.subs, data["subs"])
697703
assert np.array_equal(sptensorOr.vals, np.ones((data["vals"].shape[0], 1)))
704+
assert sptensorOr.vals.dtype == sptensorInstance.vals.dtype
698705

699706
# Sptensor logical or with tensor
700707
sptensorOr = sptensorInstance.logical_or(sptensorInstance.to_tensor())
701708
nonZeroMatrix = np.zeros(data["shape"])
702709
nonZeroMatrix[tuple(data["subs"].transpose())] = 1
703710
assert np.array_equal(sptensorOr.data, nonZeroMatrix)
711+
assert sptensorOr.data.dtype == sptensorInstance.vals.dtype
704712

705713
# Sptensor logical or with scalar, 0
706714
sptensorOr = sptensorInstance.logical_or(0)
707715
assert np.array_equal(sptensorOr.data, nonZeroMatrix)
716+
assert sptensorOr.data.dtype == sptensorInstance.vals.dtype
708717

709718
# Sptensor logical or with scalar, not 0
710719
sptensorOr = sptensorInstance.logical_or(1)
711720
assert np.array_equal(sptensorOr.data, np.ones(data["shape"]))
721+
assert sptensorOr.data.dtype == sptensorInstance.vals.dtype
712722

713723
# Sptensor logical or with wrong shape sptensor
714724
with pytest.raises(AssertionError) as excinfo:
@@ -1165,19 +1175,23 @@ def test_sptensor_logical_xor(sample_sptensor):
11651175
# Sptensor logical xor with scalar, 0
11661176
sptensorXor = sptensorInstance.logical_xor(0)
11671177
assert np.array_equal(sptensorXor.data, nonZeroMatrix)
1178+
assert sptensorXor.data.dtype == sptensorInstance.vals.dtype
11681179

11691180
# Sptensor logical xor with scalar, not 0
11701181
sptensorXor = sptensorInstance.logical_xor(1)
11711182
assert np.array_equal(sptensorXor.data, sptensorInstance.logical_not().full().data)
1183+
assert sptensorXor.data.dtype == sptensorInstance.vals.dtype
11721184

11731185
# Sptensor logical xor with another sptensor
11741186
sptensorXor = sptensorInstance.logical_xor(sptensorInstance)
11751187
assert sptensorXor.shape == data["shape"]
11761188
assert sptensorXor.vals.size == 0
1189+
assert sptensorXor.vals.dtype == sptensorInstance.vals.dtype
11771190

11781191
# Sptensor logical xor with tensor
11791192
sptensorXor = sptensorInstance.logical_xor(sptensorInstance.to_tensor())
1180-
assert np.array_equal(sptensorXor.data, np.zeros(data["shape"], dtype=bool))
1193+
assert np.array_equal(sptensorXor.data, np.zeros(data["shape"]))
1194+
assert sptensorXor.data.dtype == sptensorInstance.vals.dtype
11811195

11821196
# Sptensor logical xor with wrong shape sptensor
11831197
with pytest.raises(AssertionError) as excinfo:

0 commit comments

Comments
 (0)