Show code cell source
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('png')
import seaborn as sns
sns.set_context("paper")
sns.set_style("ticks");
Homework 1#
References#
Module 1: Introduction
Module 2: Modern Machine Learning Software
Instructions#
Type your name and email in the “Student details” section below.
Develop the code and generate the figures you need to solve the problems using this notebook.
For the answers that require a mathematical proof or derivation you should type them using latex. If you have never written latex before and you find it exceedingly difficult, we will likely accept handwritten solutions.
The total homework points are 100. Please note that the problems are not weighed equally.
Student details#
First Name:
Last Name:
Email:
Used generative AI to complete this assignment (Yes/No):
Which generative AI tool did you use (if applicable)?:
Problem 1 - Recursion vs Iteration#
This problem adjusted from the Structure and Interpretation of Computer Programs book. In particular from this section.
Imagine you are working with a programming language that does not have loops.
This is how you have to think when writing code in Jax
.
Let’s say we want to write a function that calculates the factorial of a number:
The standard recursive definition of the factorial function is:
def factorial(n):
if n == 0:
return 1
else:
return n * factorial(n-1)
Here is how it can be used:
factorial(5)
Let’s unroll what actually happens behind the scenes:
factorial(5)
5 * factorial(4)
5 * (4 * factorial(3))
5 * (4 * (3 * factorial(2)))
5 * (4 * (3 * (2 * factorial(1))))
5 * (4 * (3 * (2 * 1)))
5 * (4 * (3 * 2))
5 * (4 * 6)
5 * 24
120
You quickly notice, that the amount of intermediate results that are stored in memory grows exponentially with the input. This won’t work for large inputs, because you will run out of memory. But, there is another way to achieve the same result without exploding memory usage. We could start by multiplying 1 by 2, then the result with 3, then the result with 4, and so on. So, we keep track of a running product that we update. We don’t need a loop to do this kind of iteration. We can do it with recursion:
def fact_iter(product, counter, max_iter):
if counter > max_iter:
return product
else:
return fact_iter(counter * product, counter + 1, max_iter)
def good_factorial(n):
return fact_iter(1, 1, n)
Check that this works as before:
good_factorial(5)
Here is how this unrolls:
factorial(5)
fact_iter(1, 1, 5)
fact_iter(1, 2, 5)
fact_iter(2, 3, 5)
fact_iter(6, 4, 5)
fact_iter(24, 5, 5)
fact_iter(120, 6, 5)
120
We say that the second approach is iterative and the first approach is recursive. We want to be writing iterative code, because it is more efficient.
Write iterative code that, given \(n\), computes the fibonacci number:
where \(f_0 = 0\) and \(f_1 = 1\). You should not use a loop!
Answer:
# Your code here - Demonstrate that it works
Here show how your code works for \(n=5\) like I did above with the factorial example.
Problem 2 - The foldl
function#
The foldl
function is a higher order function that is used to implement iteration.
It is defined as follows:
where \(f\) is a function that takes two arguments and \(z\) is the initial value.
In words, foldl
takes a function \(f\), an initial value \(z\), and a list \([x_1, x_2, \dots, x_n]\).
It then applies \(f\) to \(z\) and the first element of the list, then applies \(f\) to the result of the previous application and the second element of the list, and so on.
Implement foldl
in Python
. Pay attention to create an iterative implementation.
Answer:
# Your code here - Demonstrate that it works
Use your foldl
function to implement the sum
function and the product
function.
Answer:
# Your code here - Demonstrate that it works
Problem 3 - No Loops in Jax#
Use Jax
’s jax.lax.scan
to implement and jit
a function that returns the Fibonacci sequence up to a given number.
Don’t bother using integer types, just use float32
for everything.
Answer:
# Your code here
Problem 4 - Feigenbaum Map#
Consider the function:
where \(r\) is a parameter. One can define dynamics on the real line by iterating this function:
where \(x_n\) is the state at time \(n\).
This map exhibits a period doubling cascade as \(r\) increases.
Write a function in jax
, call it logistic_map
, that takes a lot of \(r\)’s and \(x_0\)’s as inputs and returns the first \(n\) states of the system.
You should independently vectorize for the \(r\)’s and the \(x_0\)’s.
And you should jit
.
Use jax.lax.scan
to implement the iteration.
Answer:
# Your code here - Demonstrate that it works
Test your code here:
x0s = jnp.linspace(0, 1, 100)
rs = jnp.linspace(0, 4, 1_000)
n = 10_000
data = logistic_map(rs, x0s, n)
Your shape should be (1000, 100, 10000)
:
data.shape
Discard all but the last iteration:
data = data[:, :, -1:]
Make the famous period doubling plot. The plot will take a while and it will take a lot of memory. I suggest you restart your kernel before moving to the next problem.
fig, ax = plt.subplots()
ax.plot(rs,
data.reshape(data.shape[0], data.shape[1] * data.shape[2]).T,
'.k',
ms=0.1,
alpha=0.5
);
Problem 5 - Implement autoencoders in jax
, equinox
, and optax
#
Implement autoencoders in jax
and train it on the MNIST dataset.
Autoencoders, consist of two neural networks, an encoder and a decoder. The encoder maps the input to a latent space (typically of a much smaller dimension than the input), and the decoder maps the latent space back to the input space.
You can think of the encoder as a compression algorithm and the decoder as a decompression algorithm.
Alternatively, you can think of the encoder as the projection of the input data onto a lower-dimensional manifold, and the decoder as the reconstruction operator.
Part A#
Follow these directions:
Pick the dimension of the latent space to be 2. This means that the encoder will map the input to a 2-dimensional space, and the decoder will map the 2-dimensional space back to the input space.
Your encoder should work on a flattened version of the input image. This means that the input to the encoder is a vector of 784 elements (28x28).
Start by picking your encoder \(z = f(x;\theta_f)\) to be a neural network with 2 hidden layers, each with 128 units and ReLU activations. Increase the number of units and layers if you think it is necessary.
Start by picking your decoder \(x' = g(z;\theta_g)\) to be a neural network with 2 hidden layers, each with 128 units and ReLU activations. Increase the number of units and layers if you think it is necessary.
Make all your neural networks in
equinox
.The loss function is the mean squared error between the input and the output of the decoder:
where \(N\) is the number of samples in the dataset.
Split the MNIST dataset into a training and a test set.
Use
optax
for the optimization.Train the autoencoder using the Adam optimizer with a learning rate of 0.001 for 1 epoch to debug. Use a batch size of 32. Feel free to play with the learning rate and batch size.
Monitor the loss function on the training and test set. Increase the number of epochs up to the point where the loss function on the test set stops decreasing.
Here is the dataset:
# Download the MNIST dataset
from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784', version=1, parser='auto')
# Split the dataset into training and test sets
from sklearn.model_selection import train_test_split
X_train_val, X_test, y_train_val, y_test = train_test_split(
mnist.data, mnist.target, test_size=10000, random_state=42)
X_train, X_val, y_train, y_val = train_test_split(
X_train_val, y_train_val, test_size=10000, random_state=42)
Answer:
Put your answer here. Use as many markdown and code blocks as you want.
# your code
Part B#
Pick the first five digits in the test set and plot the original and reconstructed images.
Answer:
# your code here
Part C#
Plot the projections of the digits in the latent space (training and test).
Answer:
# your code here
Part D#
Use scikitlearn
to fit a mixture of Gaussians to the latent space. Use 10 components.
Then sample five times from the fitted mixture of Gaussians, reconstruct the samples, and plot the reconstructed images.
Answer:
# your code here