Skip to content

Commit ebe67ee

Browse files
committed
update JAX shape
1 parent 41eb913 commit ebe67ee

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

lectures/jax_intro.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,12 +264,18 @@ for A in matrices:
264264
print(A)
265265
```
266266

267-
One point to remember is that JAX expects tuples to describe array shapes, even for flat arrays. Hence, to get a one-dimensional array of normal random draws we use `(len, )` for the shape, as in
267+
To get a one-dimensional array of normal random draws, we can either use `(len, )` for the shape, as in
268268

269269
```{code-cell} ipython3
270270
random.normal(key, (5, ))
271271
```
272272

273+
or simply use `5` as the shape argument:
274+
275+
```{code-cell} ipython3
276+
random.normal(key, 5)
277+
```
278+
273279
## JIT compilation
274280

275281
The JAX just-in-time (JIT) compiler accelerates logic within functions by fusing linear

0 commit comments

Comments
 (0)