Skip to content

Commit 2471d51

Browse files
committed
Support updated typing in numpy 2.0
1 parent 77e122e commit 2471d51

File tree

11 files changed

+48
-27
lines changed

11 files changed

+48
-27
lines changed

pyttb/cp_als.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def cp_als( # noqa: PLR0912,PLR0913,PLR0915
185185

186186
# Reduce dimorder to only those modes we will optimize
187187
dimorder_in = dimorder # save for output
188-
dimorder = [d for d in dimorder if d in optdims]
188+
dimorder = [int(d) for d in dimorder if d in optdims]
189189

190190
# Store the last MTTKRP result to accelerate fitness computation
191191
U_mttkrp = np.zeros((input_tensor.shape[dimorder[-1]], rank))

pyttb/cp_apr.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -917,8 +917,8 @@ def tt_cp_apr_pqnr( # noqa: PLR0912,PLR0913,PLR0915
917917
delg = np.zeros((rank, lbfgsMem))
918918
rho = np.zeros((lbfgsMem,))
919919
lbfgsPos = 0
920-
m_rowOLD = np.empty(())
921-
gradOLD = np.empty(())
920+
m_rowOLD = np.empty((), dtype=m_row.dtype)
921+
gradOLD = np.empty((), dtype=m_row.dtype)
922922

923923
# Iteratively solve the row subproblem with projected quasi-Newton steps
924924
for i in range(maxinneriters):
@@ -983,8 +983,8 @@ def tt_cp_apr_pqnr( # noqa: PLR0912,PLR0913,PLR0915
983983
isRowNOTconverged[jj] = 1
984984

985985
# Update the L-BFGS approximation.
986-
tmp_delm = m_row - m_rowOLD
987-
tmp_delg = gradM - gradOLD
986+
tmp_delm: np.ndarray = m_row - m_rowOLD
987+
tmp_delg: np.ndarray = gradM - gradOLD
988988
tmp_delm_dot = tmp_delm.dot(tmp_delg.transpose())
989989
if not np.any(tmp_delm_dot == 0):
990990
tmp_rho = 1 / tmp_delm_dot

pyttb/gcp/fg_est.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def estimate_helper(
162162
ndim = subs.shape[1]
163163

164164
# Create exploded U's from the model factor matrices
165-
Uexp = [np.empty(())] * ndim
165+
Uexp = [np.empty((), dtype=factors[0].dtype)] * ndim
166166
for k in range(ndim):
167167
Uexp[k] = factors[k][subs[:, k], :]
168168

pyttb/gcp/samplers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def nonzeros(
272272

273273
# Select nonzeros
274274
if samples == nnz:
275-
nidx = np.arange(0, nnz)
275+
nidx: np.ndarray = np.arange(0, nnz, dtype=int)
276276
elif with_replacement or samples < nnz:
277277
nidx = np.random.choice(nnz, size=samples, replace=with_replacement)
278278
else:

pyttb/hosvd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def hosvd( # noqa: PLR0912,PLR0913,PLR0915
129129

130130
# Shrink!
131131
if sequential:
132-
Y = Y.ttm(factor_matrices[k].transpose(), k)
132+
Y = Y.ttm(factor_matrices[k].transpose(), int(k))
133133
# Extract final core
134134
if sequential:
135135
G = Y

pyttb/ktensor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1192,7 +1192,9 @@ def mask(self, W: Union[ttb.tensor, ttb.sptensor]) -> np.ndarray:
11921192
vals = vals + tmpvals
11931193
return vals
11941194

1195-
def mttkrp(self, U: Union[ktensor, Sequence[np.ndarray]], n: int) -> np.ndarray:
1195+
def mttkrp(
1196+
self, U: Union[ktensor, Sequence[np.ndarray]], n: Union[int, np.integer]
1197+
) -> np.ndarray:
11961198
"""
11971199
Matricized tensor times Khatri-Rao product for :class:`pyttb.ktensor`.
11981200

pyttb/pyttb_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -816,7 +816,7 @@ def get_index_variant(indices: IndexType) -> IndexVariant:
816816

817817

818818
def get_mttkrp_factors(
819-
U: Union[ttb.ktensor, Sequence[np.ndarray]], n: int, ndims: int
819+
U: Union[ttb.ktensor, Sequence[np.ndarray]], n: Union[int, np.integer], ndims: int
820820
) -> Sequence[np.ndarray]:
821821
"""Apply standard checks and type conversions for mttkrp factors"""
822822
if isinstance(U, ttb.ktensor):

pyttb/sptensor.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,7 @@ def allsubs(self) -> np.ndarray:
415415

416416
# Generate each column of the subscripts in turn
417417
for n in range(0, self.ndims):
418-
i = o.copy()
418+
i: list[np.ndarray] = o.copy()
419419
i[n] = np.expand_dims(np.arange(0, self.shape[n]), axis=1)
420420
s[:, n] = np.squeeze(ttb.khatrirao(*i))
421421

@@ -807,15 +807,19 @@ def to_sptenmat(
807807
csize = np.array(self.shape)[cdims]
808808

809809
if rsize.size == 0:
810-
ridx = np.zeros((self.nnz, 1))
810+
ridx: np.ndarray[tuple[int, ...], np.dtype[Any]] = np.zeros(
811+
(self.nnz, 1), dtype=int
812+
)
811813
elif self.subs.size == 0:
812814
ridx = np.array([], dtype=int)
813815
else:
814816
ridx = tt_sub2ind(rsize, self.subs[:, rdims])
815817
ridx = ridx.reshape((ridx.size, 1)).astype(int)
816818

817819
if csize.size == 0:
818-
cidx = np.zeros((self.nnz, 1))
820+
cidx: np.ndarray[tuple[int, ...], np.dtype[Any]] = np.zeros(
821+
(self.nnz, 1), dtype=int
822+
)
819823
elif self.subs.size == 0:
820824
cidx = np.array([], dtype=int)
821825
else:
@@ -1243,7 +1247,9 @@ def mask(self, W: sptensor) -> np.ndarray:
12431247
vals[matching_indices] = self.vals[matching_indices]
12441248
return vals
12451249

1246-
def mttkrp(self, U: Union[ttb.ktensor, Sequence[np.ndarray]], n: int) -> np.ndarray:
1250+
def mttkrp(
1251+
self, U: Union[ttb.ktensor, Sequence[np.ndarray]], n: Union[int, np.integer]
1252+
) -> np.ndarray:
12471253
"""
12481254
Matricized tensor times Khatri-Rao product using the
12491255
:class:`pyttb.sptensor`. This is an efficient form of the matrix
@@ -1312,7 +1318,7 @@ def mttkrp(self, U: Union[ttb.ktensor, Sequence[np.ndarray]], n: int) -> np.ndar
13121318
else:
13131319
Z.append(np.array([]))
13141320
# Perform ttv multiplication
1315-
ttv = self.ttv(Z, exclude_dims=n)
1321+
ttv = self.ttv(Z, exclude_dims=int(n))
13161322
# TODO is is possible to hit the float condition here?
13171323
if isinstance(ttv, float): # pragma: no cover
13181324
V[:, r] = ttv

pyttb/sumtensor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,9 @@ def innerprod(
318318
result += part.innerprod(other)
319319
return result
320320

321-
def mttkrp(self, U: Union[ttb.ktensor, List[np.ndarray]], n: int) -> np.ndarray:
321+
def mttkrp(
322+
self, U: Union[ttb.ktensor, List[np.ndarray]], n: Union[int, np.integer]
323+
) -> np.ndarray:
322324
"""
323325
Matricized tensor times Khatri-Rao product. The matrices used in the
324326
Khatri-Rao product are passed as a :class:`pyttb.ktensor` (where the

pyttb/tensor.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
import logging
1010
from collections.abc import Iterable
11-
from functools import partial
1211
from itertools import combinations_with_replacement, permutations
1312
from math import factorial, prod
1413
from typing import (
@@ -20,6 +19,7 @@
2019
Sequence,
2120
Tuple,
2221
Union,
22+
cast,
2323
overload,
2424
)
2525

@@ -734,9 +734,8 @@ def issymmetric( # noqa: PLR0912
734734
sz = np.array(self.shape)
735735

736736
if grps is None:
737-
grps = np.arange(0, n)
738-
739-
if len(grps.shape) == 1:
737+
grps = np.arange(0, n)[None, :]
738+
elif len(grps.shape) == 1:
740739
grps = np.array([grps])
741740

742741
# Substantially different routines are called depending on whether the user
@@ -903,7 +902,9 @@ def mask(self, W: tensor) -> np.ndarray:
903902
# Extract those non-zero values
904903
return self.data[tuple(wsubs.transpose())]
905904

906-
def mttkrp(self, U: Union[ttb.ktensor, Sequence[np.ndarray]], n: int) -> np.ndarray:
905+
def mttkrp(
906+
self, U: Union[ttb.ktensor, Sequence[np.ndarray]], n: Union[int, np.integer]
907+
) -> np.ndarray:
907908
"""
908909
Matricized tensor times Khatri-Rao product. The matrices used in the
909910
Khatri-Rao product are passed as a :class:`pyttb.ktensor` (where the
@@ -1272,7 +1273,9 @@ def squeeze(self) -> Union[tensor, float]:
12721273
else:
12731274
idx = np.where(shapeArray > 1)
12741275
if idx[0].size == 0:
1275-
return self.data.item()
1276+
# Why is item annotated as str?
1277+
single_item: float = cast(float, self.data.item())
1278+
return single_item
12761279
return ttb.tensor(np.squeeze(self.data))
12771280

12781281
def symmetrize( # noqa: PLR0912,PLR0915
@@ -1518,7 +1521,7 @@ def ttm(
15181521
newshape = np.array(
15191522
[p, *list(shape[range(0, n)]), *list(shape[range(n + 1, self.ndims)])]
15201523
)
1521-
Y_data = np.reshape(newdata, newshape, order=self.order)
1524+
Y_data: np.ndarray = np.reshape(newdata, newshape, order=self.order)
15221525
Y_data = np.transpose(Y_data, np.argsort(order))
15231526
return ttb.tensor(Y_data, copy=False)
15241527

@@ -1774,7 +1777,7 @@ def ttsv(
17741777

17751778
# extract scalar if needed
17761779
if len(y) == 1:
1777-
y = y.item()
1780+
return cast(float, y.item())
17781781

17791782
return y
17801783
assert False, "Invalid value for version; should be None, 1, or 2"
@@ -2600,7 +2603,10 @@ def tenones(shape: Shape, order: Union[Literal["F"], Literal["C"]] = "F") -> ten
26002603
[1. 1. 1.]
26012604
[1. 1. 1.]]
26022605
"""
2603-
ones = partial(np.ones, order=order)
2606+
2607+
def ones(shape: Tuple[int, ...]) -> np.ndarray:
2608+
return np.ones(shape, order=order)
2609+
26042610
return tensor.from_function(ones, shape)
26052611

26062612

@@ -2634,7 +2640,10 @@ def tenzeros(shape: Shape, order: Union[Literal["F"], Literal["C"]] = "F") -> te
26342640
[0. 0. 0.]
26352641
[0. 0. 0.]]
26362642
"""
2637-
zeros = partial(np.zeros, order=order)
2643+
2644+
def zeros(shape: Tuple[int, ...]) -> np.ndarray:
2645+
return np.zeros(shape, order=order)
2646+
26382647
return tensor.from_function(zeros, shape)
26392648

26402649

0 commit comments

Comments
 (0)