@@ -2272,17 +2272,42 @@ def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike:
2272
2272
In particular, the details of float-to-int and int-to-float casts are
2273
2273
implementation dependent.
2274
2274
""" )
2275
- def astype (x : ArrayLike , dtype : DTypeLike | None , / , * , copy : bool = True ) -> Array :
2275
+ def astype (x : ArrayLike , dtype : DTypeLike | None ,
2276
+ / , * , copy : bool = False ,
2277
+ device : xc .Device | Sharding | None = None ) -> Array :
2276
2278
util .check_arraylike ("astype" , x )
2277
2279
x_arr = asarray (x )
2278
- del copy # unused in JAX
2280
+
2279
2281
if dtype is None :
2280
2282
dtype = dtypes .canonicalize_dtype (float_ )
2281
2283
dtypes .check_user_dtype_supported (dtype , "astype" )
2282
- # convert_element_type(complex, bool) has the wrong semantics.
2283
- if np .dtype (dtype ) == bool and issubdtype (x_arr .dtype , complexfloating ):
2284
- return (x_arr != _lax_const (x_arr , 0 ))
2285
- return lax .convert_element_type (x_arr , dtype )
2284
+ if issubdtype (x_arr .dtype , complexfloating ):
2285
+ if dtypes .isdtype (dtype , ("integral" , "real floating" )):
2286
+ warnings .warn (
2287
+ "Casting from complex to real dtypes will soon raise a ValueError. "
2288
+ "Please first use jnp.real or jnp.imag to take the real/imaginary "
2289
+ "component of your input." ,
2290
+ DeprecationWarning , stacklevel = 2
2291
+ )
2292
+ elif np .dtype (dtype ) == bool :
2293
+ # convert_element_type(complex, bool) has the wrong semantics.
2294
+ x_arr = (x_arr != _lax_const (x_arr , 0 ))
2295
+
2296
+ # We offer a more specific warning than the usual ComplexWarning so we prefer
2297
+ # to issue our warning.
2298
+ with warnings .catch_warnings ():
2299
+ warnings .simplefilter ("ignore" , ComplexWarning )
2300
+ return _place_array (
2301
+ lax .convert_element_type (x_arr , dtype ),
2302
+ device = device , copy = copy ,
2303
+ )
2304
+
2305
+ def _place_array (x , device = None , copy = None ):
2306
+ # TODO(micky774): Implement in future PRs as we formalize device placement
2307
+ # semantics
2308
+ if copy :
2309
+ return _array_copy (x )
2310
+ return x
2286
2311
2287
2312
2288
2313
@util .implements (np .asarray , lax_description = _ARRAY_DOC )
0 commit comments