{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "tags": [ "hide-input", "hide-output" ] }, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline\n", "import matplotlib_inline\n", "matplotlib_inline.backend_inline.set_matplotlib_formats('svg')\n", "import seaborn as sns\n", "sns.set_context(\"paper\")\n", "sns.set_style(\"ticks\")\n", "\n", "import scipy\n", "import scipy.stats as st\n", "import urllib.request\n", "import os\n", "\n", "def download(\n", " url : str,\n", " local_filename : str = None\n", "):\n", " \"\"\"Download a file from a url.\n", " \n", " Arguments\n", " url -- The url we want to download.\n", " local_filename -- The filemame to write on. If not\n", " specified \n", " \"\"\"\n", " if local_filename is None:\n", " local_filename = os.path.basename(url)\n", " urllib.request.urlretrieve(url, local_filename)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "(logistic_regression_with_one_variable)=\n", "# Logistic regression with one variable (High melting explosives)\n", "\n", "\n", "[High Melting Explosives](https://en.wikipedia.org/wiki/HMX) (HMX) have applications as detonators of nuclear weapons and as solid rocket propellants.\n", "We will use logistic regression to build the probability that a specific HMX block explodes when dropped from a given height.\n", "To this end, we will use data from a 1987 Los Alamos Report\n", "(L. Smith, “Los Alamos National Laboratory explosives orientation course: Sensitivity and sensitivity tests to impact, friction, spark and shock,” Los Alamos National Lab, NM (USA), Tech. Rep., 1987).\n", "Let's download the raw data and load them.\n", "We will use the [Python Data Analysis Library](https://pandas.pydata.org/):" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "url = \"https://github.com/PredictiveScienceLab/data-analytics-se/raw/master/lecturebook/data/hmx_data.csv\"\n", "download(url)\n", "\n", "import pandas as pd\n", "data = pd.read_csv('hmx_data.csv')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Each row of the data is a different experiment.\n", "There are two columns:\n", "\n", "+ The first column is **Height**: From what height (in cm) the specimen was dropped?\n", "+ The second column is **Result**: Did the specimen explode (E) or not (N)?\n", "\n", "Here is how to see the raw data:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "tags": [ "hide-output" ] }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
HeightResult
040.5E
140.5E
240.5E
340.5E
440.5E
540.5E
640.5E
740.5E
840.5E
940.5E
1036.0E
1136.0N
1236.0E
1336.0E
1436.0E
1536.0E
1636.0N
1736.0E
1836.0E
1936.0E
2032.0E
2132.0E
2232.0N
2332.0E
2432.0E
2532.0E
2632.0N
2732.0E
2832.0N
2932.0E
3028.5N
3128.5E
3228.5N
3328.5N
3428.5E
3528.5N
3628.5N
3728.5N
3828.5E
3928.5N
4025.5N
4125.5N
4225.5N
4325.5N
4425.5N
4525.5N
4625.5E
4725.5N
4825.5N
4925.5N
5022.5N
5122.5N
5222.5N
5322.5N
5422.5N
5522.5N
5622.5N
5722.5N
5822.5N
5922.5N
\n", "
" ], "text/plain": [ " Height Result\n", "0 40.5 E\n", "1 40.5 E\n", "2 40.5 E\n", "3 40.5 E\n", "4 40.5 E\n", "5 40.5 E\n", "6 40.5 E\n", "7 40.5 E\n", "8 40.5 E\n", "9 40.5 E\n", "10 36.0 E\n", "11 36.0 N\n", "12 36.0 E\n", "13 36.0 E\n", "14 36.0 E\n", "15 36.0 E\n", "16 36.0 N\n", "17 36.0 E\n", "18 36.0 E\n", "19 36.0 E\n", "20 32.0 E\n", "21 32.0 E\n", "22 32.0 N\n", "23 32.0 E\n", "24 32.0 E\n", "25 32.0 E\n", "26 32.0 N\n", "27 32.0 E\n", "28 32.0 N\n", "29 32.0 E\n", "30 28.5 N\n", "31 28.5 E\n", "32 28.5 N\n", "33 28.5 N\n", "34 28.5 E\n", "35 28.5 N\n", "36 28.5 N\n", "37 28.5 N\n", "38 28.5 E\n", "39 28.5 N\n", "40 25.5 N\n", "41 25.5 N\n", "42 25.5 N\n", "43 25.5 N\n", "44 25.5 N\n", "45 25.5 N\n", "46 25.5 E\n", "47 25.5 N\n", "48 25.5 N\n", "49 25.5 N\n", "50 22.5 N\n", "51 22.5 N\n", "52 22.5 N\n", "53 22.5 N\n", "54 22.5 N\n", "55 22.5 N\n", "56 22.5 N\n", "57 22.5 N\n", "58 22.5 N\n", "59 22.5 N" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's encode the labels as $1$ and $0$ instead of E and N.\n", "Let's do this below:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "tags": [ "hide-output" ] }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
HeightResulty
040.5E1
140.5E1
240.5E1
340.5E1
440.5E1
540.5E1
640.5E1
740.5E1
840.5E1
940.5E1
1036.0E1
1136.0N0
1236.0E1
1336.0E1
1436.0E1
1536.0E1
1636.0N0
1736.0E1
1836.0E1
1936.0E1
2032.0E1
2132.0E1
2232.0N0
2332.0E1
2432.0E1
2532.0E1
2632.0N0
2732.0E1
2832.0N0
2932.0E1
3028.5N0
3128.5E1
3228.5N0
3328.5N0
3428.5E1
3528.5N0
3628.5N0
3728.5N0
3828.5E1
3928.5N0
4025.5N0
4125.5N0
4225.5N0
4325.5N0
4425.5N0
4525.5N0
4625.5E1
4725.5N0
4825.5N0
4925.5N0
5022.5N0
5122.5N0
5222.5N0
5322.5N0
5422.5N0
5522.5N0
5622.5N0
5722.5N0
5822.5N0
5922.5N0
\n", "
" ], "text/plain": [ " Height Result y\n", "0 40.5 E 1\n", "1 40.5 E 1\n", "2 40.5 E 1\n", "3 40.5 E 1\n", "4 40.5 E 1\n", "5 40.5 E 1\n", "6 40.5 E 1\n", "7 40.5 E 1\n", "8 40.5 E 1\n", "9 40.5 E 1\n", "10 36.0 E 1\n", "11 36.0 N 0\n", "12 36.0 E 1\n", "13 36.0 E 1\n", "14 36.0 E 1\n", "15 36.0 E 1\n", "16 36.0 N 0\n", "17 36.0 E 1\n", "18 36.0 E 1\n", "19 36.0 E 1\n", "20 32.0 E 1\n", "21 32.0 E 1\n", "22 32.0 N 0\n", "23 32.0 E 1\n", "24 32.0 E 1\n", "25 32.0 E 1\n", "26 32.0 N 0\n", "27 32.0 E 1\n", "28 32.0 N 0\n", "29 32.0 E 1\n", "30 28.5 N 0\n", "31 28.5 E 1\n", "32 28.5 N 0\n", "33 28.5 N 0\n", "34 28.5 E 1\n", "35 28.5 N 0\n", "36 28.5 N 0\n", "37 28.5 N 0\n", "38 28.5 E 1\n", "39 28.5 N 0\n", "40 25.5 N 0\n", "41 25.5 N 0\n", "42 25.5 N 0\n", "43 25.5 N 0\n", "44 25.5 N 0\n", "45 25.5 N 0\n", "46 25.5 E 1\n", "47 25.5 N 0\n", "48 25.5 N 0\n", "49 25.5 N 0\n", "50 22.5 N 0\n", "51 22.5 N 0\n", "52 22.5 N 0\n", "53 22.5 N 0\n", "54 22.5 N 0\n", "55 22.5 N 0\n", "56 22.5 N 0\n", "57 22.5 N 0\n", "58 22.5 N 0\n", "59 22.5 N 0" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Extract data for classification in numpy array\n", "# Features\n", "x = data['Height'].values\n", "# Labels (must be integer)\n", "label_coding = {'E': 1, 'N': 0}\n", "y = np.array(\n", " [\n", " label_coding[r]\n", " for r in data['Result']\n", " ]\n", ")\n", "data['y'] = y\n", "data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's visualize the data.\n", "Notice that lots of observations fall on top of each other." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-09-26T15:36:49.896817\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig, ax = plt.subplots()\n", "ax.plot(x, y, 'o')\n", "ax.set_xlabel('$x$ (Height in cm)')\n", "ax.set_ylabel('Result ($0=N; 1=E$)')\n", "sns.despine(trim=True);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's train a logistic regression model with just a linear feature using scikit-learn:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "from sklearn.linear_model import LogisticRegression\n", "\n", "# The design matrix\n", "X = np.hstack(\n", " [\n", " np.ones((x.shape[0], 1)),\n", " x[:, None]]\n", ")\n", "\n", "# Train the model (penalty = 'none' means that we do not add a prior on the weights)\n", "# we are effectively just maximizing the likelihood of the data\n", "model = LogisticRegression(\n", " penalty=None,\n", " fit_intercept=False\n", ").fit(X, y);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here is how you can get the trained weights of the model:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[-12.688, 0.411]])" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.coef_" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And here is how you can make predictions at some arbitrary heights:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[9.99810986e-01, 1.89014081e-04],\n", " [9.88560090e-01, 1.14399105e-02],\n", " [5.85351787e-01, 4.14648213e-01],\n", " [2.25419752e-02, 9.77458025e-01],\n", " [3.76605766e-04, 9.99623394e-01]])" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x_predict = np.array([10.0, 20.0, 30.0, 40.0, 50.0])\n", "X_predict = np.hstack(\n", " [\n", " np.ones((x_predict.shape[0], 1)),\n", " x_predict[:, None]\n", " ]\n", ")\n", "predictions = model.predict_proba(X_predict)\n", "predictions" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that the model gave us back the probability of each class.\n", "If you wanted, you could ask for the class of maximum probability for each prediction input:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([0, 0, 0, 1, 1])" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.predict(X_predict)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To visualize the predictions of the model as a function of the height, we can do this:" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-09-26T15:41:11.340911\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig, ax = plt.subplots()\n", "xx = np.linspace(20.0, 45.0, 100)\n", "XX = np.hstack([np.ones((xx.shape[0], 1)), xx[:, None]])\n", "predictions_xx = model.predict_proba(XX)\n", "ax.plot(\n", " xx,\n", " predictions_xx[:, 0],\n", " label='Probability of N'\n", ")\n", "ax.plot(\n", " xx,\n", " predictions_xx[:, 1],\n", " label='Probability of E'\n", ")\n", "ax.set_xlabel('$x$ (cm)')\n", "ax.set_ylabel('Probability')\n", "plt.legend(loc='best', frameon=False)\n", "sns.despine(trim=True);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Questions\n", "\n", "+ What is the probability of an explosion when the height becomes very small?\n", "+ What is the probability of an explosion when the height becomes very large?\n", "+ At what height is it particularly difficult to predict what will happen?" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "celltoolbar": "Tags", "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.6" } }, "nbformat": 4, "nbformat_minor": 4 }