A Primer on Functional Programming#

The programming you have been learning so far is called imperative (or procedural) programming. In imperative programming, you tell the computer what to do, step by step. Python, C, C++, Java, and JavaScript are all imperative languages.

In contrast, functional programming is a style of programming where you (typically) avoid changing state and mutating data. Instead, you write functions that transform data structures. Functional programming languages include Haskell, Lisp, and Clojure. Some of them allow you to write imperative code as well, but they encourage you to write functional code.

Functional programming is a very different way of thinking about programming. It can be hard to wrap your head around at first, but it is worth learning. Functional programming is a powerful tool that can help you write cleaner, more concise, and more maintainable code.

You can write functional code in Python, but it is not the default style. In this notebook, we will learn some of the basics of functional programming in Python.

Why do we care about functional programming? Because, we will use Jax, which is a Python library for differentiable programming and Jax uses a functional programming style.

Side effects#

In imperative programming, functions can have side effects. A side effect is anything that changes the state of the program or the outside world. For example, printing to the screen, writing to a file, and modifying a global variable are all side effects. Here is a function with side effects:

def add_to_list(x, lst):

We cannot use side effects in functional programming. And you cannot use side effects in Jax.

Pure functions#

A pure function is a function that has no side effects. It does not change any state and it does not mutate any data. It takes some input and returns some output. It is like a mathematical function. Here is a pure function:

def f(x):
    return x**2

Alternatively, we can write this function as a lambda function:

f = lambda x: x**2

Lambda functions are a convenient way to write simple functions.

Let’s make a non-trivial pure function. We will calculate the factorial of a number:

def factorial(n):
    if n == 0:
        return 1
        return n * factorial(n-1)

Or the Fibonacci sequence:

def fibonacci(n):
    if n == 0:
        return 0
    elif n == 1:
        return 1
        return fibonacci(n-1) + fibonacci(n-2)

So, notice that we are avoiding side effects by not using any global variables and not printing anything to the screen.

Let’s do some functional programming with lists. Remember, we are not allowed to change the lists. First, how do we sum the list?

def mysum(x):
    if x == []:
        return 0.0
        return x[0] + mysum(x[1:])


Notice that we have not used a loop. We have used recursion instead. This is a common pattern in functional programming. Some functional programming languages do not even have loops, e.g. Haskell. This sounds crazy, but it is possible to do everything with recursion.

Functions are first-class objects#

In Python, functions are first-class objects. This means that you can assign functions to variables, pass them as arguments to other functions, and return them from functions. This is a key feature of functional programming. In particular, you can have pure functions that take other functions as arguments. This is called a higher-order function.

Let’s see some useful examples.


First, the function map applies a function to every element of a list:

def mymap(x, f):
    if x == []:
        return []
        return [f(x[0])] + mymap(x[1:], f)

Here is how it works:

mymap([1,2,3,4,5], lambda x: x**2)
[1, 4, 9, 16, 25]
mymap([1,2,3,4,5], lambda x: x**3)
[1, 8, 27, 64, 125]

And so on.

Note that map is already implemented in Python.

Vectorization (vmap)#

With map we can make a function that vectorizes another function, say vmap:

myvmap = lambda f: lambda x: mymap(x, f)

Before we take the time to unwrap this, here is how it works:

vectorized_sqr = myvmap(lambda x: x**2)

[1, 4, 9, 16, 25]

What is happening here? myvmap is a function that takes another function f as an argument. So far so good. Then it returns a function that takes a list as an argument and calls mymap on the list using f as the function to apply to each element.

When we work with Jax, we will use vmap to vectorize functions a lot.


Let’s rethink of the sum function. What does it do?

  • It takes a list.

  • If the list is empty, it returns 0.

  • Otherwise, add the first element of the list to the sum of the rest of the list.

This pattern can be generalized:

  • Take a list.

  • If the list is empty, return some default value (initializer).

  • If the list has a single element, then just return that element.

  • Otherwise, apply a function to the first element of the list with the result of applying the same function to the rest of the list.

Here is the code:

def myreduce(f, x, init=None):
    if x == []:
        return init
    elif len(x) == 1 and init == None:
        return x[0]
        return f(x[0], myreduce(f, x[1:], init))

Here is how we can express sum using reduce:

mysum2 = lambda x: myreduce(lambda x,y: x+y, x, 0)


Neat, right?

Not just sum, but also prod (the product) can be expressed using reduce:

myprod = lambda x: myreduce(lambda x,y: x*y, x, 1)


And max:

mymax = lambda x: myreduce(lambda x,y: x if x > y else y, x)



The function filter takes a function and a list and returns a list with only the elements that satisfy the function:

def myfilter(f, x):
    if x == []:
        return []
        if f(x[0]):
            return [x[0]] + myfilter(f, x[1:])
            return myfilter(f, x[1:])

Let’s use it to extract the even numbers from a list:

myfilter(lambda x: x % 2 == 0, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
[2, 4, 6, 8, 10]

Partial application or currying#

The code above is a bit ugly. We can make it nicer using partial application.

Suppose you have a function that takes two arguments:

def f(x,y):
    return x+y

Now we want to make a function that fixes the first argument to some value, say 2. We can do it like this:

f2 = lambda y: f(2, y)


If we wanted to fix the second argument, we can do something similar.

Python has a very good implementation of partial application in the functools module. Here is how we can use it:

import functools as ft

f2 = ft.partial(f, 2)


It also has implementations of reduce. Let’s use them to make a nicer version of sum:

sum2 = ft.partial(ft.reduce, lambda x,y: x+y)


Here is min:

min2 = ft.partial(ft.reduce, lambda x,y: x if x < y else y)

min2([2, 4, 6, 1, 3, 5])

In functional programming languages, currying is used a lot. And the syntax is nicer. For example in Haskell, we can write:

add :: Int -> Int -> Int
add x y = x + y

add2 :: Int -> Int
add2 = add 2

add2 3 -- 5

The function add2 is the same as add 2 x = 2 + x.

Function composition#

Suppose you have two functions f and g such that the output of g is of the right type to be an input to f. Then you can compose them to make a new function h that is the same as applying f to the output of g. Mathematically, we write \(h = f \circ g\) for the function composition. You can read this as “f after g”.

Now the composition operator \(\circ\) is also a function. It takes two functions as arguments and returns a new function. Here is a simple implementation:

compose = lambda f,g: lambda x: f(g(x))

And here is how it works:

g = lambda x: x**2
f = lambda x: x+1

h = compose(f, g)

print(f"h(x)\t= {h(2)}")
print(f"f(g(x))\t= {f(g(2))}")
h(x)	= 5
f(g(x))	= 5

Again, composition is nicer in Haskell:

f :: Int -> Int
f x = x + 1

g :: Int -> Int
g x = x * 2

h :: Int -> Int
h = f . g

h 3 --  7


Decorators are a very useful feature of Python that allow you to modify functions with a nice syntax. Decorators are basically functions that take a function as an argument and return a new function. Our vectorization function myvmap is a decorator by this definition. The syntactic sugar for decorators is the @ symbol. Here is how you can vectorize a function using the @ symbol:

def g(x):
    return x**3

This is equivalent to first defining g and then doing g = myvmap(g). Here is how it works:

[1, 8, 27, 64, 125]