66
77from __future__ import annotations
88
9- from typing import Dict , List , Literal , Optional , Tuple , Union
9+ from typing import Dict , Literal , Optional , Tuple , Union
1010
1111import numpy as np
1212
1313import pyttb as ttb
14+ from pyttb .pyttb_utils import OneDArray , parse_one_d
1415
1516
1617def cp_als ( # noqa: PLR0912,PLR0913,PLR0915
1718 input_tensor : Union [ttb .tensor , ttb .sptensor , ttb .ttensor , ttb .sumtensor ],
1819 rank : int ,
1920 stoptol : float = 1e-4 ,
2021 maxiters : int = 1000 ,
21- dimorder : Optional [List [ int ] ] = None ,
22- optdims : Optional [List [ int ] ] = None ,
22+ dimorder : Optional [OneDArray ] = None ,
23+ optdims : Optional [OneDArray ] = None ,
2324 init : Union [Literal ["random" ], Literal ["nvecs" ], ttb .ktensor ] = "random" ,
2425 printitn : int = 1 ,
2526 fixsigns : bool = True ,
@@ -109,8 +110,8 @@ def cp_als( # noqa: PLR0912,PLR0913,PLR0915
109110 [[0.1467... 0.0923...]
110111 [0.1862... 0.3455...]]
111112 >>> print(output["params"]) # doctest: +NORMALIZE_WHITESPACE
112- {'stoptol': 0.0001, 'maxiters': 1000, 'dimorder': [0, 1],\
113- 'optdims': [0, 1], 'printitn': 1, 'fixsigns': True}
113+ {'stoptol': 0.0001, 'maxiters': 1000, 'dimorder': array( [0, 1]) ,\
114+ 'optdims': array( [0, 1]) , 'printitn': 1, 'fixsigns': True}
114115
115116 Example using "nvecs" initialization:
116117
@@ -135,15 +136,17 @@ def cp_als( # noqa: PLR0912,PLR0913,PLR0915
135136
136137 # Set up dimorder if not specified
137138 if dimorder is None :
138- dimorder = list ( range ( N ) )
139- elif not isinstance ( dimorder , list ) :
140- assert False , "Dimorder must be a list"
141- elif tuple (range (N )) != tuple (sorted (dimorder )):
139+ dimorder = np . arange ( N )
140+ else :
141+ dimorder = parse_one_d ( dimorder )
142+ if tuple (range (N )) != tuple (sorted (dimorder )):
142143 assert False , "Dimorder must be a list or permutation of range(tensor.ndims)"
143144
144145 # Set up optdims if not specified
145146 if optdims is None :
146- optdims = list (range (N ))
147+ optdims = np .arange (N )
148+ else :
149+ optdims = parse_one_d (optdims )
147150
148151 # Error checking
149152 assert rank > 0 , "Number of components requested must be positive"
0 commit comments