{ "cells": [ { "cell_type": "code", "execution_count": 8, "metadata": { "tags": [ "hide-input" ] }, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline\n", "import matplotlib_inline\n", "matplotlib_inline.backend_inline.set_matplotlib_formats('svg')\n", "import seaborn as sns\n", "sns.set_context(\"paper\")\n", "sns.set_style(\"ticks\")\n", "\n", "import scipy\n", "import scipy.stats as st\n", "import urllib.request\n", "import os\n", "\n", "def download(\n", " url : str,\n", " local_filename : str = None\n", "):\n", " \"\"\"Download a file from a url.\n", " \n", " Arguments\n", " url -- The url we want to download.\n", " local_filename -- The filemame to write on. If not\n", " specified \n", " \"\"\"\n", " if local_filename is None:\n", " local_filename = os.path.basename(url)\n", " urllib.request.urlretrieve(url, local_filename)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Hierarchical Bayesian Models\n", "\n", "We build some relative simple hierarchical Bayesian models in `pyro`.\n", "Note that `pyro` does not always do Gibbs sampling.\n", "Instead it uses the following approach:\n", "- It transforms all the variables so that they have a real support (e.g., if they are positive we can work with their logarithm).\n", "- It uses the computaional graph to compute the joint distribution of all the transformed.\n", "- It uses NUTS to sample from the joint distribution." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Do this in Google Colab\n", "!pip install pyro-ppl" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Example 1 - Coal Mining Disaster " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We are going to work on the Coal Mining disaster dataset. Consider the following time series dataset of recorded coal mining disasters in the UK from 1851 to 1962. Let us first import this dataset and visualize it." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "url = \"https://github.com/PredictiveScienceLab/data-analytics-se/raw/master/lecturebook/data/coal_mining_disasters.csv\"\n", "download(url)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
yeardisasters
018514
118525
218534
318540
418551
.........
10619570
10719580
10819591
10919600
11019611
\n", "

111 rows × 2 columns

\n", "
" ], "text/plain": [ " year disasters\n", "0 1851 4\n", "1 1852 5\n", "2 1853 4\n", "3 1854 0\n", "4 1855 1\n", ".. ... ...\n", "106 1957 0\n", "107 1958 0\n", "108 1959 1\n", "109 1960 0\n", "110 1961 1\n", "\n", "[111 rows x 2 columns]" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pandas as pd\n", "\n", "disaster_data = pd.read_csv('coal_mining_disasters.csv')\n", "disaster_data.dropna()\n", "disasters = disaster_data.disasters.values\n", "years = disaster_data.year.values\n", "disaster_data" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-10-26T19:21:43.620848\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.8.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.figure()\n", "plt.plot(disaster_data.year, disaster_data.disasters, 'o')\n", "plt.xlabel('Year')\n", "plt.ylabel('Number of disasters')\n", "plt.title('Recorded coal mining disasters in the UK.')\n", "sns.despine(trim=True);" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-10-26T19:22:07.275356\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.8.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.figure()\n", "bp = plt.bar(disaster_data.year, disaster_data.disasters, width=1.)\n", "plt.xlabel(\"Year\")\n", "plt.ylabel(\"Number of disasters\")\n", "plt.title(\"Recorded coal mining disasters in the UK.\")\n", "sns.despine(trim=True);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Questions\n", "\n", "1. How can we represent this disaster time series data? What are the quantities of interests?\n", "2. Is the 'disasters' variable categorial or continuous? Can it be negative? What are some other constraints?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Information about the dataset \n", "\n", "Occurrences of disasters in the time series are thought to be derived from a Poisson process with a large rate parameter in the early part of the time series (more disasters in earlier years) and from one with a smaller rate in the later part (less number of disasters in later years). We are interested in locating the change point in the series, which might be because of changes in mining safety regulations in later years." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Modeling Approach\n", "\n", "How are we going to develop a model for these data? Thinking about how the data might be generated is a good starting point.* Try to imagine how you would recreate the dataset. We begin by asking how our observations might have been generated.\n", "\n", "1. We start by thinking, \"What is the best random variable to describe this count data?\" A Poisson random variable is a good candidate because it can represent count data. So, we model the number of coal mining-related disasters sampled from a Poisson distribution.\n", "\n", "2. Next, we think, \"Okay, assuming the number of disasters is Poisson-distributed, what do I need for the Poisson distribution?\" Well, the Poisson distribution has a rate parameter $\\lambda$. \n", "\n", "3. Do we know $\\lambda$? No. We suspect there are *two* $\\lambda$ values, one for earlier and one for later years. We don't know when the change in this rate parameter occurs, though, but we call the switchpoint $\\tau$.\n", "\n", "4. What is a good distribution for the two $\\lambda$s? The exponential is good, as it assigns probabilities to positive real numbers. Well, the exponential distribution has a parameter too. Call it $\\alpha$.\n", "\n", "5. Do we know what the parameter $\\alpha$ might be? No. At this point, we could continue and assign a distribution to $\\alpha$. Still, it's better to stop once we reach a set level of ignorance: whereas we have a prior belief about $\\lambda$ (\"it probably changes over time,\" \"it's likely between 1 and 3\", etc.), we don't have any strong beliefs about $\\alpha$. \n", "\n", "6. We know that $\\alpha$ is positive. But, we don't know its exact value. So, let's pick a distribution that assigns positive probability to positive numbers. The exponential distribution is a good choice. We will pick a rate parameter of 1.0 for the exponential distribution.\n", "\n", "7. We have no expert opinion of when $\\tau$ might have occurred. So, we will suppose $\\tau$ is from a discrete uniform distribution over the entire timespan.\n", "That's a bit tricky to do in Pyro, so we will use a trick: we will use a `pyro.sample` statement to sample a value from a uniform distribution, and then we will use a `pyro.deterministic` statement to deterministically transform that value into an integer in the correct range.\n", "\n", "Here's a graphical model describing the relationships between the variables in our model:" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "tags": [ "hide-input" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "coal_mining_disasters_model\n", "\n", "\n", "\n", "alpha\n", "\n", "α\n", "\n", "\n", "\n", "lambda_1\n", "\n", "λ\n", "1\n", "\n", "\n", "\n", "alpha->lambda_1\n", "\n", "\n", "\n", "\n", "\n", "lambda_2\n", "\n", "λ\n", "2\n", "\n", "\n", "\n", "alpha->lambda_2\n", "\n", "\n", "\n", "\n", "\n", "tau\n", "\n", "τ\n", "\n", "\n", "\n", "lambda\n", "\n", "λ\n", "\n", "\n", "\n", "tau->lambda\n", "\n", "\n", "\n", "\n", "\n", "lambda_1->lambda\n", "\n", "\n", "\n", "\n", "\n", "lambda_2->lambda\n", "\n", "\n", "\n", "\n", "\n", "obs\n", "\n", "obs\n", "\n", "\n", "\n", "lambda->obs\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from graphviz import Digraph\n", "\n", "gcp = Digraph('coal_mining_disasters_model')\n", "\n", "# define the nodes \n", "gcp.node('alpha', label='<α>')\n", "gcp.node('tau', label='<τ>')\n", "gcp.node('lambda_1', label='<λ1>')\n", "gcp.node('lambda_2', label='<λ2>')\n", "gcp.node('lambda', label='<λ>')\n", "gcp.node('tau', label='<τ>')\n", "gcp.node('obs', label='obs', style='filled')\n", "\n", "# define the edges \n", "gcp.edge('alpha', 'lambda_1')\n", "gcp.edge('alpha', 'lambda_2')\n", "gcp.edge('tau', 'lambda')\n", "gcp.edge('lambda_1', 'lambda')\n", "gcp.edge('lambda_2', 'lambda')\n", "gcp.edge('lambda', 'obs')\n", "gcp" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "More formally, the generative model is expressed as:\n", "\n", "$$\\alpha \\sim \\mathrm{Exp}(1.0),$$\n", "\n", "$$\\lambda_1 \\sim \\mathrm{Exp}(\\lambda_1 | \\alpha),$$\n", "\n", "$$\\lambda_2 \\sim \\mathrm{Exp}(\\lambda_2 | \\alpha),$$\n", "\n", "$$ \\tau \\sim \\mathrm{Uniform}(1851, 1961), $$\n", "\n", "$$\\lambda_i = \\begin{cases}\n", "\\lambda_1 & \\text{if } t_i \\lt \\tau \\cr\n", "\\lambda_2 & \\text{if } t_i \\ge \\tau\n", "\\end{cases}$$\n", "\n", "$$\\mathrm{obs}_i \\sim \\mathrm{Poisson}(\\lambda_i).$$\n", "\n", "Here $t_i$ is the year of the $i$ th observation, and $\\text{obs}_i$ is the number of disasters in year $i$.\n", "\n", "We will set the rate parameter $\\alpha$ on the exponential priors on $\\lambda_1$ and $\\lambda_2$ as a constant. \n", "The latent variables to be inferred are $\\lambda_1, \\lambda_2, \\tau$.\n", "\n", "Let's make this model in `pyro`." ] }, { "cell_type": "code", "execution_count": 135, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import pyro\n", "import pyro.distributions as dist\n", "\n", "\n", "def model(years, disasters):\n", " alpha = 1.0\n", " years = torch.tensor(years)\n", " disasters = torch.tensor(disasters)\n", " tau = pyro.sample(\"tau\", dist.Uniform(1851.0, 1962.0))\n", " lambda_1 = pyro.sample(\"lambda_1\", dist.Exponential(alpha))\n", " lambda_2 = pyro.sample(\"lambda_2\", dist.Exponential(alpha))\n", " with pyro.plate(\"data\", size=disasters.shape[0]):\n", " lambda_ = pyro.deterministic(\n", " \"lambda\", \n", " torch.where(torch.gt(tau, years), lambda_1, lambda_2)\n", " )\n", " pyro.sample(\"obs\", dist.Poisson(lambda_), obs=disasters)\n", " return locals()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This was not that bad. But there is a problem. Let's try to sample from the model." ] }, { "cell_type": "code", "execution_count": 111, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Sample: 100%|██████████| 200/200 [01:29, 2.24it/s, step size=6.11e-03, acc. prob=0.633]\n", "\n" ] } ], "source": [ "from pyro.infer import MCMC, NUTS\n", "\n", "nuts_kernel = NUTS(model, jit_compile=True, ignore_jit_warnings=True)\n", "mcmc = MCMC(nuts_kernel, num_samples=100, warmup_steps=100)\n", "mcmc.run(pt_years, pt_disasters)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It takes forever. `pyro` can be very slow with sampling. It is good at variational inference (next lecture), but sometimes, if your model is complicated, it can take a long time to sample.\n", "Fortunately, there is an alternative.\n", "`numpyro` is a new library that is built on top of `jax` and `pyro`. It is much faster than `pyro` when it comes to sampling and has a similar interface. Let's try it out.\n", "The main difference is that we will use `jax` instead of `torch` for the backend. I will highlight the points that differ." ] }, { "cell_type": "code", "execution_count": 137, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "sample: 100%|██████████| 2000/2000 [00:02<00:00, 766.78it/s, 1023 steps of size 3.00e-04. acc. prob=0.88] \n" ] } ], "source": [ "# The imports are all different\n", "import numpyro\n", "import numpyro.distributions as dist\n", "import jax\n", "import jax.numpy as jnp\n", "\n", "\n", "def model(years, disasters):\n", " alpha = 1.0\n", " # This is different\n", " years = jnp.array(years, dtype=jnp.float32)\n", " # And this\n", " disasters = jnp.array(disasters)\n", " # Here we only need to change the name of the library\n", " tau = numpyro.sample(\"tau\", dist.Uniform(1851.0, 1962.0))\n", " lambda_1 = numpyro.sample(\"lambda_1\", dist.Exponential(alpha))\n", " lambda_2 = numpyro.sample(\"lambda_2\", dist.Exponential(alpha))\n", " with numpyro.plate(\"data\", size=disasters.shape[0]):\n", " # Here we are using jax instead of torch\n", " lambda_ = numpyro.deterministic(\n", " \"lambda\",\n", " jnp.where(jax.lax.gt(tau, years), lambda_1, lambda_2)\n", " )\n", " numpyro.sample(\"obs\", dist.Poisson(lambda_), obs=disasters)\n", " return locals()\n", "\n", "\n", "from numpyro.infer import MCMC, NUTS\n", "nuts_kernel = NUTS(model)\n", "mcmc = MCMC(nuts_kernel, num_samples=1000, num_warmup=1000)\n", "\n", "# This is also different\n", "rng_key = jax.random.PRNGKey(0)\n", "# And you need to pass it to the run method\n", "mcmc.run(rng_key, jax_years, jax_disasters)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Okay, that was fast. Let's look at the results." ] }, { "cell_type": "code", "execution_count": 122, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", " mean std median 5.0% 95.0% n_eff r_hat\n", " lambda_1 1.93 0.10 1.92 1.79 2.10 7.80 1.02\n", " lambda_2 0.38 0.15 0.34 0.11 0.59 10.77 1.01\n", " tau 1946.65 0.46 1946.60 1946.00 1947.33 46.78 1.00\n", "\n", "Number of divergences: 0\n" ] } ], "source": [ "# This is different btw, it used to be mcmc.summary()\n", "mcmc.print_summary()" ] }, { "cell_type": "code", "execution_count": 115, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-10-26T20:32:49.537503\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.8.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-10-26T20:32:49.566958\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.8.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-10-26T20:32:49.594812\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.8.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-10-26T20:32:49.630247\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.8.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-10-26T20:32:49.659939\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.8.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-10-26T20:32:49.691788\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.8.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "samples = mcmc.get_samples()\n", "for param in samples.keys():\n", " fig, ax = plt.subplots()\n", " ax.plot(samples[param])\n", " ax.set(xlabel='Sample number', ylabel=param)\n", " sns.despine(trim=True);\n", "\n", " fig, ax = plt.subplots()\n", " ax.hist(samples[param], label=param, density=True)\n", " ax.set(xlabel=param, ylabel='Frequency')\n", " sns.despine(trim=True);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's finish by plotting the rate parameter $\\lambda$ over time.\n", "Recall that $\\lambda$ is different depending on whether the year is before or after the switchpoint $\\tau$.\n", "Also, $\\lambda$ is a deterministic function of $\\lambda_1$, $\\lambda_2$ and $\\tau$.\n", "It is not stored in the samples, but we can compute it from them.\n", "This is done through the `numpyro.infer.Predictive`.\n", "Here is how:" ] }, { "cell_type": "code", "execution_count": 123, "metadata": {}, "outputs": [], "source": [ "predict = numpyro.infer.Predictive(model, samples)(rng_key, jax_years, jax_disasters)" ] }, { "cell_type": "code", "execution_count": 124, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "dict_items([('lambda', Array([[1.9828763 , 1.9828763 , 1.9828763 , ..., 0.40891236, 0.40891236,\n", " 0.40891236],\n", " [1.9894089 , 1.9894089 , 1.9894089 , ..., 0.40647042, 0.40647042,\n", " 0.40647042],\n", " [1.9763205 , 1.9763205 , 1.9763205 , ..., 0.39520258, 0.39520258,\n", " 0.39520258],\n", " ...,\n", " [1.9366734 , 1.9366734 , 1.9366734 , ..., 0.35683027, 0.35683027,\n", " 0.35683027],\n", " [2.000598 , 2.000598 , 2.000598 , ..., 0.3435446 , 0.3435446 ,\n", " 0.3435446 ],\n", " [1.9987974 , 1.9987974 , 1.9987974 , ..., 0.37061924, 0.37061924,\n", " 0.37061924]], dtype=float32)), ('obs', Array([[4, 5, 4, ..., 1, 0, 1],\n", " [4, 5, 4, ..., 1, 0, 1],\n", " [4, 5, 4, ..., 1, 0, 1],\n", " ...,\n", " [4, 5, 4, ..., 1, 0, 1],\n", " [4, 5, 4, ..., 1, 0, 1],\n", " [4, 5, 4, ..., 1, 0, 1]], dtype=int32))])" ] }, "execution_count": 124, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predict.items()" ] }, { "cell_type": "code", "execution_count": 129, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-10-26T21:02:48.683958\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.8.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig, ax = plt.subplots()\n", "lambda_050 = jnp.quantile(predict[\"lambda\"], 0.05, axis=0)\n", "lambda_500 = jnp.quantile(predict[\"lambda\"], 0.5, axis=0)\n", "lambda_950 = jnp.quantile(predict[\"lambda\"], 0.95, axis=0)\n", "ax.plot(jax_years, lambda_500)\n", "ax.fill_between(jax_years, lambda_050, lambda_950, alpha=0.5)\n", "ax.set(xlabel=\"Year\", ylabel=\"Rate of coal mining disasters\")\n", "ax.set_title(\"Posterior distribution of the rate of coal mining disasters\")\n", "sns.despine(trim=True);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Example 2 - Challenger Space Shuttle Disaster\n", "\n", "For this tutorial, we will work on the Challenger space shuttle disaster dataset.\n", "We have revisited the problem in [HW 5](https://colab.research.google.com/github/PredictiveScienceLab/data-analytics-se/blob/master/homework/homework_05.ipynb). There, we used logistic regression, and we trained it with maximum likelihood.\n", "We can now follow a Bayesian approach to the fullest extent.\n", "\n", "On January 28, 1986, the twenty-fifth flight of the U.S. space shuttle program ended in disaster when one of the rocket boosters of the Shuttle Challenger exploded shortly after lift-off, killing all seven crew members. The presidential commission on the accident concluded that it was caused by the failure of an O-ring in a field joint on the rocket booster. This failure was due to a faulty design that made the O-ring unacceptably sensitive to several factors, including outside temperature. Of the previous 24 flights, data were available on failures of O-rings on 23 (one was lost at sea), and these data were discussed on the evening preceding the Challenger launch. Unfortunately, only the data corresponding to the seven flights on which there was a damage incident were considered important, and these were thought to show no obvious trend. The data are shown below:" ] }, { "cell_type": "code", "execution_count": 138, "metadata": {}, "outputs": [], "source": [ "url = \"https://github.com/PredictiveScienceLab/data-analytics-se/raw/master/lecturebook/data/challenger_data.csv\"\n", "download(url)" ] }, { "cell_type": "code", "execution_count": 139, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Temp (F), O-Ring failure?\n", "[[66. 0.]\n", " [70. 1.]\n", " [69. 0.]\n", " [68. 0.]\n", " [67. 0.]\n", " [72. 0.]\n", " [73. 0.]\n", " [70. 0.]\n", " [57. 1.]\n", " [63. 1.]\n", " [70. 1.]\n", " [78. 0.]\n", " [67. 0.]\n", " [53. 1.]\n", " [67. 0.]\n", " [75. 0.]\n", " [70. 0.]\n", " [81. 0.]\n", " [76. 0.]\n", " [79. 0.]\n", " [75. 1.]\n", " [76. 0.]\n", " [58. 1.]]\n" ] } ], "source": [ "challenger_data = np.genfromtxt(\"challenger_data.csv\", skip_header=1,\n", " usecols=[1, 2], missing_values=\"NA\",\n", " delimiter=\",\")\n", "challenger_data = challenger_data[~np.isnan(challenger_data[:, 1])]\n", "print(\"Temp (F), O-Ring failure?\")\n", "print(challenger_data)" ] }, { "cell_type": "code", "execution_count": 142, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-10-26T21:12:32.927276\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.8.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# plot it, as a function of temperature (the first column)\n", "plt.figure()\n", "plt.plot(challenger_data[:, 0], challenger_data[:, 1], 'ro')\n", "plt.ylabel(\"Damage Incident?\")\n", "plt.xlabel(\"Outside temperature (Fahrenheit)\")\n", "plt.title(\"Defects of the Space Shuttle O-Rings vs temperature\")\n", "plt.yticks([0, 1])\n", "sns.despine(trim=True);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Information about the dataset \n", "It looks clear that *the probability* of damage incidents occurring increases as the outside temperature decreases. We are interested in modeling the probability here because it does not look like there is a strict cutoff point between temperature and a damage incident occurring. The best we can do is ask, \"At temperature $t$, what is the probability of a damage incident?\". The goal of this example is to answer that question." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Questions\n", "\n", "+ How can we represent this disaster binary 1/0 data? What are the quantities of interests?\n", "+ Is the \"damage incident\" variable categorical or continuous? What are some other constraints?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Modeling Approach\n", "\n", "How are we going to develop a model for this data? Thinking about how this data might be generated is a good starting point.* Try to imagine how you would recreate the dataset. We begin by asking how our observations might have been generated.\n", "\n", "- We start by thinking, \"what is the best random variable to describe this binary categorical data?\" A Bernoulli random variable is a good candidate because it can represent binary data. So, we model the `defect incident` variable as sampled from a Bernoulli distribution. A *Bernoulli* random variable with parameter $p$, denoted $\\text{Ber}(p)$, is a random variable that takes value 1 with probability $p$, and 0 else. Thus, our model can look like:\n", "\n", "$$ \\text{Defect Incident, $D_i$} \\sim \\text{Ber}( \\;p_i\\; ), \\;\\; i=1..N$$\n", "\n", "- Next, we think, \"Okay, assuming the damage incident variable is Bernoulli-distributed, what do I need for the Bernoulli distribution?\" Well, the Bernoulli distribution has a probability parameter $p_i$. \n", "\n", "- Do we know the $p_i$ parameter? No. But, we suspect (intuition) that this parameter depends on outside temperature values. The lower the outside temperature value, the greater the probability of damage incident. The higher the outside temperature value, the less the probability of damage incident. With slight abuse in notation, we can define this as:\n", "\n", "$$p_i = p(t_i) = \\sigma(t_i),$$\n", "\n", "where $\\sigma: \\mathbb{R} \\rightarrow (0, 1)$ is a suitable function that maps arbitrary temperature values to the interval $(0, 1)$ (so that we can then interpret the output of $\\sigma$ as a probability). For this problem, we are going to use the logistic link function, which is given as:\n", "\n", "$$\\sigma(t) = \\frac{1}{ 1 + e^{ \\;\\beta t + \\alpha } } $$\n", "\n", "Some plots are shown below, with differing $\\alpha$ and $\\beta$ parameter values:" ] }, { "cell_type": "code", "execution_count": 146, "metadata": { "scrolled": true }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-10-26T21:21:13.312187\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.8.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "def logistic(x, beta, alpha=0):\n", " return 1.0 / (1.0 + np.exp(np.dot(beta, x) + alpha))\n", "\n", "x = np.linspace(-4, 4, 100)\n", "plt.figure()\n", "alphas = [0, 0, 0, 1, 3, 5]\n", "betas = [1, 3, -5, 1, 3, -5]\n", "params = zip(alphas, betas)\n", "for param in params:\n", " alpha, beta = param\n", " label=\"$\\\\alpha$ = %d, $\\\\beta$ = %d\"%(alpha, beta)\n", " plt.plot(x, logistic(x, beta, alpha),\n", " label=label\n", " )\n", "plt.title(\"Logistic functon with bias\")\n", "plt.legend(loc=\"best\", frameon=False)\n", "sns.despine(trim=True);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Question:\n", "+ How can we represent $\\alpha$ and $\\beta$ parameter values? Are they categorical or continuous? Do they need to be positive?__\n", "\n", "The $\\beta, \\alpha$ parameters have no reason to be positive, bounded, or relatively large, so they are best modeled by a *Normal random variable*. Since we do not have any prior beliefs about the value of parameters beta or alpha, we place a vague prior Normal distribution (small precision, large variance) over their values.\n", "\n", "The graphical model for the data generation process is shown below. " ] }, { "cell_type": "code", "execution_count": 147, "metadata": { "tags": [ "hide-input" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "space_shuttle_disaster\n", "\n", "\n", "cluster_0\n", "\n", "i=1,2...\n", "\n", "\n", "\n", "alpha\n", "\n", "α\n", "\n", "\n", "\n", "pi\n", "\n", "p\n", "i\n", "\n", "\n", "\n", "alpha->pi\n", "\n", "\n", "\n", "\n", "\n", "beta\n", "\n", "β\n", "\n", "\n", "\n", "beta->pi\n", "\n", "\n", "\n", "\n", "\n", "xi\n", "\n", "x\n", "i\n", "\n", "\n", "\n", "pi->xi\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 147, "metadata": {}, "output_type": "execute_result" } ], "source": [ "gcp = Digraph('space_shuttle_disaster')\n", "\n", "# setup the nodes \n", "gcp.node('alpha', label='<α>')\n", "gcp.node('beta', label='<β>')\n", "with gcp.subgraph(name='cluster_0') as sg:\n", " sg.node('pi', label='i>')\n", " sg.node('xi', label='i>', style='filled')\n", " sg.attr(color='blue')\n", " sg.attr(label='i=1,2...')\n", " sg.attr(labelloc='b')\n", "\n", "# setup the edges \n", "gcp.edge('alpha', 'pi')\n", "gcp.edge('beta', 'pi')\n", "gcp.edge('pi', 'xi')\n", "gcp" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### `Pyro` Implementation" ] }, { "cell_type": "code", "execution_count": 148, "metadata": {}, "outputs": [], "source": [ "def challenger(temperature, failure):\n", " alpha = numpyro.sample(\"alpha\", dist.Normal(0, 10))\n", " beta = numpyro.sample(\"beta\", dist.Normal(0, 10))\n", " p = numpyro.deterministic(\n", " \"p\", \n", " 1.0 / (1.0 + jnp.exp(beta * temperature + alpha))\n", " )\n", " x = numpyro.sample(\"x\", dist.Bernoulli(p), obs=failure)\n", " return locals()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here is how the model looks like graphically:" ] }, { "cell_type": "code", "execution_count": 150, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "alpha\n", "\n", "alpha\n", "\n", "\n", "\n", "p\n", "\n", "p\n", "\n", "\n", "\n", "alpha->p\n", "\n", "\n", "\n", "\n", "\n", "beta\n", "\n", "beta\n", "\n", "\n", "\n", "beta->p\n", "\n", "\n", "\n", "\n", "\n", "x\n", "\n", "x\n", "\n", "\n", "\n", "p->x\n", "\n", "\n", "\n", "\n", "\n", "distribution_description_node\n", "alpha ~ Normal\n", "beta ~ Normal\n", "p ~ Deterministic\n", "x ~ BernoulliProbs\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 150, "metadata": {}, "output_type": "execute_result" } ], "source": [ "numpyro.render_model(challenger, model_args=(challenger_data[:, 0], challenger_data[:, 1]), render_distributions=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's sample from it:" ] }, { "cell_type": "code", "execution_count": 151, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "sample: 100%|██████████| 2000/2000 [00:01<00:00, 1677.74it/s, 39 steps of size 7.18e-02. acc. prob=0.96]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", " mean std median 5.0% 95.0% n_eff r_hat\n", " alpha -11.21 5.33 -10.75 -19.99 -2.93 137.53 1.00\n", " beta 0.18 0.08 0.17 0.05 0.30 132.25 1.00\n", "\n", "Number of divergences: 0\n" ] } ], "source": [ "nuts_kernel = NUTS(challenger)\n", "mcmc = MCMC(nuts_kernel, num_samples=1000, num_warmup=1000)\n", "rng_key = jax.random.PRNGKey(0)\n", "mcmc.run(rng_key, challenger_data[:, 0], challenger_data[:, 1])\n", "mcmc.print_summary()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here are all the posterior distributions:" ] }, { "cell_type": "code", "execution_count": 153, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-10-26T21:27:52.506551\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.8.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-10-26T21:27:52.532203\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.8.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-10-26T21:27:52.554212\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.8.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-10-26T21:27:52.576850\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.8.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "samples = mcmc.get_samples()\n", "for param in samples.keys():\n", " if param == \"p\":\n", " continue\n", " fig, ax = plt.subplots()\n", " ax.plot(samples[param])\n", " ax.set(xlabel='Sample number', ylabel=param)\n", " sns.despine(trim=True);\n", "\n", " fig, ax = plt.subplots()\n", " ax.hist(samples[param], label=param, density=True)\n", " ax.set(xlabel=param, ylabel='Frequency')\n", " sns.despine(trim=True);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's now end with our prediction of the probability of a damage incident occurring at a given temperature." ] }, { "cell_type": "code", "execution_count": 175, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-10-26T21:39:24.592532\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.8.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#probs = 1.0 / (1.0 + np.exp(-(samples[\"beta\"] * temps.reshape(1, -1) + samples[\"alpha\"]).T))\n", "temps = np.linspace(30, challenger_data[:, 0].max() + 5, 100)[:, None]\n", "probs = 1.0 / (1.0 + np.exp(samples[\"beta\"] * temps + samples[\"alpha\"]))\n", "\n", "fig, ax = plt.subplots()\n", "probs_025 = np.quantile(probs, 0.025, axis=1)\n", "probs_500 = np.quantile(probs, 0.5, axis=1)\n", "probs_975 = np.quantile(probs, 0.975, axis=1)\n", "ax.plot(temps, probs_500)\n", "ax.fill_between(temps.flatten(), probs_025, probs_975, alpha=0.5)\n", "ax.set(xlabel=\"Temperature\", ylabel=\"Probability of defect\")\n", "ax.set_title(\"Posterior probability of defect vs temperature\")\n", "ax.plot(challenger_data[:, 0], challenger_data[:, 1], 'ro', label=\"observed\")\n", "ax.axvline(31, ymax=0.75, color=\"k\", linestyle=\"--\", label=\"Challenger disaster\")\n", "ax.text(31, 0.8, \"Challenger disaster\", fontsize=10)\n", "sns.despine(trim=True);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The *95% credible interval*, or 95% CI, painted in blue, represents the interval for each temperature that contains 95% of the distribution. For example, at 65 degrees, we can be 95% sure that the defect probability lies between 0.25 and 0.75.\n", "\n", "More generally, as the temperature nears 60 degrees, the CIs spread out quickly. As we pass 70 degrees, the CIs tighten again. This can give us insight into how to proceed: we should test more O-rings around 60-65 temperature to estimate probabilities in that range better. Similarly, when reporting to scientists your estimates, you should be very cautious about simply telling them the expected probability. As we can see, this does not reflect how *wide* the posterior distribution is.\n", "\n", "On the day of the Challenger disaster, the outside temperature was 31 degrees Fahrenheit. What is the posterior distribution of a defect occurring, given this temperature? The distribution is plotted below. It looks almost guaranteed that the Challenger would be subject to defective O-rings." ] } ], "metadata": { "celltoolbar": "Tags", "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.6" } }, "nbformat": 4, "nbformat_minor": 2 }