Pytrees to represent model parameters#

Pytrees is Jax’s solution to the problem of working with nested data structures. This is immensely useful when working with parameters of complex neural networks. You can read about pytrees here. The definition is:

In JAX, we use the term pytree to refer to a tree-like structure built out of container-like Python objects. Classes are considered container-like if they are in the pytree registry, which by default includes lists, tuples, and dicts. That is: any object whose type is not in the pytree container registry is considered a leaf pytree; any object whose type is in the pytree container registry, and which contains pytrees, is considered a pytree.

Let’s see some examples.

Some trivial pytrees are primitives, like int, float, bool. Also, Jax arrays are pytrees:

import jax
import jax.numpy as jnp
import jax.tree_util as tree_util

x = jnp.array([1, 2, 3])
x
Array([1, 2, 3], dtype=int32)
tree_util.tree_structure(x)
PyTreeDef(*)

Tuples of arbitrary objects are also pytrees:

tree = (1, x, 'hello')
tree_util.tree_structure(tree)
PyTreeDef((*, *, *))

Same for lists:

tree = [1, x, 'hello']
tree_util.tree_structure(tree)
PyTreeDef([*, *, *])

And dictionaries:

tree = {'a': 1, 'b': x, 'c': 'hello'}

tree_util.tree_structure(tree)
PyTreeDef({'a': *, 'b': *, 'c': *})

And now it gets interesting. You can nest pytrees:

tree = {'tree1': {'a': 1, 'b': x, 'c': 'hello'},
        'tree2': {'a': 1, 'b': (x, x, x), 'c': 'hello'}}
tree
{'tree1': {'a': 1, 'b': Array([1, 2, 3], dtype=int32), 'c': 'hello'},
 'tree2': {'a': 1,
  'b': (Array([1, 2, 3], dtype=int32),
   Array([1, 2, 3], dtype=int32),
   Array([1, 2, 3], dtype=int32)),
  'c': 'hello'}}
tree_util.tree_structure(tree)
PyTreeDef({'tree1': {'a': *, 'b': *, 'c': *}, 'tree2': {'a': *, 'b': (*, *, *), 'c': *}})

The leaves of the pytree are the primitives, the arrays, and the tuples of primitives and arrays. They are shown by * above. You can get the leaves of a pytree as a flattened list with jax.tree_leaves:

tree_util.tree_leaves(tree)
[1,
 Array([1, 2, 3], dtype=int32),
 'hello',
 1,
 Array([1, 2, 3], dtype=int32),
 Array([1, 2, 3], dtype=int32),
 Array([1, 2, 3], dtype=int32),
 'hello']

You can also flatten the tree with jax.tree_flatten. It returns a tuple of the leaves and a function that can reconstruct the tree from the leaves:

flat_values, tree_type = tree_util.tree_flatten(tree)
flat_values
[1,
 Array([1, 2, 3], dtype=int32),
 'hello',
 1,
 Array([1, 2, 3], dtype=int32),
 Array([1, 2, 3], dtype=int32),
 Array([1, 2, 3], dtype=int32),
 'hello']
tree_type
PyTreeDef({'tree1': {'a': *, 'b': *, 'c': *}, 'tree2': {'a': *, 'b': (*, *, *), 'c': *}})

If you have a flattened tree, you can put it back together with jax.tree_unflatten:

tree_util.tree_unflatten(tree_type, flat_values)
{'tree1': {'a': 1, 'b': Array([1, 2, 3], dtype=int32), 'c': 'hello'},
 'tree2': {'a': 1,
  'b': (Array([1, 2, 3], dtype=int32),
   Array([1, 2, 3], dtype=int32),
   Array([1, 2, 3], dtype=int32)),
  'c': 'hello'}}

Example: Neural network parameters#

The most useful type of pytree for us is the one that contains Jax arrays. This is the structure that we will use to represent the parameters of our neural networks. Let’s make a simple neural network by hand. We will use a dictionary to represent the parameters of the network.

from jax import vmap
from functools import partial

@partial(vmap, in_axes=(0, None))
def simple_nn(x, params):
    W1 = params["layer1"]["W"]
    b1 = params["layer1"]["b"]
    W2 = params["layer2"]["W"]
    b2 = params["layer2"]["b"]
    return W2 @ jnp.tanh(W1 @ x + b1) + b2

Let’s just call it:

import jax.random as random

key = random.PRNGKey(0)
keys = random.split(key, 4)
params = {
    "layer1": {
        "W": random.normal(keys[0], (2, 3)),
        "b": random.normal(keys[1], (2,)),
    },
    "layer2": {
        "W": random.normal(keys[2], (1, 2)),
        "b": random.normal(keys[3], (1,)),
    },
}
params
{'layer1': {'W': Array([[-0.11168969,  0.58439565,  1.437887  ],
         [ 0.533231  , -1.0117726 , -2.316002  ]], dtype=float32),
  'b': Array([-1.5917008, -0.9385306], dtype=float32)},
 'layer2': {'W': Array([[ 0.43686673, -0.5115205 ]], dtype=float32),
  'b': Array([0.6714109], dtype=float32)}}

Here is how it works on a bunch of inputs:

key, subkey = random.split(keys[0])
xs = random.normal(subkey, (10, 3))

simple_nn(xs, params)
Array([[ 0.6350196 ],
       [ 0.3767303 ],
       [ 0.7132992 ],
       [ 0.88301647],
       [-0.27681673],
       [-0.17706287],
       [-0.22115844],
       [ 0.31149212],
       [ 1.293994  ],
       [ 0.97846043]], dtype=float32)

Let’s now add some fake data and a loss function:

key, subkey = random.split(key)
ys = random.normal(subkey, (10,))

def loss(params, xs, ys):
    pred = simple_nn(xs, params)
    return jnp.mean((pred - ys)**2)

The loss function works like this:

loss(params, xs, ys)
Array(1.0290194, dtype=float32)

We can take the gradient of the loss function with respect to the parameters:

from jax import grad, jit

grad_loss = jit(grad(loss))

Magic:

g = grad_loss(params, xs, ys)
g
{'layer1': {'W': Array([[-0.11328071, -0.06153299,  0.23609458],
         [ 0.04657721, -0.01834877,  0.02984624]], dtype=float32),
  'b': Array([ 0.2492916 , -0.12900007], dtype=float32)},
 'layer2': {'W': Array([[-0.47350553, -1.016762  ]], dtype=float32),
  'b': Array([1.0130714], dtype=float32)}}

Let’s unpack this. The parameters are a pytree, so the gradient is a pytree too. The structure of the pytree is the same as the structure of the parameters. But the leaves of the pytree are the gradients of the loss function with respect to the parameters. Great! This generalizes to any pytree, not just dictionaries.

What do we do with this? Well, we can do gradient descent. We have to subtract a small multiple of the gradient from the parameters. Here is how:

new_params = tree_util.tree_map(
    lambda x, g: x - 0.1 * g,
    params, g
)
new_params
{'layer1': {'W': Array([[-0.10036162,  0.59054893,  1.4142776 ],
         [ 0.5285733 , -1.0099378 , -2.3189864 ]], dtype=float32),
  'b': Array([-1.61663  , -0.9256306], dtype=float32)},
 'layer2': {'W': Array([[ 0.4842173, -0.4098443]], dtype=float32),
  'b': Array([0.57010376], dtype=float32)}}

What is going on here? The function tree_map applies a function to every leaf of a pytree. In this case, we are subtracting a small multiple of the gradient from every leaf of the pytree.

Now, suppose that we wanted to add an L2 regularization term to the loss function. This means that we have to add the square of every parameter to the loss function. How do we square all parameters? We can use tree_map again:

params_2 = tree_util.tree_map(
    lambda x: x ** 2,
    params
)
params_2
{'layer1': {'W': Array([[0.01247459, 0.34151828, 2.067519  ],
         [0.28433532, 1.0236839 , 5.363865  ]], dtype=float32),
  'b': Array([2.5335114, 0.8808397], dtype=float32)},
 'layer2': {'W': Array([[0.19085254, 0.2616532 ]], dtype=float32),
  'b': Array([0.4507926], dtype=float32)}}

And now we can just sum the squares using tree_reduce:

tree_util.tree_reduce(
    lambda x, y: jnp.sum(x) + jnp.sum(y),
    params_2,
    0.0
)
Array(13.411046, dtype=float32)

Let’s rewrite our loss:

def loss(params, xs, ys):
    pred = simple_nn(xs, params)
    squared_error = jnp.mean((pred - ys)**2)
    l2_norm = tree_util.tree_reduce(
        lambda x, y: jnp.sum(x) + jnp.sum(y),
        tree_util.tree_map(lambda x: x ** 2, params),
        0.0
    )
    return squared_error + 0.1 * l2_norm

Let me introduce another useful function, value_and_grad. It returns the value of a function and its gradient.

from jax import value_and_grad

loss_and_grad = jit(value_and_grad(loss))

Here it is:

v, g = loss_and_grad(params, xs, ys)
v
Array(2.3701239, dtype=float32)
g
{'layer1': {'W': Array([[-0.13561864,  0.05534614,  0.523672  ],
         [ 0.15322341, -0.2207033 , -0.43335414]], dtype=float32),
  'b': Array([-0.06904855, -0.31670618], dtype=float32)},
 'layer2': {'W': Array([[-0.38613218, -1.1190661 ]], dtype=float32),
  'b': Array([1.1473536], dtype=float32)}}

Named tuples#

Named tuples are a useful way to represent data. It allows you to access the elements of a tuple by name, like a dictionary, but with the dot syntax. This is useful when you have a bunch of data that you want to pass around as a single object. Named tuples are also pytrees.

You can make a named tuple like this:

from collections import namedtuple

NNParameters = namedtuple("NNParameters", ["layer1", "layer2"])
LayerParameters = namedtuple("LayerParameters", ["W", "b"])

params = NNParameters(
    LayerParameters(
        W=random.normal(keys[0], (2, 3)),
        b=random.normal(keys[1], (2,)),
    ),
    LayerParameters(
        W=random.normal(keys[2], (1, 2)),
        b=random.normal(keys[3], (1,)),
    ),
)

params
NNParameters(layer1=LayerParameters(W=Array([[-0.11168969,  0.58439565,  1.437887  ],
       [ 0.533231  , -1.0117726 , -2.316002  ]], dtype=float32), b=Array([-1.5917008, -0.9385306], dtype=float32)), layer2=LayerParameters(W=Array([[ 0.43686673, -0.5115205 ]], dtype=float32), b=Array([0.6714109], dtype=float32)))

You can access individual elements of the tuple by name:

params.layer1
LayerParameters(W=Array([[-0.11168969,  0.58439565,  1.437887  ],
       [ 0.533231  , -1.0117726 , -2.316002  ]], dtype=float32), b=Array([-1.5917008, -0.9385306], dtype=float32))

and:

params.layer1.W
Array([[-0.11168969,  0.58439565,  1.437887  ],
       [ 0.533231  , -1.0117726 , -2.316002  ]], dtype=float32)

Look at the tree structure:

tree_util.tree_structure(params)
PyTreeDef(CustomNode(namedtuple[NNParameters], [CustomNode(namedtuple[LayerParameters], [*, *]), CustomNode(namedtuple[LayerParameters], [*, *])]))

And you can apply all sorts of tee functions to them:

tree_util.tree_map(lambda x: x ** 2, params)
NNParameters(layer1=LayerParameters(W=Array([[0.01247459, 0.34151828, 2.067519  ],
       [0.28433532, 1.0236839 , 5.363865  ]], dtype=float32), b=Array([2.5335114, 0.8808397], dtype=float32)), layer2=LayerParameters(W=Array([[0.19085254, 0.2616532 ]], dtype=float32), b=Array([0.4507926], dtype=float32)))

Of course, to use this with our neural network, we need to be able to convert the dictionaries to named tuples. Let’s do this:

@partial(vmap, in_axes=(0, None))
def simple_nn(x, params):
    W1 = params.layer1.W
    b1 = params.layer1.b
    W2 = params.layer2.W
    b2 = params.layer2.b
    return W2 @ jnp.tanh(W1 @ x + b1) + b2

@jit
@value_and_grad
def loss(params, xs, ys):
    pred = simple_nn(xs, params)
    squared_error = jnp.mean((pred - ys)**2)
    l2_norm = tree_util.tree_reduce(
        lambda x, y: jnp.sum(x) + jnp.sum(y),
        tree_util.tree_map(lambda x: x ** 2, params),
        0.0
    )
    return squared_error + 0.1 * l2_norm

loss(params, xs, ys)
(Array(2.3701239, dtype=float32),
 NNParameters(layer1=LayerParameters(W=Array([[-0.13561864,  0.05534614,  0.523672  ],
        [ 0.15322341, -0.2207033 , -0.43335414]], dtype=float32), b=Array([-0.06904855, -0.31670618], dtype=float32)), layer2=LayerParameters(W=Array([[-0.38613218, -1.1190661 ]], dtype=float32), b=Array([1.1473536], dtype=float32))))

Equinox - How to actually do this in practice#

We don’t won’t to be building neural networks by hand. There are three main libraries to build neural networks in Jax:

Equinox is perhaps the simplest one as it relies only on Pytrees. It also forces us to inspect the details of the neural network. This is essential for this course. So, we will use Equinox.

You should go through All of Equinox to learn how to use it. And also some examples like MNIST. Note that we haven’t talked about optimization yet. We will do it in another lecture.

Here is how the network we built above looks like in Equinox:

import equinox as eqx


class SimpleNN(eqx.Module):
    layers: list

    def __init__(self, n_inputs, n_hidden, n_outputs, key):
        key1, key2 = random.split(key)
        self.layers = [
            eqx.nn.Linear(n_inputs, n_hidden, key=key1),
            eqx.nn.Linear(n_hidden, n_outputs, key=key2),
        ]

    # Notice how neatly we can vectorize the forward pass
    # Here we need to use in_axes=(None, 0) because the first argument
    # is to __call__ is self, which refers to the model itself.
    # We don't want to vectorize over this argument.
    @partial(vmap, in_axes=(None, 0))
    def __call__(self, x):
        for layer in self.layers[:-1]:
            x = jnp.tanh(layer(x))
        return self.layers[-1](x)

The difference is that the parameters of the network are now in a Module object, nicely organized. Here is how we can make such a network:

key = random.PRNGKey(314)
model = SimpleNN(3, 2, 1, key)

model
SimpleNN(
  layers=[
    Linear(
      weight=f32[2,3],
      bias=f32[2],
      in_features=3,
      out_features=2,
      use_bias=True
    ),
    Linear(
      weight=f32[1,2],
      bias=f32[1],
      in_features=2,
      out_features=1,
      use_bias=True
    )
  ]
)

Here is a forward pass:

model(xs)
Array([[-0.26482067],
       [-0.17487162],
       [-0.28958413],
       [-0.34640497],
       [-0.10561763],
       [-0.12498382],
       [-0.26911202],
       [-0.24482018],
       [-0.36088067],
       [-0.21855468]], dtype=float32)

It is a pytree, see:

tree_util.tree_structure(model)
PyTreeDef(CustomNode(SimpleNN[('layers',), (), ()], [[CustomNode(Linear[('weight', 'bias'), ('in_features', 'out_features', 'use_bias'), (3, 2, True)], [*, *]), CustomNode(Linear[('weight', 'bias'), ('in_features', 'out_features', 'use_bias'), (2, 1, True)], [*, *])]]))

If you want to get the parameters, you can do it like this:

model.layers[0].weight
Array([[-0.48662117,  0.08805605,  0.25260752],
       [ 0.55680007, -0.21773158, -0.5048137 ]], dtype=float32)

Or all together (but without names):

tree_util.tree_leaves(model)
[Array([[-0.48662117,  0.08805605,  0.25260752],
        [ 0.55680007, -0.21773158, -0.5048137 ]], dtype=float32),
 Array([-0.04204372, -0.52129227], dtype=float32),
 Array([[-0.42024657, -0.18588884]], dtype=float32),
 Array([-0.30617577], dtype=float32)]

Or you can get them organized a separate pytree:

eqx.tree_flatten_one_level(model)
([[Linear(
     weight=f32[2,3],
     bias=f32[2],
     in_features=3,
     out_features=2,
     use_bias=True
   ),
   Linear(
     weight=f32[1,2],
     bias=f32[1],
     in_features=2,
     out_features=1,
     use_bias=True
   )]],
 PyTreeDef(CustomNode(SimpleNN[('layers',), (), ()], [*])))

But you should really think of the model and the parameters as a single object. For example, here is how you can compute the L2 norm of the parameters:

tree_util.tree_reduce(
    lambda x, y: jnp.sum(x) + jnp.sum(y),
    tree_util.tree_map(lambda x: x ** 2, model),
    0.
)
Array(1.4990535, dtype=float32)

Let’s make our loss function:

@jit
@value_and_grad
def loss(model, xs, ys):
    pred = model(xs)
    squared_error = jnp.mean((pred - ys)**2)
    l2_norm = tree_util.tree_reduce(
        lambda x, y: jnp.sum(x) + jnp.sum(y),
        tree_util.tree_map(lambda x: x ** 2, model),
        0.0
    )
    return squared_error + 0.1 * l2_norm

Note that the gradient is with respect to the model which is identified with its parameters. Here is how it looks like:

v, g = loss(model, xs, ys)

The value:

v
Array(0.69529444, dtype=float32)

The gradient:

g
SimpleNN(
  layers=[
    Linear(
      weight=f32[2,3],
      bias=f32[2],
      in_features=3,
      out_features=2,
      use_bias=True
    ),
    Linear(
      weight=f32[1,2],
      bias=f32[1],
      in_features=2,
      out_features=1,
      use_bias=True
    )
  ]
)

Notice that the gradient is a pytree with the same structure as the model. Again, if you want to see the actual values, you can do it like this:

tree_util.tree_leaves(g)
[Array([[-0.16479225,  0.03449459,  0.04574944],
        [ 0.0967427 , -0.03741584, -0.11203536]], dtype=float32),
 Array([ 0.1264575 , -0.06837945], dtype=float32),
 Array([[-0.15345962,  0.18075086]], dtype=float32),
 Array([-0.4314887], dtype=float32)]

I admit that it is not trivial to understand, but once you get it, it is very powerful. You can make whatever neural network you want!