2
2
3
3
import cupy as cp
4
4
5
- from ..common import _aliases
5
+ from ..common import _aliases , _helpers
6
6
from .._internal import get_xp
7
7
8
8
from ._info import __array_namespace_info__
46
46
unique_counts = get_xp (cp )(_aliases .unique_counts )
47
47
unique_inverse = get_xp (cp )(_aliases .unique_inverse )
48
48
unique_values = get_xp (cp )(_aliases .unique_values )
49
- astype = _aliases .astype
50
49
std = get_xp (cp )(_aliases .std )
51
50
var = get_xp (cp )(_aliases .var )
52
51
cumulative_sum = get_xp (cp )(_aliases .cumulative_sum )
@@ -110,6 +109,21 @@ def asarray(
110
109
111
110
return cp .array (obj , dtype = dtype , ** kwargs )
112
111
112
+
113
+ def astype (
114
+ x : ndarray ,
115
+ dtype : Dtype ,
116
+ / ,
117
+ * ,
118
+ copy : bool = True ,
119
+ device : Optional [Device ] = None ,
120
+ ) -> ndarray :
121
+ if device is None :
122
+ return x .astype (dtype = dtype , copy = copy )
123
+ out = _helpers .to_device (x .astype (dtype = dtype , copy = False ), device )
124
+ return out .copy () if copy and out is x else out
125
+
126
+
113
127
# These functions are completely new here. If the library already has them
114
128
# (i.e., numpy 2.0), use the library version instead of our wrapper.
115
129
if hasattr (cp , 'vecdot' ):
@@ -127,10 +141,10 @@ def asarray(
127
141
else :
128
142
unstack = get_xp (cp )(_aliases .unstack )
129
143
130
- __all__ = _aliases .__all__ + ['__array_namespace_info__' , 'asarray' , 'bool ' ,
144
+ __all__ = _aliases .__all__ + ['__array_namespace_info__' , 'asarray' , 'astype ' ,
131
145
'acos' , 'acosh' , 'asin' , 'asinh' , 'atan' ,
132
146
'atan2' , 'atanh' , 'bitwise_left_shift' ,
133
147
'bitwise_invert' , 'bitwise_right_shift' ,
134
- 'concat' , 'pow' , 'sign' ]
148
+ 'bool' , ' concat' , 'pow' , 'sign' ]
135
149
136
150
_all_ignore = ['cp' , 'get_xp' ]
0 commit comments