|
8 | 8 |
|
9 | 9 | import logging |
10 | 10 | from collections.abc import Iterable |
11 | | -from functools import partial |
12 | 11 | from itertools import combinations_with_replacement, permutations |
13 | 12 | from math import factorial, prod |
14 | 13 | from typing import ( |
|
20 | 19 | Sequence, |
21 | 20 | Tuple, |
22 | 21 | Union, |
| 22 | + cast, |
23 | 23 | overload, |
24 | 24 | ) |
25 | 25 |
|
@@ -734,9 +734,8 @@ def issymmetric( # noqa: PLR0912 |
734 | 734 | sz = np.array(self.shape) |
735 | 735 |
|
736 | 736 | if grps is None: |
737 | | - grps = np.arange(0, n) |
738 | | - |
739 | | - if len(grps.shape) == 1: |
| 737 | + grps = np.arange(0, n)[None, :] |
| 738 | + elif len(grps.shape) == 1: |
740 | 739 | grps = np.array([grps]) |
741 | 740 |
|
742 | 741 | # Substantially different routines are called depending on whether the user |
@@ -903,7 +902,9 @@ def mask(self, W: tensor) -> np.ndarray: |
903 | 902 | # Extract those non-zero values |
904 | 903 | return self.data[tuple(wsubs.transpose())] |
905 | 904 |
|
906 | | - def mttkrp(self, U: Union[ttb.ktensor, Sequence[np.ndarray]], n: int) -> np.ndarray: |
| 905 | + def mttkrp( |
| 906 | + self, U: Union[ttb.ktensor, Sequence[np.ndarray]], n: Union[int, np.integer] |
| 907 | + ) -> np.ndarray: |
907 | 908 | """ |
908 | 909 | Matricized tensor times Khatri-Rao product. The matrices used in the |
909 | 910 | Khatri-Rao product are passed as a :class:`pyttb.ktensor` (where the |
@@ -1272,7 +1273,9 @@ def squeeze(self) -> Union[tensor, float]: |
1272 | 1273 | else: |
1273 | 1274 | idx = np.where(shapeArray > 1) |
1274 | 1275 | if idx[0].size == 0: |
1275 | | - return self.data.item() |
| 1276 | + # Why is item annotated as str? |
| 1277 | + single_item: float = cast(float, self.data.item()) |
| 1278 | + return single_item |
1276 | 1279 | return ttb.tensor(np.squeeze(self.data)) |
1277 | 1280 |
|
1278 | 1281 | def symmetrize( # noqa: PLR0912,PLR0915 |
@@ -1518,7 +1521,7 @@ def ttm( |
1518 | 1521 | newshape = np.array( |
1519 | 1522 | [p, *list(shape[range(0, n)]), *list(shape[range(n + 1, self.ndims)])] |
1520 | 1523 | ) |
1521 | | - Y_data = np.reshape(newdata, newshape, order=self.order) |
| 1524 | + Y_data: np.ndarray = np.reshape(newdata, newshape, order=self.order) |
1522 | 1525 | Y_data = np.transpose(Y_data, np.argsort(order)) |
1523 | 1526 | return ttb.tensor(Y_data, copy=False) |
1524 | 1527 |
|
@@ -1774,7 +1777,7 @@ def ttsv( |
1774 | 1777 |
|
1775 | 1778 | # extract scalar if needed |
1776 | 1779 | if len(y) == 1: |
1777 | | - y = y.item() |
| 1780 | + return cast(float, y.item()) |
1778 | 1781 |
|
1779 | 1782 | return y |
1780 | 1783 | assert False, "Invalid value for version; should be None, 1, or 2" |
@@ -2600,7 +2603,10 @@ def tenones(shape: Shape, order: Union[Literal["F"], Literal["C"]] = "F") -> ten |
2600 | 2603 | [1. 1. 1.] |
2601 | 2604 | [1. 1. 1.]] |
2602 | 2605 | """ |
2603 | | - ones = partial(np.ones, order=order) |
| 2606 | + |
| 2607 | + def ones(shape: Tuple[int, ...]) -> np.ndarray: |
| 2608 | + return np.ones(shape, order=order) |
| 2609 | + |
2604 | 2610 | return tensor.from_function(ones, shape) |
2605 | 2611 |
|
2606 | 2612 |
|
@@ -2634,7 +2640,10 @@ def tenzeros(shape: Shape, order: Union[Literal["F"], Literal["C"]] = "F") -> te |
2634 | 2640 | [0. 0. 0.] |
2635 | 2641 | [0. 0. 0.]] |
2636 | 2642 | """ |
2637 | | - zeros = partial(np.zeros, order=order) |
| 2643 | + |
| 2644 | + def zeros(shape: Tuple[int, ...]) -> np.ndarray: |
| 2645 | + return np.zeros(shape, order=order) |
| 2646 | + |
2638 | 2647 | return tensor.from_function(zeros, shape) |
2639 | 2648 |
|
2640 | 2649 |
|
|
0 commit comments