Skip to content

Commit ae7fb26

Browse files
authored
Merge pull request #736 from QuantEcon/npy2
FIX: Fix NumPy v2 compatibility issue
2 parents a83c2ae + f48c505 commit ae7fb26

File tree

5 files changed

+86
-1
lines changed

5 files changed

+86
-1
lines changed

.github/workflows/ci_np2.yml

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
name: conda-build (NumPy v2)
2+
3+
on: [push]
4+
5+
jobs:
6+
build-linux:
7+
runs-on: ubuntu-latest
8+
strategy:
9+
max-parallel: 5
10+
11+
steps:
12+
- uses: actions/checkout@v4
13+
- name: Set up Python
14+
uses: actions/setup-python@v3
15+
with:
16+
python-version: '3.12'
17+
- name: Add conda to system path
18+
run: |
19+
# $CONDA is an environment variable pointing to the root of the miniconda directory
20+
echo $CONDA/bin >> $GITHUB_PATH
21+
- name: Install dependencies
22+
run: |
23+
conda env update --file environment_np2.yml --name base
24+
- name: Conda info
25+
shell: bash -l {0}
26+
run: |
27+
conda info
28+
conda list
29+
- name: Test with pytest
30+
run: |
31+
conda install pytest
32+
pytest

environment_np2.yml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
name: qe
2+
channels:
3+
- conda-forge
4+
- defaults
5+
dependencies:
6+
- coverage
7+
- numpy>=2
8+
- scipy
9+
- pandas
10+
- numba
11+
- sympy
12+
- ipython
13+
- flake8
14+
- requests
15+
- urllib3>=2
16+
- flit
17+
- chardet # python>3.9,osx
18+
- pytest

quantecon/markov/gth_solve.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
import numpy as np
77
from numba import jit
88

9+
from ..util.compat import copy_if_needed
10+
11+
912
def gth_solve(A, overwrite=False, use_jit=True):
1013
r"""
1114
This routine computes the stationary distribution of an irreducible
@@ -52,7 +55,9 @@ def gth_solve(A, overwrite=False, use_jit=True):
5255
Simulation, Princeton University Press, 2009.
5356
5457
"""
55-
A1 = np.array(A, dtype=float, copy=not overwrite, order='C')
58+
copy = copy_if_needed if overwrite else True
59+
60+
A1 = np.array(A, dtype=float, copy=copy, order='C')
5661
# `order='C'` is for use with Numba <= 0.18.2
5762
# See issue github.com/numba/numba/issues/1103
5863

quantecon/markov/tests/test_gth_solve.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,13 @@ def test_matrices_with_C_F_orders():
196196
assert_array_equal(computed_F, stationary_dist)
197197

198198

199+
def test_unable_to_avoid_copy():
200+
A = np.array([[0, 1], [0, 1]]) # dtype=int
201+
stationary_dist = [0., 1.]
202+
x = gth_solve(A, overwrite=True)
203+
assert_array_equal(x, stationary_dist)
204+
205+
199206
def test_raises_value_error_non_2dim():
200207
"""Test with non 2dim input"""
201208
assert_raises(ValueError, gth_solve, np.array([0.4, 0.6]))

quantecon/util/compat.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
"""
2+
Utilities for compatibility
3+
4+
"""
5+
from typing import Optional
6+
import numpy as np
7+
8+
9+
# From scipy/_lib/_util.py
10+
11+
copy_if_needed: Optional[bool]
12+
13+
if np.lib.NumpyVersion(np.__version__) >= "2.0.0":
14+
copy_if_needed = None
15+
elif np.lib.NumpyVersion(np.__version__) < "1.28.0":
16+
copy_if_needed = False
17+
else:
18+
# 2.0.0 dev versions, handle cases where copy may or may not exist
19+
try:
20+
np.array([1]).__array__(copy=None) # type: ignore[call-overload]
21+
copy_if_needed = None
22+
except TypeError:
23+
copy_if_needed = False

0 commit comments

Comments
 (0)