Skip to content

Commit beec748

Browse files
committed
FIX: random.draw: Add =None
Fix #704
1 parent 535a82f commit beec748

File tree

2 files changed

+14
-3
lines changed

2 files changed

+14
-3
lines changed

quantecon/random/tests/test_utilities.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,3 +108,14 @@ def test_lln(self):
108108
pmf_computed = hist * np.diff(bin_edges)
109109
atol = 1e-2
110110
assert_allclose(pmf_computed, self.pmf, atol=atol)
111+
112+
113+
@njit
114+
def draw_jitted_w_o_size(n):
115+
cdf = np.linspace(1/n, 1, n)
116+
return draw(cdf)
117+
118+
119+
def test_draw_jitted_w_o_size():
120+
n = 3
121+
assert_(draw_jitted_w_o_size(n) in range(n))

quantecon/random/utilities.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -211,16 +211,16 @@ def draw(cdf, size=None):
211211

212212
# Overload for the `draw` function
213213
@overload(draw)
214-
def ol_draw(cdf, size):
214+
def ol_draw(cdf, size=None):
215215
if isinstance(size, types.Integer):
216-
def draw_impl(cdf, size):
216+
def draw_impl(cdf, size=None):
217217
rs = np.random.random(size)
218218
out = np.empty(size, dtype=np.int_)
219219
for i in range(size):
220220
out[i] = searchsorted(cdf, rs[i])
221221
return out
222222
else:
223-
def draw_impl(cdf, size):
223+
def draw_impl(cdf, size=None):
224224
r = np.random.random()
225225
return searchsorted(cdf, r)
226226
return draw_impl

0 commit comments

Comments
 (0)