diff --git a/lectures/jax_intro.md b/lectures/jax_intro.md index 55fc1a30..fd119c6d 100644 --- a/lectures/jax_intro.md +++ b/lectures/jax_intro.md @@ -53,7 +53,6 @@ We'll use the following imports import jax import jax.numpy as jnp import matplotlib.pyplot as plt -import matplotlib.patches as mpatches import numpy as np import quantecon as qe ``` @@ -113,7 +112,8 @@ jnp.linalg.inv(B) # Inverse of identity is identity ``` ```{code-cell} ipython3 -jnp.linalg.eigh(B) # Computes eigenvalues and eigenvectors +eigvals, eigvecs = jnp.linalg.eigh(B) # Computes eigenvalues and eigenvectors +eigvals ``` @@ -121,9 +121,120 @@ jnp.linalg.eigh(B) # Computes eigenvalues and eigenvectors Let's now look at some differences between JAX and NumPy array operations. +(jax_speed)= +#### Speed! + +Let's say we want to evaluate the cosine function at many points. + +```{code-cell} +n = 50_000_000 +x = np.linspace(0, 10, n) +``` + +##### With NumPy + +Let's try with NumPy + +```{code-cell} +with qe.Timer(): + y = np.cos(x) +``` + +And one more time. + +```{code-cell} +with qe.Timer(): + y = np.cos(x) +``` + +Here + +* NumPy uses a pre-built binary for applying cosine to an array of floats +* The binary runs on the local machine's CPU + + +##### With JAX + +Now let's try with JAX. + +```{code-cell} +x = jnp.linspace(0, 10, n) +``` + +Let's time the same procedure. + +```{code-cell} +with qe.Timer(): + y = jnp.cos(x) + jax.block_until_ready(y); +``` + +```{note} +Here, in order to measure actual speed, we use the `block_until_ready` method +to hold the interpreter until the results of the computation are returned. + +This is necessary because JAX uses asynchronous dispatch, which +allows the Python interpreter to run ahead of numerical computations. + +For non-timed code, you can drop the line containing `block_until_ready`. +``` + +And let's time it again. + + +```{code-cell} +with qe.Timer(): + y = jnp.cos(x) + jax.block_until_ready(y); +``` + +On a GPU, this code runs much faster than its NumPy equivalent. + +Also, typically, the second run is faster than the first due to JIT compilation. + +This is because even built in functions like `jnp.cos` are JIT-compiled --- and the +first run includes compile time. + +Why would JAX want to JIT-compile built in functions like `jnp.cos` instead of +just providing pre-compiled versions, like NumPy? + +The reason is that the JIT compiler wants to specialize on the *size* of the array +being used (as well as the data type). + +The size matters for generating optimized code because efficient parallelization +requires matching the size of the task to the available hardware. + +We can verify the claim that JAX specializes on array size by changing the input size and watching the runtimes. + +```{code-cell} +x = jnp.linspace(0, 10, n + 1) +``` + +```{code-cell} +with qe.Timer(): + y = jnp.cos(x) + jax.block_until_ready(y); +``` + + +```{code-cell} +with qe.Timer(): + y = jnp.cos(x) + jax.block_until_ready(y); +``` + +The run time increases and then falls again (this will be more obvious on the GPU). + +This is in line with the discussion above -- the first run after changing array +size shows compilation overhead. + +Further discussion of JIT compilation is provided below. + + + #### Precision -One difference between NumPy and JAX is that JAX uses 32 bit floats by default. +Another difference between NumPy and JAX is that JAX uses 32 bit floats by default. This is because JAX is often used for GPU computing, and most GPU computations use 32 bit floats. @@ -131,7 +242,7 @@ Using 32 bit floats can lead to significant speed gains with small loss of preci However, for some calculations precision matters. -In these cases 64 bit floats can be enforced via the command +In these cases 64 bit floats can be enforced via the command ```{code-cell} ipython3 jax.config.update("jax_enable_x64", True) @@ -143,6 +254,7 @@ Let's check this works: jnp.ones(3) ``` + #### Immutability As a NumPy replacement, a more significant difference is that arrays are treated as **immutable**. @@ -161,7 +273,7 @@ a[0] = 1 a ``` -In JAX this fails: +In JAX this fails! ```{code-cell} ipython3 a = jnp.linspace(0, 1, 3) @@ -169,29 +281,16 @@ a ``` ```{code-cell} ipython3 -:tags: [raises-exception] - -a[0] = 1 -``` - -In line with immutability, JAX does not support inplace operations: - -```{code-cell} ipython3 -a = np.array((2, 1)) -a.sort() # Unlike NumPy, does not mutate a -a -``` +try: + a[0] = 1 +except Exception as e: + print(e) -```{code-cell} ipython3 -a = jnp.array((2, 1)) -a_new = a.sort() # Instead, the sort method returns a new sorted array -a, a_new ``` The designers of JAX chose to make arrays immutable because JAX uses a -[functional programming](https://en.wikipedia.org/wiki/Functional_programming) style. +functional programming style, which we discuss below. -This design choice has important implications, which we explore next! #### A workaround @@ -225,7 +324,9 @@ From JAX's documentation: *When walking about the countryside of Italy, the people will not hesitate to tell you that JAX has "una anima di pura programmazione funzionale".* -In other words, JAX assumes a functional programming style. +In other words, JAX assumes a +[functional programming](https://en.wikipedia.org/wiki/Functional_programming) +style. ### Pure functions @@ -284,21 +385,28 @@ def add_tax_pure(prices, tax_rate): This pure version makes all dependencies explicit through function arguments, and doesn't modify any external state. -Now that we understand what pure functions are, let's explore how JAX's approach to random numbers maintains this purity. +### Why Functional Programming? + +JAX represents functions as computational graphs, which are then compiled or transformed (e.g., differentiated) + +These computational graphs describe how a given set of inputs is transformed into an output. + +They are pure by construction. + +JAX uses a functional programming style so that user-built functions map +directly into the graph-theoretic representations supported by JAX. ## Random numbers -Random numbers are rather different in JAX, compared to what you find in NumPy -or Matlab. +Random number generation in JAX differs significantly from the patterns found in NumPy or MATLAB. At first you might find the syntax rather verbose. -But you will soon realize that the syntax and semantics are necessary in order -to maintain the functional programming style we just discussed. +But the syntax and semantics are necessary to maintain the functional programming style we just discussed. -Moreover, full control of random state is -essential for parallel programming, such as when we want to run independent experiments along multiple threads. +Moreover, full control of random state is essential for parallel programming, +such as when we want to run independent experiments along multiple threads. ### Random number generation @@ -463,7 +571,8 @@ Let's see how random number generation relates to pure functions by comparing Nu #### NumPy's approach -In NumPy, random number generation works by maintaining hidden global state. +In NumPy's legacy random number generation API (which mimics MATLAB), generation +works by maintaining hidden global state. Each time we call a random function, this state is updated: @@ -523,128 +632,20 @@ The explicitness of JAX brings significant benefits: The last point is expanded on in the next section. -## JIT compilation +## JIT Compilation The JAX just-in-time (JIT) compiler accelerates execution by generating efficient machine code that varies with both task size and hardware. -### A simple example - -Let's say we want to evaluate the cosine function at many points. - -```{code-cell} -n = 50_000_000 -x = np.linspace(0, 10, n) -``` - -#### With NumPy - -Let's try with NumPy - -```{code-cell} -with qe.Timer(): - y = np.cos(x) -``` - -And one more time. - -```{code-cell} -with qe.Timer(): - y = np.cos(x) -``` - -Here NumPy uses a pre-built binary file, compiled from carefully written -low-level code, for applying cosine to an array of floats. - -This binary file ships with NumPy. - -#### With JAX - -Now let's try with JAX. - -```{code-cell} -x = jnp.linspace(0, 10, n) -``` - -Let's time the same procedure. - -```{code-cell} -with qe.Timer(): - y = jnp.cos(x) - jax.block_until_ready(y); -``` - -```{note} -Here, in order to measure actual speed, we use the `block_until_ready` method -to hold the interpreter until the results of the computation are returned. - -This is necessary because JAX uses asynchronous dispatch, which -allows the Python interpreter to run ahead of numerical computations. - -For non-timed code, you can drop the line containing `block_until_ready`. -``` - - -And let's time it again. - - -```{code-cell} -with qe.Timer(): - y = jnp.cos(x) - jax.block_until_ready(y); -``` - -On a GPU, this code runs much faster than its NumPy equivalent. - -Also, typically, the second run is faster than the first due to JIT compilation. - -This is because even built in functions like `jnp.cos` are JIT-compiled --- and the -first run includes compile time. - -Why would JAX want to JIT-compile built in functions like `jnp.cos` instead of -just providing pre-compiled versions, like NumPy? - -The reason is that the JIT compiler wants to specialize on the *size* of the array -being used (as well as the data type). - -The size matters for generating optimized code because efficient parallelization -requires matching the size of the task to the available hardware. - -That's why JAX waits to see the size of the array before compiling --- which -requires a JIT-compiled approach instead of supplying precompiled binaries. +We saw the power of JAX's JIT compiler combined with parallel hardware when we +{ref}`above `, when we applied `cos` to a large array. - -#### Changing array sizes - -Here we change the input size and watch the runtimes. - -```{code-cell} -x = jnp.linspace(0, 10, n + 1) -``` - -```{code-cell} -with qe.Timer(): - y = jnp.cos(x) - jax.block_until_ready(y); -``` - - -```{code-cell} -with qe.Timer(): - y = jnp.cos(x) - jax.block_until_ready(y); -``` - -Typically, the run time increases and then falls again (this will be more obvious on the GPU). - -This is because the JIT compiler specializes on array size to exploit -parallelization --- and hence generates fresh compiled code when the array size -changes. +Let's try the same thing with a more complex function. ### Evaluating a more complicated function -Let's try the same thing with a more complex function. +Consider the function ```{code-cell} def f(x): @@ -700,73 +701,14 @@ with qe.Timer(): The outcome is similar to the `cos` example --- JAX is faster, especially on the second run after JIT compilation. -Moreover, with JAX, we have another trick up our sleeve --- we can JIT-compile +However, with JAX, we have another trick up our sleeve --- we can JIT-compile the *entire* function, not just individual operations. -### How JIT compilation works - -When we apply `jax.jit` to a function, JAX *traces* it: instead of executing -the operations immediately, it records the sequence of operations as a -computational graph and hands that graph to the -[XLA](https://openxla.org/xla) compiler. - -XLA then fuses and optimizes the operations into a single compiled kernel -tailored to the available hardware (CPU, GPU, or TPU). - -The following diagram shows this pipeline for a simple function: - -```{code-cell} ipython3 -:tags: [hide-input] - -fig, ax = plt.subplots(figsize=(7, 2)) -ax.set_xlim(-0.2, 7.2) -ax.set_ylim(0.2, 2.2) -ax.axis('off') - -# Boxes for pipeline stages -stages = [ - (0.7, 1.2, "Python\nfunction"), - (2.6, 1.2, "computational\ngraph"), - (4.5, 1.2, "optimized\nkernel"), - (6.4, 1.2, "fast\nexecution"), -] - -colors = ["#e3f2fd", "#fff9c4", "#f3e5f5", "#d4edda"] - -for (x, y, label), color in zip(stages, colors): - box = mpatches.FancyBboxPatch( - (x - 0.7, y - 0.5), 1.4, 1.0, - boxstyle="round,pad=0.15", - facecolor=color, edgecolor="black", linewidth=1.5) - ax.add_patch(box) - ax.text(x, y, label, ha='center', va='center', fontsize=9) - -# Arrows with labels -arrows = [ - (1.4, 1.9, "trace"), - (3.3, 3.8, "XLA"), - (5.2, 5.7, "run"), -] - -for x_start, x_end, label in arrows: - ax.annotate("", xy=(x_end, 1.2), xytext=(x_start, 1.2), - arrowprops=dict(arrowstyle="->", lw=1.5, color="gray")) - ax.text((x_start + x_end) / 2, 1.55, label, - ha='center', fontsize=8, color='gray') - -plt.tight_layout() -plt.show() -``` - -The first call to a JIT-compiled function incurs compilation overhead, but -subsequent calls with the same input shapes and types reuse the cached -compiled code and run at full speed. - ### Compiling the whole function -The JAX just-in-time (JIT) compiler can accelerate execution within functions by fusing linear -algebra operations into a single optimized kernel. +The JAX just-in-time (JIT) compiler can accelerate execution within functions by fusing array +operations into a single optimized kernel. Let's try this with the function `f`: @@ -802,11 +744,29 @@ def f(x): pass # put function body here ``` +### How JIT compilation works + +When we apply `jax.jit` to a function, JAX *traces* it: instead of executing +the operations immediately, it records the sequence of operations as a +computational graph and hands that graph to the +[XLA](https://openxla.org/xla) compiler. + +XLA then fuses and optimizes the operations into a single compiled kernel +tailored to the available hardware (CPU, GPU, or TPU). + +The first call to a JIT-compiled function incurs compilation overhead, but +subsequent calls with the same input shapes and types reuse the cached +compiled code and run at full speed. + + + ### Compiling non-pure functions -Now that we've seen how powerful JIT compilation can be, it's important to understand its relationship with pure functions. +Now that we've seen how powerful JIT compilation can be, it's important to +understand its relationship with pure functions. -While JAX will not usually throw errors when compiling impure functions, execution becomes unpredictable. +While JAX will not usually throw errors when compiling impure functions, +execution becomes unpredictable. Here's an illustration of this fact, using global variables: