Just in Time Compilation#
It is no secret that Python is slow. It is an interpreted language, and it is dynamically typed. This means that the Python interpreter has to do a lot of work to figure out what your code is doing, and then it has to do a lot of work to actually do it. This is in contrast to a language like C, which is statically typed and compiled. In C, the compiler knows exactly what your code is doing, and it can generate machine code that does exactly that. This is why C is so much faster than Python.
In the past, the solution to this problem was to write the slow parts of your code in C, and then call them from Python. This is what libraries like NumPy and SciPy do. However, this is a lot of work, and it is easy to make mistakes. It would be nice if we could just write our code in Python, and have it run as fast as C. This is where just-in-time (JIT) compilers come in.
A JIT compiler is a compiler that runs at runtime. It takes your Python code, and it compiles it to machine code. This is then executed by the CPU or GPU. Jax is a library that provides a JIT compiler for Python. It is built on top of XLA, which is a compiler for linear algebra operations developed by Google. Jax is designed to be used with NumPy, and it provides a NumPy-like API. This means that you can use Jax to speed up your NumPy code without having to rewrite it.
Let’s look at some examples.
import jax
import jax.numpy as jnp
Here is a simple mathematical function:
f = lambda x: jnp.sin(x) + jnp.cos(x)
To compile it with Jax, we just need to do this:
f_jit = jax.jit(f)
You can also jit
using decorators:
@jax.jit
def f_jit2(x):
return jnp.sin(x) + jnp.cos(x)
There are no real benefits in this case, because the function is so simple. However, when we jit
bigger chunks of code, e.g., the update step in the training loop of a neural network, we can see a significant speedup.
We will have the chance to observe the speedup in subsequent lectures.
You can only jit
pure functions#
No side effects are allowed. Here is an example of what may happen when you have side effects:
@jax.jit
def bad_f(x):
print("I have side effects!")
return jnp.sin(x) + jnp.cos(x) + x
First time we call the function, it works as expected.
bad_f(2)
I have side effects!
Array(2.4931505, dtype=float32, weak_type=True)
Bad the second time we call it, nothing is printed.
bad_f(2)
Array(2.4931505, dtype=float32, weak_type=True)
Here is what is happening. The first time we run the function, Jax parses it, looks for all the mathematical operations, and compiles them to machine code. By the way, it also prints the string. The second time we run the function, Jax does not parse it again, because it has already done that. It just runs the machine code. It does not print the string, because it is not part of the mathematical operations.
Be careful with loops#
jit
works with loops, but only if the have a fixed number of iterations that is known at compile time.
This is because jit
needs to know how many times to unroll the loop.
Unrolling a loop means replacing it with a sequence of instructions that perform the same operations as the loop.
Here is an example:
@jax.jit
def f_loop(x):
y = 0.
for i in range(10):
y = y + x
return y
This works fine:
f_loop(2)
Array(20., dtype=float32, weak_type=True)
This does not work:
@jax.jit
def f_loop(x, n):
y = 0.
for i in range(n):
y = y + x
return y
f_loop(2, 10)
Show code cell output
---------------------------------------------------------------------------
TracerIntegerConversionError Traceback (most recent call last)
/Users/ibilion/Dropbox/Teaching/S2024/ME697/scientific-machine-learning/learning/jit.ipynb Cell 22 line 1
----> <a href='vscode-notebook-cell:/Users/ibilion/Dropbox/Teaching/S2024/ME697/scientific-machine-learning/learning/jit.ipynb#X31sZmlsZQ%3D%3D?line=0'>1</a> f_loop(2, 10)
[... skipping hidden 12 frame]
/Users/ibilion/Dropbox/Teaching/S2024/ME697/scientific-machine-learning/learning/jit.ipynb Cell 22 line 4
<a href='vscode-notebook-cell:/Users/ibilion/Dropbox/Teaching/S2024/ME697/scientific-machine-learning/learning/jit.ipynb#X31sZmlsZQ%3D%3D?line=0'>1</a> @jax.jit
<a href='vscode-notebook-cell:/Users/ibilion/Dropbox/Teaching/S2024/ME697/scientific-machine-learning/learning/jit.ipynb#X31sZmlsZQ%3D%3D?line=1'>2</a> def f_loop(x, n):
<a href='vscode-notebook-cell:/Users/ibilion/Dropbox/Teaching/S2024/ME697/scientific-machine-learning/learning/jit.ipynb#X31sZmlsZQ%3D%3D?line=2'>3</a> y = 0.
----> <a href='vscode-notebook-cell:/Users/ibilion/Dropbox/Teaching/S2024/ME697/scientific-machine-learning/learning/jit.ipynb#X31sZmlsZQ%3D%3D?line=3'>4</a> for i in range(n):
<a href='vscode-notebook-cell:/Users/ibilion/Dropbox/Teaching/S2024/ME697/scientific-machine-learning/learning/jit.ipynb#X31sZmlsZQ%3D%3D?line=4'>5</a> y = y + x
<a href='vscode-notebook-cell:/Users/ibilion/Dropbox/Teaching/S2024/ME697/scientific-machine-learning/learning/jit.ipynb#X31sZmlsZQ%3D%3D?line=5'>6</a> return y
[... skipping hidden 1 frame]
File ~/.pyenv/versions/3.11.6/lib/python3.11/site-packages/jax/_src/core.py:1446, in concretization_function_error.<locals>.error(self, arg)
1445 def error(self, arg):
-> 1446 raise TracerIntegerConversionError(arg)
TracerIntegerConversionError: The __index__() method was called on traced array with shape int32[].
The error occurred while tracing the function f_loop at /var/folders/5y/28n32xmx0551k29hd21qs87c0000gp/T/ipykernel_41459/3023811968.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument n.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerIntegerConversionError
To make it work, we must tell Jax that the second argument to f_loop
is a constant.
It should not trace it.
Here is how we can make it work:
import functools as ft
@ft.partial(jax.jit, static_argnums=(1,))
def f_loop(x, n):
y = 0.
for i in range(n):
y = y + x
return y
Recall that the partial
is fixing one of the arguments of jax.jit
.
That argument is static_argnums
which is a tuple of integers corresponding to the indices of the arguments that are static (i.e., not traced).
So it works:
f_loop(2, 10)
Array(20., dtype=float32, weak_type=True)
However, there is a catch. The code is recomplied every time we call f_loop
with a different value of n
.
This, for example, triggers a recompilation:
f_loop(2, 11)
Array(22., dtype=float32, weak_type=True)
This is okay if we are going to call f_loop
with only a few different values of n
.
But is is not okay if we are going to call f_loop
with many different n
s.
If we want to avoid the recompilation, we need to use LAX control flow primitives.
LAX is the low-level API of Jax.
In particular, we need to use lax.fori_loop
.
from jax import lax
@jax.jit
def f_loop(x, n):
return lax.fori_loop(0, n, lambda i, y: y + x, 0.)
But you cannot call this function directly with Python scalars because lax
sits at a lower level than jax.numpy
.
You need to first convert the scalars to Jax arrays using jax.numpy.array
.
Here is the correct call:
f_loop(jnp.array(2), 15)
Array(30., dtype=float32, weak_type=True)
Other useful LAX loop functions are lax.while_loop
and lax.scan
.
Let’s see examples for both.
@jax.jit
def sum_up_to(x):
s = 0.0
n = 1
while n < x:
s = s + n
n = n + 1
return s
sum_up_to(jnp.array(5))
Show code cell output
---------------------------------------------------------------------------
TracerBoolConversionError Traceback (most recent call last)
/Users/ibilion/Dropbox/Teaching/S2024/ME697/scientific-machine-learning/learning/jit.ipynb Cell 34 line 1
<a href='vscode-notebook-cell:/Users/ibilion/Dropbox/Teaching/S2024/ME697/scientific-machine-learning/learning/jit.ipynb#X45sZmlsZQ%3D%3D?line=6'>7</a> s += n
<a href='vscode-notebook-cell:/Users/ibilion/Dropbox/Teaching/S2024/ME697/scientific-machine-learning/learning/jit.ipynb#X45sZmlsZQ%3D%3D?line=7'>8</a> return s
---> <a href='vscode-notebook-cell:/Users/ibilion/Dropbox/Teaching/S2024/ME697/scientific-machine-learning/learning/jit.ipynb#X45sZmlsZQ%3D%3D?line=9'>10</a> sum_up_to(jnp.array(5))
[... skipping hidden 12 frame]
/Users/ibilion/Dropbox/Teaching/S2024/ME697/scientific-machine-learning/learning/jit.ipynb Cell 34 line 5
<a href='vscode-notebook-cell:/Users/ibilion/Dropbox/Teaching/S2024/ME697/scientific-machine-learning/learning/jit.ipynb#X45sZmlsZQ%3D%3D?line=2'>3</a> s = 0.0
<a href='vscode-notebook-cell:/Users/ibilion/Dropbox/Teaching/S2024/ME697/scientific-machine-learning/learning/jit.ipynb#X45sZmlsZQ%3D%3D?line=3'>4</a> n = 0
----> <a href='vscode-notebook-cell:/Users/ibilion/Dropbox/Teaching/S2024/ME697/scientific-machine-learning/learning/jit.ipynb#X45sZmlsZQ%3D%3D?line=4'>5</a> while n < x:
<a href='vscode-notebook-cell:/Users/ibilion/Dropbox/Teaching/S2024/ME697/scientific-machine-learning/learning/jit.ipynb#X45sZmlsZQ%3D%3D?line=5'>6</a> n += 1
<a href='vscode-notebook-cell:/Users/ibilion/Dropbox/Teaching/S2024/ME697/scientific-machine-learning/learning/jit.ipynb#X45sZmlsZQ%3D%3D?line=6'>7</a> s += n
[... skipping hidden 1 frame]
File ~/.pyenv/versions/3.11.6/lib/python3.11/site-packages/jax/_src/core.py:1443, in concretization_function_error.<locals>.error(self, arg)
1442 def error(self, arg):
-> 1443 raise TracerBoolConversionError(arg)
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function sum_up_to at /var/folders/5y/28n32xmx0551k29hd21qs87c0000gp/T/ipykernel_41459/3314571921.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError
But we can write this:
@jax.jit
def sum_up_to(x):
return lax.while_loop(
lambda c: c[1] <= x,
lambda c: (c[0] + c[1], c[1] + 1),
(0, 1)
)[0]
sum_up_to(jnp.array(10))
Array(55, dtype=int32, weak_type=True)
Be careful with conditionals#
@jax.jit
def myabs(x):
if x > 0:
return x
else:
return -x
myabs(jnp.array(-2))
Show code cell output
---------------------------------------------------------------------------
TracerBoolConversionError Traceback (most recent call last)
/Users/ibilion/Dropbox/Teaching/S2024/ME697/scientific-machine-learning/learning/jit.ipynb Cell 38 line 8
<a href='vscode-notebook-cell:/Users/ibilion/Dropbox/Teaching/S2024/ME697/scientific-machine-learning/learning/jit.ipynb#X53sZmlsZQ%3D%3D?line=4'>5</a> else:
<a href='vscode-notebook-cell:/Users/ibilion/Dropbox/Teaching/S2024/ME697/scientific-machine-learning/learning/jit.ipynb#X53sZmlsZQ%3D%3D?line=5'>6</a> return -x
----> <a href='vscode-notebook-cell:/Users/ibilion/Dropbox/Teaching/S2024/ME697/scientific-machine-learning/learning/jit.ipynb#X53sZmlsZQ%3D%3D?line=7'>8</a> myabs(jnp.array(-2))
[... skipping hidden 12 frame]
/Users/ibilion/Dropbox/Teaching/S2024/ME697/scientific-machine-learning/learning/jit.ipynb Cell 38 line 3
<a href='vscode-notebook-cell:/Users/ibilion/Dropbox/Teaching/S2024/ME697/scientific-machine-learning/learning/jit.ipynb#X53sZmlsZQ%3D%3D?line=0'>1</a> @jax.jit
<a href='vscode-notebook-cell:/Users/ibilion/Dropbox/Teaching/S2024/ME697/scientific-machine-learning/learning/jit.ipynb#X53sZmlsZQ%3D%3D?line=1'>2</a> def myabs(x):
----> <a href='vscode-notebook-cell:/Users/ibilion/Dropbox/Teaching/S2024/ME697/scientific-machine-learning/learning/jit.ipynb#X53sZmlsZQ%3D%3D?line=2'>3</a> if x > 0:
<a href='vscode-notebook-cell:/Users/ibilion/Dropbox/Teaching/S2024/ME697/scientific-machine-learning/learning/jit.ipynb#X53sZmlsZQ%3D%3D?line=3'>4</a> return x
<a href='vscode-notebook-cell:/Users/ibilion/Dropbox/Teaching/S2024/ME697/scientific-machine-learning/learning/jit.ipynb#X53sZmlsZQ%3D%3D?line=4'>5</a> else:
[... skipping hidden 1 frame]
File ~/.pyenv/versions/3.11.6/lib/python3.11/site-packages/jax/_src/core.py:1443, in concretization_function_error.<locals>.error(self, arg)
1442 def error(self, arg):
-> 1443 raise TracerBoolConversionError(arg)
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function myabs at /var/folders/5y/28n32xmx0551k29hd21qs87c0000gp/T/ipykernel_41459/685382902.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError
The way out is to use lax.cond
:
@jax.jit
def myabs(x):
return lax.cond(
x > 0, # condition
lambda _: x, # if true
lambda _: -x, # if false
None # operands to pass to the lambda functions (nothing here)
)
myabs(jnp.array(-2))
Array(2, dtype=int32, weak_type=True)
Another useful LAX flow control function is lax.switch
.