Hide code cell source
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib_inline
import seaborn as sns

Homework 1#


  • Lectures 1 through 3 (inclusive).


  • Type your name and email in the “Student details” section below.

  • Develop the code and generate the figures you need to solve the problems using this notebook.

  • For the answers that require a mathematical proof or derivation you should type them using latex. If you have never written latex before and you find it exceedingly difficult, we will likely accept handwritten solutions.

  • The total homework points are 100. Please note that the problems are not weighed equally.

Student details#

  • First Name:

  • Last Name:

  • Email:

  • Used generative AI to complete this assignment (Yes/No):

  • Which generative AI tool did you use (if applicable)?:

Problem 1 - Recursion vs Iteration#

This problem adjusted from the Structure and Interpretation of Computer Programs book. In particular from this section.

Imagine you are working with a programming language that does not have loops. This is how you have to think when writing code in Jax. Let’s say we want to write a function that calculates the factorial of a number:

\[ n! = n \times (n-1) \times (n-2) \times \dots \times 1 \]

The standard recursive definition of the factorial function is:

def factorial(n):
    if n == 0:
        return 1
        return n * factorial(n-1)

Here is how it can be used:


Let’s unroll what actually happens behind the scenes:

5 * factorial(4)
5 * (4 * factorial(3))
5 * (4 * (3 * factorial(2)))
5 * (4 * (3 * (2 * factorial(1))))
5 * (4 * (3 * (2 * 1)))
5 * (4 * (3 * 2))
5 * (4 * 6)
5 * 24

You quickly notice, that the amount of intermediate results that are stored in memory grows exponentially with the input. This won’t work for large inputs, because you will run out of memory. But, there is another way to achieve the same result without exploding memory usage. We could start by multiplying 1 by 2, then the result with 3, then the result with 4, and so on. So, we keep track of a running product that we update. We don’t need a loop to do this kind of iteration. We can do it with recursion:

def fact_iter(product, counter, max_iter):
    if counter > max_iter:
        return product
        return fact_iter(counter * product, counter + 1, max_iter)
def good_factorial(n):
    return fact_iter(1, 1, n)

Check that this works as before:


Here is how this unrolls:

fact_iter(1, 1, 5)
fact_iter(1, 2, 5)
fact_iter(2, 3, 5)
fact_iter(6, 4, 5)
fact_iter(24, 5, 5)
fact_iter(120, 6, 5)

We say that the second approach is iterative and the first approach is recursive. We want to be writing iterative code, because it is more efficient.

Write iterative code that, given \(n\), computes the fibonacci number:

\[ f_n = f_{n-1} + f_{n-2} \]

where \(f_0 = 0\) and \(f_1 = 1\). You should not use a loop!

# Your code here - Demonstrate that it works

Here show how your code works for \(n=5\) like I did above with the factorial example.

Problem 2 - The foldl function#

The foldl function is a higher order function that is used to implement iteration. It is defined as follows:

\[ \text{foldl}(f, z, [x_1, x_2, \dots, x_n]) = f(f(\dots f(f(z, x_1), x_2), \dots), x_n) \]

where \(f\) is a function that takes two arguments and \(z\) is the initial value. In words, foldl takes a function \(f\), an initial value \(z\), and a list \([x_1, x_2, \dots, x_n]\). It then applies \(f\) to \(z\) and the first element of the list, then applies \(f\) to the result of the previous application and the second element of the list, and so on.

Implement foldl in Python. Pay attention to create an iterative implementation.

# Your code here - Demonstrate that it works

Use your foldl function to implement the sum function and the product function.

# Your code here - Demonstrate that it works

Problem 3 - No Loops in Jax#

Use Jax’s jax.lax.scan to implement and jit a function that returns the Fibonacci sequence up to a given number. Don’t bother using integer types, just use float32 for everything.

# Your code here

Problem 4 - Feigenbaum Map#

Consider the function:

\[ f(x; r) = r x (1 - x) \]

where \(r\) is a parameter. One can define dynamics on the real line by iterating this function:

\[ x_{n+1} = f(x_n; r) \]

where \(x_n\) is the state at time \(n\).

This map exhibits a period doubling cascade as \(r\) increases.

Write a function in jax, call it logistic_map, that takes a lot of \(r\)’s and \(x_0\)’s as inputs and returns the first \(n\) states of the system. You should independently vectorize for the \(r\)’s and the \(x_0\)’s. And you should jit. Use jax.lax.scan to implement the iteration.

# Your code here - Demonstrate that it works

Test your code here:

x0s = jnp.linspace(0, 1, 100)
rs = jnp.linspace(0, 4, 1_000)
n = 10_000
data = logistic_map(rs, x0s, n)

Your shape should be (1000, 100, 10000):


Discard all but the last iteration:

data = data[:, :, -1:]

Make the famous period doubling plot. The plot will take a while and it will take a lot of memory. I suggest you restart your kernel before moving to the next problem.

fig, ax = plt.subplots()
        data.reshape(data.shape[0], data.shape[1] * data.shape[2]).T,

Problem 5 - Analysis of Nonlinear Dynamical System#

Consider the dynamical system:

\[ \dot{x_1} = \mu x_1 + x_2 - x_1^2, \]


\[ \dot{x_2} = -x_1 + \mu x_2 + 2 x_1^2. \]

Use the random initial conditions:

\[\begin{split} x(0) \sim N\left( \begin{pmatrix} 0 \\ 0 \end{pmatrix}, \begin{pmatrix} \sigma^2 & 0 \\ 0 & \sigma^2 \end{pmatrix} \right) \end{split}\]

First, write code that solves the differential equation given the initial and the parameter \(\mu\). Make sure your code is vectorized with respect to the initial conditions and that it can be jited.

# Your code and evidence that it works here
  • Use first order sensitivity analysis to compute the mean and covariance matrix of the solution for the time interval \(t \in [0, 10]\).

  • Implement a simple Monte Carlo procedure to compare the results of the sensitivity analysis.

  • Do it for three different values of \(\mu\), \(\mu=0, 0.01\), and \(0.066\).

  • Use \(\sigma=0.01\).

  • Plot the mean in the \(x_1 - x_2\) plane for each value of \(\mu\). Compare local sensitivity analysis to Monte Carlo.

  • Plot the standard deviation of \(x_1\) and \(x_2\) as a function of time for each value of \(\mu\). Compare local sensitivity analysis to Monte Carlo.

# Your code here - You will need several blocks and discussion