diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index ffb031a5f2bc..5a5511b785c3 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -3364,13 +3364,6 @@ def trace(a: ArrayLike, offset: int = 0, axis1: int = 0, axis2: int = 1, dtypes.check_user_dtype_supported(dtype, "trace") a_shape = shape(a) - if dtype is None: - dtype = _dtype(a) - if issubdtype(dtype, integer): - default_int = dtypes.canonicalize_dtype(int) - if iinfo(dtype).bits < iinfo(default_int).bits: - dtype = default_int - a = moveaxis(a, (axis1, axis2), (-2, -1)) # Mask out the diagonal and reduce. diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index 8c5594cb4541..a6f4e1f7b8e8 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -26,7 +26,7 @@ from jax import lax from jax._src import api -from jax._src import core, config +from jax._src import core from jax._src import dtypes from jax._src.numpy import ufuncs from jax._src.numpy.util import ( @@ -65,6 +65,20 @@ def _upcast_f16(dtype: DTypeLike) -> DType: return np.dtype('float32') return np.dtype(dtype) +def _promote_integer_dtype(dtype: DTypeLike) -> DTypeLike: + # Note: NumPy always promotes to 64-bit; jax instead promotes to the + # default dtype as defined by dtypes.int_ or dtypes.uint. + if dtypes.issubdtype(dtype, np.bool_): + return dtypes.int_ + elif dtypes.issubdtype(dtype, np.unsignedinteger): + if np.iinfo(dtype).bits < np.iinfo(dtypes.uint).bits: + return dtypes.uint + elif dtypes.issubdtype(dtype, np.integer): + if np.iinfo(dtype).bits < np.iinfo(dtypes.int_).bits: + return dtypes.int_ + return dtype + + ReductionOp = Callable[[Any, Any], Any] def _reduction(a: ArrayLike, name: str, np_fun: Any, op: ReductionOp, init_val: ArrayLike, @@ -103,16 +117,7 @@ def _reduction(a: ArrayLike, name: str, np_fun: Any, op: ReductionOp, init_val: result_dtype = dtype or dtypes.dtype(a) if dtype is None and promote_integers: - # Note: NumPy always promotes to 64-bit; jax instead promotes to the - # default dtype as defined by dtypes.int_ or dtypes.uint. - if dtypes.issubdtype(result_dtype, np.bool_): - result_dtype = dtypes.int_ - elif dtypes.issubdtype(result_dtype, np.unsignedinteger): - if np.iinfo(result_dtype).bits < np.iinfo(dtypes.uint).bits: - result_dtype = dtypes.uint - elif dtypes.issubdtype(result_dtype, np.integer): - if np.iinfo(result_dtype).bits < np.iinfo(dtypes.int_).bits: - result_dtype = dtypes.int_ + result_dtype = _promote_integer_dtype(result_dtype) result_dtype = dtypes.canonicalize_dtype(result_dtype) @@ -663,7 +668,8 @@ def __call__(self, a: ArrayLike, axis: Axis = None, """ def _make_cumulative_reduction(np_reduction: Any, reduction: Callable[..., Array], - fill_nan: bool = False, fill_value: ArrayLike = 0) -> CumulativeReduction: + fill_nan: bool = False, fill_value: ArrayLike = 0, + promote_integers: bool = False) -> CumulativeReduction: @implements(np_reduction, skip_params=['out'], lax_description=CUML_REDUCTION_LAX_DESCRIPTION) def cumulative_reduction(a: ArrayLike, axis: Axis = None, @@ -691,12 +697,18 @@ def _cumulative_reduction(a: ArrayLike, axis: Axis = None, if fill_nan: a = _where(lax_internal._isnan(a), _lax_const(a, fill_value), a) - if not dtype and dtypes.dtype(a) == np.bool_: - dtype = dtypes.canonicalize_dtype(dtypes.int_) - if dtype: - a = lax.convert_element_type(a, dtype) + result_type: DTypeLike = dtypes.dtype(dtype or a) + if dtype is None and promote_integers or dtypes.issubdtype(result_type, np.bool_): + result_type = _promote_integer_dtype(result_type) + result_type = dtypes.canonicalize_dtype(result_type) + + a = lax.convert_element_type(a, result_type) + result = reduction(a, axis) - return reduction(a, axis) + # We downcast to boolean because we accumulate in integer types + if dtypes.issubdtype(dtype, np.bool_): + result = lax.convert_element_type(result, np.bool_) + return result return cumulative_reduction @@ -707,6 +719,9 @@ def _cumulative_reduction(a: ArrayLike, axis: Axis = None, fill_nan=True, fill_value=0) nancumprod = _make_cumulative_reduction(np.nancumprod, lax.cumprod, fill_nan=True, fill_value=1) +_cumsum_with_promotion = _make_cumulative_reduction( + np.cumsum, lax.cumsum, fill_nan=False, promote_integers=True +) @implements(getattr(np, 'cumulative_sum', None)) def cumulative_sum( @@ -730,12 +745,7 @@ def cumulative_sum( axis = _canonicalize_axis(axis, x.ndim) dtypes.check_user_dtype_supported(dtype) - kind = x.dtype.kind - if (dtype is None and kind in {'i', 'u'} - and x.dtype.itemsize*8 < int(config.default_dtype_bits.value)): - dtype = dtypes.canonicalize_dtype(dtypes._default_types[kind]) - x = x.astype(dtype=dtype or x.dtype) - out = cumsum(x, axis=axis) + out = _cumsum_with_promotion(x, axis=axis, dtype=dtype) if include_initial: zeros_shape = list(x.shape) zeros_shape[axis] = 1 diff --git a/jax/experimental/array_api/_data_type_functions.py b/jax/experimental/array_api/_data_type_functions.py index 988840c31381..248c1c6dd0fe 100644 --- a/jax/experimental/array_api/_data_type_functions.py +++ b/jax/experimental/array_api/_data_type_functions.py @@ -76,18 +76,3 @@ def finfo(type, /) -> FInfo: smallest_normal=float(info.smallest_normal), dtype=jnp.dtype(type) ) - -# TODO(micky774): Update utility to only promote integral types -def _promote_to_default_dtype(x): - if x.dtype.kind == 'b': - return x - elif x.dtype.kind == 'i': - return x.astype(jnp.int_) - elif x.dtype.kind == 'u': - return x.astype(jnp.uint) - elif x.dtype.kind == 'f': - return x.astype(jnp.float_) - elif x.dtype.kind == 'c': - return x.astype(jnp.complex_) - else: - raise ValueError(f"Unrecognized {x.dtype=}") diff --git a/tests/lax_numpy_reducers_test.py b/tests/lax_numpy_reducers_test.py index 6c7ea59ef2d0..861b9014c589 100644 --- a/tests/lax_numpy_reducers_test.py +++ b/tests/lax_numpy_reducers_test.py @@ -791,13 +791,8 @@ def testCumulativeSum(self, shape, axis, dtype, out_dtype, include_initial): rng = jtu.rand_some_zero(self.rng()) def np_mock_op(x, axis=None, dtype=None, include_initial=False): - kind = x.dtype.kind - if (dtype is None and kind in {'i', 'u'} - and x.dtype.itemsize*8 < int(config.default_dtype_bits.value)): - dtype = dtypes.canonicalize_dtype(dtypes._default_types[kind]) axis = axis or 0 - x = x.astype(dtype=dtype or x.dtype) - out = jnp.cumsum(x, axis=axis) + out = np.cumsum(x, axis=axis, dtype=dtype or x.dtype) if include_initial: zeros_shape = list(x.shape) zeros_shape[axis] = 1